# -*- coding: utf-8 -*-
"""
**Data visualisation** functions for NetSSE.
.. dropdown:: Copyright (C) 2023-2026 Technical University of Denmark, R.E.G. Mounet
:color: primary
:icon: law
*This code is part of the NetSSE software.*
NetSSE is free software: you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by the Free
Software Foundation, either version 3 of the License, or (at your
option) any later version.
NetSSE is distributed in the hope that it will be useful, but WITHOUT
ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
for more details.
You should have received a copy of the GNU General Public License along
with this program. If not, see https://www.gnu.org/licenses/.
To credit the author, users are encouraged to use below reference:
.. code-block:: text
Mounet, R. E. G., & Nielsen, U. D. NetSSE: An open-source Python package
for network-based sea state estimation from ships, buoys, and other
observation platforms (version 2.2). Technical University of Denmark,
GitLab. February 2026. https://doi.org/10.11583/DTU.26379811.
*Last updated on 09-02-2026 by R.E.G. Mounet*
"""
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm, Normalize
import matplotlib.cm as cm
import cartopy.crs as ccrs
import numpy as np
from netsse.tools.misc_func import re_range
[docs]
def plot_dirwavespec(
spec,
dirs,
freqs,
unit_dirs="deg",
unit_freqs="Hz",
levels=None,
cmap="turbo",
vmin=None,
vmax=None,
freq_max=None,
show_dirlabels=True,
show_freqlabels=True,
rscale="linear",
ax=None,
):
"""Plots the input directional wave spectrum in a polar diagram.
The plot shows the power spectral density (PSD) distribution over wave frequencies and
directions.
Parameters
----------
spec : array_like
Directional wave spectrum to be plotted. This must be a 2-D array of shape either
`(Ndirs,Nfreqs)` or `(Nfreqs,Ndirs)`.
dirs : array_like of shape (`Ndirs`,)
Directions of the spectrum.
freqs : array_like of shape (`Nfreqs`,)
Frequencies of the spectrum.
unit_dirs : {'deg','rad'}, optional
Unit of the directions, either ``'deg'`` for degrees or ``'rad'`` for radians.
Default is ``'deg'``.
unit_freqs : {'Hz','rad/s'}, optional
Unit of the frequencies, either ``'Hz'`` for Hertz or ``'rad/s'`` for radians per second.
Default is ``'Hz'``.
levels : array_like, optional
Contour levels to use for the plot. If ``None``, levels are automatically determined
according to the formula below:
:math:`vmin*(vmax/vmin)^{i/29}` for :math:`i=0,1,...,29`.
Default is ``None``.
cmap : str, optional
Colormap to use for the plot. Default is ``'turbo'``.
.. note::
Other colormap options include ``'viridis'``, ``'plasma'``, ``'inferno'``, ``'magma'``.
Visit the `Matplotlib documentation
<https://matplotlib.org/stable/tutorials/colors/colormaps.html>`_ for an overview of
the options.
vmin : float, optional
Minimum value for the colormap. Default is ``None``, which means to use the value
:math:`10^{-1/2}`.
vmax : float, optional
Maximum value for the colormap. Default is ``None``, which means to use the maximum value
in ``spec``.
.. warning::
``vmin`` and ``vmax`` are used to set the colormap limits. These options will override the lower and upper limits of ``levels``, if the latter option is used.
freq_max : float, optional
Maximum frequency to plot. Default is ``None``, which means to use the maximum frequency in
the input data.
show_dirlabels : {True,False,'compass'}, optional
Whether to show direction labels.
If ``True``, shows labels in degrees. If ``'compass'``, shows compass directions.
If ``False``, no labels are shown. Default is ``True``.
show_freqlabels : {True,False,'all'}, optional
Whether to show frequency labels. If ``True``, shows only the maximum frequency label.
If ``'all'``, shows all labels. If ``False``, no labels are shown. Default is ``True``.
rscale : {'linear', 'log'}, optional
Scale for the radial axis (frequencies). Default is ``'linear'``.
ax : matplotlib.axes, optional
Axes object to plot on. If ``None``, a new figure and axes are created. Default is ``None``.
Returns
-------
ax : matplotlib.axes._subplots.PolarAxesSubplot
The axes object with the plot.
Example
-------
>>> # Generate some example data
>>> freqs = np.linspace(0.05, 0.5, 50)
>>> dirs = np.linspace(0, 360, 36)
>>> spec = np.random.rand(36, 50)
>>> # Plot the directional wave spectrum
>>> plot_dirwavespec(spec, dirs, freqs)
"""
# Identify the dimensions of the spectra:
Nfreqs = np.size(freqs)
Ndirs = np.size(dirs)
ax_freqs = np.where(np.array(spec.shape) == Nfreqs)[0][0]
ax_dirs = np.where(np.array(spec.shape) == Ndirs)[0][0]
spec_plot = np.transpose(spec, (ax_dirs, ax_freqs))
# Headings in radians:
if unit_dirs == "deg":
dirs_rad = np.radians(dirs)
else:
dirs_rad = dirs
dirs_rad = re_range(dirs_rad.flatten(), 0, unit="rad")
# Ensure that the headings are cyclic:
if np.round(dirs_rad[-1,], 4) == np.round(dirs_rad[0,], 4):
dirs_rad[-1,] += 2 * np.pi
else:
dirs_rad = np.append(dirs_rad, dirs_rad[0,] + 2 * np.pi)
Ndirs += 1
spec_plot = np.append(spec_plot, spec_plot[0, :].reshape(1, Nfreqs), axis=0)
# Make a meshgrid in angles and frequencies:
Headings, Freq = np.meshgrid(dirs_rad, freqs)
if rscale == "log":
# Shift the bullseye up by 0.3 decades to avoid it being at the center of the plot:
bullseye = 0.3
min_freq = np.min(Freq)
if min_freq <= 0:
raise ValueError(
"Frequencies must be positive for logarithmic scale. Found minimum frequency: %s"
% min_freq
)
min10 = np.log10(np.min(Freq))
max10 = np.log10(np.max(Freq))
Freq = np.log10(Freq) - min10 + bullseye
# Select the levels for the contour plot:
spec_plot[np.isnan(spec_plot)] = -1
if vmin is None:
min_val = 10 ** (-0.5)
else:
min_val = vmin
if vmax is None:
max_val = np.amax(spec_plot)
else:
max_val = vmax
if levels is None:
levels = 10 ** np.linspace(np.log10(min_val), np.log10(max_val), 30)
else:
if vmin is not None:
levels[0] = vmin
if vmax is not None:
levels[-1] = vmax
norm = cm.colors.LogNorm()
spec_plot = np.ma.masked_where(spec_plot <= 0, spec_plot)
# Plot the 2-D wave spectrum:
if ax is None:
fig, ax = plt.subplots(subplot_kw={"projection": "polar"}, figsize=(4, 4))
else:
fig = ax.get_figure()
rows, cols, start, _ = ax.get_subplotspec().get_geometry()
ax.remove()
ax = fig.add_subplot(rows, cols, start + 1, projection="polar")
ax.set_theta_direction(-1)
ax.set_theta_zero_location("N")
ax.contour(
Headings, Freq, spec_plot.T, levels=levels, norm=norm, linewidths=1, cmap=cmap
)
# Set the labels and grid:
ax.set_xticks(np.arange(np.pi / 6, 13 * np.pi / 6, np.pi / 6))
match show_dirlabels:
case True:
ax.set_xticklabels(
[
"",
"",
r"90$^\circ$",
"",
"",
r"180$^\circ$",
"",
"",
r"270$^\circ$",
"",
"",
r"0$^\circ$",
]
)
case "compass":
ax.set_xticklabels(["", "", "E", "", "", "S", "", "", "W", "", "", "N"])
case False:
ax.set_xticklabels([""] * 12)
if freq_max is None:
freq_max = np.max(np.round(freqs, decimals=1))
match rscale:
case "linear":
dfreq = 0.1 * (unit_freqs == "Hz") + 0.5 * (unit_freqs == "rad/s")
rticks = np.arange(0, freq_max, dfreq) + dfreq
ax.set_rticks(rticks)
ax.set_rlim(0, freq_max)
case "log":
# dfreq = np.log10(dfreq)
freq_max10 = np.log10(freq_max)
max10 = np.min([freq_max10, max10])
rticks = np.arange(np.floor(min10), max10)
ax.set_rticks(rticks - min10 + bullseye)
ax.set_rlim(0, max10 - min10 + bullseye)
match show_freqlabels:
case True:
match rscale:
case "linear":
ax.set_yticklabels(
["" for _ in rticks[:-1]]
+ ["%.1f [%s]" % (rticks[-1], unit_freqs)]
)
case "log":
ax.set_yticklabels(
["" for _ in rticks[:-1]]
+ ["1e%d [%s]" % (rticks[-1], unit_freqs)]
)
case "all":
match rscale:
case "linear":
freq_ticklabels = ["%.1f" % x for x in rticks]
freq_ticklabels[-1] += " [%s]" % (unit_freqs)
ax.set_yticklabels(freq_ticklabels)
case "log":
freq_ticklabels = ["1e%d" % x for x in rticks]
freq_ticklabels[-1] += " [%s]" % (unit_freqs)
ax.set_yticklabels(freq_ticklabels)
case False:
ax.set_yticklabels(["" for _ in rticks])
ax.grid(linestyle="--", color="black", linewidth=0.5)
return ax
[docs]
def plot_bathymetry(depths, shp_dict, ax, cmap="Blues_r"):
"""Plots the bathymetry map for a given area.
The function plots the bathymetry map for a given area using the shapefiles contained in the
dictionary `shp_dict`. The shapefiles are sorted by depth, from the surface to the bottom.
Parameters
----------
depths : list
List of the depths in the shapefiles for the specified area.
shp_dict : dict
Dictionary containing the shapefiles.
ax : matplotlib.axes
Axes object to plot on.
cmap : str, optional
Colormap to use for the plot. Default is ``'Blues_r'``.
.. note::
Other colormap options include ``'plasma'``, ``'inferno'``, ``'magma'``, ``'viridis'``.
Visit the `Matplotlib documentation
<https://matplotlib.org/stable/tutorials/colors/colormaps.html>`_ for an overview of
the options.
Returns
-------
ax : matplotlib.axes._subplots.AxesSubplot
The axes object with the plot.
colormap : matplotlib.colors.ListedColormap
The colormap used for the plot.
See Also
--------
netsse.model.bathymetry.load_bathymetry : Retrieve and read bathymetry shapefiles.
Example
-------
.. code-block:: python
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from netsse.model.bathymetry import load_bathymetry
from netsse.tools.viz import plot_bathymetry
depths, shp_dict = load_bathymetry(lonmin=-5, lonmax=15, latmin=35, latmax=45)
fig, ax = plt.subplots(subplot_kw={'projection':ccrs.Mercator(central_longitude=5,min_latitude=35,max_latitude=45)},figsize=(6,6))
ax.set_extent([-5, 15, 35, 45], crs=ccrs.PlateCarree())
ax, colormap = plot_bathymetry(depths, shp_dict, ax)
ax.add_feature(cfeature.LAND,edgecolor='black',facecolor='gainsboro',alpha=0.5)
"""
# Construct a discrete colormap with colors corresponding to each depth
depths_int = depths.astype(int)
N = len(depths)
nudge = 0.01 # shift bin edge slightly to include data
boundaries = [min(depths_int)] + sorted(depths_int + nudge) # low to high
norm = matplotlib.colors.BoundaryNorm(boundaries, N)
colormap = matplotlib.colormaps[cmap].resampled(N)
colors_depths = colormap(norm(depths_int))
# Iterate and plot bathymetry feature for each depth level
for i, depth in enumerate(depths):
ax.add_geometries(
shp_dict[depth].geometries(), crs=ccrs.PlateCarree(), color=colors_depths[i]
)
# Convert vector bathymetries to raster (saves a lot of disk space) while leaving labels as vectors:
# ax.set_rasterized(True)
return ax, colormap
[docs]
def plot_sea_state_scatter(
y_data,
x_data,
y_bin=0.25,
x_bin=0.20,
y_range=None,
x_range=None,
y_label=None,
x_label=None,
cmap="viridis",
norm="log", # "log" or "linear"
annotate=True,
min_count_annot=1, # minimum count to annotate a cell
ax=None,
colorbar=True,
grid=True,
):
"""
Create a sea state scatter diagram (2-D histogram heatmap) of counts for two wave parameters.
Parameters
----------
y_data : array-like
Values to be displayed on the vertical axis. Typically, this some wave height parameter [in metres].
x_data : array-like
Values to be displayed on the horizontal axis. Typically, this is some wave period parameter [in seconds]. Must be same length as ``y_data``.
y_bin : float, optional
Bin size for ``y_data``. Default is 0.25 [m].
x_bin : float, optional
Bin size for ``x_data``. Default is 0.20 [s].
y_range : tuple(float, float), optional
(min, max) explicit range for ``y_data``. If ``None``, inferred from data.
x_range : tuple(float, float), optional
(min, max) explicit range for ``x_data``. If ``None``, inferred from data.
y_label : str, optional.
Y-axis label for scatter diagram. If ``None``, the scatter diagram display a `y`-axis label.
x_label : str, optional.
X-axis label for scatter diagram. If ``None``, the scatter diagram display an `x`-axis label.
cmap : str, optional
Matplotlib colormap name.
norm : {"log", "linear"}, optional
Color normalization. ``"log"`` helps when counts vary widely.
annotate : bool, optional
If ``True``, annotate counts inside cells.
min_count_annot : int, optional
Minimum count to annotate a cell (avoids clutter).
ax : matplotlib.axes.Axes, optional
Axes to draw on. If ``None``, a new figure and axes are created.
colorbar : bool, optional
If ``True``, add a colorbar.
grid : bool, optional
If ``True``, show grid lines.
Returns
-------
fig : matplotlib.figure.Figure
ax : matplotlib.axes.Axes
H : 2-D ndarray
Counts matrix of shape ``(len(y_edges)-1, len(x_edges)-1)``,
where rows correspond to ``y_data`` bins and columns to ``x_data`` bins.
y_edges : 1-D ndarray
Bin edges for ``x_data``.
x_edges : 1-D ndarray
Bin edges for ``y_data``.
Notes
-----
NaNs and non-finite values are ignored.
Example
-------
>>> fig, ax, counts, hs_edges, tp_edges = plot_sea_state_scatter(
... Hs, Tp,
... y_bin=0.25,
... x_bin=0.20,
... norm="log", # try "linear" if you prefer
... annotate=False, # set True to print counts in cells
... )
"""
y = np.asarray(y_data, dtype=float)
x = np.asarray(x_data, dtype=float)
# Remove NaNs and non-finite values
mask = np.isfinite(y) & np.isfinite(x)
y = y[mask]
x = x[mask]
if y.size == 0 or x.size == 0:
raise ValueError("No finite data in y_data/x_data after filtering.")
# Define bin edges
if y_range is None:
y_min = np.floor(np.nanmin(y) / y_bin) * y_bin
y_max = np.ceil(np.nanmax(y) / y_bin) * y_bin
else:
y_min, y_max = y_range
if x_range is None:
x_min = np.floor(np.nanmin(x) / x_bin) * x_bin
x_max = np.ceil(np.nanmax(x) / x_bin) * x_bin
else:
x_min, x_max = x_range
# Ensure at least one bin
if y_max <= y_min:
y_max = y_min + y_bin
if x_max <= x_min:
x_max = x_min + x_bin
y_edges = np.arange(y_min, y_max + y_bin, y_bin)
x_edges = np.arange(x_min, x_max + x_bin, x_bin)
# 2D histogram: rows = y bins, cols = x bins
H, y_edges_h2d, x_edges_h2d = np.histogram2d(y, x, bins=[y_edges, x_edges])
# The returned edges are the same arrays we gave; keep originals for clarity
H = H.astype(int)
# Prepare axes
created_fig = False
if ax is None:
fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True)
created_fig = True
else:
fig = ax.figure
# Choose normalization
match norm:
case "log":
# Use LogNorm but avoid zero issues by setting vmin to 1 if any positive counts exist
positive = H[H > 0]
if positive.size > 0:
vmin = max(1, positive.min())
norm_obj = LogNorm(vmin=vmin, vmax=positive.max())
else:
norm_obj = Normalize(vmin=0, vmax=1)
case "linear":
norm_obj = Normalize(vmin=0, vmax=H.max() if H.max() > 0 else 1)
case _:
raise ValueError('norm must be "log" or "linear".')
# pcolormesh expects bin edges; note that X=Tp (cols), Y=Hs (rows)
mesh = ax.pcolormesh(x_edges, y_edges, H, cmap=cmap, norm=norm_obj, shading="auto")
# Labels and formatting
if x_label is not None:
ax.set_xlabel(x_label)
if y_label is not None:
ax.set_ylabel(y_label)
if grid:
ax.grid(True, which="both", linestyle=":", alpha=0.5)
ax.set_ylim(y_edges[0], y_edges[-1])
ax.set_xlim(x_edges[0], x_edges[-1])
# Colorbar
if colorbar:
cbar = fig.colorbar(mesh, ax=ax, pad=0.02)
cbar.set_label("Count")
# Optional annotations (center each cell)
if annotate:
# Compute cell centers
y_centers = 0.5 * (y_edges[:-1] + y_edges[1:])
x_centers = 0.5 * (x_edges[:-1] + x_edges[1:])
for i, yc in enumerate(y_centers):
for j, xc in enumerate(x_centers):
count = H[i, j]
if count >= min_count_annot:
ax.text(
xc,
yc,
str(int(count)),
ha="center",
va="center",
fontsize=8,
color="white",
path_effects=[],
)
if created_fig:
plt.show()
return fig, ax, H, y_edges, x_edges