aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/networkx/drawing/nx_pylab.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/networkx/drawing/nx_pylab.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to '.venv/lib/python3.12/site-packages/networkx/drawing/nx_pylab.py')
-rw-r--r--.venv/lib/python3.12/site-packages/networkx/drawing/nx_pylab.py1979
1 files changed, 1979 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/networkx/drawing/nx_pylab.py b/.venv/lib/python3.12/site-packages/networkx/drawing/nx_pylab.py
new file mode 100644
index 00000000..c4d24cc0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/networkx/drawing/nx_pylab.py
@@ -0,0 +1,1979 @@
+"""
+**********
+Matplotlib
+**********
+
+Draw networks with matplotlib.
+
+Examples
+--------
+>>> G = nx.complete_graph(5)
+>>> nx.draw(G)
+
+See Also
+--------
+ - :doc:`matplotlib <matplotlib:index>`
+ - :func:`matplotlib.pyplot.scatter`
+ - :obj:`matplotlib.patches.FancyArrowPatch`
+"""
+
+import collections
+import itertools
+from numbers import Number
+
+import networkx as nx
+from networkx.drawing.layout import (
+ circular_layout,
+ forceatlas2_layout,
+ kamada_kawai_layout,
+ planar_layout,
+ random_layout,
+ shell_layout,
+ spectral_layout,
+ spring_layout,
+)
+
+__all__ = [
+ "draw",
+ "draw_networkx",
+ "draw_networkx_nodes",
+ "draw_networkx_edges",
+ "draw_networkx_labels",
+ "draw_networkx_edge_labels",
+ "draw_circular",
+ "draw_kamada_kawai",
+ "draw_random",
+ "draw_spectral",
+ "draw_spring",
+ "draw_planar",
+ "draw_shell",
+ "draw_forceatlas2",
+]
+
+
+def draw(G, pos=None, ax=None, **kwds):
+ """Draw the graph G with Matplotlib.
+
+ Draw the graph as a simple representation with no node
+ labels or edge labels and using the full Matplotlib figure area
+ and no axis labels by default. See draw_networkx() for more
+ full-featured drawing that allows title, axis labels etc.
+
+ Parameters
+ ----------
+ G : graph
+ A networkx graph
+
+ pos : dictionary, optional
+ A dictionary with nodes as keys and positions as values.
+ If not specified a spring layout positioning will be computed.
+ See :py:mod:`networkx.drawing.layout` for functions that
+ compute node positions.
+
+ ax : Matplotlib Axes object, optional
+ Draw the graph in specified Matplotlib axes.
+
+ kwds : optional keywords
+ See networkx.draw_networkx() for a description of optional keywords.
+
+ Examples
+ --------
+ >>> G = nx.dodecahedral_graph()
+ >>> nx.draw(G)
+ >>> nx.draw(G, pos=nx.spring_layout(G)) # use spring layout
+
+ See Also
+ --------
+ draw_networkx
+ draw_networkx_nodes
+ draw_networkx_edges
+ draw_networkx_labels
+ draw_networkx_edge_labels
+
+ Notes
+ -----
+ This function has the same name as pylab.draw and pyplot.draw
+ so beware when using `from networkx import *`
+
+ since you might overwrite the pylab.draw function.
+
+ With pyplot use
+
+ >>> import matplotlib.pyplot as plt
+ >>> G = nx.dodecahedral_graph()
+ >>> nx.draw(G) # networkx draw()
+ >>> plt.draw() # pyplot draw()
+
+ Also see the NetworkX drawing examples at
+ https://networkx.org/documentation/latest/auto_examples/index.html
+ """
+ import matplotlib.pyplot as plt
+
+ if ax is None:
+ cf = plt.gcf()
+ else:
+ cf = ax.get_figure()
+ cf.set_facecolor("w")
+ if ax is None:
+ if cf.axes:
+ ax = cf.gca()
+ else:
+ ax = cf.add_axes((0, 0, 1, 1))
+
+ if "with_labels" not in kwds:
+ kwds["with_labels"] = "labels" in kwds
+
+ draw_networkx(G, pos=pos, ax=ax, **kwds)
+ ax.set_axis_off()
+ plt.draw_if_interactive()
+ return
+
+
+def draw_networkx(G, pos=None, arrows=None, with_labels=True, **kwds):
+ r"""Draw the graph G using Matplotlib.
+
+ Draw the graph with Matplotlib with options for node positions,
+ labeling, titles, and many other drawing features.
+ See draw() for simple drawing without labels or axes.
+
+ Parameters
+ ----------
+ G : graph
+ A networkx graph
+
+ pos : dictionary, optional
+ A dictionary with nodes as keys and positions as values.
+ If not specified a spring layout positioning will be computed.
+ See :py:mod:`networkx.drawing.layout` for functions that
+ compute node positions.
+
+ arrows : bool or None, optional (default=None)
+ If `None`, directed graphs draw arrowheads with
+ `~matplotlib.patches.FancyArrowPatch`, while undirected graphs draw edges
+ via `~matplotlib.collections.LineCollection` for speed.
+ If `True`, draw arrowheads with FancyArrowPatches (bendable and stylish).
+ If `False`, draw edges using LineCollection (linear and fast).
+ For directed graphs, if True draw arrowheads.
+ Note: Arrows will be the same color as edges.
+
+ arrowstyle : str (default='-\|>' for directed graphs)
+ For directed graphs, choose the style of the arrowsheads.
+ For undirected graphs default to '-'
+
+ See `matplotlib.patches.ArrowStyle` for more options.
+
+ arrowsize : int or list (default=10)
+ For directed graphs, choose the size of the arrow head's length and
+ width. A list of values can be passed in to assign a different size for arrow head's length and width.
+ See `matplotlib.patches.FancyArrowPatch` for attribute `mutation_scale`
+ for more info.
+
+ with_labels : bool (default=True)
+ Set to True to draw labels on the nodes.
+
+ ax : Matplotlib Axes object, optional
+ Draw the graph in the specified Matplotlib axes.
+
+ nodelist : list (default=list(G))
+ Draw only specified nodes
+
+ edgelist : list (default=list(G.edges()))
+ Draw only specified edges
+
+ node_size : scalar or array (default=300)
+ Size of nodes. If an array is specified it must be the
+ same length as nodelist.
+
+ node_color : color or array of colors (default='#1f78b4')
+ Node color. Can be a single color or a sequence of colors with the same
+ length as nodelist. Color can be string or rgb (or rgba) tuple of
+ floats from 0-1. If numeric values are specified they will be
+ mapped to colors using the cmap and vmin,vmax parameters. See
+ matplotlib.scatter for more details.
+
+ node_shape : string (default='o')
+ The shape of the node. Specification is as matplotlib.scatter
+ marker, one of 'so^>v<dph8'.
+
+ alpha : float or None (default=None)
+ The node and edge transparency
+
+ cmap : Matplotlib colormap, optional
+ Colormap for mapping intensities of nodes
+
+ vmin,vmax : float, optional
+ Minimum and maximum for node colormap scaling
+
+ linewidths : scalar or sequence (default=1.0)
+ Line width of symbol border
+
+ width : float or array of floats (default=1.0)
+ Line width of edges
+
+ edge_color : color or array of colors (default='k')
+ Edge color. Can be a single color or a sequence of colors with the same
+ length as edgelist. Color can be string or rgb (or rgba) tuple of
+ floats from 0-1. If numeric values are specified they will be
+ mapped to colors using the edge_cmap and edge_vmin,edge_vmax parameters.
+
+ edge_cmap : Matplotlib colormap, optional
+ Colormap for mapping intensities of edges
+
+ edge_vmin,edge_vmax : floats, optional
+ Minimum and maximum for edge colormap scaling
+
+ style : string (default=solid line)
+ Edge line style e.g.: '-', '--', '-.', ':'
+ or words like 'solid' or 'dashed'.
+ (See `matplotlib.patches.FancyArrowPatch`: `linestyle`)
+
+ labels : dictionary (default=None)
+ Node labels in a dictionary of text labels keyed by node
+
+ font_size : int (default=12 for nodes, 10 for edges)
+ Font size for text labels
+
+ font_color : color (default='k' black)
+ Font color string. Color can be string or rgb (or rgba) tuple of
+ floats from 0-1.
+
+ font_weight : string (default='normal')
+ Font weight
+
+ font_family : string (default='sans-serif')
+ Font family
+
+ label : string, optional
+ Label for graph legend
+
+ hide_ticks : bool, optional
+ Hide ticks of axes. When `True` (the default), ticks and ticklabels
+ are removed from the axes. To set ticks and tick labels to the pyplot default,
+ use ``hide_ticks=False``.
+
+ kwds : optional keywords
+ See networkx.draw_networkx_nodes(), networkx.draw_networkx_edges(), and
+ networkx.draw_networkx_labels() for a description of optional keywords.
+
+ Notes
+ -----
+ For directed graphs, arrows are drawn at the head end. Arrows can be
+ turned off with keyword arrows=False.
+
+ Examples
+ --------
+ >>> G = nx.dodecahedral_graph()
+ >>> nx.draw(G)
+ >>> nx.draw(G, pos=nx.spring_layout(G)) # use spring layout
+
+ >>> import matplotlib.pyplot as plt
+ >>> limits = plt.axis("off") # turn off axis
+
+ Also see the NetworkX drawing examples at
+ https://networkx.org/documentation/latest/auto_examples/index.html
+
+ See Also
+ --------
+ draw
+ draw_networkx_nodes
+ draw_networkx_edges
+ draw_networkx_labels
+ draw_networkx_edge_labels
+ """
+ from inspect import signature
+
+ import matplotlib.pyplot as plt
+
+ # Get all valid keywords by inspecting the signatures of draw_networkx_nodes,
+ # draw_networkx_edges, draw_networkx_labels
+
+ valid_node_kwds = signature(draw_networkx_nodes).parameters.keys()
+ valid_edge_kwds = signature(draw_networkx_edges).parameters.keys()
+ valid_label_kwds = signature(draw_networkx_labels).parameters.keys()
+
+ # Create a set with all valid keywords across the three functions and
+ # remove the arguments of this function (draw_networkx)
+ valid_kwds = (valid_node_kwds | valid_edge_kwds | valid_label_kwds) - {
+ "G",
+ "pos",
+ "arrows",
+ "with_labels",
+ }
+
+ if any(k not in valid_kwds for k in kwds):
+ invalid_args = ", ".join([k for k in kwds if k not in valid_kwds])
+ raise ValueError(f"Received invalid argument(s): {invalid_args}")
+
+ node_kwds = {k: v for k, v in kwds.items() if k in valid_node_kwds}
+ edge_kwds = {k: v for k, v in kwds.items() if k in valid_edge_kwds}
+ label_kwds = {k: v for k, v in kwds.items() if k in valid_label_kwds}
+
+ if pos is None:
+ pos = nx.drawing.spring_layout(G) # default to spring layout
+
+ draw_networkx_nodes(G, pos, **node_kwds)
+ draw_networkx_edges(G, pos, arrows=arrows, **edge_kwds)
+ if with_labels:
+ draw_networkx_labels(G, pos, **label_kwds)
+ plt.draw_if_interactive()
+
+
+def draw_networkx_nodes(
+ G,
+ pos,
+ nodelist=None,
+ node_size=300,
+ node_color="#1f78b4",
+ node_shape="o",
+ alpha=None,
+ cmap=None,
+ vmin=None,
+ vmax=None,
+ ax=None,
+ linewidths=None,
+ edgecolors=None,
+ label=None,
+ margins=None,
+ hide_ticks=True,
+):
+ """Draw the nodes of the graph G.
+
+ This draws only the nodes of the graph G.
+
+ Parameters
+ ----------
+ G : graph
+ A networkx graph
+
+ pos : dictionary
+ A dictionary with nodes as keys and positions as values.
+ Positions should be sequences of length 2.
+
+ ax : Matplotlib Axes object, optional
+ Draw the graph in the specified Matplotlib axes.
+
+ nodelist : list (default list(G))
+ Draw only specified nodes
+
+ node_size : scalar or array (default=300)
+ Size of nodes. If an array it must be the same length as nodelist.
+
+ node_color : color or array of colors (default='#1f78b4')
+ Node color. Can be a single color or a sequence of colors with the same
+ length as nodelist. Color can be string or rgb (or rgba) tuple of
+ floats from 0-1. If numeric values are specified they will be
+ mapped to colors using the cmap and vmin,vmax parameters. See
+ matplotlib.scatter for more details.
+
+ node_shape : string (default='o')
+ The shape of the node. Specification is as matplotlib.scatter
+ marker, one of 'so^>v<dph8'.
+
+ alpha : float or array of floats (default=None)
+ The node transparency. This can be a single alpha value,
+ in which case it will be applied to all the nodes of color. Otherwise,
+ if it is an array, the elements of alpha will be applied to the colors
+ in order (cycling through alpha multiple times if necessary).
+
+ cmap : Matplotlib colormap (default=None)
+ Colormap for mapping intensities of nodes
+
+ vmin,vmax : floats or None (default=None)
+ Minimum and maximum for node colormap scaling
+
+ linewidths : [None | scalar | sequence] (default=1.0)
+ Line width of symbol border
+
+ edgecolors : [None | scalar | sequence] (default = node_color)
+ Colors of node borders. Can be a single color or a sequence of colors with the
+ same length as nodelist. Color can be string or rgb (or rgba) tuple of floats
+ from 0-1. If numeric values are specified they will be mapped to colors
+ using the cmap and vmin,vmax parameters. See `~matplotlib.pyplot.scatter` for more details.
+
+ label : [None | string]
+ Label for legend
+
+ margins : float or 2-tuple, optional
+ Sets the padding for axis autoscaling. Increase margin to prevent
+ clipping for nodes that are near the edges of an image. Values should
+ be in the range ``[0, 1]``. See :meth:`matplotlib.axes.Axes.margins`
+ for details. The default is `None`, which uses the Matplotlib default.
+
+ hide_ticks : bool, optional
+ Hide ticks of axes. When `True` (the default), ticks and ticklabels
+ are removed from the axes. To set ticks and tick labels to the pyplot default,
+ use ``hide_ticks=False``.
+
+ Returns
+ -------
+ matplotlib.collections.PathCollection
+ `PathCollection` of the nodes.
+
+ Examples
+ --------
+ >>> G = nx.dodecahedral_graph()
+ >>> nodes = nx.draw_networkx_nodes(G, pos=nx.spring_layout(G))
+
+ Also see the NetworkX drawing examples at
+ https://networkx.org/documentation/latest/auto_examples/index.html
+
+ See Also
+ --------
+ draw
+ draw_networkx
+ draw_networkx_edges
+ draw_networkx_labels
+ draw_networkx_edge_labels
+ """
+ from collections.abc import Iterable
+
+ import matplotlib as mpl
+ import matplotlib.collections # call as mpl.collections
+ import matplotlib.pyplot as plt
+ import numpy as np
+
+ if ax is None:
+ ax = plt.gca()
+
+ if nodelist is None:
+ nodelist = list(G)
+
+ if len(nodelist) == 0: # empty nodelist, no drawing
+ return mpl.collections.PathCollection(None)
+
+ try:
+ xy = np.asarray([pos[v] for v in nodelist])
+ except KeyError as err:
+ raise nx.NetworkXError(f"Node {err} has no position.") from err
+
+ if isinstance(alpha, Iterable):
+ node_color = apply_alpha(node_color, alpha, nodelist, cmap, vmin, vmax)
+ alpha = None
+
+ if not isinstance(node_shape, np.ndarray) and not isinstance(node_shape, list):
+ node_shape = np.array([node_shape for _ in range(len(nodelist))])
+
+ for shape in np.unique(node_shape):
+ node_collection = ax.scatter(
+ xy[node_shape == shape, 0],
+ xy[node_shape == shape, 1],
+ s=node_size,
+ c=node_color,
+ marker=shape,
+ cmap=cmap,
+ vmin=vmin,
+ vmax=vmax,
+ alpha=alpha,
+ linewidths=linewidths,
+ edgecolors=edgecolors,
+ label=label,
+ )
+ if hide_ticks:
+ ax.tick_params(
+ axis="both",
+ which="both",
+ bottom=False,
+ left=False,
+ labelbottom=False,
+ labelleft=False,
+ )
+
+ if margins is not None:
+ if isinstance(margins, Iterable):
+ ax.margins(*margins)
+ else:
+ ax.margins(margins)
+
+ node_collection.set_zorder(2)
+ return node_collection
+
+
+class FancyArrowFactory:
+ """Draw arrows with `matplotlib.patches.FancyarrowPatch`"""
+
+ class ConnectionStyleFactory:
+ def __init__(self, connectionstyles, selfloop_height, ax=None):
+ import matplotlib as mpl
+ import matplotlib.path # call as mpl.path
+ import numpy as np
+
+ self.ax = ax
+ self.mpl = mpl
+ self.np = np
+ self.base_connection_styles = [
+ mpl.patches.ConnectionStyle(cs) for cs in connectionstyles
+ ]
+ self.n = len(self.base_connection_styles)
+ self.selfloop_height = selfloop_height
+
+ def curved(self, edge_index):
+ return self.base_connection_styles[edge_index % self.n]
+
+ def self_loop(self, edge_index):
+ def self_loop_connection(posA, posB, *args, **kwargs):
+ if not self.np.all(posA == posB):
+ raise nx.NetworkXError(
+ "`self_loop` connection style method"
+ "is only to be used for self-loops"
+ )
+ # this is called with _screen space_ values
+ # so convert back to data space
+ data_loc = self.ax.transData.inverted().transform(posA)
+ v_shift = 0.1 * self.selfloop_height
+ h_shift = v_shift * 0.5
+ # put the top of the loop first so arrow is not hidden by node
+ path = self.np.asarray(
+ [
+ # 1
+ [0, v_shift],
+ # 4 4 4
+ [h_shift, v_shift],
+ [h_shift, 0],
+ [0, 0],
+ # 4 4 4
+ [-h_shift, 0],
+ [-h_shift, v_shift],
+ [0, v_shift],
+ ]
+ )
+ # Rotate self loop 90 deg. if more than 1
+ # This will allow for maximum of 4 visible self loops
+ if edge_index % 4:
+ x, y = path.T
+ for _ in range(edge_index % 4):
+ x, y = y, -x
+ path = self.np.array([x, y]).T
+ return self.mpl.path.Path(
+ self.ax.transData.transform(data_loc + path), [1, 4, 4, 4, 4, 4, 4]
+ )
+
+ return self_loop_connection
+
+ def __init__(
+ self,
+ edge_pos,
+ edgelist,
+ nodelist,
+ edge_indices,
+ node_size,
+ selfloop_height,
+ connectionstyle="arc3",
+ node_shape="o",
+ arrowstyle="-",
+ arrowsize=10,
+ edge_color="k",
+ alpha=None,
+ linewidth=1.0,
+ style="solid",
+ min_source_margin=0,
+ min_target_margin=0,
+ ax=None,
+ ):
+ import matplotlib as mpl
+ import matplotlib.patches # call as mpl.patches
+ import matplotlib.pyplot as plt
+ import numpy as np
+
+ if isinstance(connectionstyle, str):
+ connectionstyle = [connectionstyle]
+ elif np.iterable(connectionstyle):
+ connectionstyle = list(connectionstyle)
+ else:
+ msg = "ConnectionStyleFactory arg `connectionstyle` must be str or iterable"
+ raise nx.NetworkXError(msg)
+ self.ax = ax
+ self.mpl = mpl
+ self.np = np
+ self.edge_pos = edge_pos
+ self.edgelist = edgelist
+ self.nodelist = nodelist
+ self.node_shape = node_shape
+ self.min_source_margin = min_source_margin
+ self.min_target_margin = min_target_margin
+ self.edge_indices = edge_indices
+ self.node_size = node_size
+ self.connectionstyle_factory = self.ConnectionStyleFactory(
+ connectionstyle, selfloop_height, ax
+ )
+ self.arrowstyle = arrowstyle
+ self.arrowsize = arrowsize
+ self.arrow_colors = mpl.colors.colorConverter.to_rgba_array(edge_color, alpha)
+ self.linewidth = linewidth
+ self.style = style
+ if isinstance(arrowsize, list) and len(arrowsize) != len(edge_pos):
+ raise ValueError("arrowsize should have the same length as edgelist")
+
+ def __call__(self, i):
+ (x1, y1), (x2, y2) = self.edge_pos[i]
+ shrink_source = 0 # space from source to tail
+ shrink_target = 0 # space from head to target
+ if (
+ self.np.iterable(self.min_source_margin)
+ and not isinstance(self.min_source_margin, str)
+ and not isinstance(self.min_source_margin, tuple)
+ ):
+ min_source_margin = self.min_source_margin[i]
+ else:
+ min_source_margin = self.min_source_margin
+
+ if (
+ self.np.iterable(self.min_target_margin)
+ and not isinstance(self.min_target_margin, str)
+ and not isinstance(self.min_target_margin, tuple)
+ ):
+ min_target_margin = self.min_target_margin[i]
+ else:
+ min_target_margin = self.min_target_margin
+
+ if self.np.iterable(self.node_size): # many node sizes
+ source, target = self.edgelist[i][:2]
+ source_node_size = self.node_size[self.nodelist.index(source)]
+ target_node_size = self.node_size[self.nodelist.index(target)]
+ shrink_source = self.to_marker_edge(source_node_size, self.node_shape)
+ shrink_target = self.to_marker_edge(target_node_size, self.node_shape)
+ else:
+ shrink_source = self.to_marker_edge(self.node_size, self.node_shape)
+ shrink_target = shrink_source
+ shrink_source = max(shrink_source, min_source_margin)
+ shrink_target = max(shrink_target, min_target_margin)
+
+ # scale factor of arrow head
+ if isinstance(self.arrowsize, list):
+ mutation_scale = self.arrowsize[i]
+ else:
+ mutation_scale = self.arrowsize
+
+ if len(self.arrow_colors) > i:
+ arrow_color = self.arrow_colors[i]
+ elif len(self.arrow_colors) == 1:
+ arrow_color = self.arrow_colors[0]
+ else: # Cycle through colors
+ arrow_color = self.arrow_colors[i % len(self.arrow_colors)]
+
+ if self.np.iterable(self.linewidth):
+ if len(self.linewidth) > i:
+ linewidth = self.linewidth[i]
+ else:
+ linewidth = self.linewidth[i % len(self.linewidth)]
+ else:
+ linewidth = self.linewidth
+
+ if (
+ self.np.iterable(self.style)
+ and not isinstance(self.style, str)
+ and not isinstance(self.style, tuple)
+ ):
+ if len(self.style) > i:
+ linestyle = self.style[i]
+ else: # Cycle through styles
+ linestyle = self.style[i % len(self.style)]
+ else:
+ linestyle = self.style
+
+ if x1 == x2 and y1 == y2:
+ connectionstyle = self.connectionstyle_factory.self_loop(
+ self.edge_indices[i]
+ )
+ else:
+ connectionstyle = self.connectionstyle_factory.curved(self.edge_indices[i])
+
+ if (
+ self.np.iterable(self.arrowstyle)
+ and not isinstance(self.arrowstyle, str)
+ and not isinstance(self.arrowstyle, tuple)
+ ):
+ arrowstyle = self.arrowstyle[i]
+ else:
+ arrowstyle = self.arrowstyle
+
+ return self.mpl.patches.FancyArrowPatch(
+ (x1, y1),
+ (x2, y2),
+ arrowstyle=arrowstyle,
+ shrinkA=shrink_source,
+ shrinkB=shrink_target,
+ mutation_scale=mutation_scale,
+ color=arrow_color,
+ linewidth=linewidth,
+ connectionstyle=connectionstyle,
+ linestyle=linestyle,
+ zorder=1, # arrows go behind nodes
+ )
+
+ def to_marker_edge(self, marker_size, marker):
+ if marker in "s^>v<d": # `large` markers need extra space
+ return self.np.sqrt(2 * marker_size) / 2
+ else:
+ return self.np.sqrt(marker_size) / 2
+
+
+def draw_networkx_edges(
+ G,
+ pos,
+ edgelist=None,
+ width=1.0,
+ edge_color="k",
+ style="solid",
+ alpha=None,
+ arrowstyle=None,
+ arrowsize=10,
+ edge_cmap=None,
+ edge_vmin=None,
+ edge_vmax=None,
+ ax=None,
+ arrows=None,
+ label=None,
+ node_size=300,
+ nodelist=None,
+ node_shape="o",
+ connectionstyle="arc3",
+ min_source_margin=0,
+ min_target_margin=0,
+ hide_ticks=True,
+):
+ r"""Draw the edges of the graph G.
+
+ This draws only the edges of the graph G.
+
+ Parameters
+ ----------
+ G : graph
+ A networkx graph
+
+ pos : dictionary
+ A dictionary with nodes as keys and positions as values.
+ Positions should be sequences of length 2.
+
+ edgelist : collection of edge tuples (default=G.edges())
+ Draw only specified edges
+
+ width : float or array of floats (default=1.0)
+ Line width of edges
+
+ edge_color : color or array of colors (default='k')
+ Edge color. Can be a single color or a sequence of colors with the same
+ length as edgelist. Color can be string or rgb (or rgba) tuple of
+ floats from 0-1. If numeric values are specified they will be
+ mapped to colors using the edge_cmap and edge_vmin,edge_vmax parameters.
+
+ style : string or array of strings (default='solid')
+ Edge line style e.g.: '-', '--', '-.', ':'
+ or words like 'solid' or 'dashed'.
+ Can be a single style or a sequence of styles with the same
+ length as the edge list.
+ If less styles than edges are given the styles will cycle.
+ If more styles than edges are given the styles will be used sequentially
+ and not be exhausted.
+ Also, `(offset, onoffseq)` tuples can be used as style instead of a strings.
+ (See `matplotlib.patches.FancyArrowPatch`: `linestyle`)
+
+ alpha : float or array of floats (default=None)
+ The edge transparency. This can be a single alpha value,
+ in which case it will be applied to all specified edges. Otherwise,
+ if it is an array, the elements of alpha will be applied to the colors
+ in order (cycling through alpha multiple times if necessary).
+
+ edge_cmap : Matplotlib colormap, optional
+ Colormap for mapping intensities of edges
+
+ edge_vmin,edge_vmax : floats, optional
+ Minimum and maximum for edge colormap scaling
+
+ ax : Matplotlib Axes object, optional
+ Draw the graph in the specified Matplotlib axes.
+
+ arrows : bool or None, optional (default=None)
+ If `None`, directed graphs draw arrowheads with
+ `~matplotlib.patches.FancyArrowPatch`, while undirected graphs draw edges
+ via `~matplotlib.collections.LineCollection` for speed.
+ If `True`, draw arrowheads with FancyArrowPatches (bendable and stylish).
+ If `False`, draw edges using LineCollection (linear and fast).
+
+ Note: Arrowheads will be the same color as edges.
+
+ arrowstyle : str or list of strs (default='-\|>' for directed graphs)
+ For directed graphs and `arrows==True` defaults to '-\|>',
+ For undirected graphs default to '-'.
+
+ See `matplotlib.patches.ArrowStyle` for more options.
+
+ arrowsize : int or list of ints(default=10)
+ For directed graphs, choose the size of the arrow head's length and
+ width. See `matplotlib.patches.FancyArrowPatch` for attribute
+ `mutation_scale` for more info.
+
+ connectionstyle : string or iterable of strings (default="arc3")
+ Pass the connectionstyle parameter to create curved arc of rounding
+ radius rad. For example, connectionstyle='arc3,rad=0.2'.
+ See `matplotlib.patches.ConnectionStyle` and
+ `matplotlib.patches.FancyArrowPatch` for more info.
+ If Iterable, index indicates i'th edge key of MultiGraph
+
+ node_size : scalar or array (default=300)
+ Size of nodes. Though the nodes are not drawn with this function, the
+ node size is used in determining edge positioning.
+
+ nodelist : list, optional (default=G.nodes())
+ This provides the node order for the `node_size` array (if it is an array).
+
+ node_shape : string (default='o')
+ The marker used for nodes, used in determining edge positioning.
+ Specification is as a `matplotlib.markers` marker, e.g. one of 'so^>v<dph8'.
+
+ label : None or string
+ Label for legend
+
+ min_source_margin : int or list of ints (default=0)
+ The minimum margin (gap) at the beginning of the edge at the source.
+
+ min_target_margin : int or list of ints (default=0)
+ The minimum margin (gap) at the end of the edge at the target.
+
+ hide_ticks : bool, optional
+ Hide ticks of axes. When `True` (the default), ticks and ticklabels
+ are removed from the axes. To set ticks and tick labels to the pyplot default,
+ use ``hide_ticks=False``.
+
+ Returns
+ -------
+ matplotlib.collections.LineCollection or a list of matplotlib.patches.FancyArrowPatch
+ If ``arrows=True``, a list of FancyArrowPatches is returned.
+ If ``arrows=False``, a LineCollection is returned.
+ If ``arrows=None`` (the default), then a LineCollection is returned if
+ `G` is undirected, otherwise returns a list of FancyArrowPatches.
+
+ Notes
+ -----
+ For directed graphs, arrows are drawn at the head end. Arrows can be
+ turned off with keyword arrows=False or by passing an arrowstyle without
+ an arrow on the end.
+
+ Be sure to include `node_size` as a keyword argument; arrows are
+ drawn considering the size of nodes.
+
+ Self-loops are always drawn with `~matplotlib.patches.FancyArrowPatch`
+ regardless of the value of `arrows` or whether `G` is directed.
+ When ``arrows=False`` or ``arrows=None`` and `G` is undirected, the
+ FancyArrowPatches corresponding to the self-loops are not explicitly
+ returned. They should instead be accessed via the ``Axes.patches``
+ attribute (see examples).
+
+ Examples
+ --------
+ >>> G = nx.dodecahedral_graph()
+ >>> edges = nx.draw_networkx_edges(G, pos=nx.spring_layout(G))
+
+ >>> G = nx.DiGraph()
+ >>> G.add_edges_from([(1, 2), (1, 3), (2, 3)])
+ >>> arcs = nx.draw_networkx_edges(G, pos=nx.spring_layout(G))
+ >>> alphas = [0.3, 0.4, 0.5]
+ >>> for i, arc in enumerate(arcs): # change alpha values of arcs
+ ... arc.set_alpha(alphas[i])
+
+ The FancyArrowPatches corresponding to self-loops are not always
+ returned, but can always be accessed via the ``patches`` attribute of the
+ `matplotlib.Axes` object.
+
+ >>> import matplotlib.pyplot as plt
+ >>> fig, ax = plt.subplots()
+ >>> G = nx.Graph([(0, 1), (0, 0)]) # Self-loop at node 0
+ >>> edge_collection = nx.draw_networkx_edges(G, pos=nx.circular_layout(G), ax=ax)
+ >>> self_loop_fap = ax.patches[0]
+
+ Also see the NetworkX drawing examples at
+ https://networkx.org/documentation/latest/auto_examples/index.html
+
+ See Also
+ --------
+ draw
+ draw_networkx
+ draw_networkx_nodes
+ draw_networkx_labels
+ draw_networkx_edge_labels
+
+ """
+ import warnings
+
+ import matplotlib as mpl
+ import matplotlib.collections # call as mpl.collections
+ import matplotlib.colors # call as mpl.colors
+ import matplotlib.pyplot as plt
+ import numpy as np
+
+ # The default behavior is to use LineCollection to draw edges for
+ # undirected graphs (for performance reasons) and use FancyArrowPatches
+ # for directed graphs.
+ # The `arrows` keyword can be used to override the default behavior
+ if arrows is None:
+ use_linecollection = not (G.is_directed() or G.is_multigraph())
+ else:
+ if not isinstance(arrows, bool):
+ raise TypeError("Argument `arrows` must be of type bool or None")
+ use_linecollection = not arrows
+
+ if isinstance(connectionstyle, str):
+ connectionstyle = [connectionstyle]
+ elif np.iterable(connectionstyle):
+ connectionstyle = list(connectionstyle)
+ else:
+ msg = "draw_networkx_edges arg `connectionstyle` must be str or iterable"
+ raise nx.NetworkXError(msg)
+
+ # Some kwargs only apply to FancyArrowPatches. Warn users when they use
+ # non-default values for these kwargs when LineCollection is being used
+ # instead of silently ignoring the specified option
+ if use_linecollection:
+ msg = (
+ "\n\nThe {0} keyword argument is not applicable when drawing edges\n"
+ "with LineCollection.\n\n"
+ "To make this warning go away, either specify `arrows=True` to\n"
+ "force FancyArrowPatches or use the default values.\n"
+ "Note that using FancyArrowPatches may be slow for large graphs.\n"
+ )
+ if arrowstyle is not None:
+ warnings.warn(msg.format("arrowstyle"), category=UserWarning, stacklevel=2)
+ if arrowsize != 10:
+ warnings.warn(msg.format("arrowsize"), category=UserWarning, stacklevel=2)
+ if min_source_margin != 0:
+ warnings.warn(
+ msg.format("min_source_margin"), category=UserWarning, stacklevel=2
+ )
+ if min_target_margin != 0:
+ warnings.warn(
+ msg.format("min_target_margin"), category=UserWarning, stacklevel=2
+ )
+ if any(cs != "arc3" for cs in connectionstyle):
+ warnings.warn(
+ msg.format("connectionstyle"), category=UserWarning, stacklevel=2
+ )
+
+ # NOTE: Arrowstyle modification must occur after the warnings section
+ if arrowstyle is None:
+ arrowstyle = "-|>" if G.is_directed() else "-"
+
+ if ax is None:
+ ax = plt.gca()
+
+ if edgelist is None:
+ edgelist = list(G.edges) # (u, v, k) for multigraph (u, v) otherwise
+
+ if len(edgelist):
+ if G.is_multigraph():
+ key_count = collections.defaultdict(lambda: itertools.count(0))
+ edge_indices = [next(key_count[tuple(e[:2])]) for e in edgelist]
+ else:
+ edge_indices = [0] * len(edgelist)
+ else: # no edges!
+ return []
+
+ if nodelist is None:
+ nodelist = list(G.nodes())
+
+ # FancyArrowPatch handles color=None different from LineCollection
+ if edge_color is None:
+ edge_color = "k"
+
+ # set edge positions
+ edge_pos = np.asarray([(pos[e[0]], pos[e[1]]) for e in edgelist])
+
+ # Check if edge_color is an array of floats and map to edge_cmap.
+ # This is the only case handled differently from matplotlib
+ if (
+ np.iterable(edge_color)
+ and (len(edge_color) == len(edge_pos))
+ and np.all([isinstance(c, Number) for c in edge_color])
+ ):
+ if edge_cmap is not None:
+ assert isinstance(edge_cmap, mpl.colors.Colormap)
+ else:
+ edge_cmap = plt.get_cmap()
+ if edge_vmin is None:
+ edge_vmin = min(edge_color)
+ if edge_vmax is None:
+ edge_vmax = max(edge_color)
+ color_normal = mpl.colors.Normalize(vmin=edge_vmin, vmax=edge_vmax)
+ edge_color = [edge_cmap(color_normal(e)) for e in edge_color]
+
+ # compute initial view
+ minx = np.amin(np.ravel(edge_pos[:, :, 0]))
+ maxx = np.amax(np.ravel(edge_pos[:, :, 0]))
+ miny = np.amin(np.ravel(edge_pos[:, :, 1]))
+ maxy = np.amax(np.ravel(edge_pos[:, :, 1]))
+ w = maxx - minx
+ h = maxy - miny
+
+ # Self-loops are scaled by view extent, except in cases the extent
+ # is 0, e.g. for a single node. In this case, fall back to scaling
+ # by the maximum node size
+ selfloop_height = h if h != 0 else 0.005 * np.array(node_size).max()
+ fancy_arrow_factory = FancyArrowFactory(
+ edge_pos,
+ edgelist,
+ nodelist,
+ edge_indices,
+ node_size,
+ selfloop_height,
+ connectionstyle,
+ node_shape,
+ arrowstyle,
+ arrowsize,
+ edge_color,
+ alpha,
+ width,
+ style,
+ min_source_margin,
+ min_target_margin,
+ ax=ax,
+ )
+
+ # Draw the edges
+ if use_linecollection:
+ edge_collection = mpl.collections.LineCollection(
+ edge_pos,
+ colors=edge_color,
+ linewidths=width,
+ antialiaseds=(1,),
+ linestyle=style,
+ alpha=alpha,
+ )
+ edge_collection.set_cmap(edge_cmap)
+ edge_collection.set_clim(edge_vmin, edge_vmax)
+ edge_collection.set_zorder(1) # edges go behind nodes
+ edge_collection.set_label(label)
+ ax.add_collection(edge_collection)
+ edge_viz_obj = edge_collection
+
+ # Make sure selfloop edges are also drawn
+ # ---------------------------------------
+ selfloops_to_draw = [loop for loop in nx.selfloop_edges(G) if loop in edgelist]
+ if selfloops_to_draw:
+ edgelist_tuple = list(map(tuple, edgelist))
+ arrow_collection = []
+ for loop in selfloops_to_draw:
+ i = edgelist_tuple.index(loop)
+ arrow = fancy_arrow_factory(i)
+ arrow_collection.append(arrow)
+ ax.add_patch(arrow)
+ else:
+ edge_viz_obj = []
+ for i in range(len(edgelist)):
+ arrow = fancy_arrow_factory(i)
+ ax.add_patch(arrow)
+ edge_viz_obj.append(arrow)
+
+ # update view after drawing
+ padx, pady = 0.05 * w, 0.05 * h
+ corners = (minx - padx, miny - pady), (maxx + padx, maxy + pady)
+ ax.update_datalim(corners)
+ ax.autoscale_view()
+
+ if hide_ticks:
+ ax.tick_params(
+ axis="both",
+ which="both",
+ bottom=False,
+ left=False,
+ labelbottom=False,
+ labelleft=False,
+ )
+
+ return edge_viz_obj
+
+
+def draw_networkx_labels(
+ G,
+ pos,
+ labels=None,
+ font_size=12,
+ font_color="k",
+ font_family="sans-serif",
+ font_weight="normal",
+ alpha=None,
+ bbox=None,
+ horizontalalignment="center",
+ verticalalignment="center",
+ ax=None,
+ clip_on=True,
+ hide_ticks=True,
+):
+ """Draw node labels on the graph G.
+
+ Parameters
+ ----------
+ G : graph
+ A networkx graph
+
+ pos : dictionary
+ A dictionary with nodes as keys and positions as values.
+ Positions should be sequences of length 2.
+
+ labels : dictionary (default={n: n for n in G})
+ Node labels in a dictionary of text labels keyed by node.
+ Node-keys in labels should appear as keys in `pos`.
+ If needed use: `{n:lab for n,lab in labels.items() if n in pos}`
+
+ font_size : int or dictionary of nodes to ints (default=12)
+ Font size for text labels.
+
+ font_color : color or dictionary of nodes to colors (default='k' black)
+ Font color string. Color can be string or rgb (or rgba) tuple of
+ floats from 0-1.
+
+ font_weight : string or dictionary of nodes to strings (default='normal')
+ Font weight.
+
+ font_family : string or dictionary of nodes to strings (default='sans-serif')
+ Font family.
+
+ alpha : float or None or dictionary of nodes to floats (default=None)
+ The text transparency.
+
+ bbox : Matplotlib bbox, (default is Matplotlib's ax.text default)
+ Specify text box properties (e.g. shape, color etc.) for node labels.
+
+ horizontalalignment : string or array of strings (default='center')
+ Horizontal alignment {'center', 'right', 'left'}. If an array is
+ specified it must be the same length as `nodelist`.
+
+ verticalalignment : string (default='center')
+ Vertical alignment {'center', 'top', 'bottom', 'baseline', 'center_baseline'}.
+ If an array is specified it must be the same length as `nodelist`.
+
+ ax : Matplotlib Axes object, optional
+ Draw the graph in the specified Matplotlib axes.
+
+ clip_on : bool (default=True)
+ Turn on clipping of node labels at axis boundaries
+
+ hide_ticks : bool, optional
+ Hide ticks of axes. When `True` (the default), ticks and ticklabels
+ are removed from the axes. To set ticks and tick labels to the pyplot default,
+ use ``hide_ticks=False``.
+
+ Returns
+ -------
+ dict
+ `dict` of labels keyed on the nodes
+
+ Examples
+ --------
+ >>> G = nx.dodecahedral_graph()
+ >>> labels = nx.draw_networkx_labels(G, pos=nx.spring_layout(G))
+
+ Also see the NetworkX drawing examples at
+ https://networkx.org/documentation/latest/auto_examples/index.html
+
+ See Also
+ --------
+ draw
+ draw_networkx
+ draw_networkx_nodes
+ draw_networkx_edges
+ draw_networkx_edge_labels
+ """
+ import matplotlib.pyplot as plt
+
+ if ax is None:
+ ax = plt.gca()
+
+ if labels is None:
+ labels = {n: n for n in G.nodes()}
+
+ individual_params = set()
+
+ def check_individual_params(p_value, p_name):
+ if isinstance(p_value, dict):
+ if len(p_value) != len(labels):
+ raise ValueError(f"{p_name} must have the same length as labels.")
+ individual_params.add(p_name)
+
+ def get_param_value(node, p_value, p_name):
+ if p_name in individual_params:
+ return p_value[node]
+ return p_value
+
+ check_individual_params(font_size, "font_size")
+ check_individual_params(font_color, "font_color")
+ check_individual_params(font_weight, "font_weight")
+ check_individual_params(font_family, "font_family")
+ check_individual_params(alpha, "alpha")
+
+ text_items = {} # there is no text collection so we'll fake one
+ for n, label in labels.items():
+ (x, y) = pos[n]
+ if not isinstance(label, str):
+ label = str(label) # this makes "1" and 1 labeled the same
+ t = ax.text(
+ x,
+ y,
+ label,
+ size=get_param_value(n, font_size, "font_size"),
+ color=get_param_value(n, font_color, "font_color"),
+ family=get_param_value(n, font_family, "font_family"),
+ weight=get_param_value(n, font_weight, "font_weight"),
+ alpha=get_param_value(n, alpha, "alpha"),
+ horizontalalignment=horizontalalignment,
+ verticalalignment=verticalalignment,
+ transform=ax.transData,
+ bbox=bbox,
+ clip_on=clip_on,
+ )
+ text_items[n] = t
+
+ if hide_ticks:
+ ax.tick_params(
+ axis="both",
+ which="both",
+ bottom=False,
+ left=False,
+ labelbottom=False,
+ labelleft=False,
+ )
+
+ return text_items
+
+
+def draw_networkx_edge_labels(
+ G,
+ pos,
+ edge_labels=None,
+ label_pos=0.5,
+ font_size=10,
+ font_color="k",
+ font_family="sans-serif",
+ font_weight="normal",
+ alpha=None,
+ bbox=None,
+ horizontalalignment="center",
+ verticalalignment="center",
+ ax=None,
+ rotate=True,
+ clip_on=True,
+ node_size=300,
+ nodelist=None,
+ connectionstyle="arc3",
+ hide_ticks=True,
+):
+ """Draw edge labels.
+
+ Parameters
+ ----------
+ G : graph
+ A networkx graph
+
+ pos : dictionary
+ A dictionary with nodes as keys and positions as values.
+ Positions should be sequences of length 2.
+
+ edge_labels : dictionary (default=None)
+ Edge labels in a dictionary of labels keyed by edge two-tuple.
+ Only labels for the keys in the dictionary are drawn.
+
+ label_pos : float (default=0.5)
+ Position of edge label along edge (0=head, 0.5=center, 1=tail)
+
+ font_size : int (default=10)
+ Font size for text labels
+
+ font_color : color (default='k' black)
+ Font color string. Color can be string or rgb (or rgba) tuple of
+ floats from 0-1.
+
+ font_weight : string (default='normal')
+ Font weight
+
+ font_family : string (default='sans-serif')
+ Font family
+
+ alpha : float or None (default=None)
+ The text transparency
+
+ bbox : Matplotlib bbox, optional
+ Specify text box properties (e.g. shape, color etc.) for edge labels.
+ Default is {boxstyle='round', ec=(1.0, 1.0, 1.0), fc=(1.0, 1.0, 1.0)}.
+
+ horizontalalignment : string (default='center')
+ Horizontal alignment {'center', 'right', 'left'}
+
+ verticalalignment : string (default='center')
+ Vertical alignment {'center', 'top', 'bottom', 'baseline', 'center_baseline'}
+
+ ax : Matplotlib Axes object, optional
+ Draw the graph in the specified Matplotlib axes.
+
+ rotate : bool (default=True)
+ Rotate edge labels to lie parallel to edges
+
+ clip_on : bool (default=True)
+ Turn on clipping of edge labels at axis boundaries
+
+ node_size : scalar or array (default=300)
+ Size of nodes. If an array it must be the same length as nodelist.
+
+ nodelist : list, optional (default=G.nodes())
+ This provides the node order for the `node_size` array (if it is an array).
+
+ connectionstyle : string or iterable of strings (default="arc3")
+ Pass the connectionstyle parameter to create curved arc of rounding
+ radius rad. For example, connectionstyle='arc3,rad=0.2'.
+ See `matplotlib.patches.ConnectionStyle` and
+ `matplotlib.patches.FancyArrowPatch` for more info.
+ If Iterable, index indicates i'th edge key of MultiGraph
+
+ hide_ticks : bool, optional
+ Hide ticks of axes. When `True` (the default), ticks and ticklabels
+ are removed from the axes. To set ticks and tick labels to the pyplot default,
+ use ``hide_ticks=False``.
+
+ Returns
+ -------
+ dict
+ `dict` of labels keyed by edge
+
+ Examples
+ --------
+ >>> G = nx.dodecahedral_graph()
+ >>> edge_labels = nx.draw_networkx_edge_labels(G, pos=nx.spring_layout(G))
+
+ Also see the NetworkX drawing examples at
+ https://networkx.org/documentation/latest/auto_examples/index.html
+
+ See Also
+ --------
+ draw
+ draw_networkx
+ draw_networkx_nodes
+ draw_networkx_edges
+ draw_networkx_labels
+ """
+ import matplotlib as mpl
+ import matplotlib.pyplot as plt
+ import numpy as np
+
+ class CurvedArrowText(mpl.text.Text):
+ def __init__(
+ self,
+ arrow,
+ *args,
+ label_pos=0.5,
+ labels_horizontal=False,
+ ax=None,
+ **kwargs,
+ ):
+ # Bind to FancyArrowPatch
+ self.arrow = arrow
+ # how far along the text should be on the curve,
+ # 0 is at start, 1 is at end etc.
+ self.label_pos = label_pos
+ self.labels_horizontal = labels_horizontal
+ if ax is None:
+ ax = plt.gca()
+ self.ax = ax
+ self.x, self.y, self.angle = self._update_text_pos_angle(arrow)
+
+ # Create text object
+ super().__init__(self.x, self.y, *args, rotation=self.angle, **kwargs)
+ # Bind to axis
+ self.ax.add_artist(self)
+
+ def _get_arrow_path_disp(self, arrow):
+ """
+ This is part of FancyArrowPatch._get_path_in_displaycoord
+ It omits the second part of the method where path is converted
+ to polygon based on width
+ The transform is taken from ax, not the object, as the object
+ has not been added yet, and doesn't have transform
+ """
+ dpi_cor = arrow._dpi_cor
+ # trans_data = arrow.get_transform()
+ trans_data = self.ax.transData
+ if arrow._posA_posB is not None:
+ posA = arrow._convert_xy_units(arrow._posA_posB[0])
+ posB = arrow._convert_xy_units(arrow._posA_posB[1])
+ (posA, posB) = trans_data.transform((posA, posB))
+ _path = arrow.get_connectionstyle()(
+ posA,
+ posB,
+ patchA=arrow.patchA,
+ patchB=arrow.patchB,
+ shrinkA=arrow.shrinkA * dpi_cor,
+ shrinkB=arrow.shrinkB * dpi_cor,
+ )
+ else:
+ _path = trans_data.transform_path(arrow._path_original)
+ # Return is in display coordinates
+ return _path
+
+ def _update_text_pos_angle(self, arrow):
+ # Fractional label position
+ path_disp = self._get_arrow_path_disp(arrow)
+ (x1, y1), (cx, cy), (x2, y2) = path_disp.vertices
+ # Text position at a proportion t along the line in display coords
+ # default is 0.5 so text appears at the halfway point
+ t = self.label_pos
+ tt = 1 - t
+ x = tt**2 * x1 + 2 * t * tt * cx + t**2 * x2
+ y = tt**2 * y1 + 2 * t * tt * cy + t**2 * y2
+ if self.labels_horizontal:
+ # Horizontal text labels
+ angle = 0
+ else:
+ # Labels parallel to curve
+ change_x = 2 * tt * (cx - x1) + 2 * t * (x2 - cx)
+ change_y = 2 * tt * (cy - y1) + 2 * t * (y2 - cy)
+ angle = (np.arctan2(change_y, change_x) / (2 * np.pi)) * 360
+ # Text is "right way up"
+ if angle > 90:
+ angle -= 180
+ if angle < -90:
+ angle += 180
+ (x, y) = self.ax.transData.inverted().transform((x, y))
+ return x, y, angle
+
+ def draw(self, renderer):
+ # recalculate the text position and angle
+ self.x, self.y, self.angle = self._update_text_pos_angle(self.arrow)
+ self.set_position((self.x, self.y))
+ self.set_rotation(self.angle)
+ # redraw text
+ super().draw(renderer)
+
+ # use default box of white with white border
+ if bbox is None:
+ bbox = {"boxstyle": "round", "ec": (1.0, 1.0, 1.0), "fc": (1.0, 1.0, 1.0)}
+
+ if isinstance(connectionstyle, str):
+ connectionstyle = [connectionstyle]
+ elif np.iterable(connectionstyle):
+ connectionstyle = list(connectionstyle)
+ else:
+ raise nx.NetworkXError(
+ "draw_networkx_edges arg `connectionstyle` must be"
+ "string or iterable of strings"
+ )
+
+ if ax is None:
+ ax = plt.gca()
+
+ if edge_labels is None:
+ kwds = {"keys": True} if G.is_multigraph() else {}
+ edge_labels = {tuple(edge): d for *edge, d in G.edges(data=True, **kwds)}
+ # NOTHING TO PLOT
+ if not edge_labels:
+ return {}
+ edgelist, labels = zip(*edge_labels.items())
+
+ if nodelist is None:
+ nodelist = list(G.nodes())
+
+ # set edge positions
+ edge_pos = np.asarray([(pos[e[0]], pos[e[1]]) for e in edgelist])
+
+ if G.is_multigraph():
+ key_count = collections.defaultdict(lambda: itertools.count(0))
+ edge_indices = [next(key_count[tuple(e[:2])]) for e in edgelist]
+ else:
+ edge_indices = [0] * len(edgelist)
+
+ # Used to determine self loop mid-point
+ # Note, that this will not be accurate,
+ # if not drawing edge_labels for all edges drawn
+ h = 0
+ if edge_labels:
+ miny = np.amin(np.ravel(edge_pos[:, :, 1]))
+ maxy = np.amax(np.ravel(edge_pos[:, :, 1]))
+ h = maxy - miny
+ selfloop_height = h if h != 0 else 0.005 * np.array(node_size).max()
+ fancy_arrow_factory = FancyArrowFactory(
+ edge_pos,
+ edgelist,
+ nodelist,
+ edge_indices,
+ node_size,
+ selfloop_height,
+ connectionstyle,
+ ax=ax,
+ )
+
+ individual_params = {}
+
+ def check_individual_params(p_value, p_name):
+ # TODO should this be list or array (as in a numpy array)?
+ if isinstance(p_value, list):
+ if len(p_value) != len(edgelist):
+ raise ValueError(f"{p_name} must have the same length as edgelist.")
+ individual_params[p_name] = p_value.iter()
+
+ # Don't need to pass in an edge because these are lists, not dicts
+ def get_param_value(p_value, p_name):
+ if p_name in individual_params:
+ return next(individual_params[p_name])
+ return p_value
+
+ check_individual_params(font_size, "font_size")
+ check_individual_params(font_color, "font_color")
+ check_individual_params(font_weight, "font_weight")
+ check_individual_params(alpha, "alpha")
+ check_individual_params(horizontalalignment, "horizontalalignment")
+ check_individual_params(verticalalignment, "verticalalignment")
+ check_individual_params(rotate, "rotate")
+ check_individual_params(label_pos, "label_pos")
+
+ text_items = {}
+ for i, (edge, label) in enumerate(zip(edgelist, labels)):
+ if not isinstance(label, str):
+ label = str(label) # this makes "1" and 1 labeled the same
+
+ n1, n2 = edge[:2]
+ arrow = fancy_arrow_factory(i)
+ if n1 == n2:
+ connectionstyle_obj = arrow.get_connectionstyle()
+ posA = ax.transData.transform(pos[n1])
+ path_disp = connectionstyle_obj(posA, posA)
+ path_data = ax.transData.inverted().transform_path(path_disp)
+ x, y = path_data.vertices[0]
+ text_items[edge] = ax.text(
+ x,
+ y,
+ label,
+ size=get_param_value(font_size, "font_size"),
+ color=get_param_value(font_color, "font_color"),
+ family=get_param_value(font_family, "font_family"),
+ weight=get_param_value(font_weight, "font_weight"),
+ alpha=get_param_value(alpha, "alpha"),
+ horizontalalignment=get_param_value(
+ horizontalalignment, "horizontalalignment"
+ ),
+ verticalalignment=get_param_value(
+ verticalalignment, "verticalalignment"
+ ),
+ rotation=0,
+ transform=ax.transData,
+ bbox=bbox,
+ zorder=1,
+ clip_on=clip_on,
+ )
+ else:
+ text_items[edge] = CurvedArrowText(
+ arrow,
+ label,
+ size=get_param_value(font_size, "font_size"),
+ color=get_param_value(font_color, "font_color"),
+ family=get_param_value(font_family, "font_family"),
+ weight=get_param_value(font_weight, "font_weight"),
+ alpha=get_param_value(alpha, "alpha"),
+ horizontalalignment=get_param_value(
+ horizontalalignment, "horizontalalignment"
+ ),
+ verticalalignment=get_param_value(
+ verticalalignment, "verticalalignment"
+ ),
+ transform=ax.transData,
+ bbox=bbox,
+ zorder=1,
+ clip_on=clip_on,
+ label_pos=get_param_value(label_pos, "label_pos"),
+ labels_horizontal=not get_param_value(rotate, "rotate"),
+ ax=ax,
+ )
+
+ if hide_ticks:
+ ax.tick_params(
+ axis="both",
+ which="both",
+ bottom=False,
+ left=False,
+ labelbottom=False,
+ labelleft=False,
+ )
+
+ return text_items
+
+
+def draw_circular(G, **kwargs):
+ """Draw the graph `G` with a circular layout.
+
+ This is a convenience function equivalent to::
+
+ nx.draw(G, pos=nx.circular_layout(G), **kwargs)
+
+ Parameters
+ ----------
+ G : graph
+ A networkx graph
+
+ kwargs : optional keywords
+ See `draw_networkx` for a description of optional keywords.
+
+ Notes
+ -----
+ The layout is computed each time this function is called. For
+ repeated drawing it is much more efficient to call
+ `~networkx.drawing.layout.circular_layout` directly and reuse the result::
+
+ >>> G = nx.complete_graph(5)
+ >>> pos = nx.circular_layout(G)
+ >>> nx.draw(G, pos=pos) # Draw the original graph
+ >>> # Draw a subgraph, reusing the same node positions
+ >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red")
+
+ Examples
+ --------
+ >>> G = nx.path_graph(5)
+ >>> nx.draw_circular(G)
+
+ See Also
+ --------
+ :func:`~networkx.drawing.layout.circular_layout`
+ """
+ draw(G, circular_layout(G), **kwargs)
+
+
+def draw_kamada_kawai(G, **kwargs):
+ """Draw the graph `G` with a Kamada-Kawai force-directed layout.
+
+ This is a convenience function equivalent to::
+
+ nx.draw(G, pos=nx.kamada_kawai_layout(G), **kwargs)
+
+ Parameters
+ ----------
+ G : graph
+ A networkx graph
+
+ kwargs : optional keywords
+ See `draw_networkx` for a description of optional keywords.
+
+ Notes
+ -----
+ The layout is computed each time this function is called.
+ For repeated drawing it is much more efficient to call
+ `~networkx.drawing.layout.kamada_kawai_layout` directly and reuse the
+ result::
+
+ >>> G = nx.complete_graph(5)
+ >>> pos = nx.kamada_kawai_layout(G)
+ >>> nx.draw(G, pos=pos) # Draw the original graph
+ >>> # Draw a subgraph, reusing the same node positions
+ >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red")
+
+ Examples
+ --------
+ >>> G = nx.path_graph(5)
+ >>> nx.draw_kamada_kawai(G)
+
+ See Also
+ --------
+ :func:`~networkx.drawing.layout.kamada_kawai_layout`
+ """
+ draw(G, kamada_kawai_layout(G), **kwargs)
+
+
+def draw_random(G, **kwargs):
+ """Draw the graph `G` with a random layout.
+
+ This is a convenience function equivalent to::
+
+ nx.draw(G, pos=nx.random_layout(G), **kwargs)
+
+ Parameters
+ ----------
+ G : graph
+ A networkx graph
+
+ kwargs : optional keywords
+ See `draw_networkx` for a description of optional keywords.
+
+ Notes
+ -----
+ The layout is computed each time this function is called.
+ For repeated drawing it is much more efficient to call
+ `~networkx.drawing.layout.random_layout` directly and reuse the result::
+
+ >>> G = nx.complete_graph(5)
+ >>> pos = nx.random_layout(G)
+ >>> nx.draw(G, pos=pos) # Draw the original graph
+ >>> # Draw a subgraph, reusing the same node positions
+ >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red")
+
+ Examples
+ --------
+ >>> G = nx.lollipop_graph(4, 3)
+ >>> nx.draw_random(G)
+
+ See Also
+ --------
+ :func:`~networkx.drawing.layout.random_layout`
+ """
+ draw(G, random_layout(G), **kwargs)
+
+
+def draw_spectral(G, **kwargs):
+ """Draw the graph `G` with a spectral 2D layout.
+
+ This is a convenience function equivalent to::
+
+ nx.draw(G, pos=nx.spectral_layout(G), **kwargs)
+
+ For more information about how node positions are determined, see
+ `~networkx.drawing.layout.spectral_layout`.
+
+ Parameters
+ ----------
+ G : graph
+ A networkx graph
+
+ kwargs : optional keywords
+ See `draw_networkx` for a description of optional keywords.
+
+ Notes
+ -----
+ The layout is computed each time this function is called.
+ For repeated drawing it is much more efficient to call
+ `~networkx.drawing.layout.spectral_layout` directly and reuse the result::
+
+ >>> G = nx.complete_graph(5)
+ >>> pos = nx.spectral_layout(G)
+ >>> nx.draw(G, pos=pos) # Draw the original graph
+ >>> # Draw a subgraph, reusing the same node positions
+ >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red")
+
+ Examples
+ --------
+ >>> G = nx.path_graph(5)
+ >>> nx.draw_spectral(G)
+
+ See Also
+ --------
+ :func:`~networkx.drawing.layout.spectral_layout`
+ """
+ draw(G, spectral_layout(G), **kwargs)
+
+
+def draw_spring(G, **kwargs):
+ """Draw the graph `G` with a spring layout.
+
+ This is a convenience function equivalent to::
+
+ nx.draw(G, pos=nx.spring_layout(G), **kwargs)
+
+ Parameters
+ ----------
+ G : graph
+ A networkx graph
+
+ kwargs : optional keywords
+ See `draw_networkx` for a description of optional keywords.
+
+ Notes
+ -----
+ `~networkx.drawing.layout.spring_layout` is also the default layout for
+ `draw`, so this function is equivalent to `draw`.
+
+ The layout is computed each time this function is called.
+ For repeated drawing it is much more efficient to call
+ `~networkx.drawing.layout.spring_layout` directly and reuse the result::
+
+ >>> G = nx.complete_graph(5)
+ >>> pos = nx.spring_layout(G)
+ >>> nx.draw(G, pos=pos) # Draw the original graph
+ >>> # Draw a subgraph, reusing the same node positions
+ >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red")
+
+ Examples
+ --------
+ >>> G = nx.path_graph(20)
+ >>> nx.draw_spring(G)
+
+ See Also
+ --------
+ draw
+ :func:`~networkx.drawing.layout.spring_layout`
+ """
+ draw(G, spring_layout(G), **kwargs)
+
+
+def draw_shell(G, nlist=None, **kwargs):
+ """Draw networkx graph `G` with shell layout.
+
+ This is a convenience function equivalent to::
+
+ nx.draw(G, pos=nx.shell_layout(G, nlist=nlist), **kwargs)
+
+ Parameters
+ ----------
+ G : graph
+ A networkx graph
+
+ nlist : list of list of nodes, optional
+ A list containing lists of nodes representing the shells.
+ Default is `None`, meaning all nodes are in a single shell.
+ See `~networkx.drawing.layout.shell_layout` for details.
+
+ kwargs : optional keywords
+ See `draw_networkx` for a description of optional keywords.
+
+ Notes
+ -----
+ The layout is computed each time this function is called.
+ For repeated drawing it is much more efficient to call
+ `~networkx.drawing.layout.shell_layout` directly and reuse the result::
+
+ >>> G = nx.complete_graph(5)
+ >>> pos = nx.shell_layout(G)
+ >>> nx.draw(G, pos=pos) # Draw the original graph
+ >>> # Draw a subgraph, reusing the same node positions
+ >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red")
+
+ Examples
+ --------
+ >>> G = nx.path_graph(4)
+ >>> shells = [[0], [1, 2, 3]]
+ >>> nx.draw_shell(G, nlist=shells)
+
+ See Also
+ --------
+ :func:`~networkx.drawing.layout.shell_layout`
+ """
+ draw(G, shell_layout(G, nlist=nlist), **kwargs)
+
+
+def draw_planar(G, **kwargs):
+ """Draw a planar networkx graph `G` with planar layout.
+
+ This is a convenience function equivalent to::
+
+ nx.draw(G, pos=nx.planar_layout(G), **kwargs)
+
+ Parameters
+ ----------
+ G : graph
+ A planar networkx graph
+
+ kwargs : optional keywords
+ See `draw_networkx` for a description of optional keywords.
+
+ Raises
+ ------
+ NetworkXException
+ When `G` is not planar
+
+ Notes
+ -----
+ The layout is computed each time this function is called.
+ For repeated drawing it is much more efficient to call
+ `~networkx.drawing.layout.planar_layout` directly and reuse the result::
+
+ >>> G = nx.path_graph(5)
+ >>> pos = nx.planar_layout(G)
+ >>> nx.draw(G, pos=pos) # Draw the original graph
+ >>> # Draw a subgraph, reusing the same node positions
+ >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red")
+
+ Examples
+ --------
+ >>> G = nx.path_graph(4)
+ >>> nx.draw_planar(G)
+
+ See Also
+ --------
+ :func:`~networkx.drawing.layout.planar_layout`
+ """
+ draw(G, planar_layout(G), **kwargs)
+
+
+def draw_forceatlas2(G, **kwargs):
+ """Draw a networkx graph with forceatlas2 layout.
+
+ This is a convenience function equivalent to::
+
+ nx.draw(G, pos=nx.forceatlas2_layout(G), **kwargs)
+
+ Parameters
+ ----------
+ G : graph
+ A networkx graph
+
+ kwargs : optional keywords
+ See networkx.draw_networkx() for a description of optional keywords,
+ with the exception of the pos parameter which is not used by this
+ function.
+ """
+ draw(G, forceatlas2_layout(G), **kwargs)
+
+
+def apply_alpha(colors, alpha, elem_list, cmap=None, vmin=None, vmax=None):
+ """Apply an alpha (or list of alphas) to the colors provided.
+
+ Parameters
+ ----------
+
+ colors : color string or array of floats (default='r')
+ Color of element. Can be a single color format string,
+ or a sequence of colors with the same length as nodelist.
+ If numeric values are specified they will be mapped to
+ colors using the cmap and vmin,vmax parameters. See
+ matplotlib.scatter for more details.
+
+ alpha : float or array of floats
+ Alpha values for elements. This can be a single alpha value, in
+ which case it will be applied to all the elements of color. Otherwise,
+ if it is an array, the elements of alpha will be applied to the colors
+ in order (cycling through alpha multiple times if necessary).
+
+ elem_list : array of networkx objects
+ The list of elements which are being colored. These could be nodes,
+ edges or labels.
+
+ cmap : matplotlib colormap
+ Color map for use if colors is a list of floats corresponding to points
+ on a color mapping.
+
+ vmin, vmax : float
+ Minimum and maximum values for normalizing colors if a colormap is used
+
+ Returns
+ -------
+
+ rgba_colors : numpy ndarray
+ Array containing RGBA format values for each of the node colours.
+
+ """
+ from itertools import cycle, islice
+
+ import matplotlib as mpl
+ import matplotlib.cm # call as mpl.cm
+ import matplotlib.colors # call as mpl.colors
+ import numpy as np
+
+ # If we have been provided with a list of numbers as long as elem_list,
+ # apply the color mapping.
+ if len(colors) == len(elem_list) and isinstance(colors[0], Number):
+ mapper = mpl.cm.ScalarMappable(cmap=cmap)
+ mapper.set_clim(vmin, vmax)
+ rgba_colors = mapper.to_rgba(colors)
+ # Otherwise, convert colors to matplotlib's RGB using the colorConverter
+ # object. These are converted to numpy ndarrays to be consistent with the
+ # to_rgba method of ScalarMappable.
+ else:
+ try:
+ rgba_colors = np.array([mpl.colors.colorConverter.to_rgba(colors)])
+ except ValueError:
+ rgba_colors = np.array(
+ [mpl.colors.colorConverter.to_rgba(color) for color in colors]
+ )
+ # Set the final column of the rgba_colors to have the relevant alpha values
+ try:
+ # If alpha is longer than the number of colors, resize to the number of
+ # elements. Also, if rgba_colors.size (the number of elements of
+ # rgba_colors) is the same as the number of elements, resize the array,
+ # to avoid it being interpreted as a colormap by scatter()
+ if len(alpha) > len(rgba_colors) or rgba_colors.size == len(elem_list):
+ rgba_colors = np.resize(rgba_colors, (len(elem_list), 4))
+ rgba_colors[1:, 0] = rgba_colors[0, 0]
+ rgba_colors[1:, 1] = rgba_colors[0, 1]
+ rgba_colors[1:, 2] = rgba_colors[0, 2]
+ rgba_colors[:, 3] = list(islice(cycle(alpha), len(rgba_colors)))
+ except TypeError:
+ rgba_colors[:, -1] = alpha
+ return rgba_colors