import copy
import re
import numpy as np
from plotly import colors
from ...core.util import isfinite, max_range
from ..util import color_intervals, process_cmap
# Constants
# ---------
# Trace types that are individually positioned with their own domain.
# These are traces that don't overlay on top of each other in a shared subplot,
# so they are positioned individually. All other trace types are associated
# with a layout subplot type (xaxis/yaxis, polar, scene etc.)
#
# Each of these trace types has a `domain` property with `x`/`y` properties
_domain_trace_types = {'parcoords', 'pie', 'table', 'sankey', 'parcats'}
# Subplot types that are each individually positioned with a domain
#
# Each of these subplot types has a `domain` property with `x`/`y` properties.
# Note that this set does not contain `xaxis`/`yaxis` because these behave a
# little differently.
_subplot_types = {'scene', 'geo', 'polar', 'ternary', 'mapbox'}
# For most subplot types, a trace is associated with a particular subplot
# using a trace property with a name that matches the subplot type. For
# example, a `scatter3d.scene` property set to `'scene2'` associates a
# scatter3d trace with the second `scene` subplot in the figure.
#
# There are a few subplot types that don't follow this pattern, and instead
# the trace property is just named `subplot`. For example setting
# the `scatterpolar.subplot` property to `polar3` associates the scatterpolar
# trace with the third polar subplot in the figure
_subplot_prop_named_subplot = {'polar', 'ternary', 'mapbox'}
# Mapping from trace type to subplot type(s).
_trace_to_subplot = {
# xaxis/yaxis
'bar': ['xaxis', 'yaxis'],
'box': ['xaxis', 'yaxis'],
'candlestick': ['xaxis', 'yaxis'],
'carpet': ['xaxis', 'yaxis'],
'contour': ['xaxis', 'yaxis'],
'contourcarpet': ['xaxis', 'yaxis'],
'heatmap': ['xaxis', 'yaxis'],
'heatmapgl': ['xaxis', 'yaxis'],
'histogram': ['xaxis', 'yaxis'],
'histogram2d': ['xaxis', 'yaxis'],
'histogram2dcontour': ['xaxis', 'yaxis'],
'ohlc': ['xaxis', 'yaxis'],
'pointcloud': ['xaxis', 'yaxis'],
'scatter': ['xaxis', 'yaxis'],
'scattercarpet': ['xaxis', 'yaxis'],
'scattergl': ['xaxis', 'yaxis'],
'violin': ['xaxis', 'yaxis'],
# scene
'cone': ['scene'],
'mesh3d': ['scene'],
'scatter3d': ['scene'],
'streamtube': ['scene'],
'surface': ['scene'],
# geo
'choropleth': ['geo'],
'scattergeo': ['geo'],
# polar
'barpolar': ['polar'],
'scatterpolar': ['polar'],
'scatterpolargl': ['polar'],
# ternary
'scatterternary': ['ternary'],
# mapbox
'scattermapbox': ['mapbox']
}
# trace types that support legends
legend_trace_types = {
'scatter',
'bar',
'box',
'histogram',
'histogram2dcontour',
'contour',
'scatterternary',
'violin',
'waterfall',
'pie',
'scatter3d',
'scattergeo',
'scattergl',
'splom',
'pointcloud',
'scattermapbox',
'scattercarpet',
'contourcarpet',
'ohlc',
'candlestick',
'scatterpolar',
'scatterpolargl',
'barpolar',
'area',
}
# Aliases - map common style options to more common names
STYLE_ALIASES = {'alpha': 'opacity',
'cell_height': 'height',
'marker': 'symbol',
"max_zoom": "maxzoom",
"min_zoom": "minzoom",}
# Regular expression to extract any trailing digits from a subplot-style
# string.
_subplot_re = re.compile(r'\D*(\d+)')
def _get_subplot_number(subplot_val):
"""
Extract the subplot number from a subplot value string.
'x3' -> 3
'polar2' -> 2
'scene' -> 1
'y' -> 1
Note: the absence of a subplot number (e.g. 'y') is treated by plotly as
a subplot number of 1
Parameters
----------
subplot_val: str
Subplot string value (e.g. 'scene4')
Returns
-------
int
"""
match = _subplot_re.match(subplot_val)
if match:
subplot_number = int(match.group(1))
else:
subplot_number = 1
return subplot_number
def _get_subplot_val_prefix(subplot_type):
"""
Get the subplot value prefix for a subplot type. For most subplot types
this is equal to the subplot type string itself. For example, a
`scatter3d.scene` value of `scene2` is used to associate the scatter3d
trace with the `layout.scene2` subplot.
However, the `xaxis`/`yaxis` subplot types are exceptions to this pattern.
For example, a `scatter.xaxis` value of `x2` is used to associate the
scatter trace with the `layout.xaxis2` subplot.
Parameters
----------
subplot_type: str
Subplot string value (e.g. 'scene4')
Returns
-------
str
"""
if subplot_type == 'xaxis':
subplot_val_prefix = 'x'
elif subplot_type == 'yaxis':
subplot_val_prefix = 'y'
else:
subplot_val_prefix = subplot_type
return subplot_val_prefix
def _get_subplot_prop_name(subplot_type):
"""
Get the name of the trace property used to associate a trace with a
particular subplot type. For most subplot types this is equal to the
subplot type string. For example, the `scatter3d.scene` property is used
to associate a `scatter3d` trace with a particular `scene` subplot.
However, for some subplot types the trace property is not named after the
subplot type. For example, the `scatterpolar.subplot` property is used
to associate a `scatterpolar` trace with a particular `polar` subplot.
Parameters
----------
subplot_type: str
Subplot string value (e.g. 'scene4')
Returns
-------
str
"""
if subplot_type in _subplot_prop_named_subplot:
subplot_prop_name = 'subplot'
else:
subplot_prop_name = subplot_type
return subplot_prop_name
def _normalize_subplot_ids(fig):
"""
Make sure a layout subplot property is initialized for every subplot that
is referenced by a trace in the figure.
For example, if a figure contains a `scatterpolar` trace with the `subplot`
property set to `polar3`, this function will make sure the figure's layout
has a `polar3` property, and will initialize it to an empty dict if it
does not
Note: This function mutates the input figure dict
Parameters
----------
fig: dict
A plotly figure dict
"""
layout = fig.setdefault('layout', {})
for trace in fig.get('data', None):
trace_type = trace.get('type', 'scatter')
subplot_types = _trace_to_subplot.get(trace_type, [])
for subplot_type in subplot_types:
subplot_prop_name = _get_subplot_prop_name(subplot_type)
subplot_val_prefix = _get_subplot_val_prefix(subplot_type)
subplot_val = trace.get(subplot_prop_name, subplot_val_prefix)
# extract trailing number (if any)
subplot_number = _get_subplot_number(subplot_val)
if subplot_number > 1:
layout_prop_name = subplot_type + str(subplot_number)
else:
layout_prop_name = subplot_type
if layout_prop_name not in layout:
layout[layout_prop_name] = {}
def _get_max_subplot_ids(fig):
"""
Given an input figure, return a dict containing the max subplot number
for each subplot type in the figure
Parameters
----------
fig: dict
A plotly figure dict
Returns
-------
dict
A dict from subplot type strings to integers indicating the largest
subplot number in the figure of that subplot type
"""
max_subplot_ids = {subplot_type: 0
for subplot_type in _subplot_types}
max_subplot_ids['xaxis'] = 0
max_subplot_ids['yaxis'] = 0
# Check traces
for trace in fig.get('data', []):
trace_type = trace.get('type', 'scatter')
subplot_types = _trace_to_subplot.get(trace_type, [])
for subplot_type in subplot_types:
subplot_prop_name = _get_subplot_prop_name(subplot_type)
subplot_val_prefix = _get_subplot_val_prefix(subplot_type)
subplot_val = trace.get(subplot_prop_name, subplot_val_prefix)
# extract trailing number (if any)
subplot_number = _get_subplot_number(subplot_val)
max_subplot_ids[subplot_type] = max(
max_subplot_ids[subplot_type], subplot_number)
# check annotations/shapes/images
layout = fig.get('layout', {})
for layout_prop in ['annotations', 'shapes', 'images']:
for obj in layout.get(layout_prop, []):
xref = obj.get('xref', 'x')
if xref != 'paper':
xref_number = _get_subplot_number(xref)
max_subplot_ids['xaxis'] = max(max_subplot_ids['xaxis'], xref_number)
yref = obj.get('yref', 'y')
if yref != 'paper':
yref_number = _get_subplot_number(yref)
max_subplot_ids['yaxis'] = max(max_subplot_ids['yaxis'], yref_number)
return max_subplot_ids
def _offset_subplot_ids(fig, offsets):
"""
Apply offsets to the subplot id numbers in a figure.
Note: This function mutates the input figure dict
Note: This function assumes that the normalize_subplot_ids function has
already been run on the figure, so that all layout subplot properties in
use are explicitly present in the figure's layout.
Parameters
----------
fig: dict
A plotly figure dict
offsets: dict
A dict from subplot types to the offset to be applied for each subplot
type. This dict matches the form of the dict returned by
get_max_subplot_ids
"""
# Offset traces
for trace in fig.get('data', None):
trace_type = trace.get('type', 'scatter')
subplot_types = _trace_to_subplot.get(trace_type, [])
for subplot_type in subplot_types:
subplot_prop_name = _get_subplot_prop_name(subplot_type)
# Compute subplot value prefix
subplot_val_prefix = _get_subplot_val_prefix(subplot_type)
subplot_val = trace.get(subplot_prop_name, subplot_val_prefix)
subplot_number = _get_subplot_number(subplot_val)
offset_subplot_number = (
subplot_number + offsets.get(subplot_type, 0))
if offset_subplot_number > 1:
trace[subplot_prop_name] = (
subplot_val_prefix + str(offset_subplot_number))
else:
trace[subplot_prop_name] = subplot_val_prefix
# layout subplots
layout = fig.setdefault('layout', {})
new_subplots = {}
for subplot_type in offsets:
offset = offsets[subplot_type]
if offset < 1:
continue
for layout_prop in list(layout.keys()):
if layout_prop.startswith(subplot_type):
subplot_number = _get_subplot_number(layout_prop)
new_subplot_number = subplot_number + offset
new_layout_prop = subplot_type + str(new_subplot_number)
new_subplots[new_layout_prop] = layout.pop(layout_prop)
layout.update(new_subplots)
# xaxis/yaxis anchors
x_offset = offsets.get('xaxis', 0)
y_offset = offsets.get('yaxis', 0)
for layout_prop in list(layout.keys()):
if layout_prop.startswith('xaxis'):
xaxis = layout[layout_prop]
anchor = xaxis.get('anchor', 'y')
anchor_number = _get_subplot_number(anchor) + y_offset
if anchor_number > 1:
xaxis['anchor'] = 'y' + str(anchor_number)
else:
xaxis['anchor'] = 'y'
elif layout_prop.startswith('yaxis'):
yaxis = layout[layout_prop]
anchor = yaxis.get('anchor', 'x')
anchor_number = _get_subplot_number(anchor) + x_offset
if anchor_number > 1:
yaxis['anchor'] = 'x' + str(anchor_number)
else:
yaxis['anchor'] = 'x'
# Axis matches references
for layout_prop in list(layout.keys()):
if layout_prop[1:5] == 'axis':
axis = layout[layout_prop]
matches_val = axis.get('matches', None)
if matches_val:
if matches_val[0] == 'x':
matches_number = _get_subplot_number(matches_val) + x_offset
elif matches_val[0] == 'y':
matches_number = _get_subplot_number(matches_val) + y_offset
else:
continue
suffix = str(matches_number) if matches_number > 1 else ""
axis['matches'] = matches_val[0] + suffix
# annotations/shapes/images
for layout_prop in ['annotations', 'shapes', 'images']:
for obj in layout.get(layout_prop, []):
if x_offset:
xref = obj.get('xref', 'x')
if xref != 'paper':
xref_number = _get_subplot_number(xref)
obj['xref'] = 'x' + str(xref_number + x_offset)
if y_offset:
yref = obj.get('yref', 'y')
if yref != 'paper':
yref_number = _get_subplot_number(yref)
obj['yref'] = 'y' + str(yref_number + y_offset)
def _scale_translate(fig, scale_x, scale_y, translate_x, translate_y):
"""
Scale a figure and translate it to sub-region of the original
figure canvas.
Note: If the input figure has a title, this title is converted into an
annotation and scaled along with the rest of the figure.
Note: This function mutates the input fig dict
Note: This function assumes that the normalize_subplot_ids function has
already been run on the figure, so that all layout subplot properties in
use are explicitly present in the figure's layout.
Parameters
----------
fig: dict
A plotly figure dict
scale_x: float
Factor by which to scale the figure in the x-direction. This will
typically be a value < 1. E.g. a value of 0.5 will cause the
resulting figure to be half as wide as the original.
scale_y: float
Factor by which to scale the figure in the y-direction. This will
typically be a value < 1
translate_x: float
Factor by which to translate the scaled figure in the x-direction in
normalized coordinates.
translate_y: float
Factor by which to translate the scaled figure in the x-direction in
normalized coordinates.
"""
data = fig.setdefault('data', [])
layout = fig.setdefault('layout', {})
def scale_translate_x(x):
return [min(x[0] * scale_x + translate_x, 1),
min(x[1] * scale_x + translate_x, 1)]
def scale_translate_y(y):
return [min(y[0] * scale_y + translate_y, 1),
min(y[1] * scale_y + translate_y, 1)]
def perform_scale_translate(obj):
domain = obj.setdefault('domain', {})
x = domain.get('x', [0, 1])
y = domain.get('y', [0, 1])
domain['x'] = scale_translate_x(x)
domain['y'] = scale_translate_y(y)
# Scale/translate traces
for trace in data:
trace_type = trace.get('type', 'scatter')
if trace_type in _domain_trace_types:
perform_scale_translate(trace)
# Scale/translate subplot containers
for prop in layout:
for subplot_type in _subplot_types:
if prop.startswith(subplot_type):
perform_scale_translate(layout[prop])
for prop in layout:
if prop.startswith('xaxis'):
xaxis = layout[prop]
x_domain = xaxis.get('domain', [0, 1])
xaxis['domain'] = scale_translate_x(x_domain)
elif prop.startswith('yaxis'):
yaxis = layout[prop]
y_domain = yaxis.get('domain', [0, 1])
yaxis['domain'] = scale_translate_y(y_domain)
# convert title to annotation
# This way the annotation will be scaled with the reset of the figure
annotations = layout.get('annotations', [])
title = layout.pop('title', None)
if title:
titlefont = layout.pop('titlefont', {})
title_fontsize = titlefont.get('size', 17)
min_fontsize = 12
titlefont['size'] = round(min_fontsize +
(title_fontsize - min_fontsize) * scale_x)
annotations.append({
'text': title,
'showarrow': False,
'xref': 'paper',
'yref': 'paper',
'x': 0.5,
'y': 1.01,
'xanchor': 'center',
'yanchor': 'bottom',
'font': titlefont
})
layout['annotations'] = annotations
# annotations
for obj in layout.get('annotations', []):
if obj.get('xref', None) == 'paper':
obj['x'] = obj.get('x', 0.5) * scale_x + translate_x
if obj.get('yref', None) == 'paper':
obj['y'] = obj.get('y', 0.5) * scale_y + translate_y
# shapes
for obj in layout.get('shapes', []):
if obj.get('xref', None) == 'paper':
obj['x0'] = obj.get('x0', 0.25) * scale_x + translate_x
obj['x1'] = obj.get('x1', 0.75) * scale_x + translate_x
if obj.get('yref', None) == 'paper':
obj['y0'] = obj.get('y0', 0.25) * scale_y + translate_y
obj['y1'] = obj.get('y1', 0.75) * scale_y + translate_y
# images
for obj in layout.get('images', []):
if obj.get('xref', None) == 'paper':
obj['x'] = obj.get('x', 0.5) * scale_x + translate_x
obj['sizex'] = obj.get('sizex', 0) * scale_x
if obj.get('yref', None) == 'paper':
obj['y'] = obj.get('y', 0.5) * scale_y + translate_y
obj['sizey'] = obj.get('sizey', 0) * scale_y
[docs]def merge_layout(obj, subobj):
"""
Merge layout objects recursively
Note: This function mutates the input obj dict, but it does not mutate
the subobj dict
Parameters
----------
obj: dict
dict into which the sub-figure dict will be merged
subobj: dict
dict that sill be copied and merged into `obj`
"""
for prop, val in subobj.items():
if isinstance(val, dict) and prop in obj:
# recursion
merge_layout(obj[prop], val)
elif (isinstance(val, list) and
obj.get(prop, None) and
isinstance(obj[prop][0], dict)):
# append
obj[prop].extend(val)
elif prop == "style" and val == "white-bg" and obj.get("style", None):
# Handle special cases
# Don't let layout.mapbox.style of "white-bg" override other
# background
pass
elif val is not None:
# init/overwrite
obj[prop] = copy.deepcopy(val)
def _compute_subplot_domains(widths, spacing):
"""
Compute normalized domain tuples for a list of widths and a subplot
spacing value
Parameters
----------
widths: list of float
List of the desired widths of each subplot. The length of this list
is also the specification of the number of desired subplots
spacing: float
Spacing between subplots in normalized coordinates
Returns
-------
list of tuple of float
"""
# normalize widths
widths_sum = float(sum(widths))
total_spacing = (len(widths) - 1) * spacing
total_width = widths_sum + total_spacing
relative_spacing = spacing / (widths_sum + total_spacing)
relative_widths = [(w / total_width) for w in widths]
domains = []
for c in range(len(widths)):
domain_start = c * relative_spacing + sum(relative_widths[:c])
domain_stop = min(1, domain_start + relative_widths[c])
domains.append((domain_start, domain_stop))
return domains
[docs]def get_colorscale(cmap, levels=None, cmin=None, cmax=None):
"""Converts a cmap spec to a plotly colorscale
Args:
cmap: A recognized colormap by name or list of colors
levels: A list or integer declaring the color-levels
cmin: The lower bound of the color range
cmax: The upper bound of the color range
Returns:
A valid plotly colorscale
"""
ncolors = levels if isinstance(levels, int) else None
if isinstance(levels, list):
ncolors = len(levels) - 1
if isinstance(cmap, list) and len(cmap) != ncolors:
raise ValueError('The number of colors in the colormap '
'must match the intervals defined in the '
'color_levels, expected %d colors found %d.'
% (ncolors, len(cmap)))
try:
palette = process_cmap(cmap, ncolors)
except Exception as e:
colorscale = colors.PLOTLY_SCALES.get(cmap)
if colorscale is None:
raise e
return colorscale
if isinstance(levels, int):
colorscale = []
scale = np.linspace(0, 1, levels+1)
for i in range(levels+1):
if i == 0:
colorscale.append((scale[0], palette[i]))
elif i == levels:
colorscale.append((scale[-1], palette[-1]))
else:
colorscale.append((scale[i], palette[i-1]))
colorscale.append((scale[i], palette[i]))
return colorscale
elif isinstance(levels, list):
palette, (cmin, cmax) = color_intervals(
palette, levels, clip=(cmin, cmax))
return colors.make_colorscale(palette)