"""Extends visualization features."""
from __future__ import annotations
from functools import wraps
from itertools import repeat
from math import ceil
from typing import TYPE_CHECKING, Any
import igraph as ig
import numpy as np
from igraph.drawing.colors import (
PrecalculatedPalette,
color_name_to_rgb,
color_name_to_rgba,
darken,
lighten,
)
from matplotlib.pyplot import subplots
from pandas import Series
from wasabi import msg
import textnets as tn
if TYPE_CHECKING:
from collections.abc import Callable, Iterator
from matplotlib.artist import Artist
from matplotlib.figure import Figure
#: Base colors for textnets color palette.
BASE_COLORS = [
"tomato",
"darkseagreen",
"slateblue",
"gold",
"orchid",
"springgreen",
"dodgerblue",
]
_LAST_FIG: Figure | None = None
[docs]
class TextnetPalette(PrecalculatedPalette):
"""Color palette for textnets."""
def __init__(self, n: int) -> None:
base_colors = [color_name_to_rgba(c) for c in BASE_COLORS]
num_base_colors = len(base_colors)
colors = base_colors[:]
blocks_to_add = ceil((n - num_base_colors) / num_base_colors)
ratio_increment = 1.0 / (ceil(blocks_to_add / 2.0) + 1)
adding_darker = False
ratio = ratio_increment
while len(colors) < n:
if adding_darker:
new_block = [darken(color, ratio) for color in base_colors]
else:
new_block = [lighten(color, ratio) for color in base_colors]
ratio += ratio_increment
colors.extend(new_block)
adding_darker = not adding_darker
colors = colors[:n]
super().__init__(colors)
[docs]
def decorate_plot(plot_func: Callable) -> Callable:
"""Style the plot produced by igraph's plot function."""
# Set CJK font
try:
import mpl_font.noto # noqa: F401
except ModuleNotFoundError:
if tn.params["lang"].startswith(("zh", "ja", "ko")):
msg.warn("Could not set CJK font. Set the matplotlib font manually.")
# Produce SVG if running inside a Jupyter notebook
try:
cfg = get_ipython().config # type:ignore
cfg.InlineBackend.figure_formats = ["svg"]
except NameError:
pass
@wraps(plot_func)
def wrapper(net: tn.network.TextnetBase, **kwargs) -> Artist:
graph = net.graph
# Rewrite node_* arguments as vertex_* arguments
node_opts = [k for k in kwargs if k.startswith("node_")]
for opt in node_opts:
val = kwargs.pop(opt)
kwargs[opt.replace("node_", "vertex_")] = val
# Marking and coloring clusters
show_clusters = kwargs.pop("show_clusters", False)
color_clusters = kwargs.pop("color_clusters", False)
if show_clusters:
if isinstance(show_clusters, ig.VertexClustering):
markers = zip(
_cluster_node_indices(show_clusters),
repeat(_add_opacity("limegreen", 0.4)),
)
else:
markers = zip(
_cluster_node_indices(net.clusters),
repeat(_add_opacity("limegreen", 0.4)),
)
kwargs.setdefault("mark_groups", markers)
if color_clusters:
if isinstance(color_clusters, ig.VertexClustering):
kwargs["vertex_color"] = [
TextnetPalette(color_clusters._len)[c]
for c in color_clusters.membership
]
else:
kwargs["vertex_color"] = [
TextnetPalette(net.clusters._len)[c]
for c in net.clusters.membership
]
# Default appearance
kwargs.setdefault("autocurve", True)
kwargs.setdefault("edge_color", "lightgray")
kwargs.setdefault("edge_label_size", 6)
kwargs.setdefault("edge_width", 2)
kwargs.setdefault("vertex_frame_width", 0.25)
kwargs.setdefault("vertex_label_size", 9)
kwargs.setdefault("vertex_size", 20)
kwargs.setdefault("wrap_labels", True)
kwargs.setdefault(
"layout", graph.layout_fruchterman_reingold(weights="weight", grid=False)
)
kwargs.setdefault(
"vertex_color",
["orangered" if v else "dodgerblue" for v in net.node_types],
)
kwargs.setdefault(
"vertex_shape", ["circle" if t else "square" for t in net.node_types]
)
kwargs.setdefault(
"vertex_frame_color",
["black" if t else "white" for t in net.node_types],
)
# Layouts
bipartite_layout = kwargs.pop("bipartite_layout", False)
sugiyama_layout = kwargs.pop("sugiyama_layout", False)
circular_layout = kwargs.pop("circular_layout", False)
kamada_kawai_layout = kwargs.pop("kamada_kawai_layout", False)
drl_layout = kwargs.pop("drl_layout", False)
if bipartite_layout:
layout = graph.layout_bipartite(types=net.node_types)
layout.rotate(90)
kwargs["wrap_labels"] = False
kwargs["layout"] = layout
elif sugiyama_layout:
layout = graph.layout_sugiyama(weights="weight", hgap=50, maxiter=100000)
layout.rotate(270)
kwargs["wrap_labels"] = False
kwargs["layout"] = layout
elif circular_layout:
kwargs["layout"] = graph.layout_reingold_tilford_circular()
elif kamada_kawai_layout:
kwargs["layout"] = graph.layout_kamada_kawai()
elif drl_layout:
kwargs["layout"] = graph.layout_drl(weights="weight")
# Node and edge scaling
PHI = 1.618
scale_nodes_by = kwargs.pop("scale_nodes_by", None)
if scale_nodes_by is not None:
try:
dist = getattr(net, scale_nodes_by)
except AttributeError:
dist = Series(net.nodes[scale_nodes_by])
except TypeError:
dist = Series(scale_nodes_by)
if abs(dist.skew()) < 2:
dist **= 2
norm = (dist - dist.mean()) / dist.std()
basesize = np.array(kwargs.pop("vertex_size"))
mult = basesize / abs(norm).max()
sizes = (norm * mult / PHI + basesize).fillna(0)
kwargs["vertex_size"] = sizes
scale_edges_by = kwargs.pop("scale_edges_by", None)
if scale_edges_by is not None:
if scale_edges_by in net.graph.edge_attributes():
dist = Series(net.edges[scale_edges_by])
else:
dist = Series(scale_edges_by)
if abs(dist.skew()) < 2:
dist **= 2
norm = (dist - dist.mean()) / dist.std()
basewidth = np.array(kwargs.pop("edge_width"))
mult = basewidth / abs(norm).max()
widths = (PHI / 2 * norm * mult + (basewidth * PHI / 2)).fillna(0)
kwargs["edge_width"] = widths
# Node and edge opacity
node_opacity = kwargs.pop("vertex_opacity", None)
edge_opacity = kwargs.pop("edge_opacity", None)
if node_opacity is not None:
kwargs["vertex_color"] = [
_add_opacity(c, node_opacity) for c in kwargs["vertex_color"]
]
if edge_opacity is not None:
kwargs["edge_color"] = [_add_opacity(kwargs["edge_color"], edge_opacity)]
# Node and edge labels
label_doc_nodes = kwargs.pop("label_doc_nodes", False)
label_term_nodes = kwargs.pop("label_term_nodes", False)
label_nodes = kwargs.pop("label_nodes", False)
label_edges = kwargs.pop("label_edges", False)
kwargs.setdefault(
"vertex_label",
[
(
node["id"]
if (node["type"] == "doc" and label_doc_nodes)
or (node["type"] == "term" and label_term_nodes)
or label_nodes
else None
)
for node in net.nodes
],
)
kwargs.setdefault(
"edge_label",
[f"{edge['weight']:.2f}" if label_edges else None for edge in net.edges],
)
# Node and edge label filters
node_label_filter = kwargs.pop("vertex_label_filter", False)
edge_label_filter = kwargs.pop("edge_label_filter", False)
if node_label_filter and "vertex_label" in kwargs:
node_labels = kwargs.pop("vertex_label")
filtered_node_labels = map(node_label_filter, net.nodes)
kwargs["vertex_label"] = [
lbl if keep else None
for lbl, keep in zip(node_labels, filtered_node_labels)
]
if edge_label_filter and "edge_label" in kwargs:
edge_labels = kwargs.pop("edge_label")
filtered_edge_labels = map(edge_label_filter, net.edges)
kwargs["edge_label"] = [
lbl if keep else None
for lbl, keep in zip(edge_labels, filtered_edge_labels)
]
# Use matplotlib
if "target" in kwargs:
msg.warn("Please use plt.savefig to save the network plot.")
fig, ax = subplots(figsize=tn.params["figsize"])
global _LAST_FIG # noqa:PLW0603
_LAST_FIG = fig
kwargs["target"] = ax
return plot_func(net, **kwargs)
return wrapper
def _add_opacity(color: str, alpha: float) -> tuple[Any, ...]:
"""Turn a color name into a RGBA tuple with specified opacity."""
return (*color_name_to_rgb(color), alpha)
def _cluster_node_indices(vc: ig.VertexClustering) -> Iterator[list[int]]:
"""Return node indices for nodes in each cluster."""
for n in range(vc._len):
yield [i for i, x in enumerate(vc.membership) if x == n]
[docs]
def savefig(*args, **kwargs) -> None:
"""
Save the last figure.
Parameters
----------
filename : str or path-like
File or path that the figure should be saved to.
format : str, optional
The file format.
dpi : float, optional
The image resolution.
metadata : dict, optional
Image metadata to store in the file.
"""
if _LAST_FIG is None:
msg.warn("No figure to save.")
else:
_LAST_FIG.savefig(*args, **kwargs)