import numpy as np
import param
from bokeh.models import Patches
from ...core.data import Dataset
from ...core.util import dimension_sanitizer, max_range
from ...util.transform import dim
from .graphs import GraphPlot
[docs]class SankeyPlot(GraphPlot):
labels = param.ClassSelector(class_=(str, dim), doc="""
The dimension or dimension value transform used to draw labels from.""")
label_position = param.ObjectSelector(default='right',
objects=['left', 'right', 'outer', 'inner'],
doc="""
Whether node labels should be placed to the left, right, outer or inner.""")
show_values = param.Boolean(default=True, doc="""
Whether to show the values.""")
show_legend = param.Boolean(default=False, doc="""
Whether to show the values.""")
node_width = param.Number(default=15, doc="""
Width of the nodes.""")
node_padding = param.Integer(default=None, doc="""
Number of pixels of padding relative to the bounds.""")
iterations = param.Integer(default=32, doc="""
Number of iterations to run the layout algorithm.""")
node_sort = param.Boolean(default=True, doc="""
Sort nodes in ascending breadth.""")
width = param.Integer(default=1000, allow_None=True, bounds=(0, None), doc="""
The width of the component (in pixels). This can be either
fixed or preferred width, depending on width sizing policy.""")
height = param.Integer(default=600, allow_None=True, bounds=(0, None), doc="""
The height of the component (in pixels). This can be either
fixed or preferred height, depending on height sizing policy.""")
# Deprecated options
color_index = param.ClassSelector(default=2, class_=(str, int),
allow_None=True, doc="""
Index of the dimension from which the node labels will be drawn""")
label_index = param.ClassSelector(default=2, class_=(str, int),
allow_None=True, doc="""
Index of the dimension from which the node labels will be drawn""")
_style_groups = dict(GraphPlot._style_groups, quad='node', text='label')
_draw_order = ['graph', 'quad_1', 'text_1', 'text_2']
style_opts = GraphPlot.style_opts + ['edge_fill_alpha', 'nodes_line_color',
'label_text_font_size']
filled = True
def _init_glyphs(self, plot, element, ranges, source):
super()._init_glyphs(plot, element, ranges, source)
renderer = plot.renderers.pop(plot.renderers.index(self.handles['glyph_renderer']))
plot.renderers = [renderer] + plot.renderers
arc_renderer = self.handles['quad_1_glyph_renderer']
scatter_renderer = self.handles['scatter_1_glyph_renderer']
arc_renderer.view = scatter_renderer.view
arc_renderer.data_source = scatter_renderer.data_source
self.handles['quad_1_source'] = scatter_renderer.data_source
self._sync_nodes()
[docs] def get_data(self, element, ranges, style):
data, mapping, style = super().get_data(element, ranges, style)
self._compute_quads(element, data, mapping)
style['nodes_line_color'] = 'black'
self._compute_labels(element, data, mapping)
self._patch_hover(element, data)
return data, mapping, style
def _init_glyph(self, plot, mapping, properties, key):
if key == 'quad_1':
properties.pop('size', None)
mapping.pop('size', None)
return super()._init_glyph(plot, mapping, properties, key)
def _update_glyphs(self, element, ranges, style):
self._sync_nodes()
super()._update_glyphs(element, ranges, style)
def _sync_nodes(self):
arc_renderer = self.handles['quad_1_glyph_renderer']
scatter_renderer = self.handles['scatter_1_glyph_renderer']
for gtype in ('selection_', 'nonselection_', 'muted_', 'hover_', ''):
glyph = getattr(scatter_renderer, gtype+'glyph')
arc_glyph = getattr(arc_renderer, gtype+'glyph')
if not glyph or not arc_glyph:
continue
scatter_props = glyph.properties_with_values(include_defaults=False)
styles = {k: v for k, v in scatter_props.items()
if k in arc_glyph.properties()}
arc_glyph.update(**styles)
def _compute_quads(self, element, data, mapping):
"""
Computes the node quad glyph data.x
"""
quad_mapping = {'left': 'x0', 'right': 'x1', 'bottom': 'y0', 'top': 'y1'}
quad_data = dict(data['scatter_1'])
quad_data.update({'x0': [], 'x1': [], 'y0': [], 'y1': []})
for node in element._sankey['nodes']:
quad_data['x0'].append(node['x0'])
quad_data['y0'].append(node['y0'])
quad_data['x1'].append(node['x1'])
quad_data['y1'].append(node['y1'])
data['scatter_1'].update(quad_data)
data['quad_1'] = data['scatter_1']
mapping['quad_1'] = quad_mapping
def _compute_labels(self, element, data, mapping):
"""
Computes labels for the nodes and adds it to the data.
"""
if element.vdims:
edges = Dataset(element)[element[element.vdims[0].name]>0]
nodes = list(np.unique([edges.dimension_values(i) for i in range(2)]))
nodes = element.nodes.select(**{element.nodes.kdims[2].name: nodes})
else:
nodes = element
label_dim = nodes.get_dimension(self.label_index)
labels = self.labels
if label_dim and labels:
if self.label_index not in [2, None]:
self.param.warning(
"Cannot declare style mapping for 'labels' option "
"and declare a label_index; ignoring the label_index.")
elif label_dim:
labels = label_dim
if isinstance(labels, str):
labels = element.nodes.get_dimension(labels)
if labels is None:
text = []
if isinstance(labels, dim):
text = labels.apply(element, flat=True)
else:
text = element.nodes.dimension_values(labels)
text = [labels.pprint_value(v) for v in text]
ys = nodes.dimension_values(1)
nodes = element._sankey['nodes']
if nodes:
offset = (nodes[0]['x1']-nodes[0]['x0'])/4.
else:
offset = 0
value_dim = element.vdims[0]
text_labels = []
is_outer_inner = self.label_position in ['outer', 'inner']
# initialize label x-locations
if self.label_position in ['right', 'outer']:
xs = np.array([node['x1'] for node in nodes]) + offset
else: # ['left', 'inner']
xs = np.array([node['x0'] for node in nodes]) - offset
for i, node in enumerate(nodes):
if len(text):
label = text[i]
else:
label = ''
if self.show_values:
value = value_dim.pprint_value(node['value'], print_unit=True)
if label:
label = f'{label} - {value}'
else:
label = value
if label:
text_labels.append(label)
align = 'left' if self.label_position in ['right', 'outer'] else 'right'
# handle labels that have different alignment due to inner/outer
if is_outer_inner:
xs_2 = xs.copy()
text_labels_2 = text_labels.copy()
for i, node in enumerate(nodes):
if self.label_position == 'outer' and node['x0'] == 0:
text_labels[i] = ''
xs_2[i] = node['x0'] - offset
elif self.label_position == 'inner' and node['x0'] == 0:
text_labels[i] = ''
xs_2[i] = node['x1'] + offset
else:
text_labels_2[i] = ''
align_2 = 'right' if align == 'left' else 'left'
data['text_2'] = dict(x=xs_2, y=ys,
text=[str(l) for l in text_labels_2])
mapping['text_2'] = dict(text='text', x='x', y='y',
text_baseline='middle', text_align=align_2)
data['text_1'] = dict(x=xs, y=ys, text=[str(l) for l in text_labels])
mapping['text_1'] = dict(text='text', x='x', y='y',
text_baseline='middle', text_align=align)
def _patch_hover(self, element, data):
"""
Replace edge start and end hover data with label_index data.
"""
if not (self.inspection_policy == 'edges' and 'hover' in self.handles):
return
lidx = element.nodes.get_dimension(self.label_index)
src, tgt = (dimension_sanitizer(kd.name) for kd in element.kdims[:2])
if src == 'start': src += '_values'
if tgt == 'end': tgt += '_values'
lookup = dict(zip(*(element.nodes.dimension_values(d) for d in (2, lidx))))
src_vals = data['patches_1'][src]
tgt_vals = data['patches_1'][tgt]
data['patches_1'][src] = [lookup.get(v, v) for v in src_vals]
data['patches_1'][tgt] = [lookup.get(v, v) for v in tgt_vals]
[docs] def get_extents(self, element, ranges, range_type='combined', **kwargs):
"""Return the extents of the Sankey box"""
if range_type == 'extents':
return element.nodes.extents
xdim, ydim = element.nodes.kdims[:2]
xpad = .05 if self.label_index is None else 0.25
x0, x1 = ranges[xdim.name][range_type]
y0, y1 = ranges[ydim.name][range_type]
xdiff = (x1-x0)
ydiff = (y1-y0)
if self.label_position in ['right', 'outer']:
x1 = x1 + xpad * xdiff
else: # ['left', 'inner']
x1 = x1 + (0.05 * xdiff)
if self.label_position in ['left', 'outer']:
x0 = x0 - xpad * xdiff
else: # ['right', 'inner']
x0 = x0 - (0.05 * xdiff)
x0, x1 = max_range([xdim.range, (x0, x1)])
y0, y1 = max_range([ydim.range, (y0-(0.05*ydiff), y1+(0.05*ydiff))])
return (x0, y0, x1, y1)
def _postprocess_hover(self, renderer, source):
if self.inspection_policy == 'edges':
if not isinstance(renderer.glyph, Patches):
return
elif isinstance(renderer.glyph, Patches):
return
super()._postprocess_hover(renderer, source)