Last updated on Jun 03, 2026.

Source code for netsse.tools.viz

# -*- coding: utf-8 -*-
"""
**Data visualisation** functions for NetSSE.

.. dropdown:: Copyright (C) 2023-2026 Technical University of Denmark, R.E.G. Mounet
    :color: primary
    :icon: law

    *This code is part of the NetSSE software.*

    NetSSE is free software: you can redistribute it and/or modify it under
    the terms of the GNU General Public License as published by the Free
    Software Foundation, either version 3 of the License, or (at your
    option) any later version.

    NetSSE is distributed in the hope that it will be useful, but WITHOUT
    ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
    FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
    for more details.

    You should have received a copy of the GNU General Public License along
    with this program.  If not, see https://www.gnu.org/licenses/.

    To credit the author, users are encouraged to use below reference:

    .. code-block:: text

        Mounet, R. E. G., & Nielsen, U. D. NetSSE: An open-source Python package
        for network-based sea state estimation from ships, buoys, and other
        observation platforms (version 2.2). Technical University of Denmark,
        GitLab. February 2026. https://doi.org/10.11583/DTU.26379811.

*Last updated on 17-03-2026 by R.E.G. Mounet*

"""

import os
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import rc
from matplotlib.colors import LogNorm, Normalize
import matplotlib.cm as cm
import cartopy.crs as ccrs
import numpy as np
from netsse.tools.misc_func import re_range


[docs] def fig_settings(font="serif", fsz=12, lw=1, dpi=200, format="pdf", path="../Figures/"): """Sets the default figure settings for NetSSE visualisations. The settings are applied globally to all figures created after calling this function. Parameters ---------- font : {'serif', 'sans-serif'}, optional Font family to use for the figures: ``'sans-serif'`` uses Helvetica font, while ``'serif'`` uses Palatino font. Default is ``'serif'``. fsz : int, optional Font size for text in the figures. Default is ``12``. lw : float, optional Line width for the figures. Default is ``1.5``. dpi : int, optional Dots per inch (dpi) for the figures. Default is ``100``. format : str, optional Format for saving the figures. Default is ``"pdf"``. path : str, optional Path for saving the figures. Default is ``"../Figures/"``. Returns ------- fig_path : str Path for saving the figures. colors : dict Dictionary of colorblind-friendly colors for plotting. props : dict Dictionary of properties for text and legend boxes. Example ------- >>> fig_path, colors, props = fig_settings() """ plt.close("all") # Settings for publication-quality figures: match font: case "sans-serif": # for Helvetica and other sans-serif fonts use rc("font", **{"family": "sans-serif", "sans-serif": ["Helvetica"]}) case "serif": # for Palatino and other serif fonts use: rc("font", **{"family": "serif", "serif": ["Palatino"]}) rc("text", usetex=True) # Basic settings for all figures: plt.rcParams["font.size"] = fsz plt.rcParams["lines.linewidth"] = lw plt.rcParams["figure.dpi"] = dpi plt.rcParams["savefig.format"] = format # Create the directory for saving the figures if it does not exist: if not path.endswith("/"): path += "/" if not os.path.exists(path): os.makedirs(path) plt.rcParams["savefig.directory"] = path fig_path = path # Colorblind-friendly colors: colors = { "blue": "#377eb8", "orange": "#ff7f00", "green": "#4daf4a", "pink": "#f781bf", "brown": "#a65628", "purple": "#984ea3", "gray": "#999999", "red": "#e41a1c", "yellow": "#dede00", } # Settings for text and legend boxes: props = dict(facecolor="whitesmoke", alpha=0.8) plt.rcParams["legend.frameon"] = True plt.rcParams["legend.framealpha"] = 0.8 plt.rcParams["legend.fancybox"] = False plt.rcParams["legend.edgecolor"] = "black" plt.rcParams["legend.facecolor"] = "whitesmoke" plt.rcParams["legend.fontsize"] = fsz * 0.9 plt.rcParams["legend.loc"] = "upper right" # Settings for ticks: plt.rcParams["xtick.direction"] = "in" plt.rcParams["ytick.direction"] = "in" plt.rcParams["xtick.top"] = True plt.rcParams["ytick.right"] = True # Settings for grid: plt.rcParams["grid.linestyle"] = ":" plt.rcParams["grid.color"] = "black" plt.rcParams["grid.alpha"] = 0.5 plt.rcParams["grid.linewidth"] = 0.5 return fig_path, colors, props
[docs] def plot_dirwavespec( spec, dirs, freqs, unit_dirs="deg", unit_freqs="Hz", levels=None, cmap="turbo", vmin=None, vmax=None, freq_max=None, show_dirlabels=True, show_freqlabels=True, rscale="linear", ax=None, ): """Plots the input directional wave spectrum in a polar diagram. The plot shows the power spectral density (PSD) distribution over wave frequencies and directions. Parameters ---------- spec : array_like Directional wave spectrum to be plotted. This must be a 2-D array of shape either `(Ndirs,Nfreqs)` or `(Nfreqs,Ndirs)`. dirs : array_like of shape (`Ndirs`,) Directions of the spectrum. freqs : array_like of shape (`Nfreqs`,) Frequencies of the spectrum. unit_dirs : {'deg','rad'}, optional Unit of the directions, either ``'deg'`` for degrees or ``'rad'`` for radians. Default is ``'deg'``. unit_freqs : {'Hz','rad/s'}, optional Unit of the frequencies, either ``'Hz'`` for Hertz or ``'rad/s'`` for radians per second. Default is ``'Hz'``. levels : array_like, optional Contour levels to use for the plot. If ``None``, levels are automatically determined according to the formula below: :math:`vmin*(vmax/vmin)^{i/29}` for :math:`i=0,1,...,29`. Default is ``None``. cmap : str, optional Colormap to use for the plot. Default is ``'turbo'``. .. note:: Other colormap options include ``'viridis'``, ``'plasma'``, ``'inferno'``, ``'magma'``. Visit the `Matplotlib documentation <https://matplotlib.org/stable/tutorials/colors/colormaps.html>`_ for an overview of the options. vmin : float, optional Minimum value for the colormap. Default is ``None``, which means to use the value :math:`10^{-1/2}`. vmax : float, optional Maximum value for the colormap. Default is ``None``, which means to use the maximum value in ``spec``. .. warning:: ``vmin`` and ``vmax`` are used to set the colormap limits. These options will override the lower and upper limits of ``levels``, if the latter option is used. freq_max : float, optional Maximum frequency to plot. Default is ``None``, which means to use the maximum frequency in the input data. show_dirlabels : {True,False,'compass'}, optional Whether to show direction labels. If ``True``, shows labels in degrees. If ``'compass'``, shows compass directions. If ``False``, no labels are shown. Default is ``True``. show_freqlabels : {True,False,'all'}, optional Whether to show frequency labels. If ``True``, shows only the maximum frequency label. If ``'all'``, shows all labels. If ``False``, no labels are shown. Default is ``True``. rscale : {'linear', 'log'}, optional Scale for the radial axis (frequencies). Default is ``'linear'``. ax : matplotlib.axes, optional Axes object to plot on. If ``None``, a new figure and axes are created. Default is ``None``. Returns ------- ax : matplotlib.axes._subplots.PolarAxesSubplot The axes object with the plot. Example ------- >>> # Generate some example data >>> freqs = np.linspace(0.05, 0.5, 50) >>> dirs = np.linspace(0, 360, 36) >>> spec = np.random.rand(36, 50) >>> # Plot the directional wave spectrum >>> plot_dirwavespec(spec, dirs, freqs) """ # Identify the dimensions of the spectra: Nfreqs = np.size(freqs) Ndirs = np.size(dirs) ax_freqs = np.where(np.array(spec.shape) == Nfreqs)[0][0] ax_dirs = np.where(np.array(spec.shape) == Ndirs)[0][0] spec_plot = np.transpose(spec, (ax_dirs, ax_freqs)) # Headings in radians: if unit_dirs == "deg": dirs_rad = np.radians(dirs) else: dirs_rad = dirs dirs_rad = re_range(dirs_rad.flatten(), 0, unit="rad") # Ensure that the headings are cyclic: if np.round(dirs_rad[-1,], 4) == np.round(dirs_rad[0,], 4): dirs_rad[-1,] += 2 * np.pi else: dirs_rad = np.append(dirs_rad, dirs_rad[0,] + 2 * np.pi) Ndirs += 1 spec_plot = np.append(spec_plot, spec_plot[0, :].reshape(1, Nfreqs), axis=0) # Make a meshgrid in angles and frequencies: Headings, Freq = np.meshgrid(dirs_rad, freqs) if rscale == "log": # Shift the bullseye up by 0.3 decades to avoid it being at the center of the plot: bullseye = 0.3 min_freq = np.min(Freq) if min_freq <= 0: raise ValueError( "Frequencies must be positive for logarithmic scale. Found minimum frequency: %s" % min_freq ) min10 = np.log10(np.min(Freq)) max10 = np.log10(np.max(Freq)) Freq = np.log10(Freq) - min10 + bullseye # Select the levels for the contour plot: spec_plot[np.isnan(spec_plot)] = -1 if vmin is None: min_val = 10 ** (-0.5) else: min_val = vmin if vmax is None: max_val = np.amax(spec_plot) else: max_val = vmax if levels is None: levels = 10 ** np.linspace(np.log10(min_val), np.log10(max_val), 30) else: if vmin is not None: levels[0] = vmin if vmax is not None: levels[-1] = vmax norm = cm.colors.LogNorm() spec_plot = np.ma.masked_where(spec_plot <= 0, spec_plot) # Plot the 2-D wave spectrum: if ax is None: fig, ax = plt.subplots(subplot_kw={"projection": "polar"}, figsize=(4, 4)) else: fig = ax.get_figure() rows, cols, start, _ = ax.get_subplotspec().get_geometry() ax.remove() ax = fig.add_subplot(rows, cols, start + 1, projection="polar") ax.set_theta_direction(-1) ax.set_theta_zero_location("N") ax.contour( Headings, Freq, spec_plot.T, levels=levels, norm=norm, linewidths=1, cmap=cmap ) # Set the labels and grid: ax.set_xticks(np.arange(np.pi / 6, 13 * np.pi / 6, np.pi / 6)) match show_dirlabels: case True: ax.set_xticklabels( [ "", "", r"90$^\circ$", "", "", r"180$^\circ$", "", "", r"270$^\circ$", "", "", r"0$^\circ$", ] ) case "compass": ax.set_xticklabels(["", "", "E", "", "", "S", "", "", "W", "", "", "N"]) case False: ax.set_xticklabels([""] * 12) if freq_max is None: freq_max = np.max(np.round(freqs, decimals=1)) match rscale: case "linear": dfreq = 0.1 * (unit_freqs == "Hz") + 0.5 * (unit_freqs == "rad/s") rticks = np.arange(0, freq_max, dfreq) + dfreq ax.set_rticks(rticks) ax.set_rlim(0, freq_max) case "log": # dfreq = np.log10(dfreq) freq_max10 = np.log10(freq_max) max10 = np.min([freq_max10, max10]) rticks = np.arange(np.floor(min10), max10) ax.set_rticks(rticks - min10 + bullseye) ax.set_rlim(0, max10 - min10 + bullseye) match show_freqlabels: case True: match rscale: case "linear": ax.set_yticklabels( ["" for _ in rticks[:-1]] + ["%.1f [%s]" % (rticks[-1], unit_freqs)] ) case "log": ax.set_yticklabels( ["" for _ in rticks[:-1]] + ["1e%d [%s]" % (rticks[-1], unit_freqs)] ) case "all": match rscale: case "linear": freq_ticklabels = ["%.1f" % x for x in rticks] freq_ticklabels[-1] += " [%s]" % (unit_freqs) ax.set_yticklabels(freq_ticklabels) case "log": freq_ticklabels = ["1e%d" % x for x in rticks] freq_ticklabels[-1] += " [%s]" % (unit_freqs) ax.set_yticklabels(freq_ticklabels) case False: ax.set_yticklabels(["" for _ in rticks]) ax.grid(linestyle="--", color="black", linewidth=0.5) return ax
[docs] def plot_bathymetry(depths, shp_dict, ax, cmap="Blues_r"): """Plots the bathymetry map for a given area. The function plots the bathymetry map for a given area using the shapefiles contained in the dictionary `shp_dict`. The shapefiles are sorted by depth, from the surface to the bottom. Parameters ---------- depths : list List of the depths in the shapefiles for the specified area. shp_dict : dict Dictionary containing the shapefiles. ax : matplotlib.axes Axes object to plot on. cmap : str, optional Colormap to use for the plot. Default is ``'Blues_r'``. .. note:: Other colormap options include ``'plasma'``, ``'inferno'``, ``'magma'``, ``'viridis'``. Visit the `Matplotlib documentation <https://matplotlib.org/stable/tutorials/colors/colormaps.html>`_ for an overview of the options. Returns ------- ax : matplotlib.axes._subplots.AxesSubplot The axes object with the plot. colormap : matplotlib.colors.ListedColormap The colormap used for the plot. See Also -------- netsse.model.bathymetry.load_bathymetry : Retrieve and read bathymetry shapefiles. Example ------- .. code-block:: python import matplotlib.pyplot as plt import cartopy.crs as ccrs import cartopy.feature as cfeature from netsse.model.bathymetry import load_bathymetry from netsse.tools.viz import plot_bathymetry depths, shp_dict = load_bathymetry(lonmin=-5, lonmax=15, latmin=35, latmax=45) fig, ax = plt.subplots(subplot_kw={'projection':ccrs.Mercator(central_longitude=5,min_latitude=35,max_latitude=45)},figsize=(6,6)) ax.set_extent([-5, 15, 35, 45], crs=ccrs.PlateCarree()) ax, colormap = plot_bathymetry(depths, shp_dict, ax) ax.add_feature(cfeature.LAND,edgecolor='black',facecolor='gainsboro',alpha=0.5) """ # Construct a discrete colormap with colors corresponding to each depth depths_int = depths.astype(int) N = len(depths) nudge = 0.01 # shift bin edge slightly to include data boundaries = [min(depths_int)] + sorted(depths_int + nudge) # low to high norm = matplotlib.colors.BoundaryNorm(boundaries, N) colormap = matplotlib.colormaps[cmap].resampled(N) colors_depths = colormap(norm(depths_int)) # Iterate and plot bathymetry feature for each depth level for i, depth in enumerate(depths): ax.add_geometries( shp_dict[depth].geometries(), crs=ccrs.PlateCarree(), color=colors_depths[i] ) # Convert vector bathymetries to raster (saves a lot of disk space) while leaving labels as vectors: # ax.set_rasterized(True) return ax, colormap
[docs] def plot_sea_state_scatter( y_data, x_data, y_bin=0.25, x_bin=0.20, y_range=None, x_range=None, y_label=None, x_label=None, cmap="viridis", norm="log", # "log" or "linear" annotate=True, min_count_annot=1, # minimum count to annotate a cell ax=None, colorbar=True, grid=True, ): """ Create a sea state scatter diagram (2-D histogram heatmap) of counts for two wave parameters. Parameters ---------- y_data : array-like Values to be displayed on the vertical axis. Typically, this some wave height parameter [in metres]. x_data : array-like Values to be displayed on the horizontal axis. Typically, this is some wave period parameter [in seconds]. Must be same length as ``y_data``. y_bin : float, optional Bin size for ``y_data``. Default is 0.25 [m]. x_bin : float, optional Bin size for ``x_data``. Default is 0.20 [s]. y_range : tuple(float, float), optional (min, max) explicit range for ``y_data``. If ``None``, inferred from data. x_range : tuple(float, float), optional (min, max) explicit range for ``x_data``. If ``None``, inferred from data. y_label : str, optional. Y-axis label for scatter diagram. If ``None``, the scatter diagram display a `y`-axis label. x_label : str, optional. X-axis label for scatter diagram. If ``None``, the scatter diagram display an `x`-axis label. cmap : str, optional Matplotlib colormap name. norm : {"log", "linear"}, optional Color normalization. ``"log"`` helps when counts vary widely. annotate : bool, optional If ``True``, annotate counts inside cells. min_count_annot : int, optional Minimum count to annotate a cell (avoids clutter). ax : matplotlib.axes.Axes, optional Axes to draw on. If ``None``, a new figure and axes are created. colorbar : bool, optional If ``True``, add a colorbar. grid : bool, optional If ``True``, show grid lines. Returns ------- fig : matplotlib.figure.Figure ax : matplotlib.axes.Axes H : 2-D ndarray Counts matrix of shape ``(len(y_edges)-1, len(x_edges)-1)``, where rows correspond to ``y_data`` bins and columns to ``x_data`` bins. y_edges : 1-D ndarray Bin edges for ``x_data``. x_edges : 1-D ndarray Bin edges for ``y_data``. Notes ----- NaNs and non-finite values are ignored. Example ------- >>> fig, ax, counts, hs_edges, tp_edges = plot_sea_state_scatter( ... Hs, Tp, ... y_bin=0.25, ... x_bin=0.20, ... norm="log", # try "linear" if you prefer ... annotate=False, # set True to print counts in cells ... ) """ y = np.asarray(y_data, dtype=float) x = np.asarray(x_data, dtype=float) # Remove NaNs and non-finite values mask = np.isfinite(y) & np.isfinite(x) y = y[mask] x = x[mask] if y.size == 0 or x.size == 0: raise ValueError("No finite data in y_data/x_data after filtering.") # Define bin edges if y_range is None: y_min = np.floor(np.nanmin(y) / y_bin) * y_bin y_max = np.ceil(np.nanmax(y) / y_bin) * y_bin else: y_min, y_max = y_range if x_range is None: x_min = np.floor(np.nanmin(x) / x_bin) * x_bin x_max = np.ceil(np.nanmax(x) / x_bin) * x_bin else: x_min, x_max = x_range # Ensure at least one bin if y_max <= y_min: y_max = y_min + y_bin if x_max <= x_min: x_max = x_min + x_bin y_edges = np.arange(y_min, y_max + y_bin, y_bin) x_edges = np.arange(x_min, x_max + x_bin, x_bin) # 2D histogram: rows = y bins, cols = x bins H, y_edges_h2d, x_edges_h2d = np.histogram2d(y, x, bins=[y_edges, x_edges]) # The returned edges are the same arrays we gave; keep originals for clarity H = H.astype(int) # Prepare axes created_fig = False if ax is None: fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True) created_fig = True else: fig = ax.figure # Choose normalization match norm: case "log": # Use LogNorm but avoid zero issues by setting vmin to 1 if any positive counts exist positive = H[H > 0] if positive.size > 0: vmin = max(1, positive.min()) norm_obj = LogNorm(vmin=vmin, vmax=positive.max()) else: norm_obj = Normalize(vmin=0, vmax=1) case "linear": norm_obj = Normalize(vmin=0, vmax=H.max() if H.max() > 0 else 1) case _: raise ValueError('norm must be "log" or "linear".') # pcolormesh expects bin edges; note that X=Tp (cols), Y=Hs (rows) mesh = ax.pcolormesh(x_edges, y_edges, H, cmap=cmap, norm=norm_obj, shading="auto") # Labels and formatting if x_label is not None: ax.set_xlabel(x_label) if y_label is not None: ax.set_ylabel(y_label) if grid: ax.grid(True, which="both", linestyle=":", alpha=0.5) ax.set_ylim(y_edges[0], y_edges[-1]) ax.set_xlim(x_edges[0], x_edges[-1]) # Colorbar if colorbar: cbar = fig.colorbar(mesh, ax=ax, pad=0.02) cbar.set_label("Count") # Optional annotations (center each cell) if annotate: # Compute cell centers y_centers = 0.5 * (y_edges[:-1] + y_edges[1:]) x_centers = 0.5 * (x_edges[:-1] + x_edges[1:]) for i, yc in enumerate(y_centers): for j, xc in enumerate(x_centers): count = H[i, j] if count >= min_count_annot: ax.text( xc, yc, str(int(count)), ha="center", va="center", fontsize=8, color="white", path_effects=[], ) if created_fig: plt.show() return fig, ax, H, y_edges, x_edges