Source code for mcalf.visualisation.classifications

import glob

import numpy as np
from matplotlib import pyplot as plt, colors, cm
from matplotlib.gridspec import GridSpec


__all__ = ['plot_classifications', 'plot_class_map', 'plot_averaged_class_map']


[docs]def plot_classifications(class_map, spectra, labels, extent=(0, 200, 0, 200), xticks=(0, 15, 3), yticks=(0, 15, 3), xscale=0.725*0.097, yscale=0.725*0.097, output=None, figsize=None, dpi=600, fontfamily=None): """Plot the spectra separated into their classifications along with an example classified map. Must be 5 classifications. Parameters ---------- class_map : ndarray, ndim=2 Two-dimensional array of classifications. spectra : ndarray, ndim=2 Two-dimensional array with dimensions [spectra, wavelengths]. labels : ndarray, ndim=1, length of `spectra` List of classifications for each spectrum in `spectra`. output : str, optional, default = None If present, the filename to save the plot as. figsize : 2-tuple, optional, default = None Size of the figure. dpi : int, optional, default = 600 The number of dots per inch. For controlling the quality of the outputted figure. fontfamily : str, optional, default = None If provided, this family string will be added to the 'font' rc params group. vmin : float, optional, default = -max(|`velmap`|) Minimum velocity to plot. If not given, will be -vmax, for vmax not None. vmax : float, optional, default = max(|`velmap`|) Maximum velocity to plot. If not given, will be -vmin, for vmin not None. extent : 4-tuple, optional, default = (0, 200, 0, 200) Region the `velmap` is cropped to. xticks : 3-tuple, optional, default = (0, 15, 2) The start, stop and step for the x-axis ticks in Mm. yticks : 3-tuple, optional, default = (0, 15, 2) The start, stop and step for the y-axis ticks in Mm. xscale : float, optional = 0.725 * 0.097 Scaling factor between x-axis data coordinate steps and 1 Mm. Mm = data / xscale. yscale : float, optional = 0.725 * 0.097 Scaling factor between y-axis data coordinate steps and 1 Mm. Mm = data / xscale. """ if fontfamily is not None: plt.rc('font', family=fontfamily) fig = plt.figure(constrained_layout=True, figsize=figsize, dpi=dpi) gs = GridSpec(2, 3, figure=fig, wspace=0) ax1 = fig.add_subplot(gs[0, 0]) ax2 = fig.add_subplot(gs[0, 1]) ax3 = fig.add_subplot(gs[0, 2]) ax4 = fig.add_subplot(gs[1, 0]) ax5 = fig.add_subplot(gs[1, 1]) map_plot = fig.add_subplot(gs[1, 2]) axes = [ax1, ax2, ax3, ax4, ax5, map_plot] # Optimised for readers with color blindness cmap = colors.ListedColormap(['#0072b2', '#56b4e9', '#009e73', '#e69f00', '#d55e00']) for classification in range(5): n_plots = 0 # Number plotted for this classification ax = axes[classification] # Select the axis class_colour = cmap(classification) for j in range(len(labels)): if labels[j] == classification: n_plots += 1 ax.plot(spectra[j], linewidth=0.5, color=class_colour) if n_plots >= 20: # Only plot the first 20 of each classification break ax.set_xticks([]) # No wavelengths plotted ax.set_yticks([0, 1]) # Only show that intensity is scaled [0, 1] ax.margins(0) ax = axes[-1] # Classification map will be placed in the last axis cmap.set_bad(color='white') # Background for masked points class_map_float = np.asarray(class_map, dtype=float) class_map_float[class_map == -1] = np.nan classif_img = ax.imshow(class_map_float[::-1], cmap=cmap, vmin=-0.5, vmax=4.5, interpolation='nearest') ax.set_xlim(*extent[:2]), ax.set_ylim(*extent[2:]) xticks_Mm = np.arange(*xticks) xticks = (xticks_Mm / xscale) + extent[0] ax.set_xticks(xticks) ax.set_xticklabels(xticks_Mm) ax.set_xlabel('Distance (Mm)') yticks_Mm = np.arange(*yticks) yticks = (yticks_Mm / yscale) + extent[2] ax.set_yticks(yticks) ax.set_yticklabels(yticks_Mm) ax.set_ylabel('Distance (Mm)') cbar = fig.colorbar(classif_img, ax=axes, ticks=[0, 1, 2, 3, 4], orientation='horizontal', shrink=1, pad=0) cbar.ax.set_xticklabels(['0\nabsorption', '1', '2', '3', '4\nemission']) plt.show() if output is not None and isinstance(output, str): fig.savefig(output, bbox_inches='tight', dpi=dpi)
[docs]def plot_class_map(class_map, overall_classes=None, classes=None, time_index=None, cadence=None, xticks=(0, 15, 2), yticks=(0, 15, 2), xscale=0.725 * 0.097, yscale=0.725 * 0.097, output=None, file_prefix='classmap_plot_', file_ext='png', figsize=(5 * 3 / 2.5, 3 * 3 / 2.5), dpi=600, fontfamily=None, cache=False): """Plot an image of the classifications at a particular time along with bar charts of the classifications Parameters ---------- class_map : ndarray, ndim=2 or 3 Two-dimensional array of classifications. If three dimensions are given, the first dimension is assumed to represent the time. overall_classes : ndarray or bool, optional The percentage of spectra that belong to each classification in the overall dataset. If omitted, these will be calculated used all of the classifications given is `class_map`. If true is given, these will also be calculated in the same way and returned without any plotting done. (This returned array can then be used to speed up later calls of this function.) classes : ndarray, optional, default = ndarray of [0, 1, 2, 3, 4] Array of all the possible classifications in `class_map`. time_index : int, optional, default = 0 The index of the time dimension of `class_map`, required if class_map is 3D. Also used for plotting the time. cadence : float, units = seconds, optional, default = None If given, the time index will be multiplied by this value and converted into a time in minutes on the plot. Otherwise, the `time_index` will be plotted without units. xticks : 3-tuple, optional, default = (0, 15, 2) The start, stop and step for the x-axis ticks in Mm. yticks : 3-tuple, optional, default = (0, 15, 2) The start, stop and step for the y-axis ticks in Mm. xscale : float, optional = 0.725 * 0.097 Scaling factor between x-axis data coordinate steps and 1 Mm. Mm = data / xscale. yscale : float, optional = 0.725 * 0.097 Scaling factor between y-axis data coordinate steps and 1 Mm. Mm = data / xscale. output : str or bool, optional, default = None If present, the filename to save the plot as. If omitted, the plot will not be saved. If true, the filename will be generated using the `time_index` along with the `file_prefix` and `file_ext`. file_prefix : str, optional, default = 'classmap_plot_' The prefix to use in the filename when `output` is true. file_ext : str, optional, default = 'png' The file extension (without the dot) to use when `output` is true. figsize : 2-tuple, optional, default = None Size of the figure. dpi : int, optional, default = 600 The number of dots per inch. For controlling the quality of the outputted figure. fontfamily : str, optional, default = None If provided, this family string will be added to the 'font' rc params group. cache : bool, optional, default = False If true, the plot will not be regenerated if the output filename already exists. Returns ------- overall_classes : ndarray If `overall_classes` is initially true, their calculated values will be returned. """ if classes is None: classes = np.arange(5, dtype=int) if overall_classes is None or isinstance(overall_classes, bool): just_print_overall_classes = True if overall_classes else False overall_classes = class_map.flatten() counts = np.zeros(len(classes)) for i in classes: counts[i] = len(overall_classes[overall_classes == i]) overall_classes = counts / len(overall_classes) * 100 # Convert to percentage if just_print_overall_classes: return overall_classes if class_map.ndim == 3: if time_index is None: raise ValueError('A `time_index` must be specified as multiple time dimensions are in `class_map`.') class_map = class_map[time_index] else: if class_map.ndim != 2: raise ValueError('`class_map` must have either 2 or 3 dimensions, got %s' % class_map.ndim) if time_index is None: time_index = 0 if isinstance(output, bool) and output: output = '{}{:05d}.{}'.format(file_prefix, time_index, file_ext) if cache and output is not None and len(glob.glob(output)) > 0: return 0 if fontfamily is not None: plt.rc('font', family=fontfamily) time = time_index if cadence is None else time_index * cadence / 60 time_unit = '' if cadence is None else ' min' time_prefix = 't = ' if cadence is None else '' cmap_colors = np.array(['#0072b2', '#56b4e9', '#009e73', '#e69f00', '#d55e00'])[:len(classes)] cmap = colors.ListedColormap(cmap_colors) extent = (0, len(class_map[0]), 0, len(class_map)) cmap.set_bad(color='#999999', alpha=1) bar_colors = cmap(classes) fig = plt.figure(figsize=figsize, dpi=dpi, constrained_layout=True) gs = GridSpec(2, 3, figure=fig) ax1 = fig.add_subplot(gs[:, :2]) ax2 = fig.add_subplot(gs[0, 2]) ax3 = fig.add_subplot(gs[1, 2]) plt.sca(ax1) im = plt.imshow(class_map, cmap=cmap, vmin=min(classes)-0.5, vmax=max(classes)+0.5, extent=extent, interpolation='nearest') fig.colorbar(im, ax=ax1, ticks=classes, orientation='vertical', label='absorption' + ' ' * 41 + 'emission') ax1.set_title('Classifications at {}{:.2f}{}'.format(time_prefix, time, time_unit)) xticks_Mm = np.arange(*xticks) xticks = (xticks_Mm / xscale) + extent[0] ax1.set_xticks(xticks) ax1.set_xticklabels(xticks_Mm) ax1.set_xlabel('Distance (Mm)') yticks_Mm = np.arange(*yticks) yticks = (yticks_Mm / yscale) + extent[2] ax1.set_yticks(yticks) ax1.set_yticklabels(yticks_Mm) ax1.set_ylabel('Distance (Mm)') current_classes = class_map.flatten() counts = np.zeros(len(classes)) for i in classes: counts[i] = len(current_classes[current_classes == i]) current_classes = counts / len(current_classes) * 100 # Convert to percentage plt.sca(ax2) plt.bar(classes, current_classes, color=bar_colors) ax2.set_title('Current Classes (%)') plt.sca(ax3) plt.bar(classes, overall_classes, color=bar_colors) ax3.set_title('Overall Classes (%)') for ax in [ax2, ax3]: ax.set_xlabel(None) ax.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False) ax.set_yscale('log') ax.set_ylim(0.01, 100) plt.show() if output is not None and isinstance(output, str): fig.savefig(output, bbox_inches='tight', dpi=dpi)
[docs]def plot_averaged_class_map(class_map, classes=None, continuous=False, xticks=(0, 15, 2), yticks=(0, 15, 2), xscale=0.725 * 0.097, yscale=0.725 * 0.097, output=None, figsize=None, dpi=600, fontfamily=None): """Plot an image of the time averaged classifications Parameters ---------- class_map : ndarray, ndim=3 Three-dimensional array of classifications, with the times given in the first dimension. classes : ndarray, optional, default = ndarray of [0, 1, 2, 3, 4] Array of all the possible classifications in `class_map`. continuous : bool, optional, default = False Whether to plot the with a continuous color scale or round to the nearest classification. xticks : 3-tuple, optional, default = (0, 15, 2) The start, stop and step for the x-axis ticks in Mm. yticks : 3-tuple, optional, default = (0, 15, 2) The start, stop and step for the y-axis ticks in Mm. xscale : float, optional = 0.725 * 0.097 Scaling factor between x-axis data coordinate steps and 1 Mm. Mm = data / xscale. yscale : float, optional = 0.725 * 0.097 Scaling factor between y-axis data coordinate steps and 1 Mm. Mm = data / xscale. output : str, optional, default = None If present, the filename to save the plot as. If omitted, the plot will not be saved. figsize : 2-tuple, optional, default = None Size of the figure. dpi : int, optional, default = 600 The number of dots per inch. For controlling the quality of the outputted figure. fontfamily : str, optional, default = None If provided, this family string will be added to the 'font' rc params group. """ if classes is None: classes = np.arange(5, dtype=int) if class_map.ndim != 3: raise ValueError('`class_map` must have 3 dimensions, got %s' % class_map.ndim) class_map = np.mean(class_map, axis=0) if fontfamily is not None: plt.rc('font', family=fontfamily) if continuous: cmap = cm.get_cmap('binary_r') vmin = min(classes) vmax = max(classes) else: cmap_colors = np.array(['#0072b2', '#56b4e9', '#009e73', '#e69f00', '#d55e00'])[:len(classes)] cmap = colors.ListedColormap(cmap_colors) vmin = min(classes) - 0.5 vmax = max(classes) + 0.5 extent = (0, len(class_map[0]), 0, len(class_map)) cmap.set_bad(color='#999999', alpha=1) fig, ax = plt.subplots(figsize=figsize, dpi=dpi, constrained_layout=True) im = ax.imshow(class_map, cmap=cmap, vmin=vmin, vmax=vmax, extent=extent, interpolation='nearest') fig.colorbar(im, ax=ax, ticks=classes, orientation='vertical', label='absorption' + ' ' * 47 + 'emission') xticks_Mm = np.arange(*xticks) xticks = (xticks_Mm / xscale) + extent[0] ax.set_xticks(xticks) ax.set_xticklabels(xticks_Mm) ax.set_xlabel('Distance (Mm)') yticks_Mm = np.arange(*yticks) yticks = (yticks_Mm / yscale) + extent[2] ax.set_yticks(yticks) ax.set_yticklabels(yticks_Mm) ax.set_ylabel('Distance (Mm)') plt.show() if output is not None and isinstance(output, str): fig.savefig(output, bbox_inches='tight', dpi=dpi)