Source code for pypianoroll.visualization

"""Visualization tools.

Functions
---------

- plot_multitrack
- plot_pianoroll
- plot_track

"""
from typing import TYPE_CHECKING, List, Optional, Sequence

import matplotlib
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from matplotlib.patches import Patch
from numpy import ndarray
from pretty_midi import (
    note_number_to_drum_name,
    note_number_to_name,
    program_to_instrument_class,
    program_to_instrument_name,
)

if TYPE_CHECKING:
    from .multitrack import Multitrack
    from .track import Track

__all__ = ["plot_multitrack", "plot_pianoroll", "plot_track"]


[docs]def plot_pianoroll( ax: Axes, pianoroll: ndarray, is_drum: bool = False, resolution: Optional[int] = None, downbeats: Optional[Sequence[int]] = None, preset: str = "full", cmap: str = "Blues", xtick: str = "auto", ytick: str = "octave", xticklabel: bool = True, yticklabel: str = "auto", tick_loc: Sequence[str] = ("bottom", "left"), tick_direction: str = "in", label: str = "both", grid_axis: str = "both", grid_linestyle: str = ":", grid_linewidth: float = 0.5, **kwargs, ): """ Plot a piano roll. Parameters ---------- ax : :class:`matplotlib.axes.Axes` Axes to plot the piano roll on. pianoroll : ndarray, shape=(?, 128), (?, 128, 3) or (?, 128, 4) Piano roll to plot. For a 3D piano-roll array, the last axis can be either RGB or RGBA. is_drum : bool Whether it is a percussion track. Defaults to False. resolution : int Time steps per quarter note. Required if `xtick` is 'beat'. downbeats : list Boolean array that indicates whether the time step contains a downbeat (i.e., the first time step of a bar). preset : {'full', 'frame', 'plain'} Preset theme. For 'full' preset, ticks, grid and labels are on. For 'frame' preset, ticks and grid are both off. For 'plain' preset, the x- and y-axis are both off. Defaults to 'full'. cmap : str or :class:`matplotlib.colors.Colormap` Colormap. Will be passed to :func:`matplotlib.pyplot.imshow`. Only effective when `pianoroll` is 2D. Defaults to 'Blues'. xtick : {'auto', 'beat', 'step', 'off'} Tick format for the x-axis. For 'auto' mode, set to 'beat' if `resolution` is given, otherwise set to 'step'. Defaults to 'auto'. ytick : {'octave', 'pitch', 'off'} Tick format for the y-axis. Defaults to 'octave'. xticklabel : bool Whether to add tick labels along the x-axis. yticklabel : {'auto', 'name', 'number', 'off'} Tick label format for the y-axis. For 'name' mode, use pitch name as tick labels. For 'number' mode, use pitch number. For 'auto' mode, set to 'name' if `ytick` is 'octave' and 'number' if `ytick` is 'pitch'. Defaults to 'auto'. tick_loc : sequence of {'bottom', 'top', 'left', 'right'} Tick locations. Defaults to `('bottom', 'left')`. tick_direction : {'in', 'out', 'inout'} Tick direction. Defaults to 'in'. label : {'x', 'y', 'both', 'off'} Whether to add labels to x- and y-axes. Defaults to 'both'. grid_axis : {'x', 'y', 'both', 'off'} Whether to add grids to the x- and y-axes. Defaults to 'both'. grid_linestyle : str Grid line style. Will be passed to :meth:`matplotlib.axes.Axes.grid`. grid_linewidth : float Grid line width. Will be passed to :meth:`matplotlib.axes.Axes.grid`. **kwargs Keyword arguments to be passed to :meth:`matplotlib.axes.Axes.imshow`. """ # Plot the piano roll if pianoroll.ndim == 2: transposed = pianoroll.T elif pianoroll.ndim == 3: transposed = pianoroll.transpose(1, 0, 2) else: raise ValueError("`pianoroll` must be a 2D or 3D numpy array") img = ax.imshow( transposed, cmap=cmap, aspect="auto", vmin=0, vmax=1 if pianoroll.dtype == np.bool_ else 127, origin="lower", interpolation="none", **kwargs, ) # Format ticks and labels if xtick == "auto": xtick = "beat" if resolution is not None else "step" elif xtick not in ("beat", "step", "off"): raise ValueError( "`xtick` must be one of 'auto', 'beat', 'step' or 'off', not " f"{xtick}." ) if yticklabel == "auto": yticklabel = "name" if ytick == "octave" else "number" elif yticklabel not in ("name", "number", "off"): raise ValueError( "`yticklabel` must be one of 'auto', 'name', 'number' or 'off', " f"{yticklabel}." ) if preset == "full": ax.tick_params( direction=tick_direction, bottom=("bottom" in tick_loc), top=("top" in tick_loc), left=("left" in tick_loc), right=("right" in tick_loc), labelbottom=xticklabel, labelleft=(yticklabel != "off"), labeltop=False, labelright=False, ) elif preset == "frame": ax.tick_params( direction=tick_direction, bottom=False, top=False, left=False, right=False, labelbottom=False, labeltop=False, labelleft=False, labelright=False, ) elif preset == "plain": ax.axis("off") else: raise ValueError( f"`preset` must be one of 'full', 'frame' or 'plain', not {preset}" ) # Format x-axis if xtick == "beat" and preset != "frame": if resolution is None: raise ValueError( "`resolution` must not be None when `xtick` is 'beat'." ) n_beats = pianoroll.shape[0] // resolution ax.set_xticks(resolution * np.arange(n_beats) - 0.5) ax.set_xticklabels("") ax.set_xticks( resolution * (np.arange(n_beats) + 0.5) - 0.5, minor=True ) ax.set_xticklabels(np.arange(1, n_beats + 1), minor=True) ax.tick_params(axis="x", which="minor", width=0) # Format y-axis if ytick == "octave": ax.set_yticks(np.arange(0, 128, 12)) if yticklabel == "name": ax.set_yticklabels(["C{}".format(i - 2) for i in range(11)]) elif ytick == "step": ax.set_yticks(np.arange(0, 128)) if yticklabel == "name": if is_drum: ax.set_yticklabels( [note_number_to_drum_name(i) for i in range(128)] ) else: ax.set_yticklabels( [note_number_to_name(i) for i in range(128)] ) elif ytick != "off": raise ValueError( f"`ytick` must be one of 'octave', 'pitch' or 'off', not {ytick}." ) # Format axis labels if label not in ("x", "y", "both", "off"): raise ValueError( f"`label` must be one of 'x', 'y', 'both' or 'off', not {label}." ) if label in ("x", "both"): if xtick == "step" or not xticklabel: ax.set_xlabel("time (step)") else: ax.set_xlabel("time (beat)") if label in ("y", "both"): if is_drum: ax.set_ylabel("key name") else: ax.set_ylabel("pitch") # Plot the grid if grid_axis not in ("x", "y", "both", "off"): raise ValueError( "`grid` must be one of 'x', 'y', 'both' or 'off', not " f"{grid_axis}." ) if grid_axis != "off": ax.grid( axis=grid_axis, color="k", linestyle=grid_linestyle, linewidth=grid_linewidth, ) # Plot downbeat boundaries if downbeats is not None: for downbeat in downbeats: ax.axvline(x=downbeat, color="k", linewidth=1) return img
[docs]def plot_track(track: "Track", ax: Optional[Axes] = None, **kwargs) -> Axes: """ Plot a track. Parameters ---------- track : :class:`pypianoroll.Track` Track to plot. ax : :class:`matplotlib.axes.Axes` Axes to plot the piano roll on. Defaults to call `plt.gca()`. **kwargs Keyword arguments to pass to :func:`pypianoroll.plot_pianoroll`. Returns ------- :class:`matplotlib.axes.Axes` (Created) Axes object. """ if ax is None: ax = plt.gca() plot_pianoroll(ax, track.pianoroll, track.is_drum, **kwargs) return ax
def _get_track_label(track_label, track=None): """Return corresponding track labels.""" if track_label == "name": return track.name if track_label == "program": return program_to_instrument_name(track.program) if track_label == "family": return program_to_instrument_class(track.program) return track_label def _add_tracklabel(ax, track_label, track=None): """Add a track label to an axis.""" if not ax.get_ylabel(): return ax.set_ylabel( f"{_get_track_label(track_label, track)}\n\n{ax.get_ylabel()}" )
[docs]def plot_multitrack( multitrack: "Multitrack", axs: Optional[Sequence[Axes]], mode: str = "separate", track_label: str = "name", preset: str = "full", cmaps: Optional[Sequence[str]] = None, xtick: str = "auto", ytick: str = "octave", xticklabel: bool = True, yticklabel: str = "auto", tick_loc: Sequence[str] = ("bottom", "left"), tick_direction: str = "in", label: str = "both", grid_axis: str = "both", grid_linestyle: str = ":", grid_linewidth: float = 0.5, **kwargs, ) -> List[Axes]: """ Plot the multitrack. Parameters ---------- multitrack : :class:`pypianoroll.Multitrack` Multitrack to plot. axs : sequence of :class:`matplotlib.axes.Axes` Axes to plot the tracks on. mode : {'separate', 'blended', 'hybrid'} Plotting strategy for visualizing multiple tracks. For 'separate' mode, plot each track separately. For 'blended', blend and plot the pianoroll as a colored image. For 'hybrid' mode, drum tracks are blended into a 'Drums' track and all other tracks are blended into an 'Others' track. Defaults to 'separate'. track_label : {'name', 'program', 'family', 'off'} Track label format. When `mode` is 'hybrid', all options other than 'off' will label the two track with 'Drums' and 'Others'. preset : {'full', 'frame', 'plain'} Preset theme to use. For 'full' preset, ticks, grid and labels are on. For 'frame' preset, ticks and grid are both off. For 'plain' preset, the x- and y-axis are both off. Defaults to 'full'. cmaps : tuple or list Colormaps. Will be passed to :func:`matplotlib.pyplot.imshow`. Only effective when `pianoroll` is 2D. Defaults to 'Blues'. If `mode` is 'separate', defaults to `('Blues', 'Oranges', 'Greens', 'Reds', 'Purples', 'Greys')`. If `mode` is 'blended', defaults to `('hsv')`. If `mode` is 'hybrid', defaults to `('Blues', 'Greens')`. **kwargs Keyword arguments to pass to :func:`pypianoroll.plot_pianoroll`. Returns ------- list of :class:`matplotlib.axes.Axes` (Created) list of Axes objects. """ if not multitrack.tracks: raise RuntimeError("There is no track to plot.") if track_label not in ("name", "program", "family", "off"): raise ValueError( "`track_label` must be one of 'name', 'program' or 'family', not " f"{track_label}." ) if axs is not None and not isinstance(axs, list): axs = list(axs) # Set default color maps if cmaps is None: if mode == "separate": cmaps = ("Blues", "Oranges", "Greens", "Reds", "Purples", "Greys") elif mode == "blended": cmaps = ("hsv",) else: cmaps = ("Blues", "Greens") n_tracks = len(multitrack.tracks) downbeats = multitrack.get_downbeat_steps() if mode == "separate": if axs is None: if n_tracks > 1: fig, axs_ = plt.subplots(n_tracks, sharex=True) fig.subplots_adjust(hspace=0) axs = axs_.tolist() else: fig, ax = plt.subplots() axs = [ax] for idx, track in enumerate(multitrack.tracks): now_xticklabel = xticklabel if idx < n_tracks else False plot_pianoroll( ax=axs[idx], pianoroll=track.pianoroll, is_drum=False, resolution=multitrack.resolution, downbeats=downbeats, preset=preset, cmap=cmaps[idx % len(cmaps)], xtick=xtick, ytick=ytick, xticklabel=now_xticklabel, yticklabel=yticklabel, tick_loc=tick_loc, tick_direction=tick_direction, label=label, grid_axis=grid_axis, grid_linestyle=grid_linestyle, grid_linewidth=grid_linewidth, **kwargs, ) if track_label != "none": _add_tracklabel(axs[idx], track_label, track) elif mode == "blended": is_all_drum = True for track in multitrack.tracks: if not track.is_drum: is_all_drum = False if axs is None: fig, ax = plt.subplots() axs = [ax] stacked = multitrack.stack() colormap = matplotlib.cm.get_cmap(cmaps[0]) colormatrix = colormap(np.arange(0, 1, 1 / n_tracks))[:, :3] recolored = np.clip( np.matmul(stacked.reshape(-1, n_tracks), colormatrix), 0, 1 ) blended = recolored.reshape(stacked.shape[1:] + (3,)) plot_pianoroll( ax=axs[0], pianoroll=blended, is_drum=is_all_drum, resolution=multitrack.resolution, downbeats=downbeats, preset=preset, xtick=xtick, ytick=ytick, xticklabel=xticklabel, yticklabel=yticklabel, tick_loc=tick_loc, tick_direction=tick_direction, label=label, grid_axis=grid_axis, grid_linestyle=grid_linestyle, grid_linewidth=grid_linewidth, **kwargs, ) if track_label != "none": patches = [ Patch( color=colormatrix[idx], label=_get_track_label(track_label, track), ) for idx, track in enumerate(multitrack.tracks) ] plt.legend(handles=patches) elif mode == "hybrid": drums = multitrack.copy() drums.tracks = [track for track in multitrack.tracks if track.is_drum] merged_drums = drums.blend() others = multitrack.copy() others.tracks = [ track for track in multitrack.tracks if not track.is_drum ] merged_others = others.blend() if axs is None: fig, axs_ = plt.subplots(2, sharex=True, sharey=True) axs = axs_.tolist() plot_pianoroll( axs[0], merged_drums, True, multitrack.resolution, downbeats, preset=preset, cmap=cmaps[0], xtick=xtick, ytick=ytick, xticklabel=xticklabel, yticklabel=yticklabel, tick_loc=tick_loc, tick_direction=tick_direction, label=label, grid_axis=grid_axis, grid_linestyle=grid_linestyle, grid_linewidth=grid_linewidth, **kwargs, ) plot_pianoroll( axs[1], merged_others, False, multitrack.resolution, downbeats, preset=preset, cmap=cmaps[1], ytick=ytick, xticklabel=xticklabel, yticklabel=yticklabel, tick_loc=tick_loc, tick_direction=tick_direction, label=label, grid_axis=grid_axis, grid_linestyle=grid_linestyle, grid_linewidth=grid_linewidth, **kwargs, ) fig.subplots_adjust(hspace=0) if track_label != "none": _add_tracklabel(axs[0], "Drums") _add_tracklabel(axs[1], "Others") else: raise ValueError( "`mode` must be one of 'separate', 'blended' or 'hybrid', not" f"{mode}." ) return axs # type: ignore