import inspect
import re
import warnings

import matplotlib as mpl
import numpy as np
from matplotlib import (
    units as munits,
from matplotlib.colors import Normalize, cnames
from matplotlib.lines import Line2D
from matplotlib.markers import MarkerStyle
from matplotlib.patches import Path, PathPatch
from matplotlib.rcsetup import validate_fontsize, validate_fonttype, validate_hatch
from matplotlib.transforms import Affine2D, Bbox, TransformedBbox
from packaging.version import Version

try:  # starting Matplotlib 3.4.0
    from matplotlib._enums import (
        CapStyle as validate_capstyle,
        JoinStyle as validate_joinstyle,
except ImportError:  # before Matplotlib 3.4.0
    from matplotlib.rcsetup import validate_capstyle, validate_joinstyle

    from nc_time_axis import CalendarDateTime, NetCDFTimeConverter
    nc_axis_available = True
except ImportError:
    from matplotlib.dates import DateConverter
    NetCDFTimeConverter = DateConverter
    nc_axis_available = False

from ...core.util import arraylike_types, cftime_types, is_number
from ...element import RGB, Polygons, Raster
from ..util import COLOR_ALIASES, RGB_HEX_REGEX

mpl_version = Version(mpl.__version__)

[docs]def is_color(color): """ Checks if supplied object is a valid color spec. """ if not isinstance(color, str): return False elif RGB_HEX_REGEX.match(color): return True elif color in COLOR_ALIASES: return True elif color in cnames: return True return False
validators = { 'alpha': lambda x: is_number(x) and (0 <= x <= 1), 'capstyle': validate_capstyle, 'color': is_color, 'fontsize': validate_fontsize, 'fonttype': validate_fonttype, 'hatch': validate_hatch, 'joinstyle': validate_joinstyle, 'marker': lambda x: ( x in Line2D.markers or isinstance(x, (MarkerStyle, Path)) or (isinstance(x, str) and x.startswith('$') and x.endswith('$')) ), 's': lambda x: is_number(x) and (x >= 0) } def get_old_rcparams(): deprecated_rcparams = [ 'text.latex.unicode', '', 'savefig.frameon', # deprecated in MPL 3.1, to be removed in 3.3 'verbose.level', # deprecated in MPL 3.1, to be removed in 3.3 'verbose.fileo', # deprecated in MPL 3.1, to be removed in 3.3 'datapath', # deprecated in MPL 3.2.1, to be removed in 3.3 'text.latex.preview', # deprecated in MPL 3.3.1 'animation.avconv_args', # deprecated in MPL 3.3.1 'animation.avconv_path', # deprecated in MPL 3.3.1 'animation.html_args', # deprecated in MPL 3.3.1 'keymap.all_axes', # deprecated in MPL 3.3.1 'savefig.jpeg_quality' # deprecated in MPL 3.3.1 ] old_rcparams = { k: v for k, v in mpl.rcParams.items() if mpl_version < Version('3.0') or k not in deprecated_rcparams } return old_rcparams def get_validator(style): for k, v in validators.items(): if style.endswith(k) and (len(style) != 1 or style == k): return v
[docs]def validate(style, value, vectorized=True): """ Validates a style and associated value. Arguments --------- style: str The style to validate (e.g. 'color', 'size' or 'marker') value: The style value to validate vectorized: bool Whether validator should allow vectorized setting Returns ------- valid: boolean or None If validation is supported returns boolean, otherwise None """ validator = get_validator(style) if validator is None: return None if isinstance(value, arraylike_types+(list,)) and vectorized: return all(validator(v) for v in value) try: valid = validator(value) return False if valid == False else True except Exception: return False
[docs]def filter_styles(style, group, other_groups, blacklist=None): """ Filters styles which are specific to a particular artist, e.g. for a GraphPlot this will filter options specific to the nodes and edges. Arguments --------- style: dict Dictionary of styles and values group: str Group within the styles to filter for other_groups: list Other groups to filter out blacklist: list (optional) List of options to filter out Returns ------- filtered: dict Filtered dictionary of styles """ if blacklist is None: blacklist = [] group = group+'_' filtered = {} for k, v in style.items(): if (any(k.startswith(p) for p in other_groups) or k.startswith(group) or k in blacklist): continue filtered[k] = v for k, v in style.items(): if not k.startswith(group) or k in blacklist: continue filtered[k[len(group):]] = v return filtered
[docs]def wrap_formatter(formatter): """ Wraps formatting function or string in appropriate matplotlib formatter type. """ if isinstance(formatter, ticker.Formatter): return formatter elif callable(formatter): args = [arg for arg in inspect.getfullargspec(formatter).args if arg != 'self'] wrapped = formatter if len(args) == 1: def wrapped(val, pos=None): return formatter(val) return ticker.FuncFormatter(wrapped) elif isinstance(formatter, str): if re.findall(r"\{(\w+)\}", formatter): return ticker.StrMethodFormatter(formatter) else: return ticker.FormatStrFormatter(formatter)
def unpack_adjoints(ratios): new_ratios = {} offset = 0 for k, (num, ratio_values) in sorted(ratios.items()): unpacked = [[] for _ in range(num)] for r in ratio_values: nr = len(r) for i in range(num): unpacked[i].append(r[i] if i < nr else np.nan) for i, r in enumerate(unpacked): new_ratios[k+i+offset] = r offset += num-1 return new_ratios def normalize_ratios(ratios): normalized = {} for i, v in enumerate(zip(*ratios.values())): arr = np.array(v) normalized[i] = arr/float(np.nanmax(arr)) return normalized def compute_ratios(ratios, normalized=True): unpacked = unpack_adjoints(ratios) with warnings.catch_warnings(): warnings.filterwarnings('ignore', r'All-NaN (slice|axis) encountered') if normalized: unpacked = normalize_ratios(unpacked) sorted_ratios = sorted(unpacked.items()) return np.nanmax(np.vstack([v for _, v in sorted_ratios]), axis=0)
[docs]def axis_overlap(ax1, ax2): """ Tests whether two axes overlap vertically """ b1, t1 = ax1.get_position().intervaly b2, t2 = ax2.get_position().intervaly return t1 > b2 and b1 < t2
[docs]def resolve_rows(rows): """ Recursively iterate over lists of axes merging them by their vertical overlap leaving a list of rows. """ merged_rows = [] for row in rows: overlap = False for mrow in merged_rows: if any(axis_overlap(ax1, ax2) for ax1 in row for ax2 in mrow): mrow += row overlap = True break if not overlap: merged_rows.append(row) if rows == merged_rows: return rows else: return resolve_rows(merged_rows)
[docs]def fix_aspect(fig, nrows, ncols, title=None, extra_artists=None, vspace=0.2, hspace=0.2): """ Calculate heights and widths of axes and adjust the size of the figure to match the aspect. """ if extra_artists is None: extra_artists = [] fig.canvas.draw() w, h = fig.get_size_inches() # Compute maximum height and width of each row and columns rows = resolve_rows([[ax] for ax in fig.axes]) rs, cs = len(rows), max([len(r) for r in rows]) heights = [[] for i in range(cs)] widths = [[] for i in range(rs)] for r, row in enumerate(rows): for c, ax in enumerate(row): bbox = ax.get_tightbbox(fig.canvas.get_renderer()) heights[c].append(bbox.height) widths[r].append(bbox.width) height = (max([sum(c) for c in heights])) + nrows*vspace*fig.dpi width = (max([sum(r) for r in widths])) + ncols*hspace*fig.dpi # Compute aspect and set new size (in inches) aspect = height/width offset = 0 if title and title.get_text(): offset = title.get_window_extent().height/fig.dpi fig.set_size_inches(w, (w*aspect)+offset) # Redraw and adjust title position if defined fig.canvas.draw() if title and title.get_text(): extra_artists = [a for a in extra_artists if a is not title] bbox = get_tight_bbox(fig, extra_artists) top = bbox.intervaly[1] if title and title.get_text(): title.set_y(top/(w*aspect))
[docs]def get_tight_bbox(fig, bbox_extra_artists=None, pad=None): """ Compute a tight bounding box around all the artists in the figure. """ if bbox_extra_artists is None: bbox_extra_artists = [] renderer = fig.canvas.get_renderer() bbox_inches = fig.get_tightbbox(renderer) bbox_artists = bbox_extra_artists[:] bbox_artists += fig.get_default_bbox_extra_artists() bbox_filtered = [] for a in bbox_artists: bbox = a.get_window_extent(renderer) if isinstance(bbox, tuple): continue if a.get_clip_on(): clip_box = a.get_clip_box() if clip_box is not None: bbox = Bbox.intersection(bbox, clip_box) clip_path = a.get_clip_path() if clip_path is not None and bbox is not None: clip_path = clip_path.get_fully_transformed_path() bbox = Bbox.intersection(bbox, clip_path.get_extents()) if ( bbox is not None and (bbox.width != 0 or bbox.height != 0) and np.isfinite(bbox).all() ): bbox_filtered.append(bbox) if bbox_filtered: _bbox = Bbox.union(bbox_filtered) trans = Affine2D().scale(1.0 / fig.dpi) bbox_extra = TransformedBbox(_bbox, trans) bbox_inches = Bbox.union([bbox_inches, bbox_extra]) return bbox_inches.padded(pad) if pad else bbox_inches
[docs]def get_raster_array(image): """ Return the array data from any Raster or Image type """ if isinstance(image, RGB): rgb = image.rgb data = np.dstack([np.flipud(rgb.dimension_values(d, flat=False)) for d in rgb.vdims]) else: data = image.dimension_values(2, flat=False) if type(image) is Raster: data = data.T else: data = np.flipud(data) return data
[docs]def ring_coding(array): """ Produces matplotlib Path codes for exterior and interior rings of a polygon geometry. """ # The codes will be all "LINETO" commands, except for "MOVETO"s at the # beginning of each subpath n = len(array) codes = np.ones(n, dtype=Path.code_type) * Path.LINETO codes[0] = Path.MOVETO codes[-1] = Path.CLOSEPOLY return codes
[docs]def polygons_to_path_patches(element): """ Converts Polygons into list of lists of matplotlib.patches.PathPatch objects including any specified holes. Each list represents one (multi-)polygon. """ paths = element.split(datatype='array', dimensions=element.kdims) has_holes = isinstance(element, Polygons) and element.interface.has_holes(element) holes = element.interface.holes(element) if has_holes else None mpl_paths = [] for i, path in enumerate(paths): splits = np.where(np.isnan(path[:, :2].astype('float')).sum(axis=1))[0] arrays = np.split(path, splits+1) if len(splits) else [path] subpath = [] for j, array in enumerate(arrays): if j != (len(arrays)-1): array = array[:-1] if (array[0] != array[-1]).any(): array = np.append(array, array[:1], axis=0) interiors = [] for interior in (holes[i][j] if has_holes else []): if (interior[0] != interior[-1]).any(): interior = np.append(interior, interior[:1], axis=0) interiors.append(interior) vertices = np.concatenate([array]+interiors) codes = np.concatenate([ring_coding(array)]+ [ring_coding(h) for h in interiors]) subpath.append(PathPatch(Path(vertices, codes))) mpl_paths.append(subpath) return mpl_paths
[docs]class CFTimeConverter(NetCDFTimeConverter): """ Defines conversions for cftime types by extending nc_time_axis. """
[docs] @classmethod def convert(cls, value, unit, axis): if not nc_axis_available: raise ValueError('In order to display cftime types with ' 'matplotlib install the nc_time_axis ' 'library using pip or from conda-forge ' 'using:\n\tconda install -c conda-forge ' 'nc_time_axis') if isinstance(value, cftime_types): value = CalendarDateTime(value.datetime, value.calendar) elif isinstance(value, np.ndarray): value = np.array([CalendarDateTime(v.datetime, v.calendar) for v in value]) return super().convert(value, unit, axis)
[docs]class EqHistNormalize(Normalize): def __init__(self, vmin=None, vmax=None, clip=False, rescale_discrete_levels=True, nbins=256**2, ncolors=256): super().__init__(vmin, vmax, clip) self._nbins = nbins self._bin_edges = None self._ncolors = ncolors self._color_bins = np.linspace(0, 1, ncolors+1) self._rescale = rescale_discrete_levels def binning(self, data, n=256): low = data.min() if self.vmin is None else self.vmin high = data.max() if self.vmax is None else self.vmax nbins = self._nbins eq_bin_edges = np.linspace(low, high, nbins+1) full_hist, _ = np.histogram(data, eq_bin_edges) # Remove zeros, leaving extra element at beginning for rescale_discrete_levels nonzero = np.nonzero(full_hist)[0] nhist = len(nonzero) if nhist > 1: hist = np.zeros(nhist+1) hist[1:] = full_hist[nonzero] eq_bin_centers = np.concatenate([[0.], (eq_bin_edges[nonzero] + eq_bin_edges[nonzero+1]) / 2.]) eq_bin_centers[0] = 2*eq_bin_centers[1] - eq_bin_centers[-1] else: hist = full_hist eq_bin_centers = np.convolve(eq_bin_edges, [0.5, 0.5], mode='valid') # CDF scaled from 0 to 1 except for first value cdf = np.cumsum(hist) lo = cdf[1] diff = cdf[-1] - lo with np.errstate(divide='ignore', invalid='ignore'): cdf = (cdf - lo) / diff cdf[0] = -1.0 lower_span = 0 if self._rescale: discrete_levels = nhist m = -0.5/98.0 c = 1.5 - 2*m multiple = m*discrete_levels + c if (multiple > 1): lower_span = 1 - multiple cdf_bins = np.linspace(lower_span, 1, n+1) binning = np.interp(cdf_bins, cdf, eq_bin_centers) if not self._rescale: binning[0] = low binning[-1] = high return binning def __call__(self, data, clip=None): return self.process_value(data)[0]
[docs] def process_value(self, data): if isinstance(data, np.ndarray): self._bin_edges = self.binning(data, self._ncolors) isscalar = np.isscalar(data) data = np.array([data]) if isscalar else data interped = np.interp(data, self._bin_edges, self._color_bins) return, isscalar
def inverse(self, value): if self._bin_edges is None: raise ValueError("Not invertible until eq_hist has been computed") return np.interp([value], self._color_bins, self._bin_edges)[0]
for cft in cftime_types: munits.registry[cft] = CFTimeConverter()