# -*- coding: utf-8 -*-
"""
**Data visualisation** functions for NetSSE.
.. dropdown:: Copyright (C) 2023-2026 Technical University of Denmark, R.E.G. Mounet
:color: primary
:icon: law
*This code is part of the NetSSE software.*
NetSSE is free software: you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by the Free
Software Foundation, either version 3 of the License, or (at your
option) any later version.
NetSSE is distributed in the hope that it will be useful, but WITHOUT
ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
for more details.
You should have received a copy of the GNU General Public License along
with this program. If not, see https://www.gnu.org/licenses/.
To credit the author, users are encouraged to use below reference:
.. code-block:: text
Mounet, R. E. G., & Nielsen, U. D. NetSSE: An open-source Python package
for network-based sea state estimation from ships, buoys, and other
observation platforms (version 2.2). Technical University of Denmark,
GitLab. February 2026. https://doi.org/10.11583/DTU.26379811.
*Last updated on 17-03-2026 by R.E.G. Mounet*
"""
import os
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import rc
from matplotlib.colors import LogNorm, Normalize
import matplotlib.cm as cm
import cartopy.crs as ccrs
import numpy as np
from netsse.tools.misc_func import re_range
[docs]
def fig_settings(font="serif", fsz=12, lw=1, dpi=200, format="pdf", path="../Figures/"):
"""Sets the default figure settings for NetSSE visualisations.
The settings are applied globally to all figures created after calling this function.
Parameters
----------
font : {'serif', 'sans-serif'}, optional
Font family to use for the figures: ``'sans-serif'`` uses Helvetica font, while ``'serif'`` uses Palatino font. Default is ``'serif'``.
fsz : int, optional
Font size for text in the figures. Default is ``12``.
lw : float, optional
Line width for the figures. Default is ``1.5``.
dpi : int, optional
Dots per inch (dpi) for the figures. Default is ``100``.
format : str, optional
Format for saving the figures. Default is ``"pdf"``.
path : str, optional
Path for saving the figures. Default is ``"../Figures/"``.
Returns
-------
fig_path : str
Path for saving the figures.
colors : dict
Dictionary of colorblind-friendly colors for plotting.
props : dict
Dictionary of properties for text and legend boxes.
Example
-------
>>> fig_path, colors, props = fig_settings()
"""
plt.close("all")
# Settings for publication-quality figures:
match font:
case "sans-serif": # for Helvetica and other sans-serif fonts use
rc("font", **{"family": "sans-serif", "sans-serif": ["Helvetica"]})
case "serif": # for Palatino and other serif fonts use:
rc("font", **{"family": "serif", "serif": ["Palatino"]})
rc("text", usetex=True)
# Basic settings for all figures:
plt.rcParams["font.size"] = fsz
plt.rcParams["lines.linewidth"] = lw
plt.rcParams["figure.dpi"] = dpi
plt.rcParams["savefig.format"] = format
# Create the directory for saving the figures if it does not exist:
if not path.endswith("/"):
path += "/"
if not os.path.exists(path):
os.makedirs(path)
plt.rcParams["savefig.directory"] = path
fig_path = path
# Colorblind-friendly colors:
colors = {
"blue": "#377eb8",
"orange": "#ff7f00",
"green": "#4daf4a",
"pink": "#f781bf",
"brown": "#a65628",
"purple": "#984ea3",
"gray": "#999999",
"red": "#e41a1c",
"yellow": "#dede00",
}
# Settings for text and legend boxes:
props = dict(facecolor="whitesmoke", alpha=0.8)
plt.rcParams["legend.frameon"] = True
plt.rcParams["legend.framealpha"] = 0.8
plt.rcParams["legend.fancybox"] = False
plt.rcParams["legend.edgecolor"] = "black"
plt.rcParams["legend.facecolor"] = "whitesmoke"
plt.rcParams["legend.fontsize"] = fsz * 0.9
plt.rcParams["legend.loc"] = "upper right"
# Settings for ticks:
plt.rcParams["xtick.direction"] = "in"
plt.rcParams["ytick.direction"] = "in"
plt.rcParams["xtick.top"] = True
plt.rcParams["ytick.right"] = True
# Settings for grid:
plt.rcParams["grid.linestyle"] = ":"
plt.rcParams["grid.color"] = "black"
plt.rcParams["grid.alpha"] = 0.5
plt.rcParams["grid.linewidth"] = 0.5
return fig_path, colors, props
[docs]
def plot_dirwavespec(
spec,
dirs,
freqs,
unit_dirs="deg",
unit_freqs="Hz",
levels=None,
cmap="turbo",
vmin=None,
vmax=None,
freq_max=None,
show_dirlabels=True,
show_freqlabels=True,
rscale="linear",
ax=None,
):
"""Plots the input directional wave spectrum in a polar diagram.
The plot shows the power spectral density (PSD) distribution over wave frequencies and
directions.
Parameters
----------
spec : array_like
Directional wave spectrum to be plotted. This must be a 2-D array of shape either
`(Ndirs,Nfreqs)` or `(Nfreqs,Ndirs)`.
dirs : array_like of shape (`Ndirs`,)
Directions of the spectrum.
freqs : array_like of shape (`Nfreqs`,)
Frequencies of the spectrum.
unit_dirs : {'deg','rad'}, optional
Unit of the directions, either ``'deg'`` for degrees or ``'rad'`` for radians.
Default is ``'deg'``.
unit_freqs : {'Hz','rad/s'}, optional
Unit of the frequencies, either ``'Hz'`` for Hertz or ``'rad/s'`` for radians per second.
Default is ``'Hz'``.
levels : array_like, optional
Contour levels to use for the plot. If ``None``, levels are automatically determined
according to the formula below:
:math:`vmin*(vmax/vmin)^{i/29}` for :math:`i=0,1,...,29`.
Default is ``None``.
cmap : str, optional
Colormap to use for the plot. Default is ``'turbo'``.
.. note::
Other colormap options include ``'viridis'``, ``'plasma'``, ``'inferno'``, ``'magma'``.
Visit the `Matplotlib documentation
<https://matplotlib.org/stable/tutorials/colors/colormaps.html>`_ for an overview of
the options.
vmin : float, optional
Minimum value for the colormap. Default is ``None``, which means to use the value
:math:`10^{-1/2}`.
vmax : float, optional
Maximum value for the colormap. Default is ``None``, which means to use the maximum value
in ``spec``.
.. warning::
``vmin`` and ``vmax`` are used to set the colormap limits. These options will override the lower and upper limits of ``levels``, if the latter option is used.
freq_max : float, optional
Maximum frequency to plot. Default is ``None``, which means to use the maximum frequency in
the input data.
show_dirlabels : {True,False,'compass'}, optional
Whether to show direction labels.
If ``True``, shows labels in degrees. If ``'compass'``, shows compass directions.
If ``False``, no labels are shown. Default is ``True``.
show_freqlabels : {True,False,'all'}, optional
Whether to show frequency labels. If ``True``, shows only the maximum frequency label.
If ``'all'``, shows all labels. If ``False``, no labels are shown. Default is ``True``.
rscale : {'linear', 'log'}, optional
Scale for the radial axis (frequencies). Default is ``'linear'``.
ax : matplotlib.axes, optional
Axes object to plot on. If ``None``, a new figure and axes are created. Default is ``None``.
Returns
-------
ax : matplotlib.axes._subplots.PolarAxesSubplot
The axes object with the plot.
Example
-------
>>> # Generate some example data
>>> freqs = np.linspace(0.05, 0.5, 50)
>>> dirs = np.linspace(0, 360, 36)
>>> spec = np.random.rand(36, 50)
>>> # Plot the directional wave spectrum
>>> plot_dirwavespec(spec, dirs, freqs)
"""
# Identify the dimensions of the spectra:
Nfreqs = np.size(freqs)
Ndirs = np.size(dirs)
ax_freqs = np.where(np.array(spec.shape) == Nfreqs)[0][0]
ax_dirs = np.where(np.array(spec.shape) == Ndirs)[0][0]
spec_plot = np.transpose(spec, (ax_dirs, ax_freqs))
# Headings in radians:
if unit_dirs == "deg":
dirs_rad = np.radians(dirs)
else:
dirs_rad = dirs
dirs_rad = re_range(dirs_rad.flatten(), 0, unit="rad")
# Ensure that the headings are cyclic:
if np.round(dirs_rad[-1,], 4) == np.round(dirs_rad[0,], 4):
dirs_rad[-1,] += 2 * np.pi
else:
dirs_rad = np.append(dirs_rad, dirs_rad[0,] + 2 * np.pi)
Ndirs += 1
spec_plot = np.append(spec_plot, spec_plot[0, :].reshape(1, Nfreqs), axis=0)
# Make a meshgrid in angles and frequencies:
Headings, Freq = np.meshgrid(dirs_rad, freqs)
if rscale == "log":
# Shift the bullseye up by 0.3 decades to avoid it being at the center of the plot:
bullseye = 0.3
min_freq = np.min(Freq)
if min_freq <= 0:
raise ValueError(
"Frequencies must be positive for logarithmic scale. Found minimum frequency: %s"
% min_freq
)
min10 = np.log10(np.min(Freq))
max10 = np.log10(np.max(Freq))
Freq = np.log10(Freq) - min10 + bullseye
# Select the levels for the contour plot:
spec_plot[np.isnan(spec_plot)] = -1
if vmin is None:
min_val = 10 ** (-0.5)
else:
min_val = vmin
if vmax is None:
max_val = np.amax(spec_plot)
else:
max_val = vmax
if levels is None:
levels = 10 ** np.linspace(np.log10(min_val), np.log10(max_val), 30)
else:
if vmin is not None:
levels[0] = vmin
if vmax is not None:
levels[-1] = vmax
norm = cm.colors.LogNorm()
spec_plot = np.ma.masked_where(spec_plot <= 0, spec_plot)
# Plot the 2-D wave spectrum:
if ax is None:
fig, ax = plt.subplots(subplot_kw={"projection": "polar"}, figsize=(4, 4))
else:
fig = ax.get_figure()
rows, cols, start, _ = ax.get_subplotspec().get_geometry()
ax.remove()
ax = fig.add_subplot(rows, cols, start + 1, projection="polar")
ax.set_theta_direction(-1)
ax.set_theta_zero_location("N")
ax.contour(
Headings, Freq, spec_plot.T, levels=levels, norm=norm, linewidths=1, cmap=cmap
)
# Set the labels and grid:
ax.set_xticks(np.arange(np.pi / 6, 13 * np.pi / 6, np.pi / 6))
match show_dirlabels:
case True:
ax.set_xticklabels(
[
"",
"",
r"90$^\circ$",
"",
"",
r"180$^\circ$",
"",
"",
r"270$^\circ$",
"",
"",
r"0$^\circ$",
]
)
case "compass":
ax.set_xticklabels(["", "", "E", "", "", "S", "", "", "W", "", "", "N"])
case False:
ax.set_xticklabels([""] * 12)
if freq_max is None:
freq_max = np.max(np.round(freqs, decimals=1))
match rscale:
case "linear":
dfreq = 0.1 * (unit_freqs == "Hz") + 0.5 * (unit_freqs == "rad/s")
rticks = np.arange(0, freq_max, dfreq) + dfreq
ax.set_rticks(rticks)
ax.set_rlim(0, freq_max)
case "log":
# dfreq = np.log10(dfreq)
freq_max10 = np.log10(freq_max)
max10 = np.min([freq_max10, max10])
rticks = np.arange(np.floor(min10), max10)
ax.set_rticks(rticks - min10 + bullseye)
ax.set_rlim(0, max10 - min10 + bullseye)
match show_freqlabels:
case True:
match rscale:
case "linear":
ax.set_yticklabels(
["" for _ in rticks[:-1]]
+ ["%.1f [%s]" % (rticks[-1], unit_freqs)]
)
case "log":
ax.set_yticklabels(
["" for _ in rticks[:-1]]
+ ["1e%d [%s]" % (rticks[-1], unit_freqs)]
)
case "all":
match rscale:
case "linear":
freq_ticklabels = ["%.1f" % x for x in rticks]
freq_ticklabels[-1] += " [%s]" % (unit_freqs)
ax.set_yticklabels(freq_ticklabels)
case "log":
freq_ticklabels = ["1e%d" % x for x in rticks]
freq_ticklabels[-1] += " [%s]" % (unit_freqs)
ax.set_yticklabels(freq_ticklabels)
case False:
ax.set_yticklabels(["" for _ in rticks])
ax.grid(linestyle="--", color="black", linewidth=0.5)
return ax
[docs]
def plot_bathymetry(depths, shp_dict, ax, cmap="Blues_r"):
"""Plots the bathymetry map for a given area.
The function plots the bathymetry map for a given area using the shapefiles contained in the
dictionary `shp_dict`. The shapefiles are sorted by depth, from the surface to the bottom.
Parameters
----------
depths : list
List of the depths in the shapefiles for the specified area.
shp_dict : dict
Dictionary containing the shapefiles.
ax : matplotlib.axes
Axes object to plot on.
cmap : str, optional
Colormap to use for the plot. Default is ``'Blues_r'``.
.. note::
Other colormap options include ``'plasma'``, ``'inferno'``, ``'magma'``, ``'viridis'``.
Visit the `Matplotlib documentation
<https://matplotlib.org/stable/tutorials/colors/colormaps.html>`_ for an overview of
the options.
Returns
-------
ax : matplotlib.axes._subplots.AxesSubplot
The axes object with the plot.
colormap : matplotlib.colors.ListedColormap
The colormap used for the plot.
See Also
--------
netsse.model.bathymetry.load_bathymetry : Retrieve and read bathymetry shapefiles.
Example
-------
.. code-block:: python
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from netsse.model.bathymetry import load_bathymetry
from netsse.tools.viz import plot_bathymetry
depths, shp_dict = load_bathymetry(lonmin=-5, lonmax=15, latmin=35, latmax=45)
fig, ax = plt.subplots(subplot_kw={'projection':ccrs.Mercator(central_longitude=5,min_latitude=35,max_latitude=45)},figsize=(6,6))
ax.set_extent([-5, 15, 35, 45], crs=ccrs.PlateCarree())
ax, colormap = plot_bathymetry(depths, shp_dict, ax)
ax.add_feature(cfeature.LAND,edgecolor='black',facecolor='gainsboro',alpha=0.5)
"""
# Construct a discrete colormap with colors corresponding to each depth
depths_int = depths.astype(int)
N = len(depths)
nudge = 0.01 # shift bin edge slightly to include data
boundaries = [min(depths_int)] + sorted(depths_int + nudge) # low to high
norm = matplotlib.colors.BoundaryNorm(boundaries, N)
colormap = matplotlib.colormaps[cmap].resampled(N)
colors_depths = colormap(norm(depths_int))
# Iterate and plot bathymetry feature for each depth level
for i, depth in enumerate(depths):
ax.add_geometries(
shp_dict[depth].geometries(), crs=ccrs.PlateCarree(), color=colors_depths[i]
)
# Convert vector bathymetries to raster (saves a lot of disk space) while leaving labels as vectors:
# ax.set_rasterized(True)
return ax, colormap
[docs]
def plot_sea_state_scatter(
y_data,
x_data,
y_bin=0.25,
x_bin=0.20,
y_range=None,
x_range=None,
y_label=None,
x_label=None,
cmap="viridis",
norm="log", # "log" or "linear"
annotate=True,
min_count_annot=1, # minimum count to annotate a cell
ax=None,
colorbar=True,
grid=True,
):
"""
Create a sea state scatter diagram (2-D histogram heatmap) of counts for two wave parameters.
Parameters
----------
y_data : array-like
Values to be displayed on the vertical axis. Typically, this some wave height parameter [in metres].
x_data : array-like
Values to be displayed on the horizontal axis. Typically, this is some wave period parameter [in seconds]. Must be same length as ``y_data``.
y_bin : float, optional
Bin size for ``y_data``. Default is 0.25 [m].
x_bin : float, optional
Bin size for ``x_data``. Default is 0.20 [s].
y_range : tuple(float, float), optional
(min, max) explicit range for ``y_data``. If ``None``, inferred from data.
x_range : tuple(float, float), optional
(min, max) explicit range for ``x_data``. If ``None``, inferred from data.
y_label : str, optional.
Y-axis label for scatter diagram. If ``None``, the scatter diagram display a `y`-axis label.
x_label : str, optional.
X-axis label for scatter diagram. If ``None``, the scatter diagram display an `x`-axis label.
cmap : str, optional
Matplotlib colormap name.
norm : {"log", "linear"}, optional
Color normalization. ``"log"`` helps when counts vary widely.
annotate : bool, optional
If ``True``, annotate counts inside cells.
min_count_annot : int, optional
Minimum count to annotate a cell (avoids clutter).
ax : matplotlib.axes.Axes, optional
Axes to draw on. If ``None``, a new figure and axes are created.
colorbar : bool, optional
If ``True``, add a colorbar.
grid : bool, optional
If ``True``, show grid lines.
Returns
-------
fig : matplotlib.figure.Figure
ax : matplotlib.axes.Axes
H : 2-D ndarray
Counts matrix of shape ``(len(y_edges)-1, len(x_edges)-1)``,
where rows correspond to ``y_data`` bins and columns to ``x_data`` bins.
y_edges : 1-D ndarray
Bin edges for ``x_data``.
x_edges : 1-D ndarray
Bin edges for ``y_data``.
Notes
-----
NaNs and non-finite values are ignored.
Example
-------
>>> fig, ax, counts, hs_edges, tp_edges = plot_sea_state_scatter(
... Hs, Tp,
... y_bin=0.25,
... x_bin=0.20,
... norm="log", # try "linear" if you prefer
... annotate=False, # set True to print counts in cells
... )
"""
y = np.asarray(y_data, dtype=float)
x = np.asarray(x_data, dtype=float)
# Remove NaNs and non-finite values
mask = np.isfinite(y) & np.isfinite(x)
y = y[mask]
x = x[mask]
if y.size == 0 or x.size == 0:
raise ValueError("No finite data in y_data/x_data after filtering.")
# Define bin edges
if y_range is None:
y_min = np.floor(np.nanmin(y) / y_bin) * y_bin
y_max = np.ceil(np.nanmax(y) / y_bin) * y_bin
else:
y_min, y_max = y_range
if x_range is None:
x_min = np.floor(np.nanmin(x) / x_bin) * x_bin
x_max = np.ceil(np.nanmax(x) / x_bin) * x_bin
else:
x_min, x_max = x_range
# Ensure at least one bin
if y_max <= y_min:
y_max = y_min + y_bin
if x_max <= x_min:
x_max = x_min + x_bin
y_edges = np.arange(y_min, y_max + y_bin, y_bin)
x_edges = np.arange(x_min, x_max + x_bin, x_bin)
# 2D histogram: rows = y bins, cols = x bins
H, y_edges_h2d, x_edges_h2d = np.histogram2d(y, x, bins=[y_edges, x_edges])
# The returned edges are the same arrays we gave; keep originals for clarity
H = H.astype(int)
# Prepare axes
created_fig = False
if ax is None:
fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True)
created_fig = True
else:
fig = ax.figure
# Choose normalization
match norm:
case "log":
# Use LogNorm but avoid zero issues by setting vmin to 1 if any positive counts exist
positive = H[H > 0]
if positive.size > 0:
vmin = max(1, positive.min())
norm_obj = LogNorm(vmin=vmin, vmax=positive.max())
else:
norm_obj = Normalize(vmin=0, vmax=1)
case "linear":
norm_obj = Normalize(vmin=0, vmax=H.max() if H.max() > 0 else 1)
case _:
raise ValueError('norm must be "log" or "linear".')
# pcolormesh expects bin edges; note that X=Tp (cols), Y=Hs (rows)
mesh = ax.pcolormesh(x_edges, y_edges, H, cmap=cmap, norm=norm_obj, shading="auto")
# Labels and formatting
if x_label is not None:
ax.set_xlabel(x_label)
if y_label is not None:
ax.set_ylabel(y_label)
if grid:
ax.grid(True, which="both", linestyle=":", alpha=0.5)
ax.set_ylim(y_edges[0], y_edges[-1])
ax.set_xlim(x_edges[0], x_edges[-1])
# Colorbar
if colorbar:
cbar = fig.colorbar(mesh, ax=ax, pad=0.02)
cbar.set_label("Count")
# Optional annotations (center each cell)
if annotate:
# Compute cell centers
y_centers = 0.5 * (y_edges[:-1] + y_edges[1:])
x_centers = 0.5 * (x_edges[:-1] + x_edges[1:])
for i, yc in enumerate(y_centers):
for j, xc in enumerate(x_centers):
count = H[i, j]
if count >= min_count_annot:
ax.text(
xc,
yc,
str(int(count)),
ha="center",
va="center",
fontsize=8,
color="white",
path_effects=[],
)
if created_fig:
plt.show()
return fig, ax, H, y_edges, x_edges