from __future__ import absolute_import
import matplotlib.pyplot as plt
try:
import seaborn as sns
except:
sns = None
import param
from ...interface.pandas import DFrame, DataFrameView
from ...interface.seaborn import Regression, TimeSeries, Bivariate, Distribution
from ...interface.seaborn import DFrame as SNSFrame
from ...core.options import Store
from ...core import config
from .element import ElementPlot
from .pandas import DFrameViewPlot
from .plot import MPLPlot, AdjoinedPlot, mpl_rc_context
[docs]class SeabornPlot(ElementPlot):
"""
SeabornPlot provides an abstract baseclass, defining an
update_frame method, which completely wipes the axis and
redraws the plot.
"""
aspect = param.Parameter(default='square', doc="""
Aspect ratio defaults to square, 'equal' or numeric values
are also supported.""")
show_grid = param.Boolean(default=True, doc="""
Enables the axis grid.""")
_abstract = True
def teardown_handles(self):
if self.zorder == 0:
self.handles['axis'].cla()
[docs]class RegressionPlot(SeabornPlot):
"""
RegressionPlot visualizes Regression Views using the Seaborn
regplot interface, allowing the user to perform and plot
linear regressions on a set of scatter points. Parameters
to the replot function can be supplied via the opts magic.
"""
style_opts = ['x_estimator', 'x_bins', 'x_ci', 'scatter',
'fit_reg', 'color', 'n_boot', 'order',
'logistic', 'lowess', 'robust', 'truncate',
'scatter_kws', 'line_kws', 'ci', 'dropna',
'x_jitter', 'y_jitter', 'x_partial', 'y_partial']
def init_artists(self, ax, plot_data, plot_kwargs):
plot_kwargs.pop('zorder')
return {'axis': sns.regplot(*plot_data, ax=ax, **plot_kwargs)}
def get_data(self, element, ranges, style):
xs, ys = (element[d] for d in element.dimensions()[:2])
return (xs, ys), style, {}
[docs]class BivariatePlot(SeabornPlot):
"""
Bivariate plot visualizes two-dimensional kernel density
estimates using the Seaborn kdeplot function. Additionally,
by enabling the joint option, the marginals distributions
can be plotted alongside each axis (does not animate or
compose).
"""
joint = param.Boolean(default=False, doc="""
Whether to visualize the kernel density estimate with marginal
distributions along each axis. Does not animate or compose
when enabled.""")
style_opts = ['color', 'alpha', 'err_style', 'interpolate',
'ci', 'kind', 'bw', 'kernel', 'cumulative',
'shade', 'vertical', 'cmap']
def init_artists(self, ax, plot_data, plot_kwargs):
if self.joint:
if self.joint and self.subplot:
raise Exception("Joint plots can't be animated or laid out in a grid.")
return {'fig': sns.jointplot(*plot_data, **plot_kwargs).fig}
else:
return {'axis': sns.kdeplot(*plot_data, ax=ax, **plot_kwargs)}
def get_data(self, element, ranges, style):
xs, ys = (element[d] for d in element.dimensions()[:2])
if self.joint:
style.pop('cmap', None)
style.pop('zorder', None)
return (xs, ys), style, {}
[docs]class TimeSeriesPlot(SeabornPlot):
"""
TimeSeries visualizes sets of curves using the Seaborn
tsplot function. This provides functionality to plot
error bars with various styles alongside the averaged
curve.
"""
show_legend = param.Boolean(default=True, doc="""
Whether to show legend for the plot.""")
style_opts = ['color', 'alpha', 'err_style', 'interpolate',
'ci', 'n_boot', 'err_kws', 'err_palette',
'estimator', 'kwargs']
def get_data(self, element, ranges, style):
style.pop('zorder', None)
if 'label' in style:
style['condition'] = style.pop('label')
axis_kwargs = {'xlabel': element.kdims[0].pprint_label,
'ylabel': element.vdims[0].pprint_label}
return (element.data, element.xdata), style, axis_kwargs
def init_artists(self, ax, plot_data, plot_kwargs):
return {'axis': sns.tsplot(*plot_data, ax=ax, **plot_kwargs)}
[docs]class DistributionPlot(SeabornPlot):
"""
DistributionPlot visualizes Distribution Views using the
Seaborn distplot function. This allows visualizing a 1D
array as a histogram, kernel density estimate, or rugplot.
"""
apply_ranges = param.Boolean(default=False, doc="""
Whether to compute the plot bounds from the data itself.""")
style_opts = ['bins', 'hist', 'kde', 'rug', 'fit', 'hist_kws',
'kde_kws', 'rug_kws', 'fit_kws', 'color']
def get_data(self, element, ranges, style):
style.pop('zorder', None)
if self.invert_axes:
style['vertical'] = True
vdim = element.vdims[0]
axis_kwargs = dict(dimensions=[vdim])
return (element.dimension_values(vdim),), style, axis_kwargs
def init_artists(self, ax, plot_data, plot_kwargs):
return {'axis': sns.distplot(*plot_data, ax=ax, **plot_kwargs)}
class SideDistributionPlot(AdjoinedPlot, DistributionPlot):
border_size = param.Number(default=0.2, doc="""
The size of the border expressed as a fraction of the main plot.""")
[docs]class SNSFramePlot(DFrameViewPlot):
"""
SNSFramePlot takes an SNSFrame as input and plots the
contained data using the set plot_type. This largely mirrors
the way DFramePlot works, however, since most Seaborn plot
types plot one dimension against another it uses the x and y
parameters, which can be set on the SNSFrame.
"""
plot_type = param.ObjectSelector(default='scatter_matrix',
objects=['interact', 'regplot',
'lmplot', 'corrplot',
'plot', 'boxplot',
'hist', 'scatter_matrix',
'autocorrelation_plot',
'pairgrid', 'facetgrid',
'pairplot', 'violinplot',
'factorplot'
],
doc="""
Selects which Seaborn plot type to use, when visualizing the
SNSFrame. The options that can be passed to the plot_type are
defined in dframe_options. Valid options are 'interact', 'regplot',
'lmplot', 'corrplot', 'plot', 'boxplot', 'hist', 'scatter_matrix',
'autocorrelation_plot', 'pairgrid', 'facetgrid', 'pairplot',
'violinplot' and 'factorplot'""")
dframe_options = dict(DFrameViewPlot.dframe_options,
**{'regplot': RegressionPlot.style_opts,
'factorplot': ['kind', 'col', 'aspect', 'row',
'col_wrap', 'ci', 'linestyles',
'markers', 'palette', 'dodge',
'join', 'size', 'legend',
'sharex', 'sharey', 'hue', 'estimator'],
'boxplot': ['order', 'hue_order', 'orient', 'color',
'palette', 'saturation', 'width', 'fliersize',
'linewidth', 'whis', 'notch'],
'violinplot':['groupby', 'positions',
'inner', 'join_rm', 'bw', 'cut', 'split'],
'lmplot': ['hue', 'col', 'row', 'palette',
'sharex', 'dropna', 'legend'],
'interact': ['filled', 'cmap', 'colorbar',
'levels', 'logistic', 'contour_kws',
'scatter_kws'],
'pairgrid': ['hue', 'hue_order', 'palette',
'hue_kws', 'vars', 'x_vars', 'y_vars'
'size', 'aspect', 'despine', 'map',
'map_diag', 'map_offdiag',
'map_upper', 'map_lower'],
'pairplot': ['hue', 'hue_order', 'palette',
'vars', 'x_vars', 'y_vars', 'diag_kind',
'kind', 'plot_kws', 'diag_kws', 'grid_kws'],
'facetgrid': ['hue', 'row', 'col', 'col_wrap',
'map', 'sharex', 'sharey', 'size',
'aspect', 'palette', 'row_order',
'col_order', 'hue_order', 'legend',
'legend_out', 'xlim', 'ylim', 'despine'],
})
style_opts = list({opt for opts in dframe_options.values() for opt in opts})
def __init__(self, view, **params):
if self.plot_type in ['pairgrid', 'pairplot', 'facetgrid']:
self._create_fig = False
super(SNSFramePlot, self).__init__(view, **params)
@mpl_rc_context
def initialize_plot(self, ranges=None):
dfview = self.hmap.last
axis = self.handles['axis']
self._validate(dfview)
style = self._process_style(self.style[self.cyclic_index])
self._update_plot(axis, dfview, style)
if 'fig' in self.handles and self.handles['fig'] != plt.gcf():
self.handles['fig'] = plt.gcf()
return self._finalize_axis(self.keys[-1], element=dfview)
def _process_style(self, styles):
styles = super(SNSFramePlot, self)._process_style(styles)
if self.plot_type not in DFrameViewPlot.params()['plot_type'].objects:
styles.pop('figsize', None)
return styles
def _validate(self, dfview):
super(SNSFramePlot, self)._validate(dfview)
multi_dim = dfview.ndims > 1
if self.subplot and multi_dim and self.plot_type == 'lmplot':
raise Exception("Multiple %s plots cannot be composed."
% self.plot_type)
@mpl_rc_context
def update_frame(self, key, ranges=None):
element = self.hmap.get(key, None)
axis = self.handles['axis']
if axis:
axis.set_visible(element is not None)
style = dict(zorder=self.zorder, **self.style[self.cyclic_index])
if self.show_legend:
style['label'] = element.label
axis_kwargs = self.update_handles(key, axis, element, ranges, style)
if axis:
self._finalize_axis(key, element=element, **(axis_kwargs if axis_kwargs else {}))
def _update_plot(self, axis, view, style):
style.pop('zorder', None)
if self.plot_type == 'factorplot':
opts = dict(style, **({'hue': view.x2} if view.x2 else {}))
sns.factorplot(x=view.x, y=view.y, data=view.data, **opts)
elif self.plot_type == 'regplot':
sns.regplot(x=view.x, y=view.y, data=view.data,
ax=axis, **style)
elif self.plot_type == 'boxplot':
style.pop('return_type', None)
style.pop('figsize', None)
sns.boxplot(view.data[view.y], view.data[view.x], ax=axis,
**style)
elif self.plot_type == 'violinplot':
if view.x:
sns.violinplot(view.data[view.y], view.data[view.x], ax=axis,
**style)
else:
sns.violinplot(view.data, ax=axis, **style)
elif self.plot_type == 'interact':
sns.interactplot(view.x, view.x2, view.y,
data=view.data, ax=axis, **style)
elif self.plot_type == 'lmplot':
sns.lmplot(x=view.x, y=view.y, data=view.data,
ax=axis, **style)
elif self.plot_type in ['pairplot', 'pairgrid', 'facetgrid']:
style_keys = list(style.keys())
map_opts = [(k, style.pop(k)) for k in style_keys if 'map' in k]
if self.plot_type == 'pairplot':
g = sns.pairplot(view.data, **style)
elif self.plot_type == 'pairgrid':
g = sns.PairGrid(view.data, **style)
elif self.plot_type == 'facetgrid':
g = sns.FacetGrid(view.data, **style)
for opt, args in map_opts:
plot_fn = getattr(sns, args[0]) if hasattr(sns, args[0]) else getattr(plt, args[0])
getattr(g, opt)(plot_fn, *args[1:])
if self._close_figures:
plt.close(self.handles['fig'])
self.handles['fig'] = plt.gcf()
else:
super(SNSFramePlot, self)._update_plot(axis, view, style)
Store.register({TimeSeries: TimeSeriesPlot,
Bivariate: BivariatePlot,
Distribution: DistributionPlot,
Regression: RegressionPlot,
SNSFrame: SNSFramePlot,
DFrame: SNSFramePlot,
DataFrameView: SNSFramePlot}, 'matplotlib')
MPLPlot.sideplots.update({Distribution: SideDistributionPlot})
if config.style_17:
for framelesscls in [TimeSeriesPlot, DistributionPlot]:
framelesscls.show_frame = False