try:
    # framework is running
    from .startup_choice import *
except ImportError as _excp:
    # class is imported by itself
    if (
        'attempted relative import with no known parent package' in str(_excp)
        or 'No module named \'omfit_classes\'' in str(_excp)
        or "No module named '__main__.startup_choice'" in str(_excp)
    ):
        from startup_choice import *
    else:
        raise
if framework:
    print('Loading plot utility functions...')
import collections
from signal import pause
from omfit_classes.utils_base import _available_to_user_plot, compare_version
from omfit_classes.utils_math import is_uncertain, interp1e, unsorted_unique, RectBivariateSplineNaN, point_to_line
from omfit_classes.sortedDict import SortedDict
# explicit imports
import numpy as np
from xarray import Dataset, DataArray
import uncertainties
import inspect
from scipy import interpolate, integrate
import matplotlib
from matplotlib import pyplot, gridspec, cm
from matplotlib.widgets import RectangleSelector, RadioButtons, AxesWidget
from mpl_toolkits.mplot3d import Axes3D
import platform
import re
default_colorblind_line_cycle = [
    '#1f77b4',
    '#ff7f0e',
    '#2ca02c',
    '#d62728',
    '#9467bd',
    '#8c564b',
    '#e377c2',
    '#7f7f7f',
    '#bcbd22',
    '#17becf',
]
default_matplotlib_line_cycle = eval(re.sub(r'.*(\[.*\]).*', r'\1', str(matplotlib.rcParams['axes.prop_cycle'])))
default_colorblind_line_cycle = matplotlib.cycler('color', default_colorblind_line_cycle)
default_matplotlib_cmap_cycle = 'viridis'
matplotlib.style.core.USER_LIBRARY_PATHS.append(OMFITsrc + '/extras/styles')
matplotlib.style.core.reload_library()
# Colormap definition for "Standard Gamma-II" colormap from IDL
# fmt: off
_r = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
               0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
               0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
               4, 9, 14, 19, 23, 28, 33, 38, 42, 47, 52, 57, 61, 66, 71, 76,
               81, 81, 81, 81, 81, 81, 81, 81, 80, 80, 80, 80, 80, 80, 80, 79,
               84, 89, 94, 99, 104, 109, 114, 119, 124, 129, 134, 139, 144, 149, 154, 159,
               164, 169, 174, 180, 185, 190, 196, 201, 206, 212, 217, 222, 228, 233, 255, 255,
               255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
               255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
               255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
               255, 255, 255, 248, 240, 232, 225, 217, 209, 202, 194, 186, 179, 171, 163, 168,
               173, 178, 183, 188, 193, 198, 203, 209, 214, 219, 224, 229, 234, 239, 244, 249,
               255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
               255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
               255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
               255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255])
_g = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
               0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
               0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
               0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
               0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
               0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
               0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
               0, 0, 5, 10, 16, 21, 27, 32, 37, 43, 48, 54, 59, 64, 70, 75,
               81, 85, 90, 95, 100, 105, 109, 114, 119, 124, 129, 134, 138, 143, 148, 153,
               158, 163, 163, 163, 163, 163, 163, 163, 163, 163, 163, 163, 163, 163, 163, 163,
               163, 163, 163, 163, 163, 163, 163, 163, 163, 163, 163, 163, 163, 163, 163, 163,
               163, 169, 175, 181, 187, 193, 199, 205, 212, 218, 224, 230, 236, 242, 248, 255,
               255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
               255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
               255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
               255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255])
_b = np.array([0, 5, 10, 15, 20, 26, 31, 36, 41, 46, 52, 57, 62, 67, 72, 78,
               83, 88, 93, 98, 104, 109, 114, 119, 124, 130, 135, 140, 145, 150, 156, 161,
               166, 171, 176, 182, 187, 192, 197, 202, 208, 213, 218, 223, 228, 234, 239, 244,
               249, 255, 250, 245, 239, 234, 228, 223, 218, 212, 207, 201, 196, 190, 185, 180,
               174, 169, 163, 158, 152, 147, 142, 136, 131, 125, 120, 114, 109, 104, 98, 93,
               87, 82, 76, 71, 66, 60, 55, 49, 44, 38, 33, 28, 22, 17, 11, 6,
               0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
               0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
               0, 4, 9, 14, 19, 24, 28, 33, 38, 43, 48, 53, 57, 62, 67, 72,
               77, 82, 77, 71, 65, 59, 53, 47, 41, 36, 30, 24, 18, 12, 6, 0,
               0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
               0, 3, 6, 9, 12, 16, 19, 22, 25, 29, 32, 35, 38, 41, 45, 48,
               51, 54, 58, 61, 64, 67, 71, 74, 77, 80, 83, 87, 90, 93, 96, 100,
               103, 106, 109, 112, 116, 119, 122, 125, 129, 132, 135, 138, 142, 145, 148, 151,
               154, 158, 161, 164, 167, 171, 174, 177, 180, 183, 187, 190, 193, 196, 200, 203,
               206, 209, 213, 216, 219, 222, 225, 229, 232, 235, 238, 242, 245, 248, 251, 255])
# fmt: on
_rgb = np.column_stack((_r, _g, _b)) / 255.0
cm.register_cmap(cmap=matplotlib.colors.ListedColormap(_rgb, name='Standard Gamma-II'))
matplotlib.rcParams['image.cmap'] = default_matplotlib_cmap_cycle
matplotlib.rcParamsDefault.update(matplotlib.rcParams)
if platform.system() == 'Darwin':
    rightClickMPLindex = 2
    middleClickMPLindex = 3
else:
    rightClickMPLindex = 3
    middleClickMPLindex = 2
[docs]@_available_to_user_plot
def autofmt_sharexy(trim_xlabel=True, trim_ylabel=True, fig=None):
    # todo: docstring
    if fig is None:
        fig = pyplot.gcf()
    autofmt_sharex(trim_ylabel=trim_ylabel, fig=fig)
    autofmt_sharey(trim_xlabel=trim_xlabel, fig=fig) 
[docs]@_available_to_user_plot
def autofmt_sharey(trim_xlabel=True, fig=None, wspace=0):
    """
    Prunes y-tick labels and y-axis labels from all but the first cols axes
    and moves cols (optionally) closer together.
    :param trim_xlabel: bool. prune right ytick label to prevent overlap.
    :param fig: Figure. Defaults to current figure.
    :param wspace: Horizontal spacing between axes.
    """
    if fig is None:
        fig = pyplot.gcf()
    for a in fig.axes:
        if hasattr(a, 'is_first_col') and not a.is_first_col():
            for yt in a.get_yticklabels():
                yt.set_visible(False)
            a.set_ylabel('')
        if len(a.get_xticklabels()) and trim_xlabel:
            try:
                a.xaxis.major.locator.set_params(prune='upper')
            except Exception:
                pass
    fig.subplots_adjust(wspace=wspace) 
[docs]@_available_to_user_plot
def autofmt_sharex(trim_ylabel=True, fig=None, hspace=0):
    """
    Prunes x-tick labels and x-axis labels from all but the last row axes
    and moves rows (optionally) closer together.
    :param trim_ylabel: bool. prune top ytick label to prevent overlap.
    :param fig: Figure. Defaults to current figure.
    :param hspace: Vertical spacing between axes.
    """
    if fig is None:
        fig = pyplot.gcf()
    for a in fig.axes:
        if hasattr(a, 'is_last_row') and not a.is_last_row():
            for xt in a.get_xticklabels():
                xt.set_visible(False)
            a.set_xlabel('')
            a.xaxis.get_offset_text().set_visible(False)
        if len(a.get_yticklabels()) and trim_ylabel:
            try:
                a.yaxis.major.locator.set_params(prune='upper')
            except Exception:
                pass
    fig.subplots_adjust(hspace=hspace) 
[docs]@_available_to_user_plot
def uerrorbar(x, y, ax=None, **kwargs):
    r"""
    Given arguments y or x,y where x and/or y have uncertainties, feed the
    appropriate terms to matplotlib's errorbar function.
    If y or x is more than 1D, it is flattened along every dimension but the last.
    :param x: array of independent axis values
    :param y: array of values with uncertainties, for which shaded error band is plotted
    :param ax: The axes instance into which to plot (default: pyplot.gca())
    :param \**kwargs: Passed to ax.errorbar
    :return: list. A list of ErrorbarContainer objects containing the line, bars, and caps of each (x,y) along the last dimension.
    """
    result = []
    # set default key word arguments
    if ax is None:
        ax = pyplot.gca()
    kwargs.setdefault('marker', 'o')
    kwargs.pop('axes', None)  # don't let them use this because we used ax
    if 'linestyle' not in kwargs and 'ls' not in kwargs:
        kwargs['linestyle'] = ''
    if np.all(std_devs(y) == 0) and np.all(std_devs(x) == 0):
        kwargs.setdefault('capsize', 0)
    # enable combinations of 1D and 2D x's and y's
    y = np.array(y)
    y = y.reshape(-1, y.shape[-1])
    x = np.array(x)
    x = x.reshape(-1, x.shape[-1])
    if x.shape[0] == 1 and y.shape[0] > 1:  # one x for all y's
        x = np.tile(x[0, :], y.shape[0]).reshape(-1, x.shape[-1])
    # plot each (x,y) and collect container objects
    for xi, yi in zip(x, y):
        tmp = ax.errorbar(nominal_values(xi), nominal_values(yi), xerr=std_devs(xi), yerr=std_devs(yi), **kwargs)
        result.append(tmp)
    return result 
[docs]class Uband(object):
    """
    This class wraps the line and PollyCollection(s) associated with a banded
    errorbar plot for use in the uband function.
    """
    def __init__(self, line, bands):
        """
        :param line: Line2D
            A line of the x,y nominal values
        :param bands: list of PolyCollections
            The fill_between and/or fill_betweenx PollyCollections spanning the std_devs of the x,y data
        """
        from matplotlib.cbook import flatten
        self.line = line  # matplotlib.lines.Line2D
        self.bands = list(flatten([bands]))  # matplotlib.collections.PolyCollection(s)
    def __getattr__(self, attr):
        if attr in ['set_color', 'set_lw', 'set_linewidth', 'set_dashes', 'set_linestyle']:
            def _band_line_method(method, *args, **kw):
                """
                Call the same method for line and band.
                Returns Line2D method call result.
                """
                for band in self.bands:
                    getattr(band, method)(*args, **kw)
                return getattr(self.line, method)(*args, **kw)
            return lambda *args, **kw: _band_line_method(attr, *args, **kw)
        else:
            return getattr(self.line, attr) 
[docs]@_available_to_user_plot
def uband(x, y, ax=None, fill_kwargs=None, **kwargs):
    r"""
    Given arguments x,y where either or both have uncertainties, plot x,y using pyplot.plot
    of the nominal values and surround it with with a shaded error band using matplotlib's
    fill_between and/or fill_betweenx.
    If y or x is more than 1D, it is flattened along every dimension but the last.
    :param x: array of independent axis values
    :param y: array of values with uncertainties, for which shaded error band is plotted
    :param ax: The axes instance into which to plot (default: pyplot.gca())
    :param fill_kwargs: dict. Passed to pyplot.fill_between
    :param \**kwargs: Passed to pyplot.plot
    :return: list. A list of Uband objects containing the line and bands of each (x,y) along
             the last dimension.
    """
    result = []
    if ax is None:
        ax = pyplot.gca()
    if fill_kwargs is None:
        fill_kwargs = {}
    fill_kwargs.setdefault('alpha', 0.25)
    # enable combinations of 1D and 2D x's and y's
    y = np.array(y)
    y = y.reshape(-1, y.shape[-1])
    x = np.array(x)
    x = x.reshape(-1, x.shape[-1])
    if x.shape[0] == 1 and y.shape[0] > 1:  # one x for all y's
        x = np.tile(x[0, :], y.shape[0]).reshape(-1, x.shape[-1])
    # plot each (x,y) and collect the lines/bands into a single object
    for xi, yi in zip(x, y):
        xnom = np.atleast_1d(np.squeeze(uncertainties.unumpy.nominal_values(xi)))
        xerr = np.atleast_1d(np.squeeze(uncertainties.unumpy.std_devs(xi)))
        ynom = np.atleast_1d(np.squeeze(uncertainties.unumpy.nominal_values(yi)))
        yerr = np.atleast_1d(np.squeeze(uncertainties.unumpy.std_devs(yi)))
        (l,) = ax.plot(xnom, ynom, **kwargs)
        fkwargs = copy.copy(fill_kwargs)  # changes to fill_kwargs propagate to the next call of uband!
        fkwargs.setdefault('color', l.get_color())
        bands = []
        if np.any(yerr != 0):
            bandy = ax.fill_between(xnom, ynom - yerr, ynom + yerr, **fkwargs)
            bands.append(bandy)
        if np.any(xerr != 0):
            bandx = ax.fill_betweenx(ynom, xnom - xerr, xnom + xerr, **fkwargs)
            bands.append(bandx)
        tmp = Uband(l, bands)
        result.append(tmp)
    return result 
[docs]@_available_to_user_plot
def hardcopy(fn, bbox_inches='tight', fig=None, **keyw):
    # todo: docstring
    if fig is None:
        fig = pyplot.gcf()
    tmp = {}
    tmp['ps.usedistiller'] = matplotlib.rcParams['ps.usedistiller']
    tmp['text.usetex'] = matplotlib.rcParams['text.usetex']
    matplotlib.rcParams['ps.usedistiller'] = 'ghostscript'
    matplotlib.rcParams['text.usetex'] = True
    os.system('rm %s/.config/matplotlib/tex.cache/*' % os.environ['HOME'])
    try:
        fig.savefig(fn, bbox_inches=bbox_inches, **keyw)
        printi('Saved to %s' % (fn))
    except Exception:
        printe('Error in saving file')
    matplotlib.rcParams['ps.usedistiller'] = tmp['ps.usedistiller']
    matplotlib.rcParams['text.usetex'] = tmp['text.usetex']
    return fn 
[docs]@_available_to_user_plot
def set_fontsize(fig=None, fontsize='+0'):
    """
    For each text object of a figure fig, set the font size to fontsize
    :param fig: matplotlib.figure object
    :param fontsize: can be an absolute number (e.g 10) or a relative number (-2 or +2)
    :return: None
    """
    def match(artist):
        return artist.__module__ == "matplotlib.text"
    if is_int(fig) or isinstance(fig, str):
        fontsize = fig
        fig = pyplot.gcf()
    if fig is None:
        fig = pyplot.gcf()
    fontS = fontsize
    for textobj in fig.findobj(match=match):
        if '+' in fontsize or '-' in fontsize:
            fontS = textobj.get_fontsize() + eval(fontsize)
        textobj.set_fontsize(fontS) 
[docs]def user_lines_cmap_cycle():
    """
    return colormap chosen by the user for representation of lines
    """
    mapper = {'blind': 'viridis'}
    try:
        tmp = OMFIT['MainSettings']['SETUP']['PlotAppearance']['lines.cmap'][0]
        return mapper.get(tmp, tmp)
    except KeyError:
        pass
    except Exception as _excp:
        printe(_excp)
    return default_matplotlib_cmap_cycle 
[docs]def user_image_cmap_cycle():
    """
    return colormap chosen by the user for representation of images
    """
    try:
        return OMFIT['MainSettings']['SETUP']['PlotAppearance']['image.cmap']
    except KeyError:
        pass
    except Exception as _excp:
        printe(_excp)
    return default_matplotlib_cmap_cycle 
[docs]@_available_to_user_plot
def color_cycle(n=10, k=None, cmap_name=None):
    """
    Utility function to conveniently return the color of an index in a colormap cycle
    :param n: number of uniformly spaced colors, or array defining the colors' spacings
    :param k: index of the color (if None an array of colors of length n will be returned)
    :param cmap_name: name of the colormap
    :return: color of index k from colormap cmap_name made of n colors, or array of colors of length n if k is None
             Note: if n is an array, then the associated ScalarMappable object is also returned (e.g. for use in a colorbar)
    """
    if cmap_name is None:
        cmap_name = user_lines_cmap_cycle()
    if np.iterable(n):
        nm = matplotlib.colors.Normalize(np.nanmin(n), np.nanmax(n))
        sm = matplotlib.cm.ScalarMappable(cmap=cmap_name, norm=nm)
        sm.set_array(n)
        colors = sm.cmap(sm.norm(n))
        if k is not None:
            return colors[k], sm
        else:
            return colors, sm
    else:
        cmap = getattr(cm, cmap_name)
        if k is not None:
            return cmap(k / max([1.0, float(n - 1)]))
        else:
            tmp = []
            for k in range(n):
                tmp.append(cmap(k / max([1.0, float(n - 1)])))
            return tmp 
[docs]@_available_to_user_plot
def cycle_cmap(length=50, cmap=None, start=None, stop=None, ax=None):
    """
    Set default color cycle of matplotlib based on colormap
    Note that the default color cycle is not changed if ax parameter is set; only the axes's color cycle will be changed
    :param length: The number of colors in the cycle
    :param cmap: Name of a matplotlib colormap
    :param start: Limit colormap to this range (0 < start < stop 1)
    :param stop: Limit colormap to this range (0 < start < stop 1)
    :param ax: If ax is not None, then change the axes's color cycle instead of the default color cycle
    :return: color_cycle
    """
    if cmap is None:
        cmap = user_lines_cmap_cycle()
    cmap = getattr(pyplot.cm, cmap)
    crange = [0, 1]
    if start is not None:
        crange[0] = start
    if stop is not None:
        crange[1] = stop
    assert 0 <= crange[0] <= 1
    assert 0 <= crange[1] <= 1
    color_cycle = [RGB_to_HEX(*rgb[:3]) for rgb in cmap(np.linspace(crange[0], crange[1], num=length))]
    if ax is None:
        from cycler import cycler
        pyplot.rc('axes', prop_cycle=cycler('color', color_cycle))
    else:
        ax.set_prop_cycle(color_cycle)
    return color_cycle 
[docs]@_available_to_user_plot
def contrasting_color(line_or_color):
    """
    Given a matplotlib color specification or a line2D instance or a list with a line2D instance as the first element,
    pick and return a color that will contrast well. More complicated than just inversion as inverting blue gives
    yellow, which doesn't display well on a white background.
    :param line_or_color: matplotlib color spec, line2D instance, or list w/ line2D instance as the first element
    :return: 4 element array
        RGBA color specification for a contrasting color
    """
    # Get the RGBA values for the color to transform
    if isinstance(line_or_color, matplotlib.lines.Line2D):
        color = line_or_color.get_color()
    elif np.iterable(line_or_color) and isinstance(line_or_color[0], matplotlib.lines.Line2D):
        color = line_or_color[0].get_color()
    else:
        color = line_or_color
    color = np.array(matplotlib.colors.to_rgba(color))
    # Split RGB from A and transform to HSV
    alpha = color[3]
    color0 = color[0:3]
    color = matplotlib.colors.rgb_to_hsv(color0)
    # Make sure black and white change and that colors are fairly bright
    value_lims = [0.5, 0.9]
    sat_lims = [0.75, 0.95]
    color[1] = min([max([color[1], value_lims[0]]), value_lims[1]])
    color[2] = min([max([color[2], sat_lims[0]]), sat_lims[1]])
    # Rotate the hue. 1/3 seems to look better than 1/2.
    color[0] = (color[0] + 0.333) % 1
    # Transform back to RGB, put the A back in, and return
    return np.append(matplotlib.colors.hsv_to_rgb(color), alpha) 
[docs]@_available_to_user_plot
def associated_color(line_or_color):
    """
    Given a matplotlib color specification or a line2D instance or a list with a line2D instance as the first element,
    pick and return a color that will look thematically linked to the first color, but still distinguishable.
    :param line_or_color: matplotlib color spec, line2D instance, or list w/ line2D instance as the first element
    :return: 4 element array
        RGBA color specification for a related, similar (but distinguishable) color
    """
    # Get the RGBA values for the color to transform
    if isinstance(line_or_color, matplotlib.lines.Line2D):
        color = line_or_color.get_color()
    elif np.iterable(line_or_color) and isinstance(line_or_color[0], matplotlib.lines.Line2D):
        color = line_or_color[0].get_color()
    else:
        color = line_or_color
    color = np.array(matplotlib.colors.to_rgba(color))
    # Split RGB from A and transform to HSV
    alpha = color[3]
    color0 = color[0:3]
    color = matplotlib.colors.rgb_to_hsv(color0)
    # Change saturation
    if color[1] <= 0.4:
        color[1] *= 2
    else:
        color[1] *= 0.75
    # Change brightness
    if color[2] <= 0.25:
        color[2] = 0.5 + color[2]
    elif color[2] <= 0.4:
        color[2] *= 1.8
    else:
        color[2] *= 0.45
    # Shift the hue slightly.
    color[0] = (color[0] - 0.1) % 1
    # Transform back to RGB, put the A back in, and return
    return np.append(matplotlib.colors.hsv_to_rgb(color), alpha) 
[docs]@_available_to_user_plot
def blur_image(im, n, ny=None):
    """blurs the image by convolving with a gaussian kernel of typical
    size n. The optional keyword argument ny allows for a different
    size in the y direction.
    """
    def gauss_kern(size, sizey=None):
        """Returns a normalized 2D gauss kernel array for convolutions"""
        size = int(size)
        if not sizey:
            sizey = size
        else:
            sizey = int(sizey)
        x, y = np.mgrid[-size : size + 1, -sizey : sizey + 1]
        g = np.exp(-(x**2 / float(size) + y**2 / float(sizey)))
        return g / g.sum()
    from scipy import signal
    g = gauss_kern(n, sizey=ny)
    improc = signal.convolve(im, g, mode='same')
    return improc 
[docs]@_available_to_user_plot
def pcolor2(*args, fast=False, **kwargs):
    r"""
    Plots 2D data as a patch collection.
    Differently from matplotlib.pyplot.pcolor the mesh is extended by one element so that the number of tiles equals
    the number of data points in the Z matrix.
    The X,Y grid does not have to be rectangular.
    :param \*args: Z or X,Y,Z data to be plotted
    :param fast: bool
        Use pcolorfast instead of pcolor. Speed improvements may be dramatic.
        However, pcolorfast is marked as experimental and may produce unexpected behavior.
    :param \**kwargs: these arguments are passed to matplotlib.pyplot.pclor
    :return: None
    """
    ax = kwargs.pop('ax', None)
    if ax is None:
        ax = pylab.gca()
    cmap = kwargs.pop('cmap', None)
    if cmap is None:
        cmap = cm.get_cmap()
    cmap.set_bad(color='w', alpha=1.0)
    kwargs['cmap'] = cmap
    naughty_list = ['interpolation']
    kw2 = {k: v for k, v in kwargs.items() if k not in naughty_list}
    if len(args) == 1:
        if fast:
            obj = ax.pcolorfast(np.ma.masked_invalid(args[0]), **kw2)
        else:
            obj = ax.pcolor(np.ma.masked_invalid(args[0]), **kw2)
    else:
        if len(args[0].shape) == 1:
            xy1d = True
            xdata, ydata = np.meshgrid(args[0], args[1])
        else:
            xy1d = False
            xdata = args[0]
            ydata = args[1]
        masked = np.ma.masked_invalid(args[2])
        if fast:
            if xy1d:
                # pcolorfast doesn't print the z coordinate if x,y are 2D, so try to reduce to 1d if possible
                xdata = xdata[0, :]
                ydata = ydata[:, 0]
                inner_edges = (xdata[:-1] + xdata[1:]) / 2.0
                xdata = np.append(np.append(inner_edges[0] - xdata[1] + xdata[0], inner_edges), inner_edges[-1] + xdata[-1] - xdata[-2])
                inner_edges = (ydata[:-1] + ydata[1:]) / 2.0
                ydata = np.append(np.append(inner_edges[0] - ydata[1] + ydata[0], inner_edges), inner_edges[-1] + ydata[-1] - ydata[-2])
                # pcolorfast says it's even faster if given a tuple of (min, max) for x and y, for even spacing
                if (np.max(abs(np.diff(np.diff(xdata)))) == 0) and (np.max(abs(np.diff(np.diff(ydata)))) == 0):
                    xdata = (np.min(xdata), np.max(xdata))
                    ydata = (np.min(ydata), np.max(ydata))
                obj = ax.pcolorfast(xdata, ydata, masked, **kw2)
            else:
                xdata, ydata = meshgrid_expand(xdata, ydata)
                obj = ax.pcolorfast(xdata, ydata, masked, **kw2)
        else:
            xdata, ydata = meshgrid_expand(xdata, ydata)
            obj = ax.pcolor(xdata, ydata, masked, **kw2)
    return obj 
[docs]@_available_to_user_plot
def image(*args, **kwargs):
    r"""
    Plots 2D data as an image.
    Much faster than pcolor/pcolor2(fast=False), but the data have to be on a rectangular X,Y grid
    :param \*args: Z or X,Y,Z data to be plotted
    :param \**kwargs: these arguments are passed to pcolorfast
    """
    from omfit_classes.omfit_dmp import mpl_dump, mpl_dump_enable
    mpl_dump_enable[0] = False
    try:
        # todo: add clipping
        ax = kwargs.pop('ax', None)
        if ax is None:
            ax = pylab.gca()
        cmap = kwargs.pop('cmap', None)
        if cmap is None:
            cmap = cm.get_cmap()
        cmap.set_bad(color='w', alpha=1.0)
        kwargs['cmap'] = cmap
        zlim = kwargs.pop('zlim', None)
        if len(args) == 1:
            zdata = args[0]
            xdata = np.linspace(0, zdata.shape[1] - 1, zdata.shape[1])
            ydata = np.linspace(0, zdata.shape[0] - 1, zdata.shape[0])
        elif len(args) >= 3:
            xdata = args[0]
            ydata = args[1]
            zdata = args[2]
            if xdata is None:
                xdata = np.linspace(0, zdata.shape[1] - 1, zdata.shape[1])
            if ydata is None:
                ydata = np.linspace(0, zdata.shape[0] - 1, zdata.shape[0])
        zdata = np.ma.masked_invalid(zdata)
        # accounts for data where x,ymax is at min
        if xdata[0] > xdata[-1]:
            xdata = xdata[::-1]
            zdata = zdata[:, ::-1]
        if ydata[0] > ydata[-1]:
            ydata = ydata[::-1]
            zdata = zdata[::-1, :]
        extent = [np.nanmin(xdata), np.nanmax(xdata), np.nanmin(ydata), np.nanmax(ydata)]
        ax.plot([extent[0], extent[1]], [extent[2], extent[3]])
        del ax.lines[-1]
        kwargs.setdefault('fast', True)
        obj = pcolor2(xdata, ydata, zdata, ax=ax, **kwargs)
        if zlim is not None:
            norm = matplotlib.colors.Normalize(vmin=zlim[0], vmax=zlim[1])
            obj.set_norm(norm)
        ax.set_xlim(extent[0], extent[1])
        ax.set_ylim(extent[2], extent[3])
    finally:
        mpl_dump_enable[0] = True
    dumpDict = {'args': [ax], 'func': 'image'}
    args = list(args)
    # for DMP purposes reduce color scale to 16bits
    if len(args) == 1 and args[0].dtype.name.startswith('float'):
        args[0] = args[0].astype(np.float16)
    if len(args) == 3 and args[2].dtype.name.startswith('float'):
        args[2] = args[2].astype(np.float16)
    dumpDict['args'].extend(args)
    kwargs.pop('ax', None)
    dumpDict['kwargs'] = kwargs
    mpl_dump(dumpDict)
    return obj 
[docs]@_available_to_user_plot
def meshgrid_expand(xdata, ydata):
    """
    returns the veritices of the mesh, if the xdata and ydata were the centers of the mesh
    xdata and ydata are 2D matrices, which could for example be generated by np.meshgrid
    :param xdata: center of the mesh
    :param ydata: center of the mesh
    :return:
    """
    dxdx, dxdy = np.gradient(xdata)
    dydx, dydy = np.gradient(ydata)
    xdata2 = np.vstack((xdata, xdata[-1, :] + dxdx[-1, :]))
    dxdy2 = np.vstack((dxdy, dxdy[-1, :]))
    xdata3 = np.hstack((xdata2, xdata2[:, -1:] + dxdy2[:, -1:]))
    ydata2 = np.hstack((ydata, ydata[:, -1:] + dydy[:, -1:]))
    dydx2 = np.hstack((dydx, dydx[:, -1:]))
    ydata3 = np.vstack((ydata2, ydata2[-1, :] + dydx2[-1, :]))
    dxdx3, dxdy3 = np.gradient(xdata3)
    dydx3, dydy3 = np.gradient(ydata3)
    dx3 = dxdx3 + dxdy3
    dy3 = dydx3 + dydy3
    xdata4 = xdata3 - dx3 / 2.0
    ydata4 = ydata3 - dy3 / 2.0
    return xdata4, ydata4 
[docs]@_available_to_user_plot
def map_HBS_to_RGB(H, B, S=1.0, cmap=None):
    """
    map to a RGB colormap separate HUE, BRIGHTNESS and SATURATIONS arrays
    :param H: HUE data (any shape array)
    :param B: BRIGHTNESS data (any shape array)
    :param S: SATURATION data (any shape array)
    :param cmap: matplotlib.colormap to be used
    :return: RGB array (shape of input array with one more dimension of size 3 (RGB) )
    """
    if cmap is None:
        cmap = user_image_cmap_cycle()
    if isinstance(cmap, str):
        cmap = cm.get_cmap(cmap, 256)
    x = np.linspace(0, 1, cmap.N)
    shapes = list(H.shape)
    w = np.zeros((H.size, 3))
    for n in range(H.size):
        v = H.flatten()[n]
        t = cmap(np.nanargmin(abs(x - v)))
        v = colorsys.rgb_to_hls(t[0], t[1], t[2])
        if is_float(S * 1.0):
            w[n] = colorsys.hls_to_rgb(v[0], B.flatten()[n], S)
        else:
            w[n] = colorsys.hls_to_rgb(v[0], B.flatten()[n], S.flatten()[n])
    shapes.append(3)
    return w.reshape(shapes) 
[docs]@_available_to_user_plot
def RGB_to_HEX(R, G, B):
    """
    Convert color from numerical RGB to hexadecimal representation
    :param R: integer 0<x<255 or float 0.0<x<1.0
    :param G: integer 0<x<255 or float 0.0<x<1.0
    :param B: integer 0<x<255 or float 0.0<x<1.0
    :return: hexadecimal representation of the color
    """
    if not isinstance(R, int):
        R = int(round(255 * R))
    if not isinstance(G, int):
        G = int(round(255 * G))
    if not isinstance(B, int):
        B = int(round(255 * B))
    return '#%02x%02x%02x' % (R, G, B) 
[docs]@_available_to_user_plot
def plotc(*args, **kwargs):
    r"""
    Plot the various curves defined by the arguments [X],Y,[Z]
    where X is the x value, Y is the y value, and Z is the color.
    If one argument is given it is interpreted as Y; if two, then X, Y; if three
    then X, Y, Z.  If all three are given, then it is passed to plotc and the
    labels are discarded.
    If Z is omitted, a rainbow of colors is used, with blue for the first curve
    and red for the last curve.  A different color map can be given with the
    cmap keyword (see http://wiki.scipy.org/Cookbook/Matplotlib/Show_colormaps
    for other options). If X is omitted, then the (ith) index of Y is used as the x value.
    :param \*args:
    :param \**kwargs:
    :return:
    """
    ax = kwargs.pop('ax', None)
    if ax is None:
        ax = pylab.gca()
    args = list(args)
    X = args.pop(0)
    Y = None
    Z = None
    if len(args) >= 1 and not isinstance(args[0], str):
        Y = args.pop(0)
        if len(args) >= 1 and not isinstance(args[0], str):
            Z = args.pop(0)
    # from list to matrix, this is done to work nicely with across function of SortedDict
    def list2mat(inv):
        outv = inv
        if isinstance(inv, list):
            outv = np.array(inv).T
        return outv
    X = list2mat(X)
    Y = list2mat(Y)
    Z = list2mat(Z)
    if Y is None and Z is None:
        Y = X
        X = np.linspace(0, Y.shape[0] - 1, Y.shape[0])
    if len(X.shape) == len(Y.shape) == 1 and Z is None:
        return ax.plot(X, Y, *args, **kwargs)
    Y = np.reshape(Y, (Y.shape[0], -1))
    if len(X.shape) == 1:
        X = np.tile(X.flatten(), (Y.shape[1], 1)).T
    try:
        tmp_prop_cycle = matplotlib.rcParams['axes.prop_cycle']
        nlines = Y.shape[1]
        pyplot.ioff()
        LC = []
        if Z is not None:
            vmax = kwargs.pop('vmax', None)
            vmin = kwargs.pop('vmin', None)
            kwargs.setdefault('norm', [vmin, vmax])
            if len(Z.shape) == 1:
                Z = np.tile(Z.flatten(), (Y.shape[1], 1)).T
            # lets plot it normally, since this sets the axes and other stuff that I do not know about
            ax.plot(X.flatten(), Y.flatten(), '.')
            del ax.lines[-1]
            def _plotc(x, y, z, *args, **kwargs):
                points = np.array([x, y]).T.reshape(-1, 1, 2)
                segments = np.concatenate([points[:-1], points[1:]], axis=1)
                if kwargs['norm'][0] is None:
                    kwargs['norm'] = [np.nanmin(Z), kwargs['norm'][1]]
                if kwargs['norm'][1] is None:
                    kwargs['norm'] = [kwargs['norm'][0], np.nanmax(Z)]
                if not isinstance(kwargs['norm'], pyplot.Normalize):
                    kwargs['norm'] = pyplot.Normalize(kwargs['norm'][0], kwargs['norm'][1])
                lc = matplotlib.collections.LineCollection(segments, *args, **kwargs)
                lc.set_array(z)
                ax.add_collection(lc)
                return lc
            for k in range(nlines):
                x = X[:, k]
                y = Y[:, k]
                z = Z[:, k]
                LC.append(_plotc(x, y, z, *args, **kwargs))
        else:
            labels = kwargs.pop('labels', ['_line' + str(k) for k in range(1, nlines + 1)])
            cmap = getattr(cm, kwargs.pop('cmap', default_matplotlib_cmap_cycle))
            for k in range(nlines):
                x = X[:, k]
                y = Y[~np.isnan(x), k]
                x = x[~np.isnan(x)]
                kwargs['label'] = labels[k]
                kwargs['color'] = cmap(k / max([1, float(nlines - 1)]))
                ax.plot(x, y, *args, **kwargs)
                LC.append(ax.lines[-1])
    finally:
        matplotlib.rcParams['axes.prop_cycle'] = tmp_prop_cycle
        pyplot.draw()
        pyplot.ion()
    return LC 
[docs]@_available_to_user_plot
def title_inside(string, x=0.5, y=0.9, ax=None, **kwargs):
    r"""
    Write the title of a figure inside the figure axis rather than outside
    :param string: title string
    :param x: x location of the title string (default 0.5, that is centered)
    :param y: y location of the title string (default 0.875)
    :param ax: axes to operate on
    :param \**kwargs: additional keywords passed to pyplot.title
    :return: pyplot.title text object
    """
    kwargs.setdefault('va', 'top')
    if ax is None:
        return pyplot.title(string, x=x, y=y, **kwargs)
    else:
        return ax.set_title(string, x=x, y=y, **kwargs) 
[docs]@_available_to_user_plot
def increase_resolution(*args, **kwargs):
    """
    This function takes 1 (Z) or 3 (X,Y,Z) 2D tables and interpolates them to higher resolution by bivariate spline interpolation.
    If 1 (3) table(s) is(are) provided, then the second(fourth) argument is the resolution increase,
    which can be a positive or negative integer integer: res=res0*2^n
    or a float which sets the grid size in the units provided by the X and Y tables
    """
    oneInput = False
    args = list(args)
    if len(args) in [1, 2]:
        args.insert(1, np.arange(args[0].shape[1]))
        args.insert(2, np.arange(args[0].shape[0]))
        oneInput = True
    Qin = args[0]
    Rin = args[1]
    Zin = args[2]
    if len(Rin.shape) == 2:
        Rin = Rin[0, :]
    if len(Zin.shape) == 2:
        Zin = Zin[:, 0]
    if len(args) == 4:
        kwargs['resolution'] = args[3]
    resolution = kwargs['resolution']
    quiet = kwargs.get('quiet', True)
    if resolution == 0:
        Q = Qin
        R = Rin
        Z = Zin
    elif is_int(resolution):
        if resolution > 0:
            if not quiet:
                printi('Increasing tables resolution by factor of ' + str(abs(resolution) + 1) + ' ...')
            nr = len(Rin)
            nz = len(Zin)
            for k in range(resolution):
                nr = nr + nr - 1
                nz = nz + nz - 1
            R = np.linspace(min(Rin), max(Rin), nr)
            Z = np.linspace(min(Zin), max(Zin), nz)
            Q = RectBivariateSplineNaN(Zin, Rin, Qin)(Z, R)
        elif resolution < 0:
            if not quiet:
                printi('Decreasing tables resolution by factor of ' + str(abs(resolution) + 1) + ' ...')
            R = Rin[::-resolution]
            Z = Zin[::-resolution]
            Q = Qin[::-resolution, ::-resolution]
    elif is_float(resolution):
        if not quiet:
            printi('Interpolating tables to ' + str(resolution) + ' m resolution ...')
        R = np.linspace(min(Rin), max(Rin), int(np.ceil((max(Rin) - min(Rin)) / resolution)))
        Z = np.linspace(min(Zin), max(Zin), int(np.ceil((max(Zin) - min(Zin)) / resolution)))
        Q = RectBivariateSplineNaN(Zin, Rin, Qin)(Z, R)
    if oneInput:
        return Q
    else:
        return Q, R, Z 
[docs]@_available_to_user_plot
class infoScatter(object):
    r"""
    improved version of: http://wiki.scipy.org/Cookbook/Matplotlib/Interactive_Plotting
    Callback for matplotlib to display an annotation when points are clicked on
    :param x: x of the annotations
    :param y: y of the annotations
    :param annotes: list of string annotations
    :param axis: axis on which to operate on (default to current axis)
    :param tol: vicinity in pixels where to look for annotations
    :param func: function to call with signature: func(x,y,annote,visible,axis)
    :param all_on: Make all of the text visible to begin with
    :param suppress_canvas_draw: Do not actively draw the canvas if all_on is True, makes plotting faster there are many subplots
    :param \**kw: extra keywords passed to matplotlib text class
    """
    def __init__(self, x, y, annotes, axis=None, tol=5, func=None, all_on=False, suppress_canvas_draw=False, **kw):
        self.data = list(zip(x, y, annotes))
        self.tol = tol
        if axis is None:
            self.axis = pylab.gca()
        else:
            self.axis = axis
        self.drawnAnnotations = {}
        self.func = func
        self.kw = kw
        self.kw.setdefault('clip_on', True)
        if all_on:
            for cur_x, cur_y, cur_annote in zip(x, y, annotes):
                self.drawAnnote(self.axis, cur_x, cur_y, cur_annote, redraw_canvas=False)
            if not suppress_canvas_draw:
                self.axis.figure.canvas.draw()
        self.axis.figure.canvas.mpl_connect('button_press_event', self)
    def __call__(self, event):
        if event.inaxes:
            clickX = event.x
            clickY = event.y
            if self.axis is None or self.axis == event.inaxes:
                winner = [None, None]
                for k, (xd, yd, a) in enumerate(self.data):
                    x, y = self.axis.transData.transform((xd, yd))
                    d = np.sqrt((x - clickX) ** 2 + (y - clickY) ** 2)
                    if d < self.tol:
                        if winner[0] is None or d < winner[0]:
                            winner = [d, k]
                if winner[0] is not None:
                    self.drawAnnote(event.inaxes, self.data[winner[1]][0], self.data[winner[1]][1], self.data[winner[1]][2])
[docs]    def drawAnnote(self, axis, x, y, annote, redraw_canvas=True):
        """
        Draw the annotation on the plot
        """
        if (x, y) in self.drawnAnnotations:
            markers = self.drawnAnnotations[(x, y)]
            for m in markers:
                m.set_visible(not m.get_visible())
            if redraw_canvas:
                self.axis.figure.canvas.draw()
        else:
            printi(annote)
            t = axis.text(x, y, annote, **self.kw)
            m = axis.scatter([x], [y], marker='+', c='r', zorder=100)
            self.drawnAnnotations[(x, y)] = (t, m)
            if redraw_canvas:
                self.axis.figure.canvas.draw()
        if self.func is not None:
            self.func(x, y, annote, m.get_visible(), axis) 
[docs]    def drawSpecificAnnote(self, annote):
        annotesToDraw = [(x, y, a) for x, y, a in self.data if a == annote]
        for x, y, a in annotesToDraw:
            self.drawAnnote(self.axis, x, y, a)  
[docs]@_available_to_user_plot
def infoPoint(fig=None):
    """
    print x,y coordinates where the user clicks
    :param fig: matplotlib figure
    """
    from matplotlib import pyplot
    if fig is None:
        fig = pyplot.gcf()
    def onclick(event):
        print(event.xdata, event.ydata)
    fig.canvas.mpl_connect('button_press_event', onclick) 
[docs]@_available_to_user_plot
def XKCDify(
    ax,
    mag=1.0,
    f1=50,
    f2=0.01,
    f3=15,
    bgcolor='w',
    xaxis_loc=None,
    yaxis_loc=None,
    xaxis_arrow='+',
    yaxis_arrow='+',
    ax_extend=0.1,
    expand_axes=False,
    ylabel_rot=78,
):
    """
    XKCD plot generator by, Jake Vanderplas; Modified by Sterling Smith
    This is a script that will take any matplotlib line diagram, and convert it
    to an XKCD-style plot.  It will work for plots with line & text elements,
    including axes labels and titles (but not axes tick labels).
    The idea for this comes from work by Damon McDougall
      http://www.mail-archive.com/matplotlib-users@lists.sourceforge.net/msg25499.html
    This adjusts all lines, text, legends, and axes in the figure to look
    like xkcd plots.  Other plot elements are not modified.
    :param ax: Axes instance
        the axes to be modified.
    :param mag: float
        the magnitude of the distortion
    :param f1, f2, f3: int, float, int
        filtering parameters.  f1 gives the size of the window, f2 gives
        the high-frequency cutoff, f3 gives the size of the filter
    :param xaxis_loc, yaxis_log: float
        The locations to draw the x and y axes.  If not specified, they
        will be drawn from the bottom left of the plot
    :param xaxis_arrow: str
        where to draw arrows on the x axes.  Options are '+', '-', '+-', or ''
    :param yaxis_arrow: str
        where to draw arrows on the y axes.  Options are '+', '-', '+-', or ''
    :param ax_extend: float
        How far (fractionally) to extend the drawn axes beyond the original
        axes limits
    :param expand_axes: bool
        if True, then expand axes to fill the figure (useful if there is only
        a single axes in the figure)
    :param ylabel_rot: float
        number of degrees to rotate the y axis label
    """
    import pylab as pl
    import matplotlib.font_manager as fm
    fontfn = os.sep.join([os.path.dirname(__file__), '..', 'extras', 'graphics', 'fonts', 'Humor-Sans.ttf'])
    def xkcd_line(x, y, xlim=None, ylim=None, mag=1.0, f1=30, f2=0.05, f3=15):
        """
        Mimic a hand-drawn line from (x, y) data
        Parameters
        ----------
        x, y : array_like
            arrays to be modified
        xlim, ylim : data range
            the assumed plot range for the modification.  If not specified,
            they will be guessed from the  data
        mag : float
            magnitude of distortions
        f1, f2, f3 : int, float, int
            filtering parameters.  f1 gives the size of the window, f2 gives
            the high-frequency cutoff, f3 gives the size of the filter
        Returns
        -------
        x, y : np.ndarrays
            The modified lines
        """
        from scipy import interpolate, signal
        x = np.asarray(x)
        y = np.asarray(y)
        # get limits for rescaling
        if xlim is None:
            xlim = (x.min(), x.max())
        if ylim is None:
            ylim = (y.min(), y.max())
        if xlim[1] == xlim[0]:
            xlim = ylim
        if ylim[1] == ylim[0]:
            ylim = xlim
        # scale the data
        x_scaled = (x - xlim[0]) * 1.0 / (xlim[1] - xlim[0])
        y_scaled = (y - ylim[0]) * 1.0 / (ylim[1] - ylim[0])
        # compute the total distance along the path
        dx = x_scaled[1:] - x_scaled[:-1]
        dy = y_scaled[1:] - y_scaled[:-1]
        dist_tot = np.sum(np.sqrt(dx * dx + dy * dy))
        # number of interpolated points is proportional to the distance
        Nu = int(200 * dist_tot)
        u = np.arange(-1, Nu + 1) * 1.0 / (Nu - 1)
        # interpolate curve at sampled points
        k = min(3, len(x) - 1)
        res = interpolate.splprep([x_scaled, y_scaled], s=0, k=k)
        x_int, y_int = interpolate.splev(u, res[0])
        # we'll perturb perpendicular to the drawn line
        dx = x_int[2:] - x_int[:-2]
        dy = y_int[2:] - y_int[:-2]
        dist = np.sqrt(dx * dx + dy * dy)
        # create a filtered perturbation
        coeffs = mag * np.random.normal(0, 0.01, len(x_int) - 2)
        b = signal.firwin(f1, f2 * dist_tot, window=('kaiser', f3))
        response = signal.lfilter(b, 1, coeffs)
        x_int[1:-1] += response * dy / dist
        y_int[1:-1] += response * dx / dist
        # un-scale data
        x_int = x_int[1:-1] * (xlim[1] - xlim[0]) + xlim[0]
        y_int = y_int[1:-1] * (ylim[1] - ylim[0]) + ylim[0]
        return x_int, y_int
    # Get axes aspect
    ext = ax.get_window_extent().extents
    aspect = (ext[3] - ext[1]) / (ext[2] - ext[0])
    xlim = ax.get_xlim()
    ylim = ax.get_ylim()
    xspan = xlim[1] - xlim[0]
    yspan = ylim[1] - xlim[0]
    xax_lim = (xlim[0] - ax_extend * xspan, xlim[1] + ax_extend * xspan)
    yax_lim = (ylim[0] - ax_extend * yspan, ylim[1] + ax_extend * yspan)
    if xaxis_loc is None:
        xaxis_loc = ylim[0]
    if yaxis_loc is None:
        yaxis_loc = xlim[0]
    # Draw axes
    xaxis = pl.Line2D([xax_lim[0], xax_lim[1]], [xaxis_loc, xaxis_loc], linestyle='-', color='k')
    yaxis = pl.Line2D([yaxis_loc, yaxis_loc], [yax_lim[0], yax_lim[1]], linestyle='-', color='k')
    # ttk.Label axes3, 0.5, 'hello', fontsize=14)
    ax.text(xax_lim[1], xaxis_loc - 0.02 * yspan, ax.get_xlabel(), fontsize=14, ha='right', va='top', rotation=12)
    ax.text(yaxis_loc - 0.02 * xspan, yax_lim[1], ax.get_ylabel(), fontsize=14, ha='right', va='top', rotation=ylabel_rot)
    ax.set_xlabel('')
    ax.set_ylabel('')
    # Add title
    ax.text(0.5 * (xax_lim[1] + xax_lim[0]), yax_lim[1], ax.get_title(), ha='center', va='bottom', fontsize=16)
    ax.set_title('')
    Nlines = len(ax.lines)
    lines = [xaxis, yaxis] + [ax.lines.pop(0) for i in range(Nlines)]
    for line in lines:
        x, y = line.get_data()
        x_int, y_int = xkcd_line(x, y, xlim, ylim, mag, f1, f2, f3)
        # create foreground and background line
        lw = line.get_linewidth()
        line.set_linewidth(2 * lw)
        line.set_data(x_int, y_int)
        # don't add background line for axes
        if (line is not xaxis) and (line is not yaxis):
            line_bg = pl.Line2D(x_int, y_int, color=bgcolor, linewidth=8 * lw)
            ax.add_line(line_bg)
        ax.add_line(line)
    # Draw arrow-heads at the end of axes lines
    arr1 = 0.03 * np.array([-1, 0, -1])
    arr2 = 0.02 * np.array([-1, 0, 1])
    arr1[::2] += np.random.normal(0, 0.005, 2)
    arr2[::2] += np.random.normal(0, 0.005, 2)
    x, y = xaxis.get_data()
    if '+' in str(xaxis_arrow):
        ax.plot(x[-1] + arr1 * xspan * aspect, y[-1] + arr2 * yspan, color='k', lw=2)
    if '-' in str(xaxis_arrow):
        ax.plot(x[0] - arr1 * xspan * aspect, y[0] - arr2 * yspan, color='k', lw=2)
    x, y = yaxis.get_data()
    if '+' in str(yaxis_arrow):
        ax.plot(x[-1] + arr2 * xspan * aspect, y[-1] + arr1 * yspan, color='k', lw=2)
    if '-' in str(yaxis_arrow):
        ax.plot(x[0] - arr2 * xspan * aspect, y[0] - arr1 * yspan, color='k', lw=2)
    # Change all the fonts to humor-sans.
    prop = fm.FontProperties(fname=fontfn, size=16)
    for text in ax.texts:
        text.set_fontproperties(prop)
    # modify legend
    leg = ax.get_legend()
    if leg is not None:
        leg.set_frame_on(False)
        for child in leg.get_children():
            if isinstance(child, pl.Line2D):
                x, y = child.get_data()
                child.set_data(xkcd_line(x, y, mag=10, f1=100, f2=0.001))
                child.set_linewidth(2 * child.get_linewidth())
            if isinstance(child, pl.Text):
                child.set_fontproperties(prop)
    # Set the axis limits
    ax.set_xlim(xax_lim[0] - 0.1 * xspan, xax_lim[1] + 0.1 * xspan)
    ax.set_ylim(yax_lim[0] - 0.1 * yspan, yax_lim[1] + 0.1 * yspan)
    # adjust the axes
    ax.set_xticks([])
    ax.set_yticks([])
    if expand_axes:
        ax.figure.set_facecolor(bgcolor)
        ax.set_axis_off()
        ax.set_position([0, 0, 1, 1])
    return ax 
[docs]@_available_to_user_plot
def autoscale_y(ax, margin=0.1):
    """
    Rescales the y-axis based on the data that is visible given the current xlim of the axis.
    Created by eldond at 2017 Mar 23  20:26
    This function was taken from an answer by DanKickstein on stackoverflow.com
    http://stackoverflow.com/questions/29461608/matplotlib-fixing-x-axis-scale-and-autoscale-y-axis
    http://stackoverflow.com/users/1078391/danhickstein
    I don't think this function considers shaded bands such as would be used to display error bars. Increasing the
    margin may be a good idea when dealing with such plots.
    :param ax: a matplotlib axes object
    :param margin: The fraction of the total height of the y-data to pad the upper and lower ylims
    """
    printd('autoscale_y()...')
    def get_bottom_top(line2):
        xd = line2.get_xdata()
        yd = line2.get_ydata()
        lo, hi = ax.get_xlim()
        with warnings.catch_warnings(record=False) as w:
            # Ignore RuntimeWarnings for these calculations (could be caused by NaNs in xd or yd)
            warnings.filterwarnings("ignore", category=RuntimeWarning)
            if (len(xd) == 2) and (xd[0] == 0.0) and (xd[1] == 1.0):
                # Special case to handle axhline - added by D. Eldon
                y_displayed = yd
            else:
                y_displayed = np.array(yd)[((xd >= lo) & (xd <= hi))]
        if len(y_displayed) > 1:
            h = np.nanmax(y_displayed) - np.nanmin(y_displayed)
            bot2 = np.nanmin(y_displayed) - margin * h
            top2 = np.nanmax(y_displayed) + margin * h
            return bot2, top2
        else:
            printd('Could not find any elements of y_displayed in range')
            return np.inf, -np.inf
    lines = ax.get_lines()
    bot, top = np.inf, -np.inf
    for j, line in enumerate(lines):
        printd(' line', j)
        new_bot, new_top = get_bottom_top(line)
        printd('   new_bot = {:}, new_top = {:}'.format(new_bot, new_top))
        if new_bot < bot:
            bot = new_bot
        if new_top > top:
            top = new_top
    if (bot < top) and (bot > (-np.inf)) and (top < np.inf):
        printd('Set ylim to ', bot, top)
        ax.set_ylim(bot, top)
    else:
        printd('Did not set ylim because one of these was False:')
        printd('   bot > (-np.inf) = {:}, bot = {:}'.format(bot > (-np.inf), bot))
        printd('   top < np.inf = {:}, top = {:}'.format(top < np.inf, top))
        printd('   bot < top = {:}'.format(bot < top))
    printd('Done with autoscale_y().')
    return 
[docs]@_available_to_user_plot
def set_linearray(lines, values=None, cmap=default_matplotlib_cmap_cycle, vmin=None, vmax=None):
    """
    Set colors of lines to colormapping of values.
    Other good sequential colormaps are YlOrBr and autumn.
    A good diverging colormap is bwr.
    :param lines:  Lines to set colors.
    :type lines: list
    :param values: Values corresponding to each line. Default is indexing.
    :type values: array like
    :param cmap: Valid matplotlib colormap name.
    :type cmap: str
    :param vmax: Upper bound of colormapping.
    :type vmax: float
    :param vmin: Lower bound of colormapping.
    :type vmin: float
    :return: ScalarMappable. A mapping object used for colorbars.
    """
    from matplotlib.cbook import flatten
    if values is None:
        values = list(range(len(lines)))
    nm = matplotlib.colors.Normalize(vmin, vmax)
    sm = matplotlib.cm.ScalarMappable(cmap=cmap, norm=nm)
    sm.set_array(values)
    colors = sm.cmap(sm.norm(values))
    for l, c in zip(lines, colors):
        if isinstance(l, matplotlib.container.ErrorbarContainer):
            for li in flatten(l.lines):
                li.set_color(c)
        else:
            l.set_color(c)
    return sm 
[docs]@_available_to_user_plot
def pi_multiple(x, pos=None):
    """
    Provides a string representation of x that is a multiple of the
    fraction pi/'denominator'.
    See multiple_formatter documentation for more info.
    """
    func = multiple_formatter()
    return func(x, pos) 
[docs]@_available_to_user_plot
def convert_ticks_to_pi_multiple(axis=None, major=2, minor=4):
    """
    Given an axis object, force its ticks to be at multiples of pi, with the
    labels formatted nicely [...,-2pi,-pi,0,pi,2pi,...]
    :param axis: An axis object, such as pyplot.gca().xaxis
    :param major: int
        Denominator of pi for major tick marks. 2: major ticks at 0, pi/2., pi, ...
        Can't be greater than 24.
    :param minor: int
        Denominator of pi for minor tick marks. 4: minor ticks at 0, pi/4., pi/2., ...
    :return: None
    """
    axis = pyplot.gca().xaxis if axis is None else axis
    major = min([major, 24])
    axis.set_major_locator(MultipleLocator(np.pi / major))
    axis.set_minor_locator(MultipleLocator(np.pi / minor))
    axis.set_major_formatter(FuncFormatter(pi_multiple)) 
[docs]@_available_to_user_plot
def is_colorbar(ax):
    """
    Guesses whether a set of Axes is home to a colorbar
    https://stackoverflow.com/a/53568035/6605826
    :param ax: Axes instance
    :return: bool
        True if the x xor y axis satisfies all of the following and thus looks like it's probably a colorbar:
        No ticks, no tick labels, no axis label, and range is (0, 1)
    """
    xcb = (len(ax.get_xticks()) == 0) and (len(ax.get_xticklabels()) == 0) and (len(ax.get_xlabel()) == 0) and (ax.get_xlim() == (0, 1))
    ycb = (len(ax.get_yticks()) == 0) and (len(ax.get_yticklabels()) == 0) and (len(ax.get_ylabel()) == 0) and (ax.get_ylim() == (0, 1))
    return xcb != ycb  # != is effectively xor in this case, since xcb and ycb are both bool 
    # xor trick from https://stackoverflow.com/a/433161/6605826
[docs]@_available_to_user_plot
def tag_plots_abc(
    fig=None,
    axes=None,
    corner=[1, 1],
    font_size=matplotlib.rcParams['xtick.labelsize'],
    skip_suspected_colorbars=True,
    start_at=0,
    **annotate_kw,
):
    """
    Tag plots with (a), (b), (c), ...
    :param fig: Specify a figure instance instead of letting the function pick the most recent one
    :param axes: Specify a plot axes instance or list/array of plot axes instances instead of letting the function use
        fig.get_axes()
    :param corner: Which corner does the tag go in? [0, 0] for bottom left, [1, 0] for bottom right, etc.
    :param font_size: Font size of the annotation.
    :param skip_suspected_colorbars: bool
        Try to detect axes which are home to colorbars and skip tagging them. An Axes instance is suspected of having a
        colorbar if either the xaxis or yaxis satisfies all of these conditions:
        - Length of tick list is 0
        - Length of tick label list is 0
        - Length of axis label is 0
        - Axis range is (0,1)
    :param start_at: int
        Offset value for skipping some numbers. Useful if you aren't doing real subfigs, but two separate plots and
        placing them next to each other in a publication. Set to 1 to start at (b) instead of (a), for example.
    :param annotate_kw: dict
        Additional keywords passed to annotate(). Keywords used by settings such as corner, etc. will be overriden.
    """
    from matplotlib.cbook import flatten
    if fig is None:
        fig = pyplot.gcf()
    if axes is None:
        axes = fig.get_axes()
    letters = string.ascii_lowercase + string.ascii_uppercase
    tag_x = abs(corner[0] - 0.005)
    tag_y = abs(corner[1] - 0.025)
    tag_ha = ['left', 'right'][corner[0]]
    tag_va = ['bottom', 'top'][corner[1]]
    printd('Plot tag information: ', tag_x, tag_y, tag_ha, tag_va)
    valid_axes = [ax for ax in flatten(np.atleast_1d(axes)) if not (skip_suspected_colorbars and is_colorbar(ax))]
    # Assign all keywords into annotate_kw to prevent duplicates
    annotate_kw['xycoords'] = 'axes fraction'
    annotate_kw['ha'] = tag_ha
    annotate_kw['va'] = tag_va
    annotate_kw['fontsize'] = font_size
    annotate_kw['xy'] = (tag_x, tag_y)
    j = 0 + start_at
    for i, ax in enumerate(valid_axes):
        if ax.axison:
            # Thanks to Suever for pointing out ax.axison: http://stackoverflow.com/a/42660787/6605826
            ax.annotate(' ({:}) '.format(letters[j]), **annotate_kw).draggable(True)
            j += 1
    return 
[docs]@_available_to_user_plot
def mark_as_interactive(ax, interactive=True):
    """
    Mark an axis as interactive or not
    :param ax: axis
    :param interactive: boolean
    :return: axis
    """
    ax.is_interactive = interactive
    return ax 
[docs]@_available_to_user_plot
class View1d(object):
    """
    Plot 2D or 3D data as line-plots with interactive navigation through the
    alternate dimensions. Navigation uses the 4 arrow keys to traverse up to 2 alternate
    dimensions.
    The data must be on a regular grid, and is formed into a xarray DataArray if not already.
    Uses matplotlib line plot for float/int data, OMFIT uerrrorbar for uncertainty variables.
    Examples:
    The view1d can be used to interactively explore data. For usual arrays it draws line slices.
    >> t = np.arange(20)
    >> s = np.linspace(0,2*np.pi,60)
    >> y = np.sin(np.atleast_2d(s).T+np.atleast_2d(t))
    >> da = xarray.DataArray(y,coords=SortedDict([('space',s),('time',t)]),name='sine')
    >> v = View1d(da.transpose('time','space'),dim='space',time=10)
    For uncertainties arrays, it draws errorbars using the uerrorbar function. Multiple views
    with the same dimensions can be linked for increased speed (eliminate redundant calls to redraw).
    >> y_u = unumpy.uarray(y+(random(y.shape)-0.5),random(y.shape))
    >> da_u = xarray.DataArray(y_u,coords=SortedDict([('space',s),('time',t)]),name='measured')
    >> v_u = View1d(da_u,dim='space',time=10,axes=pyplot.gca())
    >> v.link(v_u) # v will remain connected to keypress events and drive vu
    Variable dependent axis data can be viewed if x and y share a regular grid in some coordinates,
    >> x = np.array([s+(random(s.shape)-0.5)*0.2 for i in t]).T
    >> da_x = xarray.DataArray(x,coords=SortedDict([('space',s),('time',t)]),name='varspace')
    >> ds = da_u.to_dataset().merge(da_x.to_dataset())
    >> v_x = View1d(ds,name='measured',dim='varspace',time=10,axes=pyplot.gca())
    >> v.link(v_x)
    """
    def __init__(
        self,
        data,
        coords=None,
        dims=None,
        name=None,
        dim=None,
        axes=None,
        dynamic_ylim=False,
        use_uband=False,
        cornernote_options=None,
        plot_options=None,
        **indexers,
    ):
        r"""
        :param data: DataArray or array-like
            2D or 3D data values to be viewed.
        :param coords: dict-like
            Dictionary of Coordinate objects that label values along each dimension.
        :param dims: tuple
            Dimension names associated with this array.
        :param name: string
            Label used in legend. Empty or beginning with '_' produces no legend label.
            If the data is a DataArray it will be renamed before plotting.
            If the data is a Dataset, the name specifies which of its existing data_vars to plot.
        :param dim: string, DataArray
            Dimension plotted on x-axis. If DataArray, must have same dims as data.
        :param axes: Axes instance
            The axes plotting is done in.
        :param dynamic_ylim: bool
            Re-scale y limits of axes when new slices are plotted.
        :param use_uband: bool
            Use uband instead of uerrorbar to plot uncertainties variables.
        :param cornernote_options: dict
            Key word arguments passed to cornernote (such as root, shot, device). If this is present, then cornernote
            will be updated with the new time if there is only one time to show, or the time will be erased from the
            cornernote if more than one time is shown by this View1d instance (such as by freezing one slice).
        :param plot_options: dict
            Key word arguments passed to plot/uerrorbar/uband.
        :param \**indexers: dict
            Dictionary with keys given by dimension names and values given by
            arrays of coordinate index values. Must include all dimensions other,
            than the fundamental.
        """
        # check args and make Dataset if needed
        if isinstance(data, DataArray):
            if coords or dims:
                raise ValueError("View cannot re-assign coords/dims to data that is already a DataArray")
            if name:
                data.name = name
            elif not data.name:  # require name
                name = 'viewdata'
                data.name = name
            else:
                name = data.name
            data = data.to_dataset()
        elif isinstance(data, Dataset):
            if coords or dims:
                raise ValueError("View cannot re-assign coords/dims to data that is already a Dataset")
        else:
            if name is None:
                name = 'viewdata'
            data = DataArray(data, coords=coords, dims=dims, name=name).to_dataset()
        # check key words
        self.dynamic_ylim = dynamic_ylim
        self.cornernote_options = cornernote_options
        if plot_options:
            self.plot_options = plot_options
        else:
            self.plot_options = {}
        # Make figure & axes if needed
        if not axes:
            axes = pyplot.gca()
        self.figure, self.axes = axes.get_figure(), axes
        # make view aware of full data, the slice data, and slice indexes
        self.links = []
        self.is_uncertain = is_uncertain(data[name].values)
        self.data = data
        self.name = name
        if dim:
            if isinstance(dim, DataArray):
                if not dim.name:
                    dim.name = 'dim'
                data = data.merge(dim.to_dataset())
                dim = dim.name
            self.dim = dim
            # reduce any large datasets just to the things we need
            self.data = Dataset(dict(((self.name, self.data[self.name]), (dim, self.data[dim]))))
            # put dimension at front of indexing
            if dim in self.data.dims:
                dim0 = dim
            elif dim in self.data.data_vars:
                # it will return the first dimension which is not in indexers
                dim0 = [d for d in self.data.dims if d not in indexers][0]
            else:
                raise ValueError("Argument `dim` must be in DataArray dims or in Dataset")
            dims = list(self.data.dims.keys())
            dims.insert(0, dims.pop(dims.index(dim0)))
            self.data = self.data.transpose(*dims)
        else:
            self.dim = list(self.data.dims.keys())[0]
        self.values = self.data.isel(**indexers)
        self.indexes = SortedDict(indexers)
        for key, val in list(self.indexes.items()):
            self.indexes[key] = np.atleast_1d(val)
        # use metadata for labeling
        self.axes.set_xlabel(dim)
        self.label_prefix = self.plot_options.pop('label', '')
        self._label_indexers_visible = True
        lbl = self._get_label()
        # actual plotting
        self.use_uband(use_uband)
        self._inactive = []
        if self.is_uncertain:
            self._active = self._uplot(self.values[self.dim], self.values[self.name], ax=self.axes, label=lbl, **self.plot_options)
        else:
            self._active = self.axes.plot(self.values[self.dim], self.values[self.name], label=lbl, **self.plot_options)
        # Update annotations
        if self.cornernote_options is not None:
            self._update_cornernote()
        legend = self.axes.legend(loc=0)
        legend.draggable(True)
        self.figure.canvas.draw_idle()
        # store all views within the axes for safe keeping (persistence)
        if not hasattr(self.axes, 'views'):
            self.axes.views = []
        self.axes.views.append(self)
        self.axes.is_interactive = True
        # connect up the key/link driven navigation
        self._accelerate = False
        self.cid = self.figure.canvas.mpl_connect('key_press_event', self.key_command)
        if not hasattr(self.figure, 'shortcuts'):
            self.figure.shortcuts = SortedDict()
        self.figure.shortcuts['left'] = 'Reduce index of first non-axial dimension'
        self.figure.shortcuts['right'] = 'Increase index of first non-axial dimension'
        self.figure.shortcuts['up'] = 'Reduce index of second non-axial dimension'
        self.figure.shortcuts['down'] = 'Increase index of second non-axial dimension'
        self.figure.shortcuts['shift+up/down/left/right'] = 'Accelerate: Change index by one fifth the range'
        self.figure.shortcuts['a'] = 'Toggle accelerate all on/off'
        self.figure.shortcuts['w'] = 'Persist current slice of viewed data'
        self.figure.shortcuts['e'] = 'Erase all persisting slices of viewed data except the active one'
[docs]    def use_uband(self, use=True):
        """Toggle use of uband instead of uerrorbar for plotting function"""
        self._use_uband = use
        if use:
            self._uplot = uband
        else:
            self._uplot = uerrorbar 
[docs]    def set_label_indexers_visible(self, visible=True):
        """Include the current indexers in line labels."""
        self._label_indexers_visible = visible
        # update plot
        self.isel(draw=False, **self.indexes) 
    def _get_label(self):
        """(Re-)Set the view Line's label indicating the current indexes slicing"""
        if self.label_prefix:
            lbl = self.label_prefix + ' ' + self.name
        else:
            lbl = self.name
        if not self.name:
            return lbl
        for k, i in list(self.indexes.items()):
            i = np.unique(np.clip(i, 0, self.data[k].size - 1))
            if self._label_indexers_visible and len(i) < self.data[k].shape[0]:  # don't label if not sliced at all
                if len(i) == 1:
                    lbl += ', {key} = {vals} {units}'.format(
                        name=self.name, key=k, vals=self.data[k][i].values[0], units=self.data[k][i].attrs.get('units', '')
                    )
                else:
                    lbl += ', {key} = {vals}'.format(name=self.name, key=k, vals=self.data[k][i].values)
        return lbl.rstrip(', ').lstrip(', ')
[docs]    def key_command(self, event, draw=True, **plot_options):
        """
        Use arrows to navigate up to 2 extra dimensions by incrementing
        the slice indexes.
        Use w/e to write/erase slices to persist outside on navigation.
        """
        # prevent infinite loops
        if not hasattr(self, '_lastevent'):
            self._lastevent = None
        if self._lastevent == event:
            return
        self._lastevent = event
        # drivers can override plot_options
        self.plot_options.update(plot_options)
        # get standard key press name
        if not event.key:
            return
        key_press = event.key.lower()
        # had to add this because mac Tk doesn't register shift+key
        if key_press == 'a':
            self._accelerate = not self._accelerate
        # slice navigation
        if key_press in ['left', 'right', 'down', 'up', 'shift+left', 'shift+right', 'shift+up', 'shift+down']:
            # select dimension being changed
            if key_press in ['left', 'right', 'shift+left', 'shift+right']:
                dim = list(self.indexes.keys())[0]
            if key_press in ['down', 'up', 'shift+up', 'shift+down']:
                i = min(len(self.indexes) - 1, 1)
                dim = list(self.indexes.keys())[i]
            # apply change in right direction
            if key_press in ['left', 'down']:
                self.indexes[dim] -= 1 + self._accelerate * (int(self.data[dim].size // 5) - 1)
            if key_press in ['shift+left', 'shift+down']:
                self.indexes[dim] -= int(self.data[dim].size // 5)
            if key_press in ['right', 'up']:
                self.indexes[dim] += 1 + self._accelerate * (int(self.data[dim].size // 5) - 1)
            if key_press in ['shift+right', 'shift+up']:
                self.indexes[dim] += int(self.data[dim].size // 5)
            # prevent stepping out of the dimension range
            self.indexes[dim] = np.clip(self.indexes[dim], 0, self.data.dims[dim] - 1)
            # update plot
            self.isel(draw=False, **self.indexes)
        # snapshot write / erase
        if key_press == 'w':
            self._add_lines()
            if isinstance(self._inactive[-1][0], matplotlib.container.ErrorbarContainer):
                plot_options['color'] = self._inactive[-1][0].lines[0].get_color()
            else:
                plot_options['color'] = self._inactive[-1][0].get_color()
        if key_press == 'e':
            self._clear_lines()
        # drive linked views, then draw once
        for link in self.links:
            link.key_command(event, draw=False, **plot_options)
        if draw:
            self.figure.canvas.draw_idle()
            # if linked to views in other figures draw those too
            otherfigs = unsorted_unique([self.figure] + [v.figure for v in self.links])[1:]
            for f in otherfigs:
                f.canvas.draw_idle()
        return 
    def _add_lines(self):
        """Plot a persistent snapshot of the current slice"""
        lbl = self._get_label()
        if self.is_uncertain:
            ys = self.values[self.name].values.reshape(self.values[self.name].shape[0], -1)
            xs = self.values[self.dim].values.reshape(self.values[self.dim].shape[0], -1)
            if xs.shape[1] == 1:
                xs = np.tile(xs, ys.shape[1])
            for x, y in zip(xs.T, ys.T):
                if self._use_uband:
                    self._inactive.append(uband(x, y, ax=self.axes, label=lbl, **self.plot_options))
                else:
                    self._inactive.append(uerrorbar(x, y, ax=self.axes, label=lbl, **self.plot_options))
        else:
            self._inactive.append(self.axes.plot(self.values[self.dim], self.values[self.name], label=lbl, **self.plot_options))
        # Update annotations
        if self.cornernote_options is not None:
            self._update_cornernote()
    def _clear_lines(self):
        """Remove otherwise persistent snapshots"""
        from matplotlib.cbook import flatten
        for lines in self._inactive:
            if lines in self.axes.containers:  # bug in matplotlib ErrorbarContainer.remove
                for l in flatten(lines.lines):
                    l.remove()
                self.axes.containers.remove(lines)
            else:
                for l in lines:
                    l.remove()
        self._inactive = []
        self.isel(**self.indexes)  # redraw current slice
[docs]    def isel(self, draw=True, **indexers):
        """
        Re-slice the data along its extra dimensions using indexes.
        """
        from matplotlib.cbook import flatten
        self.indexes = SortedDict(indexers)
        # only plot valid indexers (but allow keeping track of others)
        for dim in indexers:
            indexers[dim] = np.unique(np.clip(self.indexes[dim], 0, self.data[dim].size - 1))
        self.values = self.data.isel(**indexers).transpose(*self.data[self.name].dims)  # make sure dim order stays consistent
        lbl = self._get_label()
        # update errorbar container lines
        if self.is_uncertain:
            # can't animate PollyCollection - have to delete and draw new
            if self._use_uband:
                colors = []
                for obj in self._active:
                    colors.append(obj.get_color())
                    if isinstance(obj, Uband):
                        if obj.line in self.axes.lines:
                            self.axes.lines.remove(obj.line)
                        for band in obj.bands:
                            if band in self.axes.collections:
                                self.axes.collections.remove(band)
                    else:  # assume errorbar container
                        for l in flatten(obj.lines):
                            if l in self.axes.lines:
                                self.axes.lines.remove(l)
                            if l in self.axes.collections:
                                self.axes.collections.remove(l)
                self._active = []
                ys = self.values[self.name].values.reshape(self.values[self.name].shape[0], -1)
                xs = self.values[self.dim].values.reshape(self.values[self.dim].shape[0], -1)
                if xs.shape[0] == 1 and ys.shape[0] == 1:
                    xs = xs.T
                    ys = ys.T
                if xs.shape[1] == 1:
                    xs = np.tile(xs, ys.shape[1])
                lim = (np.nanmin(nominal_values(ys)), np.nanmax(nominal_values(ys)))
                match_colors = not 'color' in self.plot_options  # keep same color even if not explicitly set the first time
                for i, (xi, yi) in enumerate(zip(xs.T, ys.T)):
                    if match_colors and i < len(colors):
                        self.plot_options['color'] = colors[i]
                    self._active.append(self._uplot(xi, yi, ax=self.axes, label=lbl, **self.plot_options)[0])
                    if match_colors:
                        self.plot_options.pop('color', None)
            # painfully update only the individual pieces of a ErrorbarContainer for faster redraw
            else:
                ys = self.values[self.name].values.reshape(self.values[self.name].shape[0], -1)
                xs = self.values[self.dim].values.reshape(self.values[self.dim].shape[0], -1)
                if xs.shape[1] == 1:
                    xs = np.tile(xs, ys.shape[1])
                lim = (np.nanmin(nominal_values(ys)), np.nanmax(nominal_values(ys)))
                for i, (xi, yi) in enumerate(zip(xs.T, ys.T)):
                    container = self._active[i]
                    y, ye = nominal_values(yi), std_devs(yi)
                    x, xe = nominal_values(xi), std_devs(xi)
                    # matplotlib error bar does not do multiple lines at once, so safe to simply flatten
                    x, xe = x.flatten(), xe.flatten()
                    line, capinfo, barinfo = container.lines
                    if len(barinfo) == 2:  # X and Y errors
                        (xbars, ybars) = barinfo
                        if len(capinfo):  # With end-caps
                            xtop, xbot, ytop, ybot = capinfo
                        else:  # Without end-caps
                            xtop, xbot, ytop, ybot = None, None, None, None
                    else:  # Y errors only
                        (ybars,) = barinfo[0]
                        if len(capinfo):  # With end-caps
                            ytop, ybot = capinfo
                        else:  # Without end-caps
                            ytop, ybot = None, None
                        xtop, xbot = None, None
                    yt, yb = y + ye, y - ye
                    xt, xb = x + xe, x - xe
                    line.set_data([x, y])
                    if ybot:
                        ytop.set_data([x, yt])
                        ybot.set_data([x, yb])
                    new_segments_y = [np.array([[xn, yts], [xn, ybs]]) for xn, yts, ybs in zip(x, yt, yb)]
                    ybars.set_segments(new_segments_y)
                    if xbot:
                        xtop.set_data([xt, y])
                        xbot.set_data([xb, y])
                        new_segments_x = [np.array([[xts, yn], [xbs, yn]]) for xts, xbs, yn in zip(xt, xb, y)]
                        xbars.set_segments(new_segments_x)
                    container.set_label(lbl)
        # update simple lines
        else:
            lim = (np.nanmin(self.values[self.name]), np.nanmax(self.values[self.name]))
            xlen = len(self._active[0].get_xdata())
            ys = self.values[self.name].values.reshape(xlen, -1).T
            if self.dim in self.data.dims:  # usual case where x is a 1D dimension
                xs = [self.values[self.dim]] * ys.shape[0]
            else:  # special case where x is a dataset variable
                xs = self.values[self.dim].values.reshape(xlen, -1).T
            # update each line
            for line, x, y in zip(self._active, xs, ys):
                line.set_data([x, y])
                line.set_label(lbl)
            # remove lines that got crunched into upper/lower bounds
            for line in self._active[len(ys) :]:
                self.axes.lines.remove(line)
                self._active.remove(line)
        # make sure the new labels get into the legend
        if self.axes.get_legend():
            visible = self.axes.get_legend().get_visible()
            leg = self.axes.legend(loc=0)
            leg.draggable(True)
            leg.set_visible(visible)
        # expand axis limits if necessary
        if self.dynamic_ylim:
            ylim = self.axes.get_ylim()
            if ylim[0] > lim[0]:
                self.axes.set_ylim(ymin=lim[0])
            if ylim[1] < lim[1]:
                self.axes.set_ylim(ymax=lim[1])
        # Update annotations
        if self.cornernote_options is not None:
            self._update_cornernote()
        if draw:
            self.figure.canvas.draw_idle() 
    def _update_cornernote(self):
        """
        If indexing by time and if there is only one index, use it to look up the time for the current index and then
        update the cornernote.
        """
        if len(self.indexes) == 1 and 'time' in list(self.indexes.keys()):
            time_index = np.unique(np.clip(self.indexes['time'], 0, self.data['time'].size - 1))
            times = self.data.coords['time'].values[time_index]
            co = self.cornernote_options
            root = co.get('root', None)
            device = co.get('device', None)
            shot = co.get('shot', None)
            if len(time_index) == 1 and not self._inactive:
                cornernote(root=root, device=device, shot=shot, time=times[0], ax=self.axes)
            elif len(time_index) == 2 and not self._inactive:
                cornernote(root=root, device=device, shot=shot, time='{}, {}'.format(times[0], times[1]), ax=self.axes)
            else:
                cornernote(root=root, device=device, shot=shot, time='', ax=self.axes)
[docs]    def link(self, view, disconnect=True):
        """
        Link all actions in this view to the given view.
        :param view: View1d. A view that will be driven by this View's key press responses.
        :param disconnect: bool. Disconnect key press events from driven view.
        """
        self.links.append(view)
        # disconnect the key/link driven navigation
        if disconnect:
            view.figure.canvas.mpl_disconnect(view.cid)
        return 
[docs]    def unlink(self, view):
        """
        Unlink actions in this view from controlling the given view.
        """
        if view in self.links:
            self.links.remove(view)
            # reconnect the key driven navigation if no other link
            reconnect = True
            for a in self.figure.axes:
                if hasattr(a, 'views'):
                    for v in a.views:
                        if self in v.links:
                            reconnect = False
            if reconnect:
                view.cid = view.figure.canvas.mpl_connect('key_press_event', view.key_command)
        return  
[docs]@_available_to_user_plot
class View2d(object):
    """
    Plot 2D data with interactive slice viewers attached to the 2D Axes.
    Left clicking on the 2D plot refreshes the line plot slices, right clicking
    overplots new slices.
    The original design of this viewer was for data on a rectangular grid, for which
    x and y are 1D arrays defining the axes but may be irregularly spaced. In this case,
    the line plot points correspond to the given data. If x or y is a 2D array,
    the data is assumed irregular and interpolated to a regular grid using
    scipy.interpolate.griddata.
    **Example:**
    Explore a basic 2D np array without labels,
    >> x = np.linspace(-1, 1, 200)
    >> y = np.linspace(-2, 2, 200)
    >> xx, yy = meshgrid(x, y)
    >> z = np.exp(-xx**2 - yy**2)
    >> v = View2d(z)
    To add more meaningful labels to the axes and data do,
    >> v = View2d(z, coords={'x':x, 'y':y}, dims=('x', 'y'), name='wow')
    or use a DataArray,
    >> d = DataArray(z, coords={'x':x, 'y':y}, dims=('x', 'y'), name='wow')
    >> v = View2d(d)
    Note that the coordinates should be 1D. Initializing a view with regular grid,
    2D coordinates will result in an attempt to slice them appropriately.
    This is done for consistency with some matplotlib 2D plotting routines,
    but is not recommended.
    >> v = View2d(z, coords=dict(x=x, y=y), dims=('x', 'y'))
    If you have irregularly distributed 2D data, it is recomended that you first interpolate
    it to a 2D grid in whatever way is most applicable. If you do not, initializing a view
    will result in an attempt to linearly interpolate to a automatically chosen grid.
    >> x = np.random.rand(1000)
    >> y = np.random.rand(1000) * 2
    >> z = np.exp(-x**2 - y**2)
    >> v = View2d(z, coords=dict(x=x, y=y), dims=('x', 'y'))
    The same applies for 2D collections of irregular points and values.
    >> x = x.reshape((50, 20))
    >> y = y.reshape((50, 20))
    >> z = z.reshape((50, 20))
    >> v = View2d(z, coords=[('x', x), ('y', y)], dims=('x', 'y'))
    """
    def __init__(
        self,
        data,
        coords=None,
        dims=None,
        name=None,
        axes=None,
        quiet=False,
        use_uband=False,
        contour_levels=0,
        imag_options={},
        plot_options={'marker': '', 'ls': '-'},
        **indexers,
    ):
        r"""
        :param data: DataArray or array-like
            2D or 3D data values to be viewed.
        :param coords: dict-like
            Dictionary of Coordinate objects that label values along each dimension.
        :param dims: tuple
            Dimension names associated with this array.
        :param name: string
            Label used in legend. Empty or begining with '_' produces no legend label.
        :param dim: string, DataArray
            Dimension plotted on x-axis. If DataArray, must have same dims as data.
        :param axes: Axes instance
            The axes plotting is done in.
        :param quiet: bool
            Suppress printed messages.
        :param use_uband: bool
            Use uband for 1D slice plots instead of uerrorbar.
        :param contour_levels: int or np.ndarray
            Number of or specific levels used to draw black contour lines over the 2D image.
        :param imag_options: dict
            Key word arguments passed to the DataArray plot method (modified pcolormesh).
        :param plot_options: dict
            Key word arguments passed to plot or uerrorbar. Color will be determined by cmap variable.
        :param \**indexers: dict
            Dictionary with keys given by dimension names and values given by
            arrays of coordinate index values.
        """
        # check args and make Dataset if needed
        if isinstance(data, DataArray):
            if coords or dims:
                raise ValueError("View cannot re-assign coords/dims to data that is already a DataArray")
            if name:
                data.name = name
            elif not data.name:  # require name
                name = 'viewdata'
                data.name = name
            else:
                name = data.name
            data = data
        elif isinstance(data, Dataset):
            if coords or dims:
                raise ValueError("View cannot re-assign coords/dims to data that is already a Dataset")
            data = data[list(data.data_vars.keys())]
        else:
            if name is None:
                name = 'viewdata'
            try:
                data = DataArray(data, coords=coords, dims=dims, name=name)
            except Exception as error:
                # turn the flexible input styles into the standard coords dict, dims list
                if coords is not None:
                    coords = OrderedDict(coords)
                if isinstance(dims, dict):
                    x0, x1 = list(dims.values())[:2]
                    dims = list(dims.keys())[:2]
                elif dims is not None:
                    x0, x1 = [coords[k] for k in dims[:2]]
                else:
                    x0, x1 = list(coords.values())[:2]
                    dims = list(coords.keys())[:2]
                # check if called with 2D coordinates of a regular grid
                if (
                    (x0.shape == data.shape and x1.shape == data.shape)
                    and np.all(x0 == np.roll(x0, -1, axis=0))
                    and np.all(x1 == np.roll(x1, -1, axis=1))
                ):
                    coords_1d = {dims[0]: x0[0], dims[1]: x1[:, 0]}
                    data = DataArray(data, coords=coords_1d, dims=dims, name=name)
                # interpolate weird inputs to a nice 2D grid
                else:
                    if not quiet:
                        printe("WARNING: Could not directly form a DataArray from inputs.")
                        printe("  {:}".format(error))
                        printe("  > Interpolating to a regular grid.")
                    # make a regular grid
                    x0, x1 = np.ravel(x0), np.ravel(x1)
                    dx0, dx1 = np.diff(sorted(x0)), np.diff(sorted(x1))
                    n0 = int((np.nanmax(x0) - np.nanmin(x0)) / np.median(dx0[np.where(dx0 != 0)]))
                    n1 = int((np.nanmax(x1) - np.nanmin(x1)) / np.median(dx1[np.where(dx1 != 0)]))
                    i0 = np.linspace(np.nanmin(x0), np.nanmax(x0), min(500, n0))
                    i1 = np.linspace(np.nanmin(x1), np.nanmax(x1), min(500, n1))
                    data = interpolate.griddata((x0, x1), np.ravel(data), (i0[None, :], i1[:, None]), method='linear')
                    data = data.reshape((i1.shape[0], i0.shape[0])).T
                    data = DataArray(data, coords=[(k, x) for k, x in zip(dims, (i0, i1))], dims=dims, name=name)
        # pass on key words
        if not np.shape(contour_levels) and contour_levels == 0:
            contour_levels = None  # captures 0.0, 0, None, False, etc.
        self._contour_levels = contour_levels
        self._use_uband = use_uband
        self.plot_options = plot_options
        self.plot_cmap = plot_options.get('cmap', None)
        self.imag_options = imag_options
        # store basic inputs
        self.data = data
        # make alarm for timing
        self.alarm = None
        # empty list of links
        self.links = []
        # Make figure & main axes if needed
        if axes == None:
            fig = pyplot.figure(figsize=(10, 8))
            gs = matplotlib.gridspec.GridSpec(4, 4)
            gs.update(wspace=0.05, hspace=0.05)
            axm = fig.add_subplot(gs[1:, :-1])  # main 2D plot
        else:  # make a grid within the original axes location
            fig = axes.get_figure()
            axm = axes
            rect = axm.get_position().bounds
            gs = matplotlib.gridspec.GridSpec(4, 4, left=rect[0], bottom=rect[1], right=rect[0] + rect[2], top=rect[1] + rect[3])
            axm.set_position(gs[1:, :-1].get_position(fig))  # main 2D plot
        # add surrounding axes
        axx = fig.add_subplot(gs[0, :-1], sharex=axm)
        axy = fig.add_subplot(gs[1:, -1], sharey=axm)
        axr = fig.add_subplot(gs[0, -1])
        # add to class
        self.fig = fig
        self.axm = axm
        self.axx = axx
        self.axy = axy
        self.axr = axr
        self.old_xmin = 0.0
        self.old_xmax = 0.0
        self.old_ymin = 0.0
        self.old_ymax = 1.0
        # autoscaling behaves poorly
        for ax in [axm, axy, axx, axr]:
            ax.autoscale(False)
        # store within the axes for safe keeping (persistence)
        if not hasattr(self.axm, 'views'):
            self.axm.views = []
        self.axm.views.append(self)
        self.axm.is_interactive = True
        # 2D plot with colorbar
        axm.set_xlabel(data.dims[0])
        axm.set_ylabel(data.dims[1])
        self._set_values(self.data, **imag_options)
        axm.axis('tight')
        self.colorbar = self.fig.colorbar(self.image, ax=[axm, axy, axx, axr], use_gridspec=True, shrink=0.75, anchor=(0, 0.0))
        self.colorbar.set_label(self.data.name)
        if not self.plot_cmap:
            self.plot_cmap = self.axm.collections[-1].get_cmap()
        # 1D plots aesthetics
        axx.set_ylabel(data.name)
        # axx.set_ylim(np.nanmin(data),np.nanmax(data))
        loc = pyplot.MaxNLocator(5)
        axx.yaxis.set_major_locator(loc)
        axy.set_xlabel(data.name)
        # axy.set_xlim(np.nanmin(data),np.nanmax(data))
        loc = pyplot.MaxNLocator(5)
        axy.xaxis.set_major_locator(loc)
        pyplot.setp(axy.xaxis.get_majorticklabels(), rotation=90)
        for tl in axx.xaxis.get_ticklabels() + axy.yaxis.get_ticklabels():
            tl.set_visible(False)
        # connect axes to events
        self.RectangleSelector = RectangleSelector(
            self.axm,
            self.line_select_callback,
            drawtype='box',
            useblit=True,  # minspanx=0, minspany=0,
            button=[1, 3],  # don't use middle button
            spancoords='data',
        )
        self.fig.canvas.mpl_connect('button_press_event', self.line_select_callback)
        self.fig.canvas.mpl_connect('key_press_event', self.toggle_selector)
        # connect radio buttons
        self.radio = RadioButtons(
            axr,
            ('data', 'd/d-' + self.data.dims[0], 'd/d-' + self.data.dims[1], 'int-' + self.data.dims[0], 'int-' + self.data.dims[1]),
            activecolor='black',
        )
        self.radio.on_clicked(self._radio_select)
        self.radio.selected = 'data'
        self._log_toggled = False
        if hasattr(self.fig, 'shortcuts'):
            self.fig.shortcuts['j'] = 'Toggle log scaling of 2D image data.'
        else:
            printe('No shortcuts attribute so no log scaling')
        self.fig.canvas.draw_idle()
[docs]    def use_uband(self, use=True):
        """Toggle use of uband instead of uerrorbar in 1D slice plots"""
        self._use_uband = use 
[docs]    def key_navigate_cuts(self, key_press):
        if key_press in ['left', 'right']:
            dat = self.data[self.data.dims[0]]
            xmin, xmax, std = self.get_vslice_args()
        else:
            dat = self.data[self.data.dims[1]]
            xmin, xmax, std = self.get_hslice_args()
        imin = np.abs(dat.values - xmin).argmin()
        imax = np.abs(dat.values - xmax).argmin()
        # imax = max(imax,min(imin,len(dat)-2))
        if key_press in ['right', 'up']:
            imax_new = imax + 1
            imin_new = imin + 1
        else:
            imax_new = imax - 1
            imin_new = imin - 1
        # Bound the results
        max_index = min(max(imax_new, 0), len(dat) - 1)
        min_index = min(max(imin_new, 0), len(dat) - 1)
        if key_press in ['left', 'right']:
            self.vslice(dat[min_index].values, dat[max_index].values, std)
        else:
            self.hslice(dat[min_index].values, dat[max_index].values, std) 
        # if self.alarm:
        #     OMFITaux['rootGUI'].after_cancel(self.alarm)
        #     self.alarm=None
        # self.alarm=OMFITaux['rootGUI'].after(100,None)
    def _set_values(self, values, draw=True, log_toggled=False, **kw):
        r"""
        Plots the main 2D image in axm.
        Optionally re-sets displayed values and axes, but does not change original data.
        :param values: 2D DataArray of values for display/interaction.
        :type values: DataArray.
        :param draw: Re-draw the figure when done.
        :type draw; bool
        :param log_toggled: Default value of False resets 'j' hotkey binding.
        :type log_toggled; bool
        :param \**kw: Additional key word arguments passed to the DataArray plot method (modified pcolormesh).
        :return: The main 2d :class:`matplotlib.collections.QuadMesh` from pcolormesh
        """
        self._log_toggled = log_toggled
        self.values = values
        # mask values (allows for nan's)
        zim = nominal_values(values)
        zim = np.ma.masked_where(np.isnan(zim), zim)
        if isinstance(values[values.dims[0]].values[0], str):
            x0 = np.arange(len(values[values.dims[0]]))
        else:
            x0 = nominal_values(values[values.dims[0]])
        if isinstance(values[values.dims[1]].values[0], str):
            x1 = np.arange(len(values[values.dims[1]]))
        else:
            x1 = nominal_values(values[values.dims[1]])
        # limits
        xlim = (np.nanmin(x0), np.nanmax(x0))
        ylim = (np.nanmin(x1), np.nanmax(x1))
        vlim = (np.nanmin(zim), np.nanmax(zim))
        # default key words
        dkw = {}
        dkw.update(kw)
        # first call creates the base image from the original data
        if not hasattr(self, 'image'):
            if 1 in self.data.shape:
                raise ValueError("Must have 2D data to use View2D. This data is {:}x{:}".format(*self.data.shape))
            self.image = self.data.T.plot(ax=self.axm, add_colorbar=False, **dkw)
        # updated image
        self.image.set_array(zim.T.flatten())
        self.image.autoscale()
        for k, v in list(kw.items()):
            try:
                getattr(self.image, 'set_' + k)(v)
            except Exception:
                printe('Cannot reset {:}, change manually.'.format(k))
        # contour
        if hasattr(self, 'contour'):
            for c in self.contour.collections:
                if c in self.axm.collections:
                    self.axm.collections.remove(c)
        if not np.all(zim.T == zim.flat[0]) and self._contour_levels is not None:
            self.contour = self.axm.contour(x0, x1, zim.T, self._contour_levels, colors='k')
        # aesthetics
        self.axm.set_xlim(*xlim)
        self.axm.set_ylim(*ylim)
        self.axx.set_ylim(*vlim)
        self.axy.set_xlim(*vlim)
        # re-slice
        if hasattr(self, 'get_vslice_args') and hasattr(self, 'get_hslice_args'):
            self.vslice(*self.get_vslice_args(), draw=draw, force=True)
            self.hslice(*self.get_hslice_args(), draw=draw, force=True)
        else:
            self.vslice(*[np.nanmean(x0)] * 2, draw=draw, force=True)
            self.hslice(*[np.nanmean(x1)] * 2, draw=draw, force=True)
        # drawing is slow
        if draw:
            self.fig.canvas.draw()  # Why does this change the axm limits????
        return
[docs]    def set_data(self, data=None, **kw):
        """
        Set base data and axes, as well as displayed values.
        Call with no arguments re-sets to interactively displayed values (use after :func:`der`, or :func:`int`).
        :param data: 2D array of data.
        :type data: DataArray
        :return: The main 2d :class:`matplotlib.collections.QuadMesh` from pcolormesh
        """
        if not data is None:
            self.data = data
        # keep the radio selection
        self._radio_select(self.radio.selected)
        return 
[docs]    def der(self, axis=0):
        """
        Set 2D values to derivative of data along specified axis.
        :param axis: Axis along which derivative is taken (0=horizontal, 1=vertical).
        :type axis: int
        :return: The main 2d :class:`matplotlib.collections.QuadMesh` from pcolormesh
        """
        if not isinstance(axis, int):
            axis = self.data.get_axis_num(axis)
        dx = np.gradient(nominal_values(self.data[self.data.dims[axis]])) + 0 * self.data[self.data.dims[axis]]
        dy = np.gradient(nominal_values(self.data), edge_order=2)[axis] + 0 * self.data
        data = dy / dx
        if np.allclose(nominal_values(self.values), nominal_values(data)):
            return
        self._set_values(data, **self.imag_options)
        # force action in linked views
        for v in self.links:
            v.der(axis=axis)
        return 
[docs]    def int(self, axis=0):
        """
        Set 2D values to derivative of data along specified axis.
        :param axis: Axis along which integration is taken (0=horizontal, 1=vertical).
        :type axis: int
        :return: The main 2d :class:`matplotlib.collections.QuadMesh` from pcolormesh
        """
        if isinstance(axis, int):
            x = self.data[self.data.dims[axis]]
        else:
            x = self.data[axis]
        data = integrate.cumtrapz(np.nan_to_num(nominal_values(self.data)), x=x, axis=axis, initial=0) + 0 * self.data
        if np.allclose(nominal_values(self.values), nominal_values(data)):
            return
        self._set_values(data, **self.imag_options)
        # force action in linked views
        for v in self.links:
            v.int(axis=axis)
        return 
[docs]    def toggle_log(self):
        """
        Toggle log/linear scaling of data in 2D plot.
        """
        if self._log_toggled:
            self._set_values(10**self.values, **self.imag_options)
            lbl = self.colorbar.ax.get_ylabel().lstrip('log ')
            self.colorbar.set_label(lbl)
            self.axx.set_ylabel(lbl)
            self.axy.set_xlabel(lbl)
            self.axx.set_yscale('linear')
            self.axy.set_xscale('linear')
            self._log_toggled = False
        else:
            self._set_values(np.log10(self.values), **self.imag_options)
            # relabel all axes
            lbl = 'log ' + self.colorbar.ax.get_ylabel()
            self.colorbar.set_label(lbl)
            self.axx.set_ylabel(lbl)
            self.axy.set_xlabel(lbl)
            # log scale line plots
            self.axx.set_yscale('log')
            self.axy.set_xscale('log')
            # robust log scale axes should not go to 0
            zim = nominal_values(self.values)
            zim = np.ma.masked_where(np.isnan(zim), zim)
            zim = np.ma.masked_where(zim <= 0, zim)
            vlim = (np.nanmin(zim), np.nanmax(zim))
            self.axx.set_ylim(*vlim)
            self.axy.set_xlim(*vlim)
            self._log_toggled = True
        return 
[docs]    def vslice(self, xmin, xmax, std=False, draw=True, force=False, **kw):
        r"""
        Plot line collection of x slices.
        :param xmin: Lower bound of slices dispayed in line plot.
        :type xmin: float
        :param xmax: Upper bound of slices dispayed in line plot.
        :type xmax: float
        :param std: Display mean and standard deviation instead of individual slices.
        :type std: bool
        :param draw: Redraw the figure canvas
        :type draw: bool
        :param Force: Re-slice even if arguments are identical to last slice.
        :type draw: bool
        :param \**kw: Extra key words are passed to the 1D :func:`plot` function
        :return: Possibly modified (xmin,xmax,std)
        :rtype: tuple
        """
        # plot options
        self.plot_options.update(kw)
        # slices
        if isinstance(self.data[self.data.dims[0]].values[0], str):
            x0 = np.arange(len(self.data[self.data.dims[0]]))
        else:
            x0 = nominal_values(self.data[self.data.dims[0]])
        if isinstance(self.data[self.data.dims[1]].values[0], str):
            x1 = np.arange(len(self.data[self.data.dims[1]]))
        else:
            x1 = nominal_values(self.data[self.data.dims[1]])
        if isinstance(xmin, DataArray):
            xmin = xmin.values
        if isinstance(xmax, DataArray):
            xmax = xmax.values
        imin = np.abs(x0 - xmin).argmin()
        imax = np.abs(x0 - xmax).argmin()
        imax = max(imax, min(imin, len(x0) - 2))
        x = self.values[imin : imax + 1, :]
        y = self.data[self.data.dims[1]]
        limits = self.axm.axis()
        # Stop infinite loops
        if hasattr(self, 'get_vslice_args'):
            if not force and (x0[imin], x0[imax], std) == self.get_vslice_args():
                return self.get_vslice_args()
        self.get_vslice_args = lambda: (x0[imin], x0[imax], std)
        self.axy.containers = []
        if compare_version(matplotlib.__version__, '3.3.4') <= 0:
            self.axy.collections = []
            self.axy.lines = []
        else:
            for collection in self.axy.collections:
                collection.remove()
            for line in self.axy.lines:
                line.remove()
        xm, sd = np.mean(nominal_values(x), axis=0), np.std(nominal_values(x), axis=0)
        self._vslice_mean = xm
        self._vslice_std = sd
        self._vslice_xdata = y
        if std:
            (l,) = self.axy.plot(nominal_values(xm), nominal_values(y), color='r')
            self.axy.fill_betweenx(nominal_values(y), nominal_values(xm - sd), nominal_values(xm + sd), alpha=0.3, color='r')
        else:
            y = x1.repeat(x.shape[0]).reshape((len(x1), -1)).T
            lines = []
            for i, (xi, yi) in enumerate(zip(x, y)):
                if self._use_uband:
                    (l,) = uband(xi, yi, ax=self.axy, **self.plot_options)
                else:
                    (l,) = uerrorbar(xi, yi, ax=self.axy, **self.plot_options)
                if i == 0 or i == imax - imin:
                    l.set_label('{:}={:6.2g}'.format(self.axm.get_xlabel(), x0[imin + i]))
                lines.append(l)
            sm = set_linearray(lines[:], x0[imin : imax + 1], cmap=self.plot_cmap)
            # self.fig.colorbar(sm,cax=self.axy,orientation='horizontal',use_gridspec=True)
            # lc = plotslices(x,y,colorbar='None',axes=self.axy,cmap='autumn_r')
        self.axy.legend(numpoints=1, prop={'size': 'xx-small'}).draggable(True)
        # self.axy.texts = []
        # self.axy.text(0.02,0.95,'{:}=[{:.2},{:.2}]'.format(self.axm.get_xlabel(),*self.get_vslice_args()),
        #                    verticalalignment='top',transform=self.axy.transAxes)
        if hasattr(self.axm, '_vspan'):
            if self.axm._vspan in self.axm.patches:
                self.axm.patches.pop(self.axm.patches.index(self.axm._vspan))
        self.axm._vspan = self.axm.axvspan(x0[imin], x0[imax], facecolor='w', edgecolor='w', alpha=0.3)
        if draw:
            self.fig.canvas.draw()
        # force action in linked views
        for v in self.links:
            v.vslice(*self.get_vslice_args())
        return self.get_vslice_args() 
[docs]    def hslice(self, ymin, ymax, std=False, draw=True, force=False, **kw):
        r"""
        Plot line collection of y slices.
        :param ymin: Lower bound of slices dispayed in line plot.
        :type ymin: float
        :param ymax: Upper bound of slices dispayed in line plot.
        :type ymax: float
        :param std: Display mean and standard deviation instead of individual slices.
        :type std: bool
        :param draw: Redraw the figure canvas
        :type draw: bool
        :param Force: Re-slice even if arguments are identical to last slice.
        :type draw: bool
        :param \**kw: Extra key words are passed to the 1D :func:`plot` function
        :return: Possibly modified (ymin,ymax,std)
        :rtype: tuple
        """
        # plot options
        self.plot_options.update(kw)
        # slices
        if isinstance(self.data[self.data.dims[0]].values[0], str):
            x0 = np.arange(len(self.data[self.data.dims[0]]))
        else:
            x0 = nominal_values(self.data[self.data.dims[0]])
        if isinstance(self.data[self.data.dims[1]].values[0], str):
            x1 = np.arange(len(self.data[self.data.dims[1]]))
        else:
            x1 = nominal_values(self.data[self.data.dims[1]])
        if isinstance(ymin, DataArray):
            ymin = ymin.values
        if isinstance(ymax, DataArray):
            ymax = ymax.values
        imin = np.abs(x1 - ymin).argmin()
        imax = np.abs(x1 - ymax).argmin()
        imax = max(imax, min(imin, len(x1) - 2))
        y = self.values[:, imin : imax + 1]
        x = self.data[self.data.dims[0]].values
        if isinstance(x[0], str):
            x = np.arange(len(x))
        # Prevent infinite loops
        if hasattr(self, 'get_hslice_args'):
            if not force and (x1[imin], x1[imax], std) == self.get_hslice_args():
                return self.get_hslice_args()
        self.get_hslice_args = lambda: (x1[imin], x1[imax], std)
        self.axx.containers = []
        if compare_version(matplotlib.__version__, '3.3.4') <= 0:
            self.axx.collections = []
            self.axx.lines = []
        else:
            for collection in self.axx.collections:
                collection.remove()
            for line in self.axx.lines:
                line.remove()
        ym, sd = np.mean(nominal_values(y), axis=1), np.std(nominal_values(y), axis=1)
        self._hslice_mean = ym
        self._hslice_std = sd
        self._hslice_xdata = x0
        if std:
            (l,) = self.axx.plot(x0, nominal_values(ym), color='b')
            self.axx.fill_between(nominal_values(x), nominal_values(ym - sd), nominal_values(ym + sd), alpha=0.3, color='b')
        else:
            lines = []
            for i, yi in enumerate(y.T):
                if self._use_uband:
                    (l,) = uband(x, yi.values, ax=self.axx, **self.plot_options)
                else:
                    (l,) = uerrorbar(x, yi.values, ax=self.axx, **self.plot_options)
                if i == 0 or i == imax - imin:
                    l.set_label('{:}={:6.2g}'.format(self.axm.get_ylabel(), x1[imin + i]))
                lines.append(l)
            sm = set_linearray(lines[:], x1[imin : imax + 1], cmap=self.plot_cmap)
        self.axx.legend(numpoints=1, prop={'size': 'xx-small'}).draggable(True)
        if hasattr(self.axm, '_hspan'):
            if self.axm._hspan in self.axm.patches:
                self.axm.patches.pop(self.axm.patches.index(self.axm._hspan))
        self.axm._hspan = self.axm.axhspan(x1[imin], x1[imax], facecolor='w', edgecolor='w', alpha=0.3)
        if draw:
            self.fig.canvas.draw()
        # force action in linked views
        for v in self.links:
            v.hslice(*self.get_hslice_args())
        return self.get_hslice_args() 
[docs]    def line_select_callback(self, eclick, erelease=None):
        """
        Call vslice and hslice for the range of x and y spanned
        by the rectangle between mouse press and release.
        Called by RectangleSelector.
        :param eclick: Matplotlib mouse click.
        :type eclick: matplotlib Event
        :param erelease: Matplotlib mouse release.
        :type erelease: matplotlib Event
        :return: None
        """
        # do nothing if click not in main axes
        if eclick.inaxes != self.axm:
            return
        # special treatment for single click (no dragged box)
        eclick.button == rightClickMPLindex  # right click
        if erelease is None:
            x, y = self.axm.transData.inverted().transform([eclick.x, eclick.y])
            self.hslice(y, y, std=rightClickMPLindex)
            self.vslice(x, x, std=rightClickMPLindex)
            self.fig.canvas.draw_idle()
            return
        # check if box is new (prevent infinite loops)
        x1, y1 = eclick.xdata, eclick.ydata
        x2, y2 = erelease.xdata, erelease.ydata
        xmin, xmax = sorted([x1, x2])
        ymin, ymax = sorted([y1, y2])
        if xmin == self.old_xmin and xmax == self.old_xmax and ymin == self.old_ymin and ymax == self.old_ymax:
            return
        self.old_xmin, self.old_xmax, self.old_ymin, self.old_ymax = xmin, xmax, ymin, ymax
        # print("(%3.2f, %3.2f) --> (%3.2f, %3.2f)" % (x1, y1, x2, y2))
        # print(" The button you used were: %s %s" % (eclick.button, erelease.button))
        self.hslice(ymin, ymax, std=rightClickMPLindex)
        self.vslice(xmin, xmax, std=rightClickMPLindex)
        self.fig.canvas.draw_idle() 
[docs]    def toggle_selector(self, event):
        """
        Connected to key press events to turn on (a) or off (q) selector.
        :param event: key press event.
        :type event: matplotlib event
        :return: None
        """
        if event is None or event.key is None:
            return
        if event.key.lower() in ['up', 'down', 'left', 'right']:
            self.key_navigate_cuts(event.key.lower())
        # if event.key.lower()=='f':
        #    self.scroll_event('','forward')
        if event.key in ['Q', 'q'] and self.RectangleSelector.active:
            print(' RectangleSelector deactivated.')
            self.RectangleSelector.set_active(False)
        if event.key in ['A', 'a'] and not self.RectangleSelector.active:
            print(' RectangleSelector activated.')
            self.RectangleSelector.set_active(True)
        if event.key == 'j':
            self.toggle_log() 
    def _radio_select(self, label):
        """
        Radio button effect.
        """
        self.radio.selected = label
        if label == 'data':
            self._set_values(self.data, **self.imag_options)
        elif label == 'd/d-' + self.data.dims[0]:
            self.der(0)
        elif label == 'd/d-' + self.data.dims[1]:
            self.der(1)
        elif label == 'int-' + self.data.dims[0]:
            self.int(0)
        elif label == 'int-' + self.data.dims[1]:
            self.int(1)
        for view in self.links:
            # have to manually change the fill
            for l, c in zip(view.radio.labels, view.radio.circles):
                if l.get_text() == label:
                    c.set_facecolor(view.radio.activecolor)
                else:
                    c.set_facecolor(view.radio.ax.get_axis_bgcolor())
                view.radio.ax.draw_artist(c)
            view.fig.canvas.blit(view.radio.ax.bbox)
[docs]    def link(self, view):
        """
        Link all actions in this view to the given view.
        """
        self.links.append(view)
        return 
[docs]    def unlink(self, view):
        """
        Unlink actions in this view from controlling the given view.
        """
        if view in self.links:
            self.links.remove(view)
        return  
[docs]@_available_to_user_plot
class View3d(View2d):
    """
    View 3D data by scrolling through 3rd dimension in a View2d plot.
    """
    def __init__(
        self,
        data,
        coords=None,
        dims=None,
        name=None,
        axes=None,
        use_uband=False,
        quiet=False,
        contour_levels=0,
        imag_options={},
        plot_options={'marker': '', 'ls': '-'},
        **indexers,
    ):
        r"""
        :param data: DataArray or array-like
            2D or 3D data values to be viewed.
        :param coords: dict-like
            Dictionary of Coordinate objects that label values along each dimension.
        :param dims: tuple
            Dimension names associated with this array.
        :param name: string
            Label used in legend. Empty or begining with '_' produces no legend label.
        :param dim: string, DataArray
            Dimension plotted on x-axis. If DataArray, must have same dims as data.
        :param axes: Axes instance
            The axes plotting is done in.
        :param quiet: bool
            Suppress printed messages.
        :param use_uband: bool
            Use uband for 1D slice plots instead of uerrorbar.
        :param contour_levels: int or np.ndarray
            Number of or specific levels used to draw black contour lines over the 2D image.
        :param imag_options: dict
            Key word arguments passed to the DataArray plot method (modified pcolormesh).
        :param plot_options: dict
            Key word arguments passed to plot or uerrorbar/uband. Color will be determined by cmap variable.
        :param \**indexers: dict
            Dictionary with keys given by dimension names and values given by
            arrays of coordinate index values.
        """
        # check args and make Dataset if needed
        if isinstance(data, DataArray):
            if coords or dims:
                raise ValueError("View cannot re-assign coords/dims to data that is already a DataArray")
            if name:
                data.name = name
            elif not data.name:  # require name
                name = 'viewdata'
                data.name = name
            else:
                name = data.name
            data = data
        elif isinstance(data, Dataset):
            if coords or dims:
                raise ValueError("View cannot re-assign coords/dims to data that is already a Dataset")
            data = data[list(data.data_vars.keys())]
        else:
            if name is None:
                name = 'viewdata'
            # must be regular gridded data
            data = DataArray(data, coords=coords, dims=dims, name=name)
        # pass on key words
        self._use_uband = use_uband
        self.plot_options = plot_options
        self.plot_cmap = plot_options.get('cmap', None)
        self.imag_options = imag_options
        # store basic inputs
        self.data3d = data
        # adjust inputs
        self.step = max(1, int(len(data[data.dims[2]]) // 100))
        self._index = int(self.data3d.shape[-1] // 2)
        View2d.__init__(
            self,
            data[:, :, self._index],
            axes=axes,
            use_uband=use_uband,
            contour_levels=contour_levels,
            imag_options=imag_options,
            plot_options=plot_options,
            **indexers,
        )
        self.axm.set_title('')  # remove xarray label of 3rd dimension's slice (we have our own dynamically updated one)
        self._set_3d()
        self.alarm = None
        self.widget = AxesWidget(self.axm)
        self.widget.connect_event('scroll_event', self._scrollx)
        self.widget.active = True
        printi("Scroll over 2D plot to change " + data.dims[2])
        self.fig.canvas.draw()
    def _scrollx(self, event):
        """
        Update viewer2d axes to show new 3d slice.
        """
        key = self.data3d.dims[2]
        ilim = self.data3d[key].shape[0]
        # change index
        if event.button == 'up':
            # ignore if already at limit
            if self._index == ilim - 1:
                return
            self._index += self.step
            if self._index > ilim - 1:
                self._index = ilim - 1
        elif event.button == 'down':
            if self._index == 0:
                return
            self._index -= self.step
            if self._index < 0:
                self._index = 0
        else:
            return
        if self.alarm:
            OMFITaux['rootGUI'].after_cancel(self.alarm)
            self.alarm = None
        self.alarm = OMFITaux['rootGUI'].after(10, self._set_3d)
    def _set_3d(self):
        self.set_3d(self._index, index=True)
        for v in self.links:
            try:
                v._scrollx(event)
            except Exception:
                pass
[docs]    def set_3d(self, x2, index=False, draw=True):
        """
        Set third dimension of view to value nearest to slice.
        :param x2: Slice in third dimension.  Type depends on value of `index`.
        :type x2: float or int
        :param index: Set True if x2 is the integer index
        :type index: bool
        :param draw: If True, redraw the canvas
        :type draw: bool
        :return: None
        """
        # nearest index
        key = self.data3d.dims[2]
        if index:
            indx = x2
        else:
            indx = np.abs(self.data[key] - x2).argmin()
        # update the central 2D plot
        self.set_data(self.data3d[:, :, indx], draw=False)
        ax = self.axm
        if compare_version(matplotlib.__version__, '3.3.4') <= 0:
            ax.texts = []
        else:
            for aText in ax.texts:
                aText.remove()
        ax.text(0.02, 0.95, '{:}={:6.2g}'.format(key, float(self.data3d[key].values[self._index])), transform=ax.transAxes)
        if draw:
            self.fig.canvas.draw()  
[docs]@_available_to_user_plot
class DragPoints(object):
    r"""
    This class is used to define matplotlib draggable arrays
    :param yArray: location in the OMFIT tree for the y array
    :param xArray: location in the OMFIT tree for the x array
    :param eyArray: location in the OMFIT tree for the x error array
    :param exArray: location in the OMFIT tree for the y error array
    :param editY: allow dragging of points along Y axis
    :param editX: allow dragging of points along X axis
    :param exDelta: increments of exArray
    :param eyDelta: increments of eyArray
    :param editX: allow dragging of points along X axis
    :param sorted: keep points sorted in x
    :param continuous: draw continuously even while user is dragging (may be good to disable if takes long time)
    :param showOriginal: show original points
    :param func: a function with signature like `func(x,y,motion,fargs=[])`.
                 where:
                 * x : x coordinate of control points
                 * y : y coordinate of control points
                 * motion : True if the user has dragged a point
                 This function must return x\_, y\_, x, y where:
                 * x\_ : interpolating points between x coordinate of control points
                 * y\_ : interpolating points between y coordinate of control points
                 * x : x coordinate of control points
                 * y : y coordinate of control points
    :param fargs: arguments to the function
    :param cyArray: location in the OMFIT tree for the y interpolation array
    :param cxArray: location in the OMFIT tree for the x interpolation array
    :param onMove: function handle to call whenever control point is moved
    :param resize: boolean to whether update axes size when drag point gets to an edge for the figure
    :param show_hints: bool. Show a cornernote with tips for interaction.
    :param ax: Axes. Axes in which to plot.
    All other key word arguments passed to the matplotlib plot function.
    :return:
    """
    epsilon = 5
    def __init__(
        self,
        yArray,
        xArray=None,
        eyArray=None,
        exArray=None,
        editY=True,
        editX=True,
        exDelta=1,
        eyDelta=1,
        sorted=False,
        continuous=True,
        showOriginal=False,
        func=None,
        fargs=[],
        cyArray=None,
        cxArray=None,
        onMove=None,
        resize=False,
        show_hints=True,
        ax=None,
        **kwargs,
    ):
        self.xArray = xArray
        self.yArray = yArray
        self.exArray = exArray
        self.eyArray = eyArray
        self.cxArray = cxArray
        self.cyArray = cyArray
        if ax is not None:
            pyplot.sca(ax)
        # enable any key words for lines
        kw = copy.deepcopy(kwargs)
        for k in ['animated', 'label']:  # these are set by this class
            kw.pop(k, None)
        ckw = copy.deepcopy(kw)
        kw['linestyle'] = ''  # the interpolation line gets the linestyle
        kw.setdefault('marker', kw.pop('m', 'o'))
        if func == 'nearest':
            def func(x, y, motion, *fargs, **fkw):
                if self.cxArray is not None:
                    x_ = eval(self.cxArray)
                else:
                    x_ = np.linspace(min(x), max(x), 1001)
                try:
                    return x_, interp1d(x, y, kind='nearest')(x_), x, y
                except ValueError as _excp:
                    printe(repr(_excp))
        elif func == 'linear':
            def func(x, y, motion, *fargs, **fkw):
                if self.cxArray is not None:
                    x_ = eval(self.cxArray)
                else:
                    x_ = np.linspace(min(x), max(x), 1001)
                try:
                    return x_, interpolate.InterpolatedUnivariateSpline(x, y, k=1, *fargs, **fkw)(x_), x, y
                except ValueError as _excp:
                    printe(repr(_excp))
        elif func == 'spline':
            def func(x, y, motion, *fargs, **fkw):
                if self.cxArray is not None:
                    x_ = eval(self.cxArray)
                else:
                    x_ = np.linspace(min(x), max(x), 1001)
                try:
                    return x_, interpolate.InterpolatedUnivariateSpline(x, y, *fargs, **fkw)(x_), x, y
                except ValueError as _excp:
                    printe(repr(_excp))
        elif func == 'pchip':
            def func(x, y, motion, *fargs, **fkw):
                if self.cxArray is not None:
                    x_ = eval(self.cxArray)
                else:
                    x_ = np.linspace(min(x), max(x), 1001)
                try:
                    return x_, interpolate.PchipInterpolator(x, y, *fargs, **fkw)(x_), x, y
                except ValueError as _excp:
                    printe(repr(_excp))
        elif func == 'circular':
            sorted = False
            def func(x, y, motion, *fargs, **fkw):
                x[-1] = x[0]
                y[-1] = y[0]
                x0 = np.mean(x)
                y0 = np.mean(y)
                xm = max(x) - min(x)
                ym = max(y) - min(y)
                t = np.unwrap(np.arctan2((y - y0) / ym, (x - x0) / xm))
                if t[0] > t[1]:
                    t = -t
                if self.cxArray is not None:
                    x_ = eval(self.cxArray)
                else:
                    x_ = np.linspace(min(x), max(x), 1001)
                if self.cyArray is not None:
                    y_ = eval(self.cyArray)
                else:
                    y_ = np.linspace(min(y), max(y), 1001)
                t_ = np.linspace(min(t), max(t), len(y_))
                tckp = interpolate.splrep(t, x, k=3, per=True)
                xi = interpolate.splev(t_, tckp, ext=0)
                tckp = interpolate.splrep(t, y, k=3, per=True)
                yi = interpolate.splev(t_, tckp, ext=0)
                return xi, yi, x, y
        self.func = func
        self.onMove = onMove
        self.fargs = fargs
        self.sorted = sorted
        self.continuous = continuous
        self.motion = False
        self.resize = resize
        self.editX = editX
        self.editY = editY
        self.exDelta = exDelta
        self.eyDelta = eyDelta
        self.ax = pyplot.gca()
        self.ax.dp = self
        self._ind = None
        self.canvas = pyplot.gcf().canvas
        y = eval(self.yArray)
        if self.xArray is None:
            x = np.array(list(range(y.size)))
            self.editX = False
        else:
            x = eval(self.xArray)
        if self.exArray is None:
            ex = y * 0
        else:
            ex = eval(self.exArray)
        if self.eyArray is None:
            ey = y * 0
        else:
            ey = eval(self.eyArray)
        if self.cxArray is None:
            cx = y * 0
        else:
            cx = eval(self.cxArray)
            self.ax.set_xlabel(self.cxArray)
        if self.cyArray is None:
            cy = y * 0
        else:
            cy = eval(self.cyArray)
            self.ax.set_ylabel(self.cyArray)
        if self.func is not None:
            argspec = inspect.getfullargspec(self.func)
            if len(argspec.args) > 2 and argspec.args[2] == 'motion':
                x_, y_ = self.func(x, y, False, *self.fargs)[:2]
            else:
                # backward compatibility
                x_, y_ = self.func(x, y, *self.fargs)[:2]
            self.cline, *_ = self.ax.plot(x_, y_, animated=True, **ckw)
            if self.cxArray is not None:
                self.cline.set_marker('.')
                self.cline.set_markerfacecolor('k')
        if showOriginal:
            okw = copy.deepcopy(kw)
            okw.pop('color', None)
            okw.pop('c', None)
            self.line_orig, *_ = self.ax.plot(x, y, animated=True, label='Original', color='grey', **okw)
        self.line, *_ = self.ax.plot(x, y, animated=True, label='Control point', **kw)
        if self.onMove is not None:
            self.onMove(self.cline._xy[:, 0], self.cline._xy[:, 1], self.line._xy[:, 0], self.line._xy[:, 1])
        self.orig = {}
        self.orig['xArray'] = copy.deepcopy(x)
        self.orig['yArray'] = copy.deepcopy(y)
        self.orig['cxArray'] = copy.deepcopy(cx)
        self.orig['cyArray'] = copy.deepcopy(cy)
        self.orig['exArray'] = copy.deepcopy(ex)
        self.orig['eyArray'] = copy.deepcopy(ey)
        self.canvas.mpl_connect('draw_event', self.draw_callback)
        self.canvas.mpl_connect('button_press_event', self.button_press_callback)
        self.canvas.mpl_connect('key_press_event', self.key_press_callback)
        self.canvas.mpl_connect('button_release_event', self.button_release_callback)
        self.canvas.mpl_connect('motion_notify_event', self.motion_notify_callback)
        txt = '<u>: undo\n<i>: insert node\n<d>: delete node\n'
        if self.eyArray is not None:
            txt += '<=>: increment ey\n'
            txt += '<->: decrement ey\n'
        if self.exArray is not None:
            txt += '<+>: increment ex\n'
            txt += '<_>: decrement ex\n'
        if show_hints:
            cornernote('', txt.strip())
        self.inhibitFunc = True
[docs]    def draw_callback(self, event=None):
        self.background = self.canvas.copy_from_bbox(self.ax.bbox)
        if hasattr(self, 'line_orig'):
            self.ax.draw_artist(self.line_orig)
        if self.sorted:
            index = np.argsort(self.line._xy[:, 0])
            self.line._xy[:, 0] = self.line._xy[index, 0]
            self.line._xy[:, 1] = self.line._xy[index, 1]
        if self.func is not None:
            if not self.inhibitFunc or (event is not None and hasattr(event, 'key') and event.key == 'u' and self.cyArray is not None):
                x_ = self.cline._xy[:, 0]
                if self.cxArray is not None:
                    x_ = np.squeeze(eval(self.cxArray))
                y_ = self.cline._xy[:, 1]
                if self.cyArray is not None:
                    y_ = np.squeeze(eval(self.cyArray))
                x = self.line._xy[:, 0]
                y = self.line._xy[:, 1]
                self.ax.draw_artist(self.cline)
            else:
                argspec = inspect.getfullargspec(self.func)
                if len(argspec.args) > 2 and argspec.args[2] == 'motion':
                    try:
                        x_, y_, x, y = self.func(self.line._xy[:, 0], self.line._xy[:, 1], self.motion, *self.fargs)
                    except Exception as _excp:
                        return
                else:
                    # backward compatibility
                    try:
                        x_, y_, x, y = self.func(self.line._xy[:, 0], self.line._xy[:, 1], *self.fargs)
                    except Exception as _excp:
                        return
            self.motion = False
            self.cline.set_data(x_, y_)
            if self.cxArray is not None:
                exec('%s=np.reshape(x_,%s.shape)' % (self.cxArray, self.cxArray))
            if self.cyArray is not None:
                exec('%s=np.reshape(y_,%s.shape)' % (self.cyArray, self.cyArray))
            self.line._xy[:, 0] = np.squeeze(x)
            self.line._xy[:, 1] = np.squeeze(y)
            self.ax.draw_artist(self.cline)
        self.ax.draw_artist(self.line)
        self.canvas.blit(self.ax.bbox)
        if self.onMove is not None:
            self.onMove(self.cline._xy[:, 0], self.cline._xy[:, 1], self.line._xy[:, 0], self.line._xy[:, 1]) 
[docs]    def get_ind_under_point(self, event):
        xy = self.line._xy
        xyt = self.line.get_transform().transform(xy)
        xt, yt = xyt[:, 0], xyt[:, 1]
        d = np.sqrt((xt - event.x) ** 2 + (yt - event.y) ** 2)
        indseq = np.nonzero(np.equal(d, np.amin(d)))[0]
        ind = indseq[0]
        if d[ind] >= self.epsilon:
            ind = None
        return ind 
[docs]    def key_press_callback(self, event=None):
        if event is not None and not event.inaxes:
            return
        if event is None or event.key == 'u':
            self.line.set_data(copy.deepcopy(self.orig['xArray']), copy.deepcopy(self.orig['yArray']))
            if self.xArray is not None:
                exec(self.xArray + "=copy.deepcopy(self.orig['xArray'])")
            if self.yArray is not None:
                exec(self.yArray + "=copy.deepcopy(self.orig['yArray'])")
            if self.cxArray is not None:
                exec(self.cxArray + "=copy.deepcopy(self.orig['cxArray'])")
            if self.cyArray is not None:
                exec(self.cyArray + "=copy.deepcopy(self.orig['cyArray'])")
            if self.exArray is not None:
                exec(self.exArray + "=copy.deepcopy(self.orig['exArray'])")
            if self.eyArray is not None:
                exec(self.eyArray + "=copy.deepcopy(self.orig['eyArray'])")
        elif event.key == 'd':
            ind = self.get_ind_under_point(event)
            if ind is not None:
                printd('Deleted point #' + str(ind), topic='figure')
                tmp = np.array([tup for i, tup in enumerate(self.line._xy) if i != ind])
                if not self.editX:
                    tmp.T[0] = np.array(list(range(tmp.T[0].size)))
                self.line.set_data(tmp.T[0], tmp.T[1])
                if self.xArray is not None:
                    exec(self.xArray + '=tmp.T[0]')
                if self.yArray is not None:
                    exec(self.yArray + '=tmp.T[1]')
                if self.exArray is not None:
                    tmp = np.squeeze(eval(self.exArray)).tolist()
                    del tmp[ind]
                    tmp = np.array(tmp)
                    exec(self.exArray + '=tmp')
                if self.eyArray is not None:
                    tmp = np.squeeze(eval(self.eyArray)).tolist()
                    del tmp[ind]
                    tmp = np.array(tmp)
                    exec(self.eyArray + '=tmp')
        elif event.key == 'i':
            xysl = self.line.get_transform().transform(self.line._xy)
            try:
                xys = self.cline.get_transform().transform(self.cline._xy)
            except Exception:
                xys = xysl
            p = event.x, event.y
            for i in range(len(xys) - 1):
                s0 = xys[i]
                s1 = xys[i + 1]
                d = point_to_line(p[0], p[1], s0[0], s0[1], s1[0], s1[1])
                new_point = (event.xdata, event.ydata)
                condition = True
                if d <= self.epsilon and condition:
                    tmp = np.array(list(self.line._xy[: i + 1]) + [new_point] + list(self.line._xy[i + 1 :]))
                    if not self.editX:
                        tmp.T[0] = np.array(list(range(tmp.T[0].size)))
                    self.line.set_data(tmp.T[0], tmp.T[1])
                    if self.xArray is not None:
                        exec(self.xArray + '=tmp.T[0]')
                    if self.yArray is not None:
                        exec(self.yArray + '=tmp.T[1]')
                    if self.exArray is not None:
                        ex = np.hstack(
                            (
                                np.squeeze(eval(self.exArray))[: i + 1],
                                (np.squeeze(eval(self.exArray))[i] + np.squeeze(eval(self.exArray))[i + 1]) * 0.5,
                                np.squeeze(eval(self.exArray))[i + 1 :],
                            )
                        )
                        exec(self.exArray + '=ex')
                    if self.eyArray is not None:
                        ey = np.hstack(
                            (
                                np.squeeze(eval(self.eyArray))[: i + 1],
                                (np.squeeze(eval(self.eyArray))[i] + np.squeeze(eval(self.eyArray))[i + 1]) * 0.5,
                                np.squeeze(eval(self.eyArray))[i + 1 :],
                            )
                        )
                        exec(self.eyArray + '=ey')
                    break
        elif event.key == '=' and self.eyArray is not None:
            ind = self.get_ind_under_point(event)
            if ind is not None:
                exec(self.eyArray + '[ind]+=self.eyDelta')
                print('figure: Increased ey point #%d to %3.3f' % (ind, np.squeeze(eval(self.eyArray))[ind]))
        elif event.key == '-' and self.eyArray is not None:
            ind = self.get_ind_under_point(event)
            if ind is not None:
                exec(self.eyArray + '[ind]-=self.eyDelta')
                print('figure: Decreased ey point #%d to %3.3f' % (ind, np.squeeze(eval(self.eyArray))[ind]))
        elif event.key == '+' and self.exArray is not None:
            ind = self.get_ind_under_point(event)
            if ind is not None:
                exec(self.exArray + '[ind]+=self.exDelta')
                print('figure: Increased ex point #%d to %3.3f' % (ind, np.squeeze(eval(self.eyArray))[ind]))
        elif event.key == '_' and self.exArray is not None:
            ind = self.get_ind_under_point(event)
            if ind is not None:
                exec(self.exArray + '[ind]-=self.exDelta')
                print('figure: Decreased ex point #%d to %3.3f' % (ind, np.squeeze(eval(self.eyArray))[ind]))
        elif event.key == 'x':
            self.editX = not self.editX
            printi('drag X:' + str(self.editX))
        elif event.key == 'y':
            self.editY = not self.editY
            printi('drag Y:' + str(self.editY))
        OMFITaux['rootGUI'].event_generate("<<update_treeGUI>>")
        if self.onMove is not None:
            self.onMove(self.cline._xy[:, 0], self.cline._xy[:, 1], self.line._xy[:, 0], self.line._xy[:, 1])
        self.canvas.draw_idle()
        self.canvas.restore_region(self.background)
        self.draw_callback(event) 
[docs]    def motion_notify_callback(self, event):
        if self._ind is None:
            return
        if event.inaxes is None:
            return
        if event.button != 1:
            return
        x, y = event.xdata, event.ydata
        tmp = self.line._xy
        if self.editX:
            tmp[self._ind][0] = x
        if self.editY:
            tmp[self._ind][1] = y
        self.line.set_data(tmp.T[0], tmp.T[1])
        if self.editY:
            pyplot.sca(self.ax)
            dy = np.diff(pyplot.ylim())[0] / 20.0
            if self.resize and y > pyplot.ylim()[1] - dy:
                pyplot.ylim(np.array(pyplot.ylim()) + np.array([0, dy]))
            if self.resize and y < pyplot.ylim()[0] + dy:
                pyplot.ylim(np.array(pyplot.ylim()) + np.array([-dy, 0]))
        if self.editX:
            pyplot.sca(self.ax)
            dx = np.diff(pyplot.xlim())[0] / 20.0
            if self.resize and x > pyplot.xlim()[1] - dx:
                pyplot.xlim(np.array(pyplot.xlim()) + np.array([0, dx]))
            if self.resize and x < pyplot.xlim()[0] + dx:
                pyplot.xlim(np.array(pyplot.xlim()) + np.array([-dx, 0]))
        if self.xArray is not None:
            exec(self.xArray + '=tmp.T[0]')
        if self.yArray is not None:
            exec(self.yArray + '=tmp.T[1]')
        self.motion = True
        self.canvas.restore_region(self.background)
        if self.continuous:
            self.draw_callback()
        else:
            self.ax.draw_artist(self.line)
            self.canvas.blit(self.ax.bbox)  
[docs]@_available_to_user_plot
def editProfile(yi, xi=None, n=None, showOriginal=True, func='spline', onMove=None):
    """
    This function opens an interactive figure for convenient editing of profiles via spline
    :param yi: string whose `eval` yields the y data
    :param xi: string whose `eval` yields the x data
    :param n: number of control points
    :param showOriginal: plot original data
    :param func: interpolation function used to interpolate between control points 'linear', 'spline', 'pchip', 'circular'
    :param onMove: function to call when moving of control points occurs
    :param resize: boolean to whether update axes size when drag point gets to an edge for the figure
    :return: DragPoints object
    """
    from omfit_classes.utils_fit import autoknot
    if n is None:
        n = int(len(np.squeeze(eval(yi))) // 5)
    n = max([4, n])
    n = min([n, int(len(np.squeeze(eval(yi))) // 2)])
    if xi is None:
        OMFIT['scratch']['x_' + str(id(yi))] = np.linspace(0, len(np.squeeze(eval(yi))) - 1, len(np.squeeze(eval(yi))))
        xi = "OMFIT['scratch']['x_" + str(id(yi)) + "']"
    if func == 'linear':
        x0, y0 = autoknot(np.squeeze(eval(xi)), np.squeeze(eval(yi)), n, s=1, allKnots=True)
    elif func == 'pchip':
        tmp = interpolate.PchipInterpolator(np.squeeze(eval(xi)), np.squeeze(eval(yi)))
        class pchip(object):
            def __init__(self, x, y):
                pass
            def __call__(self, x0):
                return tmp(x0)
        x0, y0 = autoknot(np.squeeze(eval(xi)), np.squeeze(eval(yi)), n, mindist=1e-6, allKnots=True, userFunc=pchip)
    elif func == 'nearest':
        tmp = interpolate.interp1d(np.squeeze(eval(xi)), np.squeeze(eval(yi)), kind='nearest', bounds_error=False)
        class nearest(object):
            def __init__(self, x, y):
                pass
            def __call__(self, x0):
                return tmp(x0)
        x0, y0 = autoknot(np.squeeze(eval(xi)), np.squeeze(eval(yi)), n, mindist=1e-6, allKnots=True, userFunc=nearest)
    elif func == 'spline':
        x0, y0 = autoknot(np.squeeze(eval(xi)), np.squeeze(eval(yi)), n, s=3, allKnots=True)
    elif func == 'circular':
        t = np.cumsum(np.sqrt(gradient(np.squeeze(eval(xi))) ** 2 + gradient(np.squeeze(eval(yi))) ** 2))
        te = np.linspace(min(t), max(t), 1001)
        xe = interp1e(t, np.squeeze(eval(xi)), kind=3)(te)
        ye = interp1e(t, np.squeeze(eval(yi)), kind=3)(te)
        xe0 = np.mean(xe)
        ye0 = np.mean(ye)
        xem = max(xe) - min(xe)
        yem = max(ye) - min(ye)
        tec = np.cumsum(np.sqrt(gradient((xe - xe0) / xem) ** 2 + gradient((ye - ye0) / yem) ** 2))
        t0, x0 = autoknot(tec, (xe - xe0) / xem, n + 1, mindist=1e-6, s=5, allKnots=True)
        t1, y0 = autoknot(tec, (ye - ye0) / yem, n + 1, mindist=1e-6, s=5, allKnots=True)
        t_ = (t0 + t1) / 2.0
        x0 = interp1e(tec, xe)(t_)
        y0 = interp1e(tec, ye)(t_)
    OMFIT['scratch']['x0_' + str(id(x0))] = x0
    OMFIT['scratch']['y0_' + str(id(y0))] = y0
    if showOriginal:
        pyplot.plot(np.squeeze(eval(xi)), np.squeeze(eval(yi)))
    tmp = DragPoints(
        "OMFIT['scratch']['y0_" + str(id(y0)) + "']",
        "OMFIT['scratch']['x0_" + str(id(x0)) + "']",
        func=func,
        sorted=True,
        cxArray=xi,
        cyArray=yi,
        onMove=onMove,
        resize=False,
    )
    if func == 'circular':
        tmp.ax.set_aspect('equal')
    return tmp 
[docs]@_available_to_user_plot
def cornernote(
    text='', root=None, device=None, shot=None, time=None, ax=None, fontsize='small', clean=True, remove=False, remove_specific=False
):
    """
    Write text at the bottom right corner of a figure
    :param text: text to appear in the bottom left corner
    :param root: * if '' append nothing
                 * if None append shot/time as from `OMFIT['MainSettings']['EXPERIMENT']`
                 * if OMFITmodule append shot/time as from `root['SETTINGS']['EXPERIMENT']`
    :param device: override device string (does not print device at all if empty string)
    :param shot: override shot string (does not print shot at all if empty string)
    :param time: override time string (does not print time at all if empty string)
    :param ax: axis to plot on
    :param fontsize: str or float. Sets font size of the Axes annotate method.
    :param clean: delete existing cornernote(s) from current axes before drawing a new cornernote
    :param remove: delete existing cornernote(s) and return before drawing any new ones
    :param remove_specific: delete existing cornernote(s) from current axes only if text matches the text that would be printed by the current call to cornernote() (such as identical shot, time, etc.)
    :return: Matplotlib annotate object
    """
    if ax is None:
        ax = pyplot.gca()
    if root is None:
        root = eval('OMFIT')
    if root and not isinstance(root, str):
        for location in ['MainSettings', 'SETTINGS']:
            if device is None:
                try:
                    device = root[location]['EXPERIMENT']['device']
                except Exception:
                    pass
            if shot is None:
                try:
                    shot = root[location]['EXPERIMENT']['shot']
                except Exception:
                    pass
            if time is None:
                try:
                    time = root[location]['EXPERIMENT']['time']
                except Exception:
                    pass
        text_ = []
        if device is not None and len(str(device)):
            text_.append(str(device))
        if shot is not None and len(str(shot)):
            text_.append('#' + str(shot))
        if time is not None and len(str(time)):
            text_.append(str(time) + ' ms')
        if len(text):
            text = ' '.join(text_) + ' : ' + text
        else:
            text = ' '.join(text_)
    elif isinstance(root, str) and not len(text):
        text = root
    # cornernote deletion options
    if remove_specific:
        # delete existing cornernote only if the text that would be written by this call matches exactly with existing text
        for textObj in ax.texts:
            if textObj.get_text() == text and textObj.get_label() == 'cornernote' and np.all(textObj._get_xy_display() == [0.98, 0.02]):
                textObj.remove()
        return  # for the specific version only, return without writing a new cornernote (assume you want to delete just one specific note and not recreate it exactly)
    else:
        # if not trying to remove a specific note, then scrub all notes but don't return
        if clean or remove:
            for textObj in ax.texts:
                if textObj.get_label() == 'cornernote':
                    textObj.remove()
        if remove:
            return  # the remove keyword should leave the corner blank (no cornernote visible)
        else:
            pass  # no return after clean because we want the default behavior to be: clear all notes and then write a fresh one
    return ax.annotate(text, xy=(0.98, 0.02), xycoords='figure fraction', fontsize=fontsize, ha='right', va='bottom', label='cornernote') 
# Diagonal line to go with axhline and axvline
[docs]class axdline(matplotlib.pyplot.Line2D):
    """
    Draw a line based on its slope and y-intercept. Additional arguments are
    passed to the <matplotlib.lines.Line2D> constructor.
    From stackoverflow anser by ali_m: http://stackoverflow.com/a/14348481/6605826
    Originally named ABLine2D
    """
    def __init__(self, slope=1, intercept=0, *args, **kwargs):
        # get current axes if user has not specified them
        if not 'axes' in kwargs:
            if 'ax' in kwargs:  # translate ax into axes
                kwargs.update({'axes': kwargs['ax']})
                del kwargs['ax']
            else:
                kwargs.update({'axes': pyplot.gca()})
        ax = kwargs['axes']
        # init the line, add it to the axes
        super().__init__([], [], *args, **kwargs)
        self._slope = slope
        self._intercept = intercept
        ax.add_line(self)
        # cache the renderer, draw the line for the first time
        ax.figure.canvas.draw()
        self._update_lim(None)
        # connect to axis callbacks
        self.axes.callbacks.connect('xlim_changed', self._update_lim)
        self.axes.callbacks.connect('ylim_changed', self._update_lim)
    def _update_lim(self, event):
        """called whenever axis x/y limits change"""
        x = np.array(self.axes.get_xbound())
        y = (self._slope * x) + self._intercept
        self.set_data(x, y)
        try:
            self.axes.draw_artist(self)
        except AttributeError:
            printw('WARNING: could not update axdline artist! Diagonal lines might not display correctly!') 
[docs]def square_subplots(nplots, ncol_max=np.inf, flip=False, sparse_column=True, just_numbers=False, identify=False, fig=None, **kw):
    r"""
    Creates a set of subplots in an approximate square, with a few empty subplots if needed
    :param nplots: int
        Number of subplots desired
    :param ncol_max: int
        Maximum number of columns to allow
    :param flip: bool
        True: Puts row 0 at the bottom, so every plot on the bottom row can accept an X axis label
        False: Normal plot numbering with row 0 at the top. The bottom row may be sparsely populated.
    :param sparse_column: bool
        Controls the arrangement of empty subplots.
        True: the last column is sparse. That is, all the empty plots will be in the last column. There will be at most
            one plot missing from the last row, and potentially several from the last column. The advantage is this
            provides plenty of X axes on the bottom row to accept labels. To get natural numbering of flattened
            subplots, transpose before flattening: axs.T.flatten(), or just use the 1D axsf array that's returned.
        False: the last row is sparse. All the empty plots will be in the last row. The last column will be missing at
            most one plot, but the last row may be missing several. This arrangement goes more smoothly with the
            numbering of axes after flattening.
    :param just_numbers: bool
        Don't create any axes, but instead just return the number of rows, columns, and empty subplots in the array.
    :param identify: bool
        For debugging: write the number (as flattened) and [row, col] coordinates of each subplot on the plot itself.
        These go in the center, in black. In the top left corner in red is the naive flattened count, which will appear
        on empty plots as well to show how wrong it is. In the bottom right corner in blue is the proper flattened count
        based on axsf.
    :param fig: Figure instance [optional]
    :param \**kw: keywords passed to pyplot.subplots when creating axes (like sharex, etc.)
    :return: (axs, axsf) or (nr, nc, on, empty)
        axs: 2d array of Axes instances. It is flipped vertically relative to normal axes output by pyplot.subplots,
            so the 0th row is the bottom. This is so the bottom row will be populated and can receive x axis labels.
        axsf: 1d array of Axes instances, leaving out the empty ones (they might not be in order nicely)
        empty: int: number of empty cells in axs.
            The first empty, if there is one, is [-1, -1] (top right), then [-1, -2] (top row, 2nd from the right), etc.
        nr: int: number of rows
        nc: int: number of columns
        on: 2d bool array: flags indicating which axes should be on (True) and which should be hidden/off (False)
    """
    kw.setdefault('sharex', True)
    kw['squeeze'] = False
    # Find number of columns, rows, and empty plots
    nc = int(nplots**0.5)
    nc = int(np.ceil(np.min([nc, ncol_max])))
    nr = int(np.ceil(nplots / float(nc)))
    empty = nr * nc - nplots
    # Make a rule for hiding the unused subplots
    on = np.ones((nr, nc), bool)
    for j in range(empty):
        if sparse_column:
            on[-(1 + j), -1] = False
        else:
            on[-1, -(1 + j)] = False
    # Map 1D indices to 2D indices (the hidden axes make this harder than it should be)
    offset = 0
    coords = [None] * nplots
    for j in range(nplots):
        jeff = j + offset
        ir = jeff // nc
        ic = np.mod(jeff, nc)
        if sparse_column:
            # Skip empties by incrementing offset
            while not on[ir, ic]:
                offset += 1
                jeff = j + offset
                ir = jeff // nc  # Find which row we're on.
                ic = np.mod(jeff, nc)  # Find which column we're on
        coords[j] = (ir, ic)
    if just_numbers:
        return nr, nc, on, empty
    # Make the plot grid
    if fig is None:
        fig, axs = pyplot.subplots(nr, nc, **kw)  # Default to new figure instead of pyplot.gcf()
        if flip:
            axs = axs[::-1, :]
        for axx in axs.flatten()[~on.flatten()]:
            axx.axis('off')
        axsf = axs.flatten()[on.flatten()]
    else:
        pyplot.figure(num=fig.number)
        axs = np.empty((nr, nc), object)
        axsf = np.empty(nplots, object)
        sharex = None
        sharey = None
        for j in range(nplots):
            ir, ic = coords[j]
            if flip:
                slot = (nr - 1 - ir) * nc + ic + 1
            else:
                slot = ir * nc + ic + 1
            axx = pyplot.subplot(nr, nc, slot, sharex=sharex, sharey=sharey)
            axs[ir, ic] = axx
            axsf[j] = axx
            if kw.get('sharex', False):
                if ((ir < (nr - 1)) and (not flip)) or ((ir > 0) and flip):
                    # This is not the last row, so tradition demands we hide the axis labels
                    for tick in axx.xaxis.get_ticklabels():
                        tick.set_visible(False)
                if sharex is None:
                    sharex = axx  # If sharex=True, catch the first axes as the ones to share with
            if kw.get('sharey', False):
                if ic > 0:
                    # This is not the first column, so tradition demands we hide the axis labels
                    for tick in axx.yaxis.get_ticklabels():
                        tick.set_visible(False)
                if sharey is None:
                    sharey = axx
    # Write numbers on the subplots to identify them
    if identify:
        for j in range(nplots):
            ir, ic = coords[j]
            axx = axs[ir, ic]  # Get a reference to the subplot we're working with
            label = '{} [{}, {}]'.format(j, ir, ic)
            axx.text(0.5, 0.5, label, transform=axx.transAxes, ha='center', va='center')
            axsf[j].text(0.95, 0.05, str(j), transform=axsf[j].transAxes, ha='right', va='bottom', color='blue')
            if on[ir, ic]:
                axx.text(0.75, 0.75, 'ON', transform=axx.transAxes, ha='center', va='center', color='blue')
            else:
                if axx is not None:
                    axx.text(0.75, 0.75, 'OFF', transform=axx.transAxes, ha='center', va='center', color='red')
        for j in range(nc * nr):
            ax1 = axs.flatten()[j]
            axt = axs.T.flatten()[j]
            if ax1 is not None:
                ax1.text(0.05, 0.95, str(j), transform=ax1.transAxes, ha='left', va='top', color='red')
            if axt is not None:
                axt.text(0.95, 0.95, 'T' + str(j), transform=axt.transAxes, ha='right', va='top', color='green')
    return axs, axsf 
try:
    from matplotlib.widgets import TextBox as _TextBox
except ImportError:
    class _TextBox(AxesWidget):
        """
        A GUI neutral text input box.
        For the text box to remain responsive you must keep a reference to it.
        The following attributes are accessible:
          *ax*
            The :class:`matplotlib.axes.Axes` the button renders into.
          *label*
            A :class:`matplotlib.text.Text` instance.
          *color*
            The color of the text box when not hovering.
          *hovercolor*
            The color of the text box when hovering.
        Call :meth:`on_text_change` to be updated whenever the text changes.
        Call :meth:`on_submit` to be updated whenever the user hits enter or
        leaves the text entry field.
        Taken from
        https://github.com/QuadmasterXLII/matplotlib/blob/2734051ec04b73280dad6a27f2003d1697d11195/lib/matplotlib/widgets.py
        which is in a pull request to matplotlib
        https://github.com/matplotlib/matplotlib/pull/5375/files
        smithsp added the set_val method to be merged into that pull request with
        https://github.com/QuadmasterXLII/matplotlib/pull/1
        """
        def __init__(self, ax, label, initial='', color='.95', hovercolor='1', label_pad=0.01):
            """
            Parameters:
            ax : matplotlib.axes.Axes
                The :class:`matplotlib.axes.Axes` instance the button
                will be placed into.
            label : str
                Label for this text box. Accepts string.
            initial : str
                Initial value in the text box
            color : color
                The color of the box
            hovercolor : color
                The color of the box when the mouse is over it
            label_pad : float
                the distance between the label and the right side of the textbox
            """
            from matplotlib.widgets import AxesWidget
            AxesWidget.__init__(self, ax)
            self.DIST_FROM_LEFT = 0.05
            self.params_to_disable = []
            for key in list(rcParams.keys()):
                if u'keymap' in key:
                    self.params_to_disable += [key]
            self.text = initial
            self.label = ax.text(-label_pad, 0.5, label, verticalalignment='center', horizontalalignment='right', transform=ax.transAxes)
            self.text_disp = self._make_text_disp(self.text)
            self.cnt = 0
            self.change_observers = {}
            self.submit_observers = {}
            # If these lines are removed, the cursor won't appear the first
            # time the box is clicked:
            self.ax.set_xlim(0, 1)
            self.ax.set_ylim(0, 1)
            self.cursor_index = 0
            # Because this is initialized, _render_cursor
            # can assume that cursor exists.
            self.cursor = self.ax.vlines(0, 0, 0)
            self.cursor.set_visible(False)
            self.connect_event('button_press_event', self._click)
            self.connect_event('button_release_event', self._release)
            self.connect_event('motion_notify_event', self._motion)
            self.connect_event('key_press_event', self._keypress)
            self.connect_event('resize_event', self._resize)
            ax.set_navigate(False)
            ax.set_axis_bgcolor(color)
            ax.set_xticks([])
            ax.set_yticks([])
            self.color = color
            self.hovercolor = hovercolor
            self._lastcolor = color
            self.capturekeystrokes = False
        def _make_text_disp(self, string):
            return self.ax.text(
                self.DIST_FROM_LEFT, 0.5, string, verticalalignment='center', horizontalalignment='left', transform=self.ax.transAxes
            )
        def _rendercursor(self):
            # this is a hack to figure out where the cursor should go.
            # we draw the text up to where the cursor should go, measure
            # and save its dimensions, draw the real text, then put the cursor
            # at the saved dimensions
            widthtext = self.text[: self.cursor_index]
            no_text = False
            if widthtext == "" or widthtext == " " or widthtext == "  ":
                no_text = widthtext == ""
                widthtext = ","
            wt_disp = self._make_text_disp(widthtext)
            self.ax.figure.canvas.draw()
            bb = wt_disp.get_window_extent()
            inv = self.ax.transData.inverted()
            bb = inv.transform(bb)
            wt_disp.set_visible(False)
            if no_text:
                bb[1, 0] = bb[0, 0]
            # hack done
            self.cursor.set_visible(False)
            self.cursor = self.ax.vlines(bb[1, 0], bb[0, 1], bb[1, 1])
            self.ax.figure.canvas.draw()
        def _notify_submit_observers(self):
            for cid, func in self.submit_observers.items():
                func(self.text)
        def _release(self, event):
            if self.ignore(event):
                return
            if event.canvas.mouse_grabber != self.ax:
                return
            event.canvas.release_mouse(self.ax)
        def _keypress(self, event):
            if self.ignore(event):
                return
            if self.capturekeystrokes:
                key = event.key
                if len(key) == 1:
                    self.text = self.text[: self.cursor_index] + key + self.text[self.cursor_index :]
                    self.cursor_index += 1
                elif key == "right":
                    if self.cursor_index != len(self.text):
                        self.cursor_index += 1
                elif key == "left":
                    if self.cursor_index != 0:
                        self.cursor_index -= 1
                elif key == "home":
                    self.cursor_index = 0
                elif key == "end":
                    self.cursor_index = len(self.text)
                elif key == "backspace":
                    if self.cursor_index != 0:
                        self.text = self.text[: self.cursor_index - 1] + self.text[self.cursor_index :]
                        self.cursor_index -= 1
                elif key == "delete":
                    if self.cursor_index != len(self.text):
                        self.text = self.text[: self.cursor_index] + self.text[self.cursor_index + 1 :]
                self.text_disp.set_text(self.text)
                self._rendercursor()
                self._notify_change_observers()
                if key == "enter":
                    self._notify_submit_observers()
        def set_val(self, val):
            newval = str(val)
            if self.text == newval:
                return
            self.text = newval
            self.text_disp.set_text(self.text)
            self._rendercursor()
            self._notify_change_observers()
            self._notify_submit_observers()
        def _notify_change_observers(self):
            for cid, func in self.change_observers.items():
                func(self.text)
        def begin_typing(self, x):
            self.capturekeystrokes = True
            # disable command keys so that the user can type without
            # command keys causing figure to be saved, etc
            self.reset_params = {}
            for key in self.params_to_disable:
                self.reset_params[key] = rcParams[key]
                rcParams[key] = []
        def stop_typing(self):
            notifysubmit = False
            # because _notify_submit_users might throw an error in the
            # user's code, we only want to call it once we've already done
            # our cleanup.
            if self.capturekeystrokes:
                # since the user is no longer typing,
                # reactivate the standard command keys
                for key in self.params_to_disable:
                    rcParams[key] = self.reset_params[key]
                notifysubmit = True
            self.capturekeystrokes = False
            self.cursor.set_visible(False)
            self.ax.figure.canvas.draw()
            if notifysubmit:
                self._notify_submit_observers()
        def position_cursor(self, x):
            # now, we have to figure out where the cursor goes.
            # approximate it based on assuming all characters the same length
            if len(self.text) == 0:
                self.cursor_index = 0
            else:
                bb = self.text_disp.get_window_extent()
                trans = self.ax.transData
                inv = self.ax.transData.inverted()
                bb = trans.transform(inv.transform(bb))
                text_start = bb[0, 0]
                text_end = bb[1, 0]
                ratio = (x - text_start) / (text_end - text_start)
                if ratio < 0:
                    ratio = 0
                if ratio > 1:
                    ratio = 1
                self.cursor_index = int(len(self.text) * ratio)
            self._rendercursor()
        def _click(self, event):
            if self.ignore(event):
                return
            if event.inaxes != self.ax:
                self.stop_typing()
                return
            if not self.eventson:
                return
            if event.canvas.mouse_grabber != self.ax:
                event.canvas.grab_mouse(self.ax)
            if not (self.capturekeystrokes):
                self.begin_typing(event.x)
            self.position_cursor(event.x)
        def _resize(self, event):
            self.stop_typing()
        def _motion(self, event):
            if self.ignore(event):
                return
            if event.inaxes == self.ax:
                c = self.hovercolor
            else:
                c = self.color
            if c != self._lastcolor:
                self.ax.set_axis_bgcolor(c)
                self._lastcolor = c
                if self.drawon:
                    self.ax.figure.canvas.draw()
        def on_text_change(self, func):
            """
            When the text changes, call this *func* with event.
            A connection id is returned which can be used to disconnect.
            """
            cid = self.cnt
            self.change_observers[cid] = func
            self.cnt += 1
            return cid
        def on_submit(self, func):
            """
            When the user hits enter or leaves the submision box, call this
            *func* with event.
            A connection id is returned which can be used to disconnect.
            """
            cid = self.cnt
            self.submit_observers[cid] = func
            self.cnt += 1
            return cid
        def disconnect(self, cid):
            """remove the observer with connection id *cid*"""
            try:
                del self.observers[cid]
            except KeyError:
                pass
    matplotlib.widgets.TextBox = _TextBox
if __name__ == '__main__':
    from pylab import figure, imshow
    import uncertainties
    x = np.linspace(0, 2 * np.pi, 20)
    y = np.linspace(0, 3 * np.pi, 20)
    s = np.sin(x)
    z = []
    for i in range(10):
        z.append(np.sin(x) + i)
    z = np.array(z)
    xx, yy = np.meshgrid(x, y)
    zz = np.sin(xx) * np.cos(yy)
    xerr = abs(0.1 * np.random.random(x.shape) * x)
    serr = abs(0.1 * np.random.random(x.shape) * s)
    X = uncertainties.unumpy.uarray(x, xerr)
    S = uncertainties.unumpy.uarray(s, serr)
    figure('sin')
    pyplot.plot(x, s)
    figure('multi-sin')
    plotc(x, z.T)
    ax = pyplot.gca()
    XKCDify(ax, xaxis_loc=0)
    draw()
    map_HBS_to_RGB(np.array([1]), np.array([2]), np.array([0]))
    paths = contourPaths(x, y, zz, [-0.75, 0, 0.75])
    fig = figure('pcolormesh')
    pyplot.figure(num=fig.number)
    ax = pyplot.subplot(121)
    pcolor2(xx, yy, zz)
    ax2 = pyplot.subplot(122, sharey=ax, sharex=ax)
    im = blur_image(zz, 3)
    ax2.imshow(im, extent=(min(x), max(x), min(y), max(y)), origin='lower')
    set_fontsize(fontsize='+10')
    plot_equality_lims(ax=ax2)
    autofmt_sharey()
    hardcopy('/tmp/%s/test_hardcopy.pdf' % os.environ['USER'])
    fig = figure('uerrorbar')
    pyplot.figure(num=fig.number)
    ax = pyplot.subplot(211)
    uerrorbar(X, S, ax=ax)
    ax2 = pyplot.subplot(212, sharex=ax)
    uerrorbar(X, S, ax=ax2)
    plot_equality(X, S, ax=ax2)
    autofmt_sharex()
    fig = figure('uband')
    uband(X, S)