Source code for holoviews.plotting.seaborn

from __future__ import absolute_import

import matplotlib.pyplot as plt

try:
    import seaborn.apionly as sns
except:
    sns = None

import param

from ..interface.pandas import DFrame, DataFrameView
from ..interface.seaborn import Regression, TimeSeries, Bivariate, Distribution
from ..interface.seaborn import DFrame as SNSFrame
from ..core.options import Store
from .element import ElementPlot
from .pandas import DFrameViewPlot


[docs]class FullRedrawPlot(ElementPlot): """ FullRedrawPlot provides an abstract baseclass, defining an update_frame method, which completely wipes the axis and redraws the plot. """ apply_databounds = param.Boolean(default=False, doc=""" Enables computing the plot bounds from the data itself. Disabled by default since data is often preprocessed, before display, changing the bounds.""") aspect = param.Parameter(default='square', doc=""" Aspect ratio defaults to square, 'equal' or numeric values are also supported.""") show_grid = param.Boolean(default=True, doc=""" Enables the axis grid.""") _abstract = True def update_handles(self, axis, view, key, ranges=None): if self.zorder == 0 and axis: axis.cla() self._update_plot(axis, view)
[docs]class RegressionPlot(FullRedrawPlot): """ RegressionPlot visualizes Regression Views using the Seaborn regplot interface, allowing the user to perform and plot linear regressions on a set of scatter points. Parameters to the replot function can be supplied via the opts magic. """ style_opts = ['x_estimator', 'x_bins', 'x_ci', 'scatter', 'fit_reg', 'color', 'n_boot', 'order', 'logistic', 'lowess', 'robust', 'truncate', 'scatter_kws', 'line_kws', 'ci', 'dropna', 'x_jitter', 'y_jitter', 'x_partial', 'y_partial'] def __call__(self, ranges=None): self._update_plot(self.handles['axis'], self.map.last) return self._finalize_axis(self.keys[-1]) def _update_plot(self, axis, view): label = view.label if self.overlaid == 1 else '' sns.regplot(view.data[:, 0], view.data[:, 1], ax=axis, label=label, **self.style[self.cyclic_index])
[docs]class BivariatePlot(FullRedrawPlot): """ Bivariate plot visualizes two-dimensional kernel density estimates using the Seaborn kdeplot function. Additionally, by enabling the joint option, the marginals distributions can be plotted alongside each axis (does not animate or compose). """ joint = param.Boolean(default=False, doc=""" Whether to visualize the kernel density estimate with marginal distributions along each axis. Does not animate or compose when enabled.""") style_opts = ['color', 'alpha', 'err_style', 'interpolate', 'ci', 'kind', 'bw', 'kernel', 'cumulative', 'shade', 'vertical', 'cmap'] def __call__(self, ranges=None): kdeview = self.map.last axis = self.handles['axis'] self.style = self.style[self.cyclic_index] if self.joint and self.subplot: raise Exception("Joint plots can't be animated or laid out in a grid.") self._update_plot(axis, kdeview) return self._finalize_axis(self.keys[-1]) def _update_plot(self, axis, view): if self.joint: self.style.pop('cmap', None) self.handles['fig'] = sns.jointplot(view.data[:,0], view.data[:,1], **self.style).fig else: label = view.label if self.overlaid == 1 else '' sns.kdeplot(view.data, ax=axis, label=label, zorder=self.zorder, **self.style)
[docs]class TimeSeriesPlot(FullRedrawPlot): """ TimeSeries visualizes sets of curves using the Seaborn tsplot function. This provides functionality to plot error bars with various styles alongside the averaged curve. """ show_frame = param.Boolean(default=False, doc=""" Disabled by default for clarity.""") show_legend = param.Boolean(default=True, doc=""" Whether to show legend for the plot.""") style_opts = ['color', 'alpha', 'err_style', 'interpolate', 'ci', 'n_boot', 'err_kws', 'err_palette', 'estimator', 'kwargs'] def __call__(self, ranges=None): element = self.map.last axis = self.handles['axis'] self.style = self.style[self.cyclic_index] self._update_plot(axis, element) return self._finalize_axis(self.keys[-1]) def _update_plot(self, axis, view): sns.tsplot(view.data, view.xdata, ax=axis, condition=view.label, zorder=self.zorder, **self.style) def _axis_labels(self, view, subplots, xlabel, ylabel, zlabel): xlabel = xlabel if xlabel else str(view.kdims[0]) ylabel = ylabel if ylabel else str(view.vdims[0]) return xlabel, ylabel, zlabel
[docs]class DistributionPlot(FullRedrawPlot): """ DistributionPlot visualizes Distribution Views using the Seaborn distplot function. This allows visualizing a 1D array as a histogram, kernel density estimate, or rugplot. """ apply_ranges = param.Boolean(default=False, doc=""" Whether to compute the plot bounds from the data itself.""") show_frame = param.Boolean(default=False, doc=""" Disabled by default for clarity.""") style_opts = ['bins', 'hist', 'kde', 'rug', 'fit', 'hist_kws', 'kde_kws', 'rug_kws', 'fit_kws', 'color'] def __call__(self, ranges=None): distview = self.map.last axis = self.handles['axis'] self.style = self.style[self.cyclic_index] self._update_plot(axis, distview) return self._finalize_axis(self.keys[-1]) def _update_plot(self, axis, view): label = view.label if self.overlaid == 1 else '' sns.distplot(view.data, ax=axis, label=label, **self.style)
[docs]class SNSFramePlot(DFrameViewPlot): """ SNSFramePlot takes an SNSFrame as input and plots the contained data using the set plot_type. This largely mirrors the way DFramePlot works, however, since most Seaborn plot types plot one dimension against another it uses the x and y parameters, which can be set on the SNSFrame. """ plot_type = param.ObjectSelector(default='scatter_matrix', objects=['interact', 'regplot', 'lmplot', 'corrplot', 'plot', 'boxplot', 'hist', 'scatter_matrix', 'autocorrelation_plot', 'pairgrid', 'facetgrid', 'pairplot', 'violinplot', 'factorplot' ], doc=""" Selects which Seaborn plot type to use, when visualizing the SNSFrame. The options that can be passed to the plot_type are defined in dframe_options.""") dframe_options = dict(DFrameViewPlot.dframe_options, **{'regplot': RegressionPlot.style_opts, 'factorplot': ['kind', 'col', 'aspect', 'row', 'col_wrap', 'ci', 'linestyles', 'markers', 'palette', 'dodge', 'join', 'size', 'legend', 'sharex', 'sharey', 'hue', 'estimator'], 'boxplot': [], 'violinplot':['groupby', 'positions', 'inner', 'join_rm', 'bw', 'cut'], 'lmplot': ['hue', 'col', 'row', 'palette', 'sharex', 'dropna', 'legend'], 'corrplot': ['annot', 'sig_stars', 'sig_tail', 'sig_corr', 'cmap', 'cmap_range', 'cbar'], 'interact': ['filled', 'cmap', 'colorbar', 'levels', 'logistic', 'contour_kws', 'scatter_kws'], 'pairgrid': ['hue', 'hue_order', 'palette', 'hue_kws', 'vars', 'x_vars', 'y_vars' 'size', 'aspect', 'despine', 'map', 'map_diag', 'map_offdiag', 'map_upper', 'map_lower'], 'pairplot': ['hue', 'hue_order', 'palette', 'vars', 'x_vars', 'y_vars', 'diag_kind', 'kind', 'plot_kws', 'diag_kws', 'grid_kws'], 'facetgrid': ['hue', 'row', 'col', 'col_wrap', 'map', 'sharex', 'sharey', 'size', 'aspect', 'palette', 'row_order', 'col_order', 'hue_order', 'legend', 'legend_out', 'xlim', 'ylim', 'despine'], }) style_opts = list({opt for opts in dframe_options.values() for opt in opts}) def __init__(self, view, **params): if self.plot_type in ['pairgrid', 'pairplot', 'facetgrid']: self._create_fig = False super(SNSFramePlot, self).__init__(view, **params) def __call__(self, ranges=None): dfview = self.map.last axis = self.handles['axis'] self._validate(dfview) self._update_plot(axis, dfview) if 'fig' in self.handles and self.handles['fig'] != plt.gcf(): self.handles['fig'] = plt.gcf() return self._finalize_axis(self.keys[-1]) def _process_style(self, styles): styles = super(SNSFramePlot, self)._process_style(styles) if self.plot_type not in DFrameViewPlot.params()['plot_type'].objects: styles.pop('figsize', None) return styles def _validate(self, dfview): super(SNSFramePlot, self)._validate(dfview) multi_dim = dfview.ndims > 1 if self.subplot and multi_dim and self.plot_type == 'lmplot': raise Exception("Multiple %s plots cannot be composed." % self.plot_type) def update_frame(self, key, ranges=None): view = self.map.get(key, None) axis = self.handles['axis'] if axis: axis.set_visible(view is not None) axis_kwargs = self.update_handles(axis, view, key, ranges) if axis: self._finalize_axis(key, **(axis_kwargs if axis_kwargs else {})) def _update_plot(self, axis, view): style = self._process_style(self.style[self.cyclic_index]) if self.plot_type == 'factorplot': opts = dict(style, **({'hue': view.x2} if view.x2 else {})) sns.factorplot(x=view.x, y=view.y, data=view.data, **opts) elif self.plot_type == 'regplot': sns.regplot(x=view.x, y=view.y, data=view.data, ax=axis, **style) elif self.plot_type == 'boxplot': style.pop('return_type', None) style.pop('figsize', None) sns.boxplot(view.data[view.y], view.data[view.x], ax=axis, **style) elif self.plot_type == 'violinplot': if view.x: sns.violinplot(view.data[view.y], view.data[view.x], ax=axis, **style) else: sns.violinplot(view.data, ax=axis, **style) elif self.plot_type == 'interact': sns.interactplot(view.x, view.x2, view.y, data=view.data, ax=axis, **style) elif self.plot_type == 'corrplot': sns.corrplot(view.data, ax=axis, **style) elif self.plot_type == 'lmplot': sns.lmplot(x=view.x, y=view.y, data=view.data, ax=axis, **style) elif self.plot_type in ['pairplot', 'pairgrid', 'facetgrid']: style_keys = list(style.keys()) map_opts = [(k, style.pop(k)) for k in style_keys if 'map' in k] if self.plot_type == 'pairplot': g = sns.pairplot(view.data, **style) elif self.plot_type == 'pairgrid': g = sns.PairGrid(view.data, **style) elif self.plot_type == 'facetgrid': g = sns.FacetGrid(view.data, **style) for opt, args in map_opts: plot_fn = getattr(sns, args[0]) if hasattr(sns, args[0]) else getattr(plt, args[0]) getattr(g, opt)(plot_fn, *args[1:]) plt.close(self.handles['fig']) self.handles['fig'] = plt.gcf() else: super(SNSFramePlot, self)._update_plot(axis, view)
Store.registry.update({TimeSeries: TimeSeriesPlot, Bivariate: BivariatePlot, Distribution: DistributionPlot, Regression: RegressionPlot, SNSFrame: SNSFramePlot, DFrame: SNSFramePlot, DataFrameView: SNSFramePlot})