import numpy as np
import param
from matplotlib import cm
from mpl_toolkits.mplot3d.art3d import Line3DCollection
from packaging.version import Version
from ...core import Dimension
from ...core.options import abbreviated_exception
from ...util.transform import dim as dim_expr
from ..util import map_colors
from .chart import PointPlot
from .element import ColorbarPlot
from .path import PathPlot
from .util import mpl_version
[docs]class Plot3D(ColorbarPlot):
"""
Plot3D provides a common baseclass for mplot3d based
plots.
"""
azimuth = param.Integer(default=-60, bounds=(-180, 180), doc="""
Azimuth angle in the x,y plane.""")
elevation = param.Integer(default=30, bounds=(0, 180), doc="""
Elevation angle in the z-axis.""")
distance = param.Integer(default=10, bounds=(7, 15), doc="""
Distance from the plotted object.""")
disable_axes = param.Boolean(default=False, doc="""
Disable all axes.""")
bgcolor = param.String(default='white', doc="""
Background color of the axis.""")
labelled = param.List(default=['x', 'y', 'z'], doc="""
Whether to plot the 'x', 'y' and 'z' labels.""")
projection = param.ObjectSelector(default='3d', objects=['3d'], doc="""
The projection of the matplotlib axis.""")
show_grid = param.Boolean(default=True, doc="""
Whether to draw a grid in the figure.""")
xaxis = param.ObjectSelector(default='fixed',
objects=['fixed', None], doc="""
Whether and where to display the xaxis.""")
yaxis = param.ObjectSelector(default='fixed',
objects=['fixed', None], doc="""
Whether and where to display the yaxis.""")
zaxis = param.ObjectSelector(default='fixed',
objects=['fixed', None], doc="""
Whether and where to display the yaxis.""")
def _finalize_axis(self, key, **kwargs):
"""
Extends the ElementPlot _finalize_axis method to set appropriate
labels, and axes options for 3D Plots.
"""
axis = self.handles['axis']
self.handles['fig'].set_frameon(False)
axis.grid(self.show_grid)
axis.view_init(elev=self.elevation, azim=self.azimuth)
try:
axis._dist = self.distance
except Exception:
# axis.dist is deprecated see here:
# https://github.com/matplotlib/matplotlib/pull/22084
axis.dist = self.distance
if self.xaxis is None:
axis.w_xaxis.line.set_lw(0.)
axis.w_xaxis.label.set_text('')
if self.yaxis is None:
axis.w_yaxis.line.set_lw(0.)
axis.w_yaxis.label.set_text('')
if self.zaxis is None:
axis.w_zaxis.line.set_lw(0.)
axis.w_zaxis.label.set_text('')
if self.disable_axes:
axis.set_axis_off()
if mpl_version <= Version('1.5.9'):
axis.set_axis_bgcolor(self.bgcolor)
else:
axis.set_facecolor(self.bgcolor)
return super()._finalize_axis(key, **kwargs)
def _draw_colorbar(self, element=None, dim=None, redraw=True):
if element is None:
element = self.hmap.last
artist = self.handles.get('artist', None)
fig = self.handles['fig']
ax = self.handles['axis']
# Get colorbar label
if isinstance(dim, dim_expr):
dim = dim.dimension
if dim is None:
if hasattr(self, 'color_index'):
dim = element.get_dimension(self.color_index)
else:
dim = element.get_dimension(2)
elif not isinstance(dim, Dimension):
dim = element.get_dimension(dim)
label = dim.pprint_label
cbar = fig.colorbar(artist, shrink=0.7, ax=ax)
self.handles['cbar'] = cbar
self.handles['cax'] = cbar.ax
self._adjust_cbar(cbar, label, dim)
[docs]class Scatter3DPlot(Plot3D, PointPlot):
"""
Subclass of PointPlot allowing plotting of Points
on a 3D axis, also allows mapping color and size
onto a particular Dimension of the data.
"""
color_index = param.ClassSelector(default=None, class_=(str, int),
allow_None=True, doc="""
Index of the dimension from which the color will the drawn""")
size_index = param.ClassSelector(default=None, class_=(str, int),
allow_None=True, doc="""
Index of the dimension from which the sizes will the drawn.""")
_plot_methods = dict(single='scatter')
def get_data(self, element, ranges, style):
xs, ys, zs = (element.dimension_values(i) for i in range(3))
self._compute_styles(element, ranges, style)
with abbreviated_exception():
style = self._apply_transforms(element, ranges, style)
if style.get('edgecolors') == 'none':
style.pop('edgecolors')
return (xs, ys, zs), style, {}
[docs] def update_handles(self, key, axis, element, ranges, style):
artist = self.handles['artist']
artist._offsets3d, style, _ = self.get_data(element, ranges, style)
cdim = element.get_dimension(self.color_index)
if cdim and 'cmap' in style:
clim = style['vmin'], style['vmax']
cmap = cm.get_cmap(style['cmap'])
artist._facecolor3d = map_colors(style['c'], clim, cmap, hex=False)
if element.get_dimension(self.size_index):
artist.set_sizes(style['s'])
[docs]class Path3DPlot(Plot3D, PathPlot):
"""
Allows plotting paths on a 3D axis.
"""
style_opts = ['alpha', 'color', 'linestyle', 'linewidth', 'visible', 'cmap']
def get_data(self, element, ranges, style):
paths = element.split(datatype='array', dimensions=element.kdims)
if self.invert_axes:
paths = [p[:, ::-1] for p in paths]
with abbreviated_exception():
style = self._apply_transforms(element, ranges, style)
if 'c' in style:
style['array'] = style.pop('c')
if isinstance(style.get('color'), np.ndarray):
style['colors'] = style.pop('color')
if 'vmin' in style:
style['clim'] = style.pop('vmin', None), style.pop('vmax', None)
return (paths,), style, {}
[docs] def init_artists(self, ax, plot_args, plot_kwargs):
line_segments = Line3DCollection(*plot_args, **plot_kwargs)
ax.add_collection(line_segments)
return {'artist': line_segments}
[docs] def update_handles(self, key, axis, element, ranges, style):
PathPlot.update_handles(self, key, axis, element, ranges, style)
[docs]class SurfacePlot(Plot3D):
"""
Plots surfaces wireframes and contours in 3D space.
Provides options to switch the display type via the
plot_type parameter has support for a number of
styling options including strides and colors.
"""
colorbar = param.Boolean(default=False, doc="""
Whether to add a colorbar to the plot.""")
plot_type = param.ObjectSelector(default='surface',
objects=['surface', 'wireframe',
'contour'], doc="""
Specifies the type of visualization for the Surface object.
Valid values are 'surface', 'wireframe' and 'contour'.""")
style_opts = ['antialiased', 'cmap', 'color', 'shade',
'linewidth', 'facecolors', 'rstride', 'cstride',
'norm', 'edgecolor', 'rcount', 'ccount']
[docs] def init_artists(self, ax, plot_data, plot_kwargs):
if self.plot_type == "wireframe":
artist = ax.plot_wireframe(*plot_data, **plot_kwargs)
elif self.plot_type == "surface":
artist = ax.plot_surface(*plot_data, **plot_kwargs)
elif self.plot_type == "contour":
artist = ax.contour3D(*plot_data, **plot_kwargs)
return {'artist': artist}
def get_data(self, element, ranges, style):
zdata = element.dimension_values(2, flat=False)
data = np.ma.array(zdata, mask=np.logical_not(np.isfinite(zdata)))
coords = [element.interface.coords(element, d, ordered=True,
expanded=True)
for d in element.kdims]
if self.invert_axes:
coords = coords[::-1]
data = data.T
cmesh_data = coords + [data]
if self.plot_type != 'wireframe' and 'cmap' in style:
self._norm_kwargs(element, ranges, style, element.vdims[0])
return cmesh_data, style, {}
[docs]class TriSurfacePlot(Plot3D):
"""
Plots a trisurface given a TriSurface element, containing
X, Y and Z coordinates.
"""
colorbar = param.Boolean(default=False, doc="""
Whether to add a colorbar to the plot.""")
style_opts = ['cmap', 'color', 'shade', 'linewidth', 'edgecolor',
'norm']
_plot_methods = dict(single='plot_trisurf')
def get_data(self, element, ranges, style):
dims = element.dimensions()
self._norm_kwargs(element, ranges, style, dims[2])
x, y, z = (element.dimension_values(d) for d in dims)
return (x, y, z), style, {}