# Copyright 2026 IPSL / CNRS / Sorbonne University
# Authors: Kazem Ardaneh, Kishanthan Kingston, Pierre Chapel, Rosie Eade
#
# This work is licensed under the Creative Commons
# Attribution-NonCommercial-ShareAlike 4.0 International License.
# To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc-sa/4.0/
import os
os.environ.setdefault(
"CARTOPY_DATA_DIR",
"/leonardo_work/EUHPC_D27_095/cartopy_data",
)
import numpy as np
import torch
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import matplotlib.colors as mcolors
import matplotlib.patches as patches
from matplotlib.patches import ConnectionPatch
import matplotlib as mpl
from scipy import stats
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import mpltex
from sklearn.metrics import r2_score
import seaborn as sns
import pandas as pd
# ---------------------------------------------
# COMPLETE MATPLOTLIB STYLE CONFIGURATION
# ---------------------------------------------
params = {
# DPI & figure settings
# "figure.dpi": 150,
# "savefig.dpi": 300,
# Fonts
"font.family": "DejaVu Sans",
"mathtext.rm": "arial",
"font.size": 12, # General font size (affects ax.text())
"font.style": "normal", # 'normal', 'italic', 'oblique'
"font.weight": "normal", # 'normal', 'bold', 'heavy', 'light', 'ultrabold', 'ultralight'
"font.stretch": "normal", # Font stretch
# Line properties
"lines.linewidth": 2,
"lines.dashed_pattern": [4, 2],
"lines.dashdot_pattern": [6, 3, 2, 3],
"lines.dotted_pattern": [2, 3],
# Axis labels and titles
"axes.labelsize": 15,
"axes.titlesize": 15,
# Tick settings
"xtick.labelsize": 12,
"ytick.labelsize": 12,
"xtick.major.size": 6,
"ytick.major.size": 6,
"xtick.direction": "out",
"ytick.direction": "out",
# Legend
"legend.fontsize": 10,
"legend.loc": "best",
"legend.frameon": False,
# Text properties
"text.color": "black", # Default text color
"text.usetex": False, # LaTeX rendering
"text.hinting": "auto", # Text hinting
"text.antialiased": True, # Text anti-aliasing
"text.latex.preamble": "", # LaTeX preamble
}
mpl.rcParams.update(params)
# ============================================================================
# PLOTTING CONFIGURATION
# ============================================================================
[docs]
class PlotConfig:
"""Central configuration for all plotting functions."""
# General settings
DEFAULT_SAVE_DIR = "./results"
DEFAULT_FIGSIZE_MULTIPLIER = 4
# Color schemes
COLORMAPS = {
"temperature": "rainbow",
"temp": "rainbow",
"2t": "rainbow",
"zonal": "BrBG_r",
"10u": "BrBG_r",
"meridional": "BrBG_r",
"10v": "BrBG_r",
"tp": "Blues",
"TP": "Blues",
"precipitation": "Blues",
"dewpoint": "rainbow",
"d2m": "rainbow",
"surface temperature": "rainbow",
"st": "rainbow",
"pressure": "viridis",
"pres": "viridis",
"humidity": "Greens",
"humid": "Greens",
"wind": "coolwarm",
"speed": "coolwarm",
"mae": "Reds",
"error": "Reds",
"divergence": "seismic",
"curl": "seismic",
"ssr": "seismic",
"default": "viridis",
}
# Fixed visualization ranges for error diagnostics
FIXED_DIFF_RANGES = {
"T2M": (-5.0, 5.0), # K
"temperature": (-5.0, 5.0),
"2t": (-5.0, 5.0),
"VAR_2T": (-5.0, 5.0),
"U10": (-5.0, 5.0), # m/s
"10u": (-5.0, 5.0),
"meridional": (-5.0, 5.0),
"VAR_10U": (-5.0, 5.0),
"V10": (-5.0, 5.0), # m/s
"10v": (-5.0, 5.0),
"VAR_10V": (-5.0, 5.0),
"TP": (-0.5, 0.5), # mm/h
"tp": (-0.5, 0.5),
"VAR_TP": (-0.5, 0.5),
"VAR_D2M": (-5.0, 5.0), # K
"VAR_ST": (-5.0, 5.0), # K
}
FIXED_DIFF_RANGES_ERRORS = {
"VAR_2T": (0, 0.01), # K
"VAR_10U": (0, 3.0), # m/s
"VAR_10V": (0, 3.0), # m/s
"VAR_TP": (0, 0.5), # mm/h
"VAR_D2M": (0, 0.01), # K
"VAR_ST": (0, 0.01), # K
"Temp": (0, 3.0),
"Press": (0, 3.0),
"Humid": (0, 3.0),
"Wind": (0, 3.0),
}
FIXED_MAE_RANGES = {
"T2M": (0.0, 3.0),
"temperature": (0.0, 3.0),
"2t": (0.0, 3.0),
"VAR_2T": (0.0, 3.0),
"U10": (0.0, 3.0),
"10u": (0.0, 3.0),
"meridional": (0.0, 3.0),
"VAR_10U": (0.0, 3.0),
"V10": (0.0, 3.0),
"10v": (0.0, 3.0),
"VAR_10V": (0.0, 3.0),
"TP": (0.0, 1.0),
"tp": (0.0, 1.0),
"VAR_TP": (0.0, 1.0),
"VAR_D2M": (0.0, 3.0),
"VAR_ST": (0.0, 3.0),
}
FIXED_SSR_RANGES = {
"T2M": (0.0, 3.0),
"temperature": (0.0, 3.0),
"2t": (0.0, 3.0),
"VAR_2T": (0.0, 3.0),
"U10": (0.0, 3.0),
"10u": (0.0, 3.0),
"meridional": (0.0, 3.0),
"VAR_10U": (0.0, 3.0),
"V10": (0.0, 3.0),
"10v": (0.0, 3.0),
"VAR_10V": (0.0, 3.0),
"TP": (0.0, 3.0),
"tp": (0.0, 3.0),
"VAR_TP": (0.0, 1.0),
"VAR_D2M": (0.0, 3.0),
"VAR_ST": (0.0, 3.0),
}
# Geographic features
COASTLINE_w = 0.5
BORDER_w = 0.5
LAKE_w = 0.5
BORDER_STYLE = "--"
# Colorbar settings
COLORBAR_h = 0.02
COLORBAR_PAD = 0.05
[docs]
@classmethod
def get_colormap(cls, variable_name):
"""Get appropriate colormap for a variable."""
var_lower = variable_name.lower()
for key, cmap in cls.COLORMAPS.items():
if key in var_lower:
return cmap
return cls.COLORMAPS["default"]
[docs]
@classmethod
def get_plot_name(cls, variable_name):
"""Convert variable name to readable plot name."""
# Remove common prefixes
name = variable_name.replace("VAR_", "").replace("var_", "")
# Special cases
if name == "2T":
return "Temperature [K]"
elif name == "10U":
return "Zonal Wind [m/s]"
elif name == "10V":
return "Meridional Wind [m/s]"
elif name == "MSLP":
return "Sea Level Pressure"
elif name == "T2M":
return "2m Temperature [K]"
elif name == "U10":
return "10m Zonal Wind [m/s]"
elif name == "V10":
return "10m Meridional Wind [m/s]"
elif name == "TP":
return "Precipitation [mm/h]"
elif name == "tp":
return "Precipitation [mm/h]"
elif name == "D2M":
return "Dewpoint [K]"
elif name == "ST":
return "Surface Temperature [K]"
# General conversion
name = name.replace("_", " ")
return name.title()
[docs]
@classmethod
def convert_units(cls, variable_name, data):
"""
Safe unit conversion when required.
- NEVER modifies input
- Returns a new array only if conversion is needed
"""
name = variable_name.lower()
if name in ["tp", "var_tp", "precipitation"]:
return data * 1000.0 # m to mm
return data
[docs]
@staticmethod
def get_fixed_diff_range(var_name):
"""Get fixed visualization range for signed differences (Prediction − Truth)."""
return PlotConfig.FIXED_DIFF_RANGES.get(var_name, None)
[docs]
@staticmethod
def get_fixed_diff_range_errors(var_name):
"""Get fixed visualization range for error map."""
return PlotConfig.FIXED_DIFF_RANGES_ERRORS.get(var_name, None)
[docs]
@staticmethod
def get_fixed_mae_range(var_name):
"""Get fixed visualization range for Mean Absolute Error (MAE)."""
return PlotConfig.FIXED_MAE_RANGES.get(var_name, None)
[docs]
@staticmethod
def get_fixed_ssr_range(var_name):
"""Get fixed visualization range for Spread Skill Ratio (SSR)."""
return PlotConfig.FIXED_SSR_RANGES.get(var_name, None)
[docs]
def plot_validation_hexbin(
predictions, # Model predictions (fine predicted)
targets, # Ground truth (fine true)
coarse_inputs=None, # Coarse inputs for comparison (optional)
variable_names=None, # List of variable names
filename="validation_hexbin.png",
save_dir="./results",
figsize_multiplier=4, # Base size per subplot
):
"""
Create hexbin plots comparing model predictions vs ground truth for all variables.
Parameters
----------
predictions : torch.Tensor or np.array
Model predictions of shape [batch_size, num_variables, h, w]
targets : torch.Tensor or np.array
Ground truth of shape [batch_size, num_variables, h, w]
coarse_inputs : torch.Tensor or np.array, optional
Coarse inputs of shape [batch_size, num_variables, h, w]
variable_names : list of str, optional
Names of the variables for subplot titles
filename : str, optional
Output filename
save_dir : str, optional
Directory to save the plot
figsize_multiplier : int, optional
Base size multiplier for subplots
"""
# Convert to numpy if they're tensors
if hasattr(predictions, "detach"):
predictions = predictions.detach().cpu().numpy()
if hasattr(targets, "detach"):
targets = targets.detach().cpu().numpy()
if coarse_inputs is not None and hasattr(coarse_inputs, "detach"):
coarse_inputs = coarse_inputs.detach().cpu().numpy()
batch_size, num_vars, h, w = predictions.shape
# Default variable names if not provided
if variable_names is None:
variable_names = [f"VAR_{i}" for i in range(num_vars)]
# Calculate grid dimensions
ncols = num_vars
nrows = (num_vars + ncols - 1) // ncols # Ceiling division
# Create figure
fig, axes = plt.subplots(
nrows, ncols, figsize=(ncols * figsize_multiplier, figsize_multiplier)
) #
axes = np.atleast_1d(axes).ravel()
for ax in axes:
ax.set_box_aspect(1)
# Handle single subplot case
if num_vars == 1:
axes = np.array([axes])
if axes.ndim == 1:
axes = axes.reshape(1, -1)
# Flatten axes for easy iteration
axes_flat = axes.flatten()
# Plot each variable
max_count = 0
for i, (var_name, ax) in enumerate(zip(variable_names, axes_flat)):
if i >= num_vars:
ax.set_visible(False)
continue
pred_i = PlotConfig.convert_units(var_name, predictions[:, i])
tgt_i = PlotConfig.convert_units(var_name, targets[:, i])
pred_flat = pred_i.reshape(-1)
target_flat = tgt_i.reshape(-1)
# Create hexbin plot
hb = ax.hexbin(
target_flat, pred_flat, gridsize=100, cmap="jet", bins="log", mincnt=1
)
# Get counts for colorbar scaling
counts = hb.get_array()
if counts is not None:
max_count = max(max_count, np.max(counts))
# Add identity line
min_val = min(target_flat.min(), pred_flat.min())
max_val = max(target_flat.max(), pred_flat.max())
ax.plot([min_val, max_val], [min_val, max_val], "r--", alpha=0.7)
# Calculate metrics
r2 = r2_score(target_flat, pred_flat)
mae = np.mean(np.abs(pred_flat - target_flat))
rmse = np.sqrt(np.mean((pred_flat - target_flat) ** 2))
# Add metrics to plot
textstr = f"$R^2$: {r2:.3f}\nMAE: {mae:.3f}\nRMSE: {rmse:.3f}"
ax.text(
0.05,
0.95,
textstr,
transform=ax.transAxes,
fontsize=10,
verticalalignment="top",
)
# Set title
plot_name = PlotConfig.get_plot_name(var_name)
ax.set_title(plot_name)
# Format ticks
ax.xaxis.set_major_locator(ticker.MaxNLocator(5))
ax.yaxis.set_major_locator(ticker.MaxNLocator(5))
# Only show y-label for leftmost subplots
if i % ncols == 0: # First column
ax.set_ylabel("Predicted Values")
else:
ax.set_ylabel("") # Remove y-label for non-leftmost plots
# Only show x-label for bottom row subplots
if i >= (nrows - 1) * ncols: # Last row
ax.set_xlabel("True Values")
else:
ax.set_xlabel("") # Remove x-label for non-bottom plots
ax_last = axes_flat[min(num_vars - 1, len(axes_flat) - 1)]
cax = ax_last.inset_axes([1.05, 0.0, 0.04, 1.0]) # [x, y, width, height]
cbar = fig.colorbar(hb, cax=cax)
cbar.set_label(r"$\log_{10}[\mathrm{Count}]$")
plt.subplots_adjust(
hspace=0.1, wspace=0.3, left=0.1, right=0.9, top=0.9, bottom=0.1
)
# Ensure save directory exists
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, filename)
plt.savefig(save_path, bbox_inches="tight")
plt.close(fig)
return save_path
[docs]
def plot_comparison_hexbin(
predictions,
targets,
coarse_inputs,
variable_names=None,
filename="comparison_hexbin.png",
save_dir="./results",
figsize_multiplier=4,
):
"""
Create hexbin comparison plots between model predictions, ground truth, and coarse inputs.
For each variable, creates two side-by-side hexbin plots:
1. Model predictions vs ground truth
2. Coarse inputs vs ground truth
Each plot includes an identity line and R²/MAE metrics.
Parameters
----------
predictions : torch.Tensor or np.array
Model predictions of shape [batch_size, num_variables, h, w]
targets : torch.Tensor or np.array
Ground truth of shape [batch_size, num_variables, h, w]
coarse_inputs : torch.Tensor or np.array
Coarse inputs of shape [batch_size, num_variables, h, w]
variable_names : list of str, optional
Names of the variables for subplot titles. If None, uses VAR_0, VAR_1, etc.
filename : str, optional
Output filename
save_dir : str, optional
Directory to save the plot
figsize_multiplier : int, optional
Base size multiplier for subplots
Returns
-------
save_path : str
Path to the saved figure
"""
# Convert tensors → numpy
if hasattr(predictions, "detach"):
predictions = predictions.detach().cpu().numpy()
if hasattr(targets, "detach"):
targets = targets.detach().cpu().numpy()
if hasattr(coarse_inputs, "detach"):
coarse_inputs = coarse_inputs.detach().cpu().numpy()
batch_size, num_vars, h, w = predictions.shape
if variable_names is None:
variable_names = [f"VAR_{i}" for i in range(num_vars)]
plot_variable_names = [PlotConfig.get_plot_name(var) for var in variable_names]
# For color scaling: collect all hexbin counts
all_counts = []
# ----------------------------------------
# 2) Pre-pass: collect hexbin densities
# ----------------------------------------
for i, var_name in enumerate(variable_names):
pred_i = PlotConfig.convert_units(var_name, predictions[:, i])
tgt_i = PlotConfig.convert_units(var_name, targets[:, i])
coarse_i = PlotConfig.convert_units(var_name, coarse_inputs[:, i])
pred_flat = pred_i.reshape(-1)
target_flat = tgt_i.reshape(-1)
coarse_flat = coarse_i.reshape(-1)
# Use a temporary invisible axes to get density arrays
fig_tmp, ax_tmp = plt.subplots()
hb1 = ax_tmp.hexbin(
target_flat, pred_flat, gridsize=100, cmap="jet", bins="log", mincnt=1
)
hb2 = ax_tmp.hexbin(
target_flat, coarse_flat, gridsize=100, cmap="jet", bins="log", mincnt=1
)
all_counts.append(hb1.get_array())
all_counts.append(hb2.get_array())
plt.close(fig_tmp)
# Global colorbar limits
all_counts = np.concatenate(all_counts)
global_vmin = np.min(all_counts)
global_vmax = np.max(all_counts)
# ----------------------------------------
# 3) Actual plot
# ----------------------------------------
fig, axes = plt.subplots(
num_vars,
2,
figsize=(2 * figsize_multiplier, num_vars * figsize_multiplier * 0.8),
)
plt.subplots_adjust(
hspace=0.3, wspace=0.4, left=0.1, right=0.9, top=0.9, bottom=0.1
)
if num_vars == 1:
axes = axes.reshape(1, -1)
last_hb = None
for i, var_name in enumerate(variable_names):
pred_i = PlotConfig.convert_units(var_name, predictions[:, i])
tgt_i = PlotConfig.convert_units(var_name, targets[:, i])
coarse_i = PlotConfig.convert_units(var_name, coarse_inputs[:, i])
pred_flat = pred_i.reshape(-1)
target_flat = tgt_i.reshape(-1)
coarse_flat = coarse_i.reshape(-1)
# Calculate per-variable min/max for this variable
var_min = min(target_flat.min(), pred_flat.min(), coarse_flat.min())
var_max = max(target_flat.max(), pred_flat.max(), coarse_flat.max())
# Add a small margin
margin = 0.05 * (var_max - var_min)
plot_min = var_min - margin
plot_max = var_max + margin
# --------------------------
# Left: Model vs True
# --------------------------
ax = axes[i, 0]
hb = ax.hexbin(
target_flat,
pred_flat,
gridsize=100,
cmap="jet",
bins="log",
mincnt=1,
vmin=global_vmin,
vmax=global_vmax,
)
last_hb = hb # store for colorbar
# Use per-variable axis limits
ax.set_xlim(plot_min, plot_max)
ax.set_ylim(plot_min, plot_max)
# identity line
ax.plot([plot_min, plot_max], [plot_min, plot_max], "r--", alpha=0.7)
r2 = r2_score(target_flat, pred_flat)
mae = np.mean(np.abs(pred_flat - target_flat))
ax.text(
0.05,
0.95,
f"$R^2$: {r2:.3f}\nMAE: {mae:.3f}",
transform=ax.transAxes,
va="top",
)
ax.set_title(f"{plot_variable_names[i]} – Model vs True")
ax.set_ylabel("Model Values")
if i == num_vars - 1:
ax.set_xlabel("True Values")
else:
ax.set_xlabel("")
# --------------------------
# Right: Coarse vs True
# --------------------------
ax = axes[i, 1]
hb = ax.hexbin(
target_flat,
coarse_flat,
gridsize=100,
cmap="jet",
bins="log",
mincnt=1,
vmin=global_vmin,
vmax=global_vmax,
)
last_hb = hb
# Use the same per-variable limits
ax.set_xlim(plot_min, plot_max)
ax.set_ylim(plot_min, plot_max)
ax.plot(
[plot_min, plot_max], [plot_min, plot_max], "r--", alpha=0.7, linewidth=1
)
r2 = r2_score(target_flat, coarse_flat)
mae = np.mean(np.abs(coarse_flat - target_flat))
ax.text(
0.05,
0.95,
f"$R^2$: {r2:.3f}\nMAE: {mae:.3f}",
transform=ax.transAxes,
va="top",
)
ax.set_title(f"{plot_variable_names[i]} – Coarse vs True")
ax.set_ylabel("Coarse Values")
if i == num_vars - 1:
ax.set_xlabel("True Values")
else:
ax.set_xlabel("")
# ----------------------------------------
# 4) Single shared colorbar
# ----------------------------------------
cbar_ax = fig.add_axes([0.98, 0.1, 0.02, 0.8])
fig.colorbar(last_hb, cax=cbar_ax, label=r"$\log_{10}[\mathrm{Count}]$")
# Save
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, filename)
plt.savefig(save_path, bbox_inches="tight")
plt.close(fig)
return save_path
[docs]
def plot_metric_histories(
valid_metrics_history,
variable_names,
metric_names,
filename="validation_metrics",
save_dir="./results",
figsize_multiplier=4,
):
"""
Creates row-based panel plots: one figure per metric, rows = variables, shared x-axis.
Parameters
----------
valid_metrics_history : dict
Dict from training loop storing metric histories.
variable_names : list of str
Names of variables.
metric_names : list of str
List of metric names (e.g. ["MAE"]).
filename : str
Prefix for saved figures.
save_dir : str
Directory where images are saved.
"""
os.makedirs(save_dir, exist_ok=True)
num_vars = len(variable_names)
num_metrics = len(metric_names)
# Rows = variables, metrics column, shared x-axis
fig, axes = plt.subplots(
nrows=num_vars,
ncols=num_metrics,
figsize=(6 * num_metrics, figsize_multiplier * num_vars),
squeeze=False,
sharex=True,
)
plt.subplots_adjust(
hspace=0.2, wspace=0.4, left=0.1, right=0.9, top=0.9, bottom=0.1
)
for metric in metric_names:
for i, var in enumerate(variable_names):
ax = axes[i, metric_names.index(metric)]
key_pred = f"{var}_pred_vs_fine_{metric}"
key_coarse = f"{var}_coarse_vs_fine_{metric}"
if (
key_pred not in valid_metrics_history
or key_coarse not in valid_metrics_history
):
ax.text(0.5, 0.5, "Missing Data", ha="center", va="center")
ax.set_yscale("log")
continue
pred_hist = valid_metrics_history[key_pred]
coarse_hist = valid_metrics_history[key_coarse]
# Plot
linestyles = mpltex.linestyle_generator(markers=[])
ax.plot(pred_hist, label="Pred vs Fine", **next(linestyles))
ax.plot(coarse_hist, label="Coarse vs Fine", **next(linestyles))
ax.set_yscale("log")
ax.set_ylabel(f"{metric} ({var})")
ax.grid(True, alpha=0.3)
ax.legend()
# Only bottom row shows x-axis label
if i == num_vars - 1:
ax.set_xlabel("Epoch")
else:
ax.tick_params(labelbottom=False)
save_path = os.path.join(save_dir, f"{filename}.png")
plt.savefig(save_path, bbox_inches="tight")
plt.close(fig)
return save_path
[docs]
def plot_metrics_heatmap(
valid_metrics_history,
variable_names,
metric_names,
filename="validation_metrics_heatmap",
save_dir="./results",
figsize_multiplier=4,
):
"""
Plot a heatmap of validation metrics.
Parameters
----------
valid_metrics_history : dict
Dict from validation loop storing metric histories.
variable_names : list of str
Names of variables.
metric_names : list of str
List of metric names (["MAE", "NMAE", "RMSE", "R²"]).
filename : str
Prefix for saved figures.
save_dir : str
Directory where images are saved.
figsize_multiplier : float
Controls overall figure size
"""
os.makedirs(save_dir, exist_ok=True)
# Build DataFrame
data = {}
for metric in metric_names:
values = []
for var in variable_names:
key = f"{var}_pred_vs_fine_{metric}"
if key in valid_metrics_history:
tracker = valid_metrics_history[key]
if tracker.count > 0:
value = tracker.getmean()
# Convert only dimensional metrics
if metric.lower() in ["mae", "rmse", "crps"]:
value = PlotConfig.convert_units(var, value)
else:
value = np.nan
else:
value = np.nan
values.append(value)
data[metric] = values
df = pd.DataFrame(data, index=variable_names)
fig_width = figsize_multiplier + len(metric_names)
fig_height = 0.6 * len(variable_names) + figsize_multiplier / 2
fig, ax = plt.subplots(figsize=(fig_width, fig_height))
sns.heatmap(
df, ax=ax, cmap="viridis", annot=True, fmt=".3f", linewidths=0.8, cbar=True
)
ax.set_title("Validation metrics")
ax.set_xlabel("Metric")
ax.set_ylabel("Variable")
plt.tight_layout()
save_path = os.path.join(save_dir, f"{filename}.png")
plt.savefig(save_path, bbox_inches="tight")
plt.close(fig)
return save_path
[docs]
def plot_loss_histories(
train_loss_history,
valid_loss_history,
filename="training_validation_loss.png",
save_dir="./results",
figsize_multiplier=4,
):
"""
Plots training and validation loss in a single panel.
Parameters:
-----------
train_loss_history : list or array
History of training loss values.
valid_loss_history : list or array
History of validation loss values.
filename : str
Output image file name for the plot.
save_dir : str
Directory to save the plot.
"""
# Ensure inputs are lists
if not isinstance(train_loss_history, list):
train_loss_history = list(train_loss_history)
if not isinstance(valid_loss_history, list):
valid_loss_history = list(valid_loss_history)
fig = plt.figure(figsize=(6, figsize_multiplier))
ax = fig.add_subplot(111)
epochs = range(len(train_loss_history))
# Plot losses
linestyles = mpltex.linestyle_generator(markers=[])
ax.plot(epochs, train_loss_history, label="Training Loss", **next(linestyles))
if valid_loss_history and any(valid_loss_history):
ax.plot(epochs, valid_loss_history, label="Validation Loss", **next(linestyles))
ax.set_yscale("log")
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss Value")
ax.legend()
ax.grid(True, alpha=0.3)
# Ensure save directory exists
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, filename)
plt.savefig(save_path, bbox_inches="tight")
plt.close(fig)
return save_path
print(f"Loss history plot saved to: '{save_path}'")
[docs]
def plot_average_metrics(
valid_metrics_history,
metric_names, # List of metrics to plot
filename="average_metrics.png",
save_dir="./results",
figsize_multiplier=4,
):
"""
Plots average metrics across all variables in a row-based layout with shared x-axis.
Each row corresponds to one metric, plotting both:
- average_pred_vs_fine_<metric>
- average_coarse_vs_fine_<metric>
Parameters
----------
valid_metrics_history : dict
Dictionary containing validation metrics history.
metric_names : list of str
Names of metrics to plot.
filename : str
Output image file name for the plot.
save_dir : str
Directory to save the plot.
"""
if not metric_names:
print("No metric names provided")
return
num_rows = len(metric_names)
# Create figure: rows = num_rows, 1 column, share x-axis
fig, axes = plt.subplots(
nrows=num_rows,
ncols=1,
figsize=(6, figsize_multiplier * num_rows),
squeeze=False,
sharex=True,
)
plt.subplots_adjust(hspace=0.1, left=0.15, right=0.95, top=0.95, bottom=0.1)
for idx, metric in enumerate(metric_names):
ax = axes[idx, 0]
linestyles = mpltex.linestyle_generator(markers=[])
# Keys
key_pred = f"average_pred_vs_fine_{metric}"
key_coarse = f"average_coarse_vs_fine_{metric}"
# Plot pred vs fine
if key_pred in valid_metrics_history:
hist = valid_metrics_history[key_pred]
if not isinstance(hist, list):
hist = list(hist)
ax.plot(hist, label="Pred vs Fine", **next(linestyles))
# Plot coarse vs fine
if key_coarse in valid_metrics_history:
hist = valid_metrics_history[key_coarse]
if not isinstance(hist, list):
hist = list(hist)
ax.plot(hist, label="Coarse vs Fine", **next(linestyles))
ax.set_yscale("log")
ax.set_ylabel(metric.replace("_", " ").title())
ax.grid(True, alpha=0.3)
ax.legend()
# Only bottom row gets x-label
if idx == num_rows - 1:
ax.set_xlabel("Epoch")
else:
ax.set_xlabel("")
ax.tick_params(labelbottom=False)
# Ensure save directory exists
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, filename)
plt.savefig(save_path, bbox_inches="tight")
plt.close(fig)
return save_path
[docs]
def plot_spatiotemporal_histograms(
steps,
tindex_lim,
centers,
tindices,
mode="train",
filename="average_metrics.png",
save_dir="./results",
figsize_multiplier=4,
):
"""
Plot two 2D hexagonal bin histograms showing spatial-temporal data coverage:
latitude center vs temporal index and longitude center vs temporal index.
This function visualizes the distribution of data samples across spatial
(latitude/longitude) and temporal dimensions using hexagonal binning,
which provides smoother density estimation compared to rectangular binning.
Parameters
----------
steps : EasyDict
Dictionary containing coordinate dimensions and limits. Expected to have
attributes 'latitude' (or 'lat') and 'longitude' (or 'lon') specifying
the maximum spatial indices.
tindex_lim : tuple
Tuple of (min_time, max_time) specifying the temporal index limits.
centers : list of tuples
List of (lat_center, lon_center) coordinates for each data sample.
Each center represents the spatial location of a data point.
tindices : list or array-like
List of temporal indices corresponding to each data sample.
Should have the same length as 'centers'.
mode : str
Dataset mode identifier, typically "train" or "validation".
Used for plot title and filename.
save_dir : str
Directory path where the plot will be saved.
Directory will be created if it doesn't exist.
filename : str, optional
Optional prefix to prepend to the output filename.
Default is empty string.
Returns
-------
None
The function saves the plot to disk and does not return any value.
Notes
-----
- The function creates two side-by-side subplots:
1. Latitude center index vs temporal index
2. Longitude center index vs temporal index
- Uses hexagonal binning (hexbin) for density visualization, which reduces
visual artifacts compared to rectangular histograms.
- A single colorbar is shared between both plots with log10 scaling.
- The color scale is normalized to the maximum count across both histograms.
- Hexagons with zero count (mincnt=1) are not displayed.
Examples
--------
>>> steps = EasyDict({'latitude': 180, 'longitude': 360})
>>> tindex_lim = (0, 1000)
>>> centers = [(10, 20), (15, 25), (10, 20), ...] # list of (lat, lon)
>>> tindices = [0, 5, 10, 15, ...] # corresponding temporal indices
>>> plot_spatiotemporal_histograms(steps, tindex_lim, centers,
... tindices, "train", "./plots")
The function will save a plot named "spatiotemporal_train_hexbin.png"
in the "./plots" directory.
"""
if not centers or not tindices:
print(f"No data to plot for {mode} mode")
return
# Convert to numpy arrays for efficient processing
centers = np.array(centers)
lat_centers = centers[:, 0]
lon_centers = centers[:, 1]
tindices = np.array(tindices)
# Extract spatial limits from steps dictionary with fallback options
max_lat = getattr(steps, "latitude", getattr(steps, "lat", None))
max_lon = getattr(steps, "longitude", getattr(steps, "lon", None))
min_time, max_time = tindex_lim
# Create figure with two side-by-side subplots sharing y-axis
fig, (ax1, ax2) = plt.subplots(
1, 2, figsize=(2 * figsize_multiplier, figsize_multiplier), sharey=True
)
plt.subplots_adjust(
hspace=0.1, wspace=0.1, left=0.1, right=0.9, top=0.9, bottom=0.1
)
# Plot latitude vs time using hexagonal binning
hex1 = ax1.hexbin(
lat_centers,
tindices,
gridsize=100, # Number of hexagons in x-direction
extent=[0, max_lat, min_time, max_time], # Data limits
cmap="jet", # Color map (assumed to be defined)
mincnt=1, # Only show hexagons with at least 1 count
edgecolors="none",
) # No borders on hexagons
ax1.set_xlabel("Latitude Center Index", fontsize=12)
ax1.set_ylabel("Temporal Index", fontsize=12)
ax1.set_xlim(0, max_lat)
ax1.set_ylim(min_time, max_time)
ax1.grid(True, alpha=0.3, linestyle="--")
# Plot longitude vs time using hexagonal binning
hex2 = ax2.hexbin(
lon_centers,
tindices,
gridsize=100,
extent=[0, max_lon, min_time, max_time],
cmap="jet",
mincnt=1,
edgecolors="none",
)
ax2.set_xlabel("Longitude Center Index", fontsize=12)
ax2.set_xlim(0, max_lon)
ax2.set_ylim(min_time, max_time)
ax2.grid(True, alpha=0.3, linestyle="--")
# Normalize color scale to maximum count across both plots
max_count = 1
if hex1.get_array() is not None and len(hex1.get_array()) > 0:
max_count = max(max_count, hex1.get_array().max())
if hex2.get_array() is not None and len(hex2.get_array()) > 0:
max_count = max(max_count, hex2.get_array().max())
hex1.set_clim(0, max_count)
hex2.set_clim(0, max_count)
# Add single colorbar for both plots
cbar_ax = fig.add_axes([0.93, 0.1, 0.02, 0.8])
fig.colorbar(hex1, cax=cbar_ax, label=r"$\log_{10}[\mathrm{Count}]$")
# Save figure to disk
os.makedirs(save_dir, exist_ok=True)
filename = f"{filename}spatiotemporal_{mode}_hexbin.png"
save_path = os.path.join(save_dir, filename)
plt.savefig(save_path, bbox_inches="tight")
plt.close(fig)
return save_path
[docs]
def plot_surface(
predictions,
targets,
coarse_inputs,
lat_1d,
lon_1d,
timestamp=None,
variable_names=None,
filename="forecast_plot.png",
save_dir=None,
figsize_multiplier=None,
):
"""
Plot side-by-side forecast maps (coarse_inputs input, true target, model prediction, and difference)
for one or more meteorological variables over a geographic domain.
Parameters
----------
coarse_inputs : torch.Tensor or np.ndarray
coarse_inputs-resolution input data with shape [1, n_vars, H, W].
targets : torch.Tensor or np.ndarray
Ground-truth high-resolution data with shape [1, n_vars, H, W].
predictions : torch.Tensor or np.ndarray
Model predictions at targets resolution with shape [1, n_vars, H, W].
lat_1d : array-like
1D array of latitude coordinates with shape [H].
lon_1d : array-like
1D array of longitude coordinates with shape [W].
timestamp : datetime.datetime
Forecast timestamp to include in the plot title.
variable_names : list of str, optional
Variable names or identifiers.
filename : str, optional
Output filename for saving the plot.
save_dir : str, optional
Directory to save the plot.
figsize_multiplier : int, optional
Base size multiplier for subplots.
Returns
-------
None
"""
# Use defaults from config if not provided
if save_dir is None:
save_dir = PlotConfig.DEFAULT_SAVE_DIR
if figsize_multiplier is None:
figsize_multiplier = PlotConfig.DEFAULT_FIGSIZE_MULTIPLIER
# Convert tensors to numpy if needed
if hasattr(coarse_inputs, "detach"):
coarse_inputs = coarse_inputs.detach().cpu().numpy()
if hasattr(targets, "detach"):
targets = targets.detach().cpu().numpy()
if hasattr(predictions, "detach"):
predictions = predictions.detach().cpu().numpy()
if hasattr(lat_1d, "detach"):
lat_1d = lat_1d.detach().cpu().numpy()
if hasattr(lon_1d, "detach"):
lon_1d = lon_1d.detach().cpu().numpy()
# Create 2D meshgrid from 1D coordinates
lat_min, lat_max = lat_1d.min(), lat_1d.max()
lon_min, lon_max = lon_1d.min(), lon_1d.max()
# Shape
h, w = coarse_inputs[0, 0].shape
lat_block = np.linspace(lat_max, lat_min, h)
lon_block = np.linspace(lon_min, lon_max, w)
lat, lon = np.meshgrid(lat_block, lon_block, indexing="ij")
# Projection center
lon_center = float((lon_min + lon_max) / 2)
# Check data dimensions
n_vars = coarse_inputs.shape[1]
if targets.shape[1] != n_vars:
raise ValueError(
f"targets data has {targets.shape[1]} variables but coarse_inputs has {n_vars}"
)
if predictions.shape[1] != n_vars:
raise ValueError(
f"predictions data has {predictions.shape[1]} variables but coarse_inputs has {n_vars}"
)
# Default variable names if not provided
if variable_names is None:
variable_names = [f"VAR_{i}" for i in range(n_vars)]
# Derive plot names and colormaps
plot_variable_names = [PlotConfig.get_plot_name(var) for var in variable_names]
cmaps = [PlotConfig.get_colormap(var) for var in variable_names]
# Derive vmin/vmax from data for each variable (for coarse_inputs, truth, prediction)
vmin_list = []
vmax_list = []
# Derive vmin/vmax for difference plots (signed difference)
diff_vmin_list = []
diff_vmax_list = []
for i in range(n_vars):
var_name = variable_names[i]
coarse_i = PlotConfig.convert_units(var_name, coarse_inputs[0, i])
target_i = PlotConfig.convert_units(var_name, targets[0, i])
pred_i = PlotConfig.convert_units(var_name, predictions[0, i])
all_data = np.concatenate(
[coarse_i.flatten(), target_i.flatten(), pred_i.flatten()]
)
# Calculate vmin/vmax (using quantile approach like original function)
all_data_flat = all_data[~np.isnan(all_data)]
if len(all_data_flat) > 0:
q_low, q_high = np.quantile(all_data_flat, [0.02, 0.98])
vmin, vmax = float(q_low), float(q_high)
else:
vmin, vmax = -1, 1
# Ensure vmin < vmax
if vmin >= vmax:
vmin, vmax = float(np.nanmin(all_data)), float(np.nanmax(all_data))
vmin_list.append(vmin)
vmax_list.append(vmax)
# Calculate signed difference between prediction and truth
fixed_range = PlotConfig.get_fixed_diff_range(var_name)
diff_data = (predictions[0, i] - targets[0, i]).flatten()
diff_data = diff_data[~np.isnan(diff_data)]
if fixed_range is not None:
diff_vmin, diff_vmax = fixed_range
else:
if len(diff_data) > 0:
# For signed difference, we want symmetric range around 0
max_abs_diff = np.max(np.abs(diff_data))
diff_vmin = -max_abs_diff * 1.1 # Add 10% padding
diff_vmax = max_abs_diff * 1.1 # Add 10% padding
# If all differences are zero or very small
if diff_vmax <= 0.001:
diff_vmin, diff_vmax = -0.1, 0.1
else:
diff_vmin, diff_vmax = -1, 1
diff_vmin_list.append(diff_vmin)
diff_vmax_list.append(diff_vmax)
# Use fixed figure size instead of geo_ratio calculation
# This ensures rectangular panels regardless of location
base_width_per_panel = 4.5 # Same as original scale
base_height_per_panel = 3.0 # Keep this as is
fig_width = base_width_per_panel * n_vars
fig_height = base_height_per_panel * 4 # 4 rows
# Set up figure
fig, axes = plt.subplots(
4,
n_vars, # 4 rows, n_vars columns
figsize=(fig_width, fig_height),
subplot_kw={
"projection": ccrs.PlateCarree(central_longitude=lon_center)
}, # ccrs.Mercator(central_longitude=lon_center)
gridspec_kw={"wspace": 0.1, "hspace": 0.1}, # Keep spacing
squeeze=False,
)
# Main title
if timestamp is not None:
# fig.suptitle(
# f"Forecast for {timestamp.strftime('%Y-%m-%d %H:%M')}",
# fontsize=16, y=1.02
# )
print(f"Forecast for {timestamp.strftime('%Y-%m-%d %H:%M')}")
# Plot each variable
for col_idx in range(n_vars):
var_name = variable_names[col_idx]
# plot_name = plot_variable_names[col_idx]
coarse_inputs_data = PlotConfig.convert_units(
var_name, coarse_inputs[0, col_idx]
)
targets_data = PlotConfig.convert_units(var_name, targets[0, col_idx])
pred_data = PlotConfig.convert_units(var_name, predictions[0, col_idx])
diff_data = pred_data - targets_data # Signed difference (pred - truth)
# Store image objects for rows that need colorbars
im_coar = None
im_diff = None
# Process all rows
for row_idx in range(4):
ax = axes[row_idx, col_idx]
# Select data based on row
if row_idx == 0:
data = coarse_inputs_data
vmin, vmax = vmin_list[col_idx], vmax_list[col_idx]
cmap = cmaps[col_idx]
elif row_idx == 1:
data = targets_data
vmin, vmax = vmin_list[col_idx], vmax_list[col_idx]
cmap = cmaps[col_idx]
elif row_idx == 2:
data = pred_data
vmin, vmax = vmin_list[col_idx], vmax_list[col_idx]
cmap = cmaps[col_idx]
else: # row_idx == 3
data = diff_data
vmin, vmax = diff_vmin_list[col_idx], diff_vmax_list[col_idx]
cmap = "RdBu_r" # Diverging colormap for differences
# Create the plot
im = ax.pcolormesh(
lon,
lat,
data,
vmin=vmin,
vmax=vmax,
cmap=cmap,
transform=ccrs.PlateCarree(),
shading="auto",
)
# Store image objects for rows that need colorbars
if row_idx == 0:
im_coar = im
elif row_idx == 3:
im_diff = im
# Set extent and features
ax.set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree())
ax.coastlines(linewidth=0.6)
ax.add_feature(
cfeature.BORDERS.with_scale("50m"),
linewidth=0.9,
linestyle="--",
edgecolor="black",
zorder=11,
)
ax.add_feature(
cfeature.LAKES.with_scale("50m"),
edgecolor="black",
facecolor="none",
linewidth=0.9,
zorder=9,
)
# ax.set_aspect("auto") # CRITICAL: This makes panels rectangular regardless of projection
ax.set_xticks([])
ax.set_yticks([])
# Add colorbar for PREDICTION row (row 2)
if im_coar is not None:
ax_coar = axes[0, col_idx]
# Position at top of panel: [x, y, width, height] where y > 1.0 places it above
cax_top = ax_coar.inset_axes([0.1, 1.05, 0.8, 0.05])
cbar = fig.colorbar(im_coar, cax=cax_top, orientation="horizontal")
cbar.set_label(f"{plot_variable_names[col_idx]}")
cax_top.xaxis.set_ticks_position("top")
cax_top.xaxis.set_label_position("top")
# Add colorbar for DIFFERENCE row (row 3)
if im_diff is not None:
ax_diff = axes[3, col_idx]
cax_diff = ax_diff.inset_axes([0.1, -0.12, 0.8, 0.05])
fig.colorbar(
im_diff,
cax=cax_diff,
orientation="horizontal",
label=f"Δ {plot_variable_names[col_idx]} (Pred - Truth)",
)
# Add row labels on the left side
row_labels = ["Coarse", "Truth", "Prediction", "Pred - Truth"]
for row_idx, label in enumerate(row_labels):
axes[row_idx, 0].text(
-0.12,
0.5,
label,
transform=axes[row_idx, 0].transAxes,
va="center",
ha="right",
rotation="vertical",
fontsize=12,
)
# Adjust layout - give more room at bottom for colorbars
fig.subplots_adjust(
top=0.90, bottom=0.25, left=0.10, right=0.95, wspace=0.1, hspace=0.15
)
# Save figure
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, filename)
plt.savefig(save_path, bbox_inches="tight")
plt.close(fig)
return save_path
[docs]
def plot_ensemble_surface(
predictions_ens,
lat_1d,
lon_1d,
variable_names,
timestamp=None,
filename="ensemble_surface.png",
save_dir="./results",
):
"""
Plot ensemble members, ensemble mean, and ensemble spread.
Parameters
----------
predictions_ens : torch.Tensor or np.ndarray
Ensemble predictions of shape [n_ensemble_members, n_vars, H, W]
lat_1d : array-like
1D array of latitude coordinates with shape [H].
lon_1d : array-like
1D array of longitude coordinates with shape [W].
variable_names : list of str, optional
Variable names or identifiers.
timestamp : datetime.datetime
Forecast timestamp to include in the plot title.
filename : str, optional
Output filename for saving the plot.
save_dir : str, optional
Directory to save the plot.
figsize_multiplier : int, optional
Base size multiplier for subplots.
Returns
-------
None
"""
if torch.is_tensor(predictions_ens):
predictions_ens = predictions_ens.detach().cpu().numpy()
N_ens, C, H, W = predictions_ens.shape
if N_ens < 3:
raise ValueError("Need at least 3 ensemble members")
# Ensemble statistics
ensemble_mean = np.mean(predictions_ens, axis=0)
ensemble_std = np.std(predictions_ens, axis=0)
lat_1d = np.asarray(lat_1d)
lon_1d = np.asarray(lon_1d)
lat_min, lat_max = lat_1d.min(), lat_1d.max()
lon_min, lon_max = lon_1d.min(), lon_1d.max()
lat_block = np.linspace(lat_max, lat_min, H)
lon_block = np.linspace(lon_min, lon_max, W)
lat, lon = np.meshgrid(lat_block, lon_block, indexing="ij")
lon_center = float((lon_min + lon_max) / 2)
plot_variable_names = [PlotConfig.get_plot_name(v) for v in variable_names]
cmaps = [PlotConfig.get_colormap(v) for v in variable_names]
# Compute vmin/vmax
vmin_list = []
vmax_list = []
for i in range(C):
var = variable_names[i]
ens_members = [
PlotConfig.convert_units(var, predictions_ens[k, i]) for k in range(N_ens)
]
mean_i = PlotConfig.convert_units(var, ensemble_mean[i])
all_data = np.concatenate(
[x.flatten() for x in ens_members] + [mean_i.flatten()]
)
all_data = all_data[~np.isnan(all_data)]
if len(all_data) > 0:
q_low, q_high = np.quantile(all_data, [0.02, 0.98])
vmin, vmax = float(q_low), float(q_high)
else:
vmin, vmax = -1, 1
if vmin >= vmax:
vmin = float(np.nanmin(all_data))
vmax = float(np.nanmax(all_data))
vmin_list.append(vmin)
vmax_list.append(vmax)
n_rows = 5
n_cols = C
base_width_per_panel = 4.5
base_height_per_panel = 3.0
fig_width = base_width_per_panel * n_cols
fig_height = base_height_per_panel * n_rows
fig, axes = plt.subplots(
n_rows,
n_cols,
figsize=(fig_width, fig_height),
subplot_kw={"projection": ccrs.PlateCarree(central_longitude=lon_center)},
gridspec_kw={"wspace": 0.1, "hspace": 0.15},
squeeze=False,
)
row_labels = [
"Prediction 1",
"Prediction 2",
"Prediction 3",
"Ensemble Mean",
"Ensemble std (σ)",
]
for col in range(n_cols):
var = variable_names[col]
member1 = PlotConfig.convert_units(var, predictions_ens[0, col])
member2 = PlotConfig.convert_units(var, predictions_ens[1, col])
member3 = PlotConfig.convert_units(var, predictions_ens[2, col])
mean_field = PlotConfig.convert_units(var, ensemble_mean[col])
std_field = PlotConfig.convert_units(var, ensemble_std[col])
rows_data = [member1, member2, member3, mean_field, std_field]
im_main = None
im_spread = None
for row in range(n_rows):
ax = axes[row, col]
if row == 4:
cmap = "Reds"
vmin = 0
vmax = np.nanmax(std_field)
# vmax = np.quantile(std_field, 0.99)
else:
cmap = cmaps[col]
vmin = vmin_list[col]
vmax = vmax_list[col]
im = ax.pcolormesh(
lon,
lat,
rows_data[row],
vmin=vmin,
vmax=vmax,
cmap=cmap,
transform=ccrs.PlateCarree(),
shading="auto",
)
ax.set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree())
ax.coastlines(linewidth=0.6)
ax.add_feature(
cfeature.BORDERS.with_scale("50m"),
linewidth=0.9,
linestyle="--",
edgecolor="black",
zorder=11,
)
ax.add_feature(
cfeature.LAKES.with_scale("50m"),
edgecolor="black",
facecolor="none",
linewidth=0.9,
zorder=9,
)
ax.set_xticks([])
ax.set_yticks([])
if row == 4:
im_spread = im
else:
if im_main is None:
im_main = im
ax_top = axes[0, col]
cax_top = ax_top.inset_axes([0.1, 1.05, 0.8, 0.05])
cbar = fig.colorbar(im_main, cax=cax_top, orientation="horizontal")
cbar.set_label(plot_variable_names[col])
cax_top.xaxis.set_ticks_position("top")
cax_top.xaxis.set_label_position("top")
ax_bottom = axes[4, col]
cax_bottom = ax_bottom.inset_axes([0.1, -0.12, 0.8, 0.05])
fig.colorbar(
im_spread,
cax=cax_bottom,
orientation="horizontal",
label=f"Std {plot_variable_names[col]}",
)
for r, label in enumerate(row_labels):
axes[r, 0].text(
-0.12,
0.5,
label,
transform=axes[r, 0].transAxes,
va="center",
ha="right",
rotation="vertical",
fontsize=12,
)
if timestamp is not None:
print(f"Ensemble predictions — {timestamp}")
fig.subplots_adjust(
top=0.90,
bottom=0.25,
left=0.10,
right=0.95,
wspace=0.1,
hspace=0.15,
)
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, filename)
plt.savefig(save_path, bbox_inches="tight")
plt.close(fig)
return save_path
[docs]
def plot_zoom_comparison(
predictions,
targets,
lat_1d,
lon_1d,
variable_names=None,
filename="zoom_plot.png",
save_dir=None,
zoom_box=None,
):
"""
Plot a comparison between ground truth and model predictions with a geographic zoom.
Parameters
----------
targets : torch.Tensor or np.ndarray
Ground-truth high-resolution data with shape [1, n_vars, H, W].
predictions : torch.Tensor or np.ndarray
Model predictions at targets resolution with shape [1, n_vars, H, W].
lat_1d : array-like
1D array of latitude coordinates with shape [H].
lon_1d : array-like
1D array of longitude coordinates with shape [W].
variable_names : list of str, optional
Variable names or identifiers.
filename : str, optional
Output filename for saving the plot.
save_dir : str, optional
Directory to save the plot.
zoom_box : dict, optional
Dictionary defining the zoom region with keys.
Returns
-------
None
"""
if save_dir is None:
save_dir = PlotConfig.DEFAULT_SAVE_DIR
if zoom_box is None:
zoom_box = {"lat_min": -23, "lat_max": 13, "lon_min": 255, "lon_max": 345}
# Convert tensors
if hasattr(predictions, "detach"):
predictions = predictions.detach().cpu().numpy()
if hasattr(targets, "detach"):
targets = targets.detach().cpu().numpy()
if hasattr(lat_1d, "detach"):
lat_1d = lat_1d.detach().cpu().numpy()
if hasattr(lon_1d, "detach"):
lon_1d = lon_1d.detach().cpu().numpy()
lat = lat_1d
lon = lon_1d
lon2d, lat2d = np.meshgrid(lon, lat)
lat_min, lat_max = lat.min(), lat.max()
lon_min, lon_max = lon.min(), lon.max()
lon_center = float((lon_min + lon_max) / 2)
lat_mask = (lat >= zoom_box["lat_min"]) & (lat <= zoom_box["lat_max"])
lon_mask = (lon >= zoom_box["lon_min"]) & (lon <= zoom_box["lon_max"])
lat_zoom = lat[lat_mask]
lon_zoom = lon[lon_mask]
lon_zoom2d, lat_zoom2d = np.meshgrid(lon_zoom, lat_zoom)
n_vars = targets.shape[1]
if variable_names is None:
variable_names = [f"VAR_{i}" for i in range(n_vars)]
plot_variable_names = [PlotConfig.get_plot_name(v) for v in variable_names]
cmaps = [PlotConfig.get_colormap(v) for v in variable_names]
proj_global = ccrs.PlateCarree(central_longitude=lon_center)
proj_zoom = ccrs.PlateCarree()
base_width_per_panel = 4.5
base_height_per_panel = 2.5
fig = plt.figure(figsize=(base_width_per_panel * n_vars, base_height_per_panel * 4))
left_margin = 0.08
right_margin = 0.02
bottom_margin = 0.08
top_margin = 0.06
hspace = 0.002
wspace = 0.008
total_width = 1 - left_margin - right_margin
total_height = 1 - bottom_margin - top_margin
col_width = total_width / n_vars
row_height = total_height / 4
axes = np.empty((4, n_vars), dtype=object)
for row in range(4):
for col in range(n_vars):
proj = proj_global if row == 0 else proj_zoom
x0 = left_margin + col * col_width + wspace / 2
y0 = 1 - top_margin - (row + 1) * row_height + hspace / 2
w = col_width - wspace
h = row_height - hspace
axes[row, col] = fig.add_axes([x0, y0, w, h], projection=proj)
coastline = cfeature.COASTLINE.with_scale("50m")
borders = cfeature.BORDERS.with_scale("50m")
for col in range(n_vars):
var = variable_names[col]
truth = PlotConfig.convert_units(var, targets[0, col])
pred = PlotConfig.convert_units(var, predictions[0, col])
mae = np.abs(pred - truth)
cmap = cmaps[col]
all_data = np.concatenate([truth.flatten(), pred.flatten()])
all_data = all_data[~np.isnan(all_data)]
vmin, vmax = np.quantile(all_data, [0.02, 0.98])
mae_vmax = np.quantile(mae[~np.isnan(mae)], 0.98)
truth_zoom = truth[np.ix_(lat_mask, lon_mask)]
pred_zoom = pred[np.ix_(lat_mask, lon_mask)]
mae_zoom = mae[np.ix_(lat_mask, lon_mask)]
# ---- Row 0 Truth global ----
ax = axes[0, col]
im = ax.pcolormesh(
lon2d,
lat2d,
truth,
cmap=cmap,
vmin=vmin,
vmax=vmax,
transform=ccrs.PlateCarree(),
shading="auto",
)
ax.set_extent([lon_min, lon_max, lat_min, lat_max])
ax.add_feature(coastline, linewidth=0.6)
ax.add_feature(borders, linewidth=0.5)
rect = patches.Rectangle(
(zoom_box["lon_min"], zoom_box["lat_min"]),
zoom_box["lon_max"] - zoom_box["lon_min"],
zoom_box["lat_max"] - zoom_box["lat_min"],
linewidth=2,
edgecolor="red",
facecolor="none",
transform=ccrs.PlateCarree(),
)
ax.add_patch(rect)
zoom_ax = axes[1, col]
fig.add_artist(
ConnectionPatch(
xyA=(zoom_box["lon_min"], zoom_box["lat_max"]),
coordsA=ccrs.PlateCarree()._as_mpl_transform(ax),
xyB=(0, 1),
coordsB=zoom_ax.transAxes,
color="red",
linewidth=1.5,
)
)
fig.add_artist(
ConnectionPatch(
xyA=(zoom_box["lon_max"], zoom_box["lat_max"]),
coordsA=ccrs.PlateCarree()._as_mpl_transform(ax),
xyB=(1, 1),
coordsB=zoom_ax.transAxes,
color="red",
linewidth=1.5,
)
)
im_global = im
# ---- Row 1 Truth zoom ----
ax = axes[1, col]
ax.pcolormesh(
lon_zoom2d,
lat_zoom2d,
truth_zoom,
cmap=cmap,
vmin=vmin,
vmax=vmax,
transform=ccrs.PlateCarree(),
shading="auto",
)
ax.set_extent(
[
lon_zoom.min(),
lon_zoom.max(),
lat_zoom.min(),
lat_zoom.max(),
],
crs=ccrs.PlateCarree(),
)
ax.add_feature(coastline, linewidth=0.6)
ax.add_feature(borders, linewidth=0.5)
# ---- Row 2 Prediction zoom ----
ax = axes[2, col]
ax.pcolormesh(
lon_zoom2d,
lat_zoom2d,
pred_zoom,
cmap=cmap,
vmin=vmin,
vmax=vmax,
transform=ccrs.PlateCarree(),
shading="auto",
)
ax.set_extent(
[
lon_zoom.min(),
lon_zoom.max(),
lat_zoom.min(),
lat_zoom.max(),
],
crs=ccrs.PlateCarree(),
)
ax.add_feature(coastline, linewidth=0.6)
ax.add_feature(borders, linewidth=0.5)
# ---- Row 3 MAE ----
ax = axes[3, col]
im_mae = ax.pcolormesh(
lon_zoom2d,
lat_zoom2d,
mae_zoom,
cmap="Reds",
vmin=0,
vmax=mae_vmax,
transform=ccrs.PlateCarree(),
shading="auto",
)
ax.set_extent(
[
lon_zoom.min(),
lon_zoom.max(),
lat_zoom.min(),
lat_zoom.max(),
],
crs=ccrs.PlateCarree(),
)
ax.add_feature(coastline, linewidth=0.6)
ax.add_feature(borders, linewidth=0.5)
# Colorbars
cax = ax.inset_axes([0.15, -0.25, 0.7, 0.05])
fig.colorbar(im_mae, cax=cax, orientation="horizontal").set_label(
f"MAE {plot_variable_names[col]}"
)
ax_top = axes[0, col]
cax_top = ax_top.inset_axes([0.1, 1.05, 0.8, 0.05])
cbar = fig.colorbar(im_global, cax=cax_top, orientation="horizontal")
cbar.set_label(plot_variable_names[col])
cax_top.xaxis.set_ticks_position("top")
cax_top.xaxis.set_label_position("top")
for r in range(4):
axes[r, col].set_xticks([])
axes[r, col].set_yticks([])
# Labels lignes
row_labels = ["Truth", "Truth (Zoom)", "Prediction (Zoom)", "MAE"]
for i, label in enumerate(row_labels):
axes[i, 0].text(
-0.15,
0.5,
label,
transform=axes[i, 0].transAxes,
rotation=90,
va="center",
ha="right",
)
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, filename)
plt.savefig(save_path, bbox_inches="tight")
plt.close()
return save_path
[docs]
def plot_global_surface_robinson(
predictions,
targets,
coarse_inputs,
lat_1d,
lon_1d,
timestamp=None,
variable_names=None,
filename="global_robinson.png",
save_dir=None,
figsize_multiplier=None,
):
"""
Plot coarse, truth, prediction and difference fields in Robinson projection.
Parameters
----------
coarse_inputs : torch.Tensor or np.ndarray
coarse_inputs-resolution input data with shape [1, n_vars, H, W].
targets : torch.Tensor or np.ndarray
Ground-truth high-resolution data with shape [1, n_vars, H, W].
predictions : torch.Tensor or np.ndarray
Model predictions at targets resolution with shape [1, n_vars, H, W].
lat_1d : array-like
1D array of latitude coordinates with shape [H].
lon_1d : array-like
1D array of longitude coordinates with shape [W].
timestamp : datetime.datetime
Forecast timestamp to include in the plot title.
variable_names : list of str, optional
Variable names or identifiers.
filename : str, optional
Output filename for saving the plot.
save_dir : str, optional
Directory to save the plot.
figsize_multiplier : int, optional
Base size multiplier for subplots.
Returns
-------
None
"""
# Use defaults from config if not provided
if save_dir is None:
save_dir = PlotConfig.DEFAULT_SAVE_DIR
if figsize_multiplier is None:
figsize_multiplier = PlotConfig.DEFAULT_FIGSIZE_MULTIPLIER
# Convert tensors to numpy if needed
if hasattr(coarse_inputs, "detach"):
coarse_inputs = coarse_inputs.detach().cpu().numpy()
if hasattr(targets, "detach"):
targets = targets.detach().cpu().numpy()
if hasattr(predictions, "detach"):
predictions = predictions.detach().cpu().numpy()
if hasattr(lat_1d, "detach"):
lat_1d = lat_1d.detach().cpu().numpy()
if hasattr(lon_1d, "detach"):
lon_1d = lon_1d.detach().cpu().numpy() #
# Create 2D meshgrid from 1D coordinates
lat_min, lat_max = lat_1d.min(), lat_1d.max()
lon_min, lon_max = lon_1d.min(), lon_1d.max()
# Shape
h, w = coarse_inputs[0, 0].shape
lat_block = np.linspace(lat_max, lat_min, h)
lon_block = np.linspace(lon_min, lon_max, w)
lat2d, lon2d = np.meshgrid(lat_block, lon_block, indexing="ij")
lon2d = ((lon2d + 180) % 360) - 180 # normalize
# Check data dimensions
n_vars = coarse_inputs.shape[1]
if targets.shape[1] != n_vars:
raise ValueError(
f"targets data has {targets.shape[1]} variables but coarse_inputs has {n_vars}"
)
if predictions.shape[1] != n_vars:
raise ValueError(
f"predictions data has {predictions.shape[1]} variables but coarse_inputs has {n_vars}"
)
# Default variable names if not provided
if variable_names is None:
variable_names = [f"VAR_{i}" for i in range(n_vars)]
# Derive plot names and colormaps
plot_variable_names = [PlotConfig.get_plot_name(var) for var in variable_names]
cmaps = [PlotConfig.get_colormap(var) for var in variable_names]
# Derive vmin/vmax from data for each variable (for coarse_inputs, truth, prediction)
vmin_list = []
vmax_list = []
# Derive vmin/vmax for difference plots (signed difference)
diff_vmin_list = []
diff_vmax_list = []
for i in range(n_vars):
# Get combined data range for this variable (coarse_inputs, truth, prediction)
all_data = np.concatenate(
[
coarse_inputs[0, i].flatten(),
targets[0, i].flatten(),
predictions[0, i].flatten(),
]
)
# Calculate vmin/vmax (using quantile approach like original function)
all_data_flat = all_data[~np.isnan(all_data)]
if len(all_data_flat) > 0:
q_low, q_high = np.quantile(all_data_flat, [0.02, 0.98])
vmin, vmax = float(q_low), float(q_high)
else:
vmin, vmax = -1, 1
# Ensure vmin < vmax
if vmin >= vmax:
vmin, vmax = float(np.nanmin(all_data)), float(np.nanmax(all_data))
vmin_list.append(vmin)
vmax_list.append(vmax)
# Calculate signed difference between prediction and truth
diff_data = (predictions[0, i] - targets[0, i]).flatten()
diff_data = diff_data[~np.isnan(diff_data)]
if len(diff_data) > 0:
# For signed difference, we want symmetric range around 0
max_abs_diff = np.max(np.abs(diff_data))
diff_vmin = -max_abs_diff * 1.1 # Add 10% padding
diff_vmax = max_abs_diff * 1.1 # Add 10% padding
# If all differences are zero or very small
if diff_vmax <= 0.001:
diff_vmin, diff_vmax = -0.1, 0.1
else:
diff_vmin, diff_vmax = -1, 1
diff_vmin_list.append(diff_vmin)
diff_vmax_list.append(diff_vmax)
# Set up figure
fig, axes = plt.subplots(
4,
n_vars, # 4 rows, n_vars columns
figsize=(4.5 * n_vars, 3.2 * 4),
subplot_kw={"projection": ccrs.Robinson()},
gridspec_kw={"hspace": 0.12, "wspace": 0.05},
)
if n_vars == 1:
axes = axes.reshape(4, 1)
row_labels = ["Coarse", "Truth", "Prediction", "Pred − Truth"]
# Plot each variable
for col in range(n_vars):
coarse = coarse_inputs[0, col]
truth = targets[0, col]
pred = predictions[0, col]
diff = pred - truth
data_rows = [coarse, truth, pred, diff]
vmins = [vmin_list[col]] * 3 + [diff_vmin_list[col]]
vmaxs = [vmax_list[col]] * 3 + [diff_vmax_list[col]]
cmaps_row = [cmaps[col]] * 3 + ["RdBu_r"]
for row in range(4):
ax = axes[row, col]
ax.set_global()
# Create the plot
im = ax.pcolormesh(
lon2d,
lat2d,
data_rows[row],
transform=ccrs.PlateCarree(),
cmap=cmaps_row[row],
vmin=vmins[row],
vmax=vmaxs[row],
shading="auto",
)
ax.coastlines(linewidth=0.9)
ax.add_feature(
cfeature.BORDERS.with_scale("50m"),
linewidth=0.9,
linestyle="--",
edgecolor="black",
zorder=11,
)
ax.add_feature(
cfeature.LAKES.with_scale("50m"),
edgecolor="black",
facecolor="none",
linewidth=0.9,
zorder=9,
)
ax.set_xticks([])
ax.set_yticks([])
# if row == 0:
# ax.set_title(plot_variable_names[col], fontsize=13)
if col == 0:
ax.text(
-0.08,
0.5,
row_labels[row],
transform=ax.transAxes,
va="center",
ha="right",
rotation=90,
fontsize=12,
)
# Colorbars
if row == 0:
# cax = ax.inset_axes([0.1, 1.02, 0.8, 0.05])
cax = ax.inset_axes([0.1, 1.08, 0.8, 0.05])
cb = fig.colorbar(im, cax=cax, orientation="horizontal")
cb.set_label(plot_variable_names[col])
cax.xaxis.set_ticks_position("top")
cax.xaxis.set_label_position("top")
if row == 3:
cax = ax.inset_axes([0.1, -0.12, 0.8, 0.05])
fig.colorbar(
im,
cax=cax,
orientation="horizontal",
label=f"Δ {plot_variable_names[col]} (Pred - Truth)",
)
if timestamp is not None:
fig.suptitle(
f"Global Robinson diagnostic – {timestamp.strftime('%Y-%m-%d %H:%M')}",
fontsize=16,
y=0.96,
)
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, filename)
plt.savefig(save_path, bbox_inches="tight")
plt.close(fig)
return save_path
[docs]
def plot_MAE_map(
predictions, # Model predictions (fine predicted)
targets, # Ground truth (fine true)
lat_1d,
lon_1d,
timestamp=None,
variable_names=None,
filename="validation_mae_map.png",
save_dir=None,
figsize_multiplier=None, # Base size per subplot
):
"""
Plot spatial MAE maps averaged over all time steps:
MAE(x, y) = mean_t(abs(prediction - target))
Parameters
----------
predictions : torch.Tensor or np.array
Model predictions of shape [batch_size, num_variables, h, w]
targets : torch.Tensor or np.array
Ground truth of shape [batch_size, num_variables, h, w]
lat_1d : array-like
1D array of latitude coordinates with shape [H].
lon_1d : array-like
1D array of longitude coordinates with shape [W].
timestamp : datetime.datetime
Forecast timestamp to include in the plot title.
variable_names : list of str, optional
Variable names or identifiers.
filename : str, optional
Output filename for saving the plot.
save_dir : str, optional
Directory to save the plot.
figsize_multiplier : int, optional
Base size multiplier for subplots.
Returns
-------
None
"""
if save_dir is None:
save_dir = PlotConfig.DEFAULT_SAVE_DIR
if figsize_multiplier is None:
figsize_multiplier = PlotConfig.DEFAULT_FIGSIZE_MULTIPLIER
# Convert tensors to numpy
if hasattr(predictions, "detach"):
predictions = predictions.detach().cpu().numpy()
if hasattr(targets, "detach"):
targets = targets.detach().cpu().numpy()
if hasattr(lat_1d, "detach"):
lat_1d = lat_1d.detach().cpu().numpy()
if hasattr(lon_1d, "detach"):
lon_1d = lon_1d.detach().cpu().numpy()
lat_min, lat_max = lat_1d.min(), lat_1d.max()
lon_min, lon_max = lon_1d.min(), lon_1d.max()
T, n_vars, h, w = predictions.shape
lat_block = np.linspace(lat_max, lat_min, h)
lon_block = np.linspace(lon_min, lon_max, w)
lat, lon = np.meshgrid(lat_block, lon_block, indexing="ij")
lon_center = float((lon_min + lon_max) / 2)
if targets.shape[1] != n_vars:
raise ValueError("targets and predictions must have same number of variables")
if variable_names is None:
variable_names = [f"VAR_{i}" for i in range(n_vars)]
plot_variable_names = [PlotConfig.get_plot_name(var) for var in variable_names]
# cmaps = [PlotConfig.get_colormap(var) for var in variable_names]
cmaps = PlotConfig.get_colormap("mae")
vmin_list, vmax_list = [], []
# MAE averaged over time for color scaling
for i in range(n_vars):
# mae_data = np.mean(np.abs(predictions[:, i] - targets[:, i]), axis=0)
pred_i = PlotConfig.convert_units(variable_names[i], predictions[:, i])
tgt_i = PlotConfig.convert_units(variable_names[i], targets[:, i])
mae_data = np.mean(np.abs(pred_i - tgt_i), axis=0)
mae_flat = mae_data.flatten()
mae_flat = mae_flat[~np.isnan(mae_flat)]
fixed_range = PlotConfig.get_fixed_mae_range(variable_names[i])
if fixed_range is not None:
vmin, vmax = fixed_range
else:
if len(mae_flat) > 0:
q_low, q_high = np.quantile(mae_flat, [0.02, 0.98])
vmin, vmax = float(q_low), float(q_high)
else:
vmin, vmax = 0.0, 1.0
if vmin >= vmax:
vmin, vmax = float(np.nanmin(mae_flat)), float(np.nanmax(mae_flat))
vmin_list.append(vmin)
vmax_list.append(vmax)
base_width_per_panel = 4.5
base_height_per_panel = 3.0
fig_width = base_width_per_panel * n_vars
fig_height = base_height_per_panel
fig, axes = plt.subplots(
1,
n_vars,
figsize=(fig_width, fig_height),
subplot_kw={
"projection": ccrs.PlateCarree(central_longitude=lon_center)
}, # ccrs.Mercator(central_longitude=lon_center)
gridspec_kw={"wspace": 0.1},
)
if n_vars == 1:
axes = [axes]
if timestamp is not None:
fig.suptitle(
f"MAE Map (time-averaged) - {timestamp.strftime('%Y-%m-%d %H:%M')}",
fontsize=16,
y=1.05,
)
for col_idx in range(n_vars):
ax = axes[col_idx]
pred_i = PlotConfig.convert_units(
variable_names[col_idx], predictions[:, col_idx]
)
tgt_i = PlotConfig.convert_units(variable_names[col_idx], targets[:, col_idx])
# MAE averaged over all time steps
mae_data = np.mean(np.abs(pred_i - tgt_i), axis=0)
im = ax.pcolormesh(
lon,
lat,
mae_data,
vmin=vmin_list[col_idx],
vmax=vmax_list[col_idx],
cmap=cmaps,
transform=ccrs.PlateCarree(),
shading="auto",
)
ax.set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree())
# ax.set_global()
ax.coastlines(linewidth=0.6)
ax.add_feature(
cfeature.BORDERS.with_scale("50m"),
linewidth=0.6,
linestyle="--",
edgecolor="black",
zorder=11,
)
ax.add_feature(
cfeature.LAKES.with_scale("50m"),
edgecolor="black",
facecolor="none",
linewidth=0.6,
zorder=9,
)
# ax.set_aspect("auto")
ax.set_xticks([])
ax.set_yticks([])
cax = ax.inset_axes([0.1, -0.15, 0.8, 0.05])
fig.colorbar(
im,
cax=cax,
orientation="horizontal",
label=f"MAE {plot_variable_names[col_idx]}",
)
fig.subplots_adjust(top=0.85, bottom=0.25, left=0.08, right=0.95)
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, filename)
plt.savefig(save_path, bbox_inches="tight")
plt.close(fig)
return save_path
[docs]
def plot_error_map(
predictions, # Model predictions (fine predicted)
targets, # Ground truth (fine true)
lat_1d,
lon_1d,
timestamp=None,
variable_names=None,
filename="validation_error_map.png",
save_dir=None,
figsize_multiplier=None,
):
"""
Plot spatial ERROR maps averaged over all time steps.
Parameters
----------
predictions : torch.Tensor or np.array
Model predictions of shape [batch_size, num_variables, h, w]
targets : torch.Tensor or np.array
Ground truth of shape [batch_size, num_variables, h, w]
lat_1d : array-like
1D array of latitude coordinates with shape [H].
lon_1d : array-like
1D array of longitude coordinates with shape [W].
timestamp : datetime.datetime
Forecast timestamp to include in the plot title.
variable_names : list of str, optional
Variable names or identifiers.
filename : str, optional
Output filename for saving the plot.
save_dir : str, optional
Directory to save the plot.
figsize_multiplier : int, optional
Base size multiplier for subplots.
Returns
-------
None
"""
if save_dir is None:
save_dir = PlotConfig.DEFAULT_SAVE_DIR
if figsize_multiplier is None:
figsize_multiplier = PlotConfig.DEFAULT_FIGSIZE_MULTIPLIER
# Convert tensors to numpy
if hasattr(predictions, "detach"):
predictions = predictions.detach().cpu().numpy()
if hasattr(targets, "detach"):
targets = targets.detach().cpu().numpy()
if hasattr(lat_1d, "detach"):
lat_1d = lat_1d.detach().cpu().numpy()
if hasattr(lon_1d, "detach"):
lon_1d = lon_1d.detach().cpu().numpy()
lat_min, lat_max = lat_1d.min(), lat_1d.max()
lon_min, lon_max = lon_1d.min(), lon_1d.max()
T, n_vars, h, w = predictions.shape
lat_block = np.linspace(lat_max, lat_min, h)
lon_block = np.linspace(lon_min, lon_max, w)
lat, lon = np.meshgrid(lat_block, lon_block, indexing="ij")
lon_center = float((lon_min + lon_max) / 2)
if variable_names is None:
variable_names = [f"VAR_{i}" for i in range(n_vars)]
plot_variable_names = [PlotConfig.get_plot_name(v) for v in variable_names]
cmaps = PlotConfig.get_colormap("error")
vmin_list, vmax_list = [], []
eps = 1e-6
# Compute time-averaged error for scaling
for i in range(n_vars):
var = variable_names[i]
pred_i = PlotConfig.convert_units(var, predictions[:, i])
tgt_i = PlotConfig.convert_units(var, targets[:, i])
if var.lower() in ["var_tp", "precip", "precipitation"]:
err = np.mean(np.abs(pred_i - tgt_i), axis=0)
else:
err = np.mean(np.abs(pred_i - tgt_i) / (np.abs(tgt_i) + eps), axis=0)
err_flat = err.flatten()
err_flat = err_flat[~np.isnan(err_flat)]
fixed_range = PlotConfig.get_fixed_diff_range_errors(var)
if fixed_range is not None:
vmin, vmax = fixed_range
else:
if len(err_flat) > 0:
vmax = np.max(err_flat)
vmin = 0
vmax = 1.1 * vmax
else:
vmin, vmax = 0, 1
vmin_list.append(vmin)
vmax_list.append(vmax)
base_w, base_h = 4.5, 3.0
fig, axes = plt.subplots(
1,
n_vars,
figsize=(base_w * n_vars, base_h),
subplot_kw={"projection": ccrs.PlateCarree(central_longitude=lon_center)},
gridspec_kw={"wspace": 0.1},
)
if n_vars == 1:
axes = [axes]
for i in range(n_vars):
ax = axes[i]
var = variable_names[i]
pred_i = PlotConfig.convert_units(var, predictions[:, i])
tgt_i = PlotConfig.convert_units(var, targets[:, i])
if var.lower() in ["var_tp", "precip", "precipitation"]:
err_map = np.mean(np.abs(pred_i - tgt_i), axis=0)
label = "Absolute Error (mm/h)"
else:
err_map = np.mean(np.abs(pred_i - tgt_i) / (np.abs(tgt_i) + eps), axis=0)
label = "Relative Error"
im = ax.pcolormesh(
lon,
lat,
err_map,
vmin=vmin_list[i],
vmax=vmax_list[i],
cmap=cmaps,
transform=ccrs.PlateCarree(),
shading="auto",
)
ax.set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree())
ax.coastlines(linewidth=0.6)
ax.add_feature(
cfeature.BORDERS.with_scale("50m"),
linewidth=0.6,
linestyle="--",
edgecolor="black",
zorder=11,
)
ax.add_feature(
cfeature.LAKES.with_scale("50m"),
edgecolor="black",
facecolor="none",
linewidth=0.6,
zorder=9,
)
ax.set_xticks([])
ax.set_yticks([])
ax.set_title(plot_variable_names[i])
cax = ax.inset_axes([0.1, -0.15, 0.8, 0.05])
fig.colorbar(
im,
cax=cax,
orientation="horizontal",
label=f"{label}",
)
fig.subplots_adjust(top=0.85, bottom=0.25, left=0.08, right=0.95)
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, filename)
plt.savefig(save_path, bbox_inches="tight")
plt.close(fig)
return save_path
[docs]
def spread_skill_ratio(
predictions, # Model predictions (fine predicted)
targets, # Ground truth (fine true)
variable_names,
pixel_wise=False,
):
"""
Compute spread skill ratio of predictions with respect to targets.
The formula implemented is equation (15) in "Why Should Ensemble Spread Match the RMSE of the Ensemble Mean?", Fortin et al.
Parameters
----------
predictions : torch.Tensor or np.array
Model predictions of shape [ensemble_size, batch_size, num_variables, h, w]
It is very important not to switch dimensions order.
ensemble_size must be greater or equal than 2 for spread skill ratio to be computed.
targets : torch.Tensor or np.array
Ground truth of shape [batch_size, num_variables, h, w]
variable_names : list of str, optional
Variable names or identifiers.
pixel_wise : bool
If True, computes and return the SSR for each pixel independantly.
If False, computes and return the SSR averaged over all pixels and all timesteps.
Defaults to False
Returns
-------
np.array of shape [num_variables, h, w] if pixel_wise == True
or of shape [num_variables,] if pixel_wise == False (default)
"""
# Convert tensors to numpy
if hasattr(predictions, "detach"):
predictions = predictions.detach().cpu().numpy()
if hasattr(targets, "detach"):
targets = targets.detach().cpu().numpy()
if len(predictions.shape) != 5:
raise ValueError(
"predictions needs to be 5 dimensional tensor / array [ensemble_size, temporal_size, n_vars, h, w]."
)
if predictions.shape[0] == 1:
raise ValueError(
"predictions needs to contain more than 1 member to compute spread skill ratio."
)
E, T, n_vars, h, w = predictions.shape
if targets.shape[1] != n_vars:
raise ValueError("targets and predictions must have same number of variables")
if variable_names is None:
variable_names = [f"VAR_{i}" for i in range(n_vars)]
ssr_list = []
for i in range(n_vars):
# mae_data = np.mean(np.abs(predictions[:, i] - targets[:, i]), axis=0)
pred_i = PlotConfig.convert_units(
variable_names[i], predictions[:, :, i]
) # [E,T,h,w]
tgt_i = PlotConfig.convert_units(variable_names[i], targets[:, i]) # [T,h,w]
# apply the formula (15) found in "Why Should Ensemble Spread Match the RMSE of the Ensemble Mean?", Fortin et al.
mean_pred_i = np.mean(pred_i, axis=0) # ensemble mean [T,h,w]
if pixel_wise:
rmse_data_i = np.sqrt(np.mean((mean_pred_i - tgt_i) ** 2, axis=0)) # [h,w]
spread_i = (
np.sqrt(np.mean(np.var(pred_i, axis=0), axis=0)) * np.sqrt((E + 1) / E)
) # [h,w] # sqrt of temporal mean of variance * corrective factor depending on the number of members in the ensemble.
ssr_i = np.divide(spread_i, rmse_data_i) # [h,w]
else:
# do the same but for average over every pixel and every timestep
rmse_data_i_mean = np.sqrt(np.mean((mean_pred_i - tgt_i) ** 2)) # float
spread_i_mean = (
np.sqrt(np.mean(np.var(pred_i, axis=0))) * np.sqrt((E + 1) / E)
) # float # sqrt of temporal mean of variance * corrective factor depending on the number of members in the ensemble.
ssr_i = np.divide(spread_i_mean, rmse_data_i_mean) # float
ssr_list.append(ssr_i)
return np.array(ssr_list)
[docs]
def plot_spread_skill_ratio_map(
predictions, # Model predictions (fine predicted)
targets, # Ground truth (fine true)
lat_1d,
lon_1d,
timestamp=None,
variable_names=None,
filename="validation_spread_skill_ratio_map.png",
save_dir=None,
figsize_multiplier=None, # Base size per subplot
):
"""
Plot spatial spread skill ratio maps averaged over all time steps for each individual pixel.
The formula implemented is equation (15) in article "Why Should Ensemble Spread Match the RMSE of the Ensemble Mean?", Fortin et al.
Parameters
----------
predictions : torch.Tensor or np.array
Model predictions of shape [ensemble_size, batch_size, num_variables, h, w]
It is very important not to switch dimensions order.
ensemble_size must be greater or equal than 2 for spread skill ratio to be computed.
targets : torch.Tensor or np.array
Ground truth of shape [batch_size, num_variables, h, w]
lat_1d : array-like
1D array of latitude coordinates with shape [H].
lon_1d : array-like
1D array of longitude coordinates with shape [W].
timestamp : datetime.datetime
Forecast timestamp to include in the plot title.
variable_names : list of str, optional
Variable names or identifiers.
filename : str, optional
Output filename for saving the plot.
save_dir : str, optional
Directory to save the plot.
figsize_multiplier : int, optional
Base size multiplier for subplots.
Returns
-------
None
"""
if save_dir is None:
save_dir = PlotConfig.DEFAULT_SAVE_DIR
if figsize_multiplier is None:
figsize_multiplier = PlotConfig.DEFAULT_FIGSIZE_MULTIPLIER
# Convert tensors to numpy
if hasattr(predictions, "detach"):
predictions = predictions.detach().cpu().numpy()
if hasattr(targets, "detach"):
targets = targets.detach().cpu().numpy()
if hasattr(lat_1d, "detach"):
lat_1d = lat_1d.detach().cpu().numpy()
if hasattr(lon_1d, "detach"):
lon_1d = lon_1d.detach().cpu().numpy()
lat_min, lat_max = lat_1d.min(), lat_1d.max()
lon_min, lon_max = lon_1d.min(), lon_1d.max()
if len(predictions.shape) != 5:
raise ValueError(
"predictions needs to be 5 dimensional tensor / array [ensemble_size, temporal_size, n_vars, h, w]."
)
if predictions.shape[0] == 1:
raise ValueError(
"predictions needs to contain more than 1 member to compute spread skill ratio."
)
E, T, n_vars, h, w = predictions.shape
lat_block = np.linspace(lat_max, lat_min, h)
lon_block = np.linspace(lon_min, lon_max, w)
lat, lon = np.meshgrid(lat_block, lon_block, indexing="ij")
lon_center = float((lon_min + lon_max) / 2)
if targets.shape[1] != n_vars:
raise ValueError("targets and predictions must have same number of variables")
if variable_names is None:
variable_names = [f"VAR_{i}" for i in range(n_vars)]
plot_variable_names = [PlotConfig.get_plot_name(var) for var in variable_names]
# cmaps = [PlotConfig.get_colormap(var) for var in variable_names]
cmaps = PlotConfig.get_colormap("SSR")
vmin_list, vmax_list = [], []
ssr_list = spread_skill_ratio(predictions, targets, variable_names, pixel_wise=True)
ssr_mean_list = spread_skill_ratio(
predictions, targets, variable_names, pixel_wise=False
)
# MAE averaged over time for color scaling
for i in range(n_vars):
ssr_i = ssr_list[i]
ssr_i_flat = (ssr_i).flatten() # flatten [h*w]
fixed_range = PlotConfig.get_fixed_ssr_range(variable_names[i])
if fixed_range is not None:
vmin, vmax = fixed_range
else:
if len(ssr_i_flat) > 0:
q_low, q_high = np.quantile(ssr_i_flat, [0.02, 0.98])
vmin, vmax = float(q_low), float(q_high)
else:
vmin, vmax = 0.0, 2.0
if vmin >= vmax:
vmin, vmax = float(np.nanmin(ssr_i_flat)), float(np.nanmax(ssr_i_flat))
vmin_list.append(vmin)
vmax_list.append(vmax)
base_width_per_panel = 4.5
base_height_per_panel = 3.0
fig_width = base_width_per_panel * n_vars
fig_height = base_height_per_panel
fig, axes = plt.subplots(
1,
n_vars,
figsize=(fig_width, fig_height),
subplot_kw={
"projection": ccrs.PlateCarree(central_longitude=lon_center)
}, # ccrs.Mercator(central_longitude=lon_center)
gridspec_kw={"wspace": 0.1},
)
if n_vars == 1:
axes = [axes]
for col_idx in range(n_vars):
ax = axes[col_idx]
# MAE averaged over all time steps
ssr_data_i = ssr_list[col_idx]
ssr_data_i_mean = ssr_mean_list[col_idx]
vmin = vmin_list[col_idx]
vmax = vmax_list[col_idx]
if vmin <= 1 <= vmax:
norm = mcolors.TwoSlopeNorm(vmin=vmin, vcenter=1, vmax=vmax)
im = ax.pcolormesh(
lon,
lat,
ssr_data_i,
cmap=cmaps,
norm=norm,
transform=ccrs.PlateCarree(),
shading="auto",
)
elif vmax >= 1:
norm = mcolors.TwoSlopeNorm(vmin=0, vcenter=1, vmax=vmax)
im = ax.pcolormesh(
lon,
lat,
ssr_data_i,
cmap=cmaps,
norm=norm,
transform=ccrs.PlateCarree(),
shading="auto",
)
else:
norm = mcolors.TwoSlopeNorm(vmin=0, vcenter=1, vmax=2)
im = ax.pcolormesh(
lon,
lat,
ssr_data_i,
cmap=cmaps,
norm=norm,
transform=ccrs.PlateCarree(),
shading="auto",
)
ax.set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree())
# ax.set_global()
ax.coastlines(linewidth=0.6)
ax.add_feature(
cfeature.BORDERS.with_scale("50m"),
linewidth=0.6,
linestyle="--",
edgecolor="black",
zorder=11,
)
ax.add_feature(
cfeature.LAKES.with_scale("50m"),
edgecolor="black",
facecolor="none",
linewidth=0.6,
zorder=9,
)
# ax.set_aspect("auto")
ax.set_xticks([])
ax.set_yticks([])
ax.set_title(
f"{plot_variable_names[col_idx]} (SSR={ssr_data_i_mean:.2f})",
pad=10,
)
cax = ax.inset_axes([0.1, -0.15, 0.8, 0.05])
fig.colorbar(
im,
cax=cax,
orientation="horizontal",
label=f"SSR {plot_variable_names[col_idx]}",
)
fig.subplots_adjust(top=0.85, bottom=0.25, left=0.08, right=0.95)
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, filename)
plt.savefig(save_path, bbox_inches="tight")
plt.close(fig)
return save_path
[docs]
def plot_spread_skill_ratio_hexbin(
predictions, # Model predictions (fine predicted)
targets, # Ground truth (fine true)
variable_names=None,
filename="validation_spread_skill_ratio_hexbin.png",
save_dir=None,
figsize_multiplier=None, # Base size per subplot
):
"""
Plot spatial spread skill ratio scatterplot, where each point represent a prediction for a single pixel, single timestep:
SSR(x, y) = spread(x,y) / skill(x,y)
where spread(x,y) = temporal mean of standard deviation of ensemble members predictions
and skill = temporal mean of RMSE of the mean of the ensemble members.
Parameters
----------
predictions : torch.Tensor or np.array
Model predictions of shape [ensemble_size, batch_size, num_variables, h, w]
It is very important not to switch dimensions order.
ensemble_size must be greater or equal than 2 for spread skill ratio to be computed.
targets : torch.Tensor or np.array
Ground truth of shape [batch_size, num_variables, h, w]
variable_names : list of str, optional
Variable names or identifiers.
filename : str, optional
Output filename for saving the plot.
save_dir : str, optional
Directory to save the plot.
figsize_multiplier : int, optional
Base size multiplier for subplots.
Returns
-------
None
"""
if save_dir is None:
save_dir = PlotConfig.DEFAULT_SAVE_DIR
if figsize_multiplier is None:
figsize_multiplier = PlotConfig.DEFAULT_FIGSIZE_MULTIPLIER
# Convert tensors to numpy
if hasattr(predictions, "detach"):
predictions = predictions.detach().cpu().numpy()
if hasattr(targets, "detach"):
targets = targets.detach().cpu().numpy()
if len(predictions.shape) != 5:
raise ValueError(
"predictions needs to be 5 dimensional tensor / array [ensemble_size, temporal_size, n_vars, h, w]."
)
if predictions.shape[0] == 1:
raise ValueError(
"predictions needs to contain more than 1 member to compute spread skill ratio."
)
E, T, n_vars, h, w = predictions.shape
if targets.shape[1] != n_vars:
raise ValueError("targets and predictions must have same number of variables")
if variable_names is None:
variable_names = [f"VAR_{i}" for i in range(n_vars)]
plot_variable_names = [PlotConfig.get_plot_name(var) for var in variable_names]
rmse_list = []
spread_list = []
mean_ssr_list = spread_skill_ratio(
predictions, targets, variable_names, pixel_wise=False
)
# MAE averaged over time for color scaling
for i in range(n_vars):
# mae_data = np.mean(np.abs(predictions[:, i] - targets[:, i]), axis=0)
pred_i = PlotConfig.convert_units(variable_names[i], predictions[:, :, i])
tgt_i = PlotConfig.convert_units(variable_names[i], targets[:, i])
mean_pred_i = np.mean(pred_i, axis=0)
rmse_data_i = np.abs((mean_pred_i - tgt_i)).flatten()
rmse_list.append(rmse_data_i)
spread_i = np.std(pred_i, axis=0).flatten()
spread_list.append(spread_i)
ncols = n_vars
nrows = (n_vars + ncols - 1) // ncols
fig, axes = plt.subplots(
nrows,
ncols,
figsize=(ncols * figsize_multiplier, figsize_multiplier),
)
axes = np.atleast_1d(axes).ravel()
for ax in axes:
ax.set_box_aspect(1)
last_hb = None
for i, ax in enumerate(axes):
if i >= n_vars:
ax.set_visible(False)
continue
rmse = rmse_list[i]
spread = spread_list[i]
hb = ax.hexbin(
rmse,
spread,
gridsize=100,
cmap="jet",
bins="log",
mincnt=1,
)
last_hb = hb
textstr = f"SSR: {mean_ssr_list[i]:.3f}"
ax.text(
0.05,
0.95,
textstr,
transform=ax.transAxes,
fontsize=10,
verticalalignment="top",
)
var_min = min(rmse.min(), spread.min())
var_max = max(rmse.max(), spread.max())
margin = 0.05 * (var_max - var_min)
plot_min = var_min - margin
plot_max = var_max + margin
# Identity line
ax.plot([plot_min, plot_max], [plot_min, plot_max], "r--", alpha=0.7)
ax.set_xlim(plot_min, plot_max)
ax.set_ylim(plot_min, plot_max)
ax.set_title(plot_variable_names[i])
ax.xaxis.set_major_locator(ticker.MaxNLocator(5))
ax.yaxis.set_major_locator(ticker.MaxNLocator(5))
if i % ncols == 0:
ax.set_ylabel("Ensemble std")
else:
ax.set_ylabel("")
if i >= (nrows - 1) * ncols:
ax.set_xlabel("RMSE")
else:
ax.set_xlabel("")
ax_last = axes[min(n_vars - 1, len(axes) - 1)]
cax = ax_last.inset_axes([1.05, 0.0, 0.04, 1.0])
cbar = fig.colorbar(last_hb, cax=cax)
cbar.set_label(r"$\log_{10}[\mathrm{Count}]$")
plt.subplots_adjust(
hspace=0.1, wspace=0.3, left=0.1, right=0.9, top=0.9, bottom=0.1
)
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, filename)
plt.savefig(save_path, bbox_inches="tight")
plt.close(fig)
return save_path
[docs]
def plot_validation_pdfs(
predictions, # Model predictions (fine predicted)
targets, # Ground truth (fine true)
coarse_inputs=None, # Coarse inputs for comparison (optional)
variable_names=None, # List of variable names
filename="validation_pdfs.png",
save_dir="./results",
figsize_multiplier=4, # Base size per subplot
save_npz=False,
):
"""
Create PDF (Probability Density Function) plots comparing distributions of
model predictions vs ground truth for all variables.
Parameters
----------
predictions : torch.Tensor or np.array
Model predictions of shape [batch_size, num_variables, h, w]
targets : torch.Tensor or np.array
Ground truth of shape [batch_size, num_variables, h, w]
coarse_inputs : torch.Tensor or np.array, optional
Coarse inputs of shape [batch_size, num_variables, h, w]
variable_names : list of str, optional
Names of the variables for subplot titles
filename : str, optional
Output filename
save_dir : str, optional
Directory to save the plot
figsize_multiplier : int, optional
Base size multiplier for subplots
save_npz : bool, optional
If True, saves the PDF diagnostics to a compressed .npz file.
Returns
-------
None
The function saves the plot to disk and does not return any value.
Notes
-----
- Creates horizontal subplots (one per variable) showing PDFs
- Each subplot shows up to 3 lines: Predictions, Ground Truth, and Coarse Inputs
- Uses automatic color and linestyle cycling based on global matplotlib settings
- Calculates and displays key statistics for each distribution
- Handles both PyTorch tensors and numpy arrays
Examples
--------
>>> predictions = np.random.randn(10, 3, 64, 64) # 10 samples, 3 variables
>>> targets = np.random.randn(10, 3, 64, 64)
>>> plot_validation_pdfs(predictions, targets, variable_names=['Temp', 'Pres', 'Humid'])
"""
# Convert to numpy if they're tensors
if hasattr(predictions, "detach"):
predictions = predictions.detach().cpu().numpy()
if hasattr(targets, "detach"):
targets = targets.detach().cpu().numpy()
if coarse_inputs is not None and hasattr(coarse_inputs, "detach"):
coarse_inputs = coarse_inputs.detach().cpu().numpy()
batch_size, num_vars, h, w = predictions.shape
# Default variable names if not provided
if variable_names is None:
variable_names = [f"Variable {i + 1}" for i in range(num_vars)]
plot_variable_names = [PlotConfig.get_plot_name(var) for var in variable_names]
# Calculate grid dimensions for horizontal layout
ncols = num_vars
nrows = 1 # Single row for horizontal layout
# Create figure with horizontal subplots
fig, axes = plt.subplots(
nrows, ncols, figsize=(ncols * figsize_multiplier, figsize_multiplier)
)
# Handle single subplot case
if num_vars == 1:
axes = np.array([axes])
if axes.ndim == 0:
axes = np.array([axes])
axes = axes.flatten()
for ax in axes:
ax.set_box_aspect(1)
plt.subplots_adjust(
hspace=0.1, wspace=0.3, left=0.1, right=0.9, top=0.9, bottom=0.1
)
if save_npz:
pdf_npz_data = {}
# Plot PDF for each variable
for i, (var_name, ax) in enumerate(zip(variable_names, axes)):
if i >= num_vars:
ax.set_visible(False)
continue
linestyles = mpltex.linestyle_generator(markers=[])
# Flatten the spatial dimensions
pred_i = PlotConfig.convert_units(var_name, predictions[:, i])
tgt_i = PlotConfig.convert_units(var_name, targets[:, i])
plot_name = plot_variable_names[i]
pred_flat = pred_i.reshape(-1)
target_flat = tgt_i.reshape(-1)
# Collect all data for combined range
all_data = [pred_flat, target_flat]
if coarse_inputs is not None:
# coarse_flat = coarse_inputs[:, i, :, :].flatten() #.mean(axis=0).reshape(-1)
coarse_i = PlotConfig.convert_units(var_name, coarse_inputs[:, i])
coarse_flat = coarse_i.reshape(-1)
all_data.append(coarse_flat)
# Calculate global range for consistent x-axis
all_values = np.concatenate(all_data)
data_min = np.percentile(all_values, 0.25) # 0.5th percentile
data_max = np.percentile(all_values, 99.5) # 99.5th percentile
data_range = data_max - data_min
# Extend range slightly for better visualization
x_min = data_min - 0.05 * data_range
x_max = data_max + 0.05 * data_range
# Create bins for PDF calculation
n_bins = 100
bins = np.linspace(x_min, x_max, n_bins + 1)
# Small epsilon to avoid log(0)
epsilon = 1e-12
# Plot log PDFs
# Plot predictions
hist_pred, bin_edges = np.histogram(pred_flat, bins=bins, density=True)
bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])
log_hist_pred = np.log10(hist_pred + epsilon)
ax.plot(bin_centers, log_hist_pred, label="Pred", **next(linestyles))
# Plot ground truth
hist_target, _ = np.histogram(target_flat, bins=bins, density=True)
log_hist_target = np.log10(hist_target + epsilon)
ax.plot(bin_centers, log_hist_target, label="Truth", **next(linestyles))
# Plot coarse inputs if available
if coarse_inputs is not None:
hist_coarse, _ = np.histogram(coarse_flat, bins=bins, density=True)
log_hist_coarse = np.log10(hist_coarse + epsilon)
ax.plot(bin_centers, log_hist_coarse, label="Coarse", **next(linestyles))
# Calculate and display statistics
stats_text = []
# Predictions statistics
pred_mean = np.mean(pred_flat)
pred_std = np.std(pred_flat)
stats_text.append(f"Predictions: μ={pred_mean:.3f}, σ={pred_std:.3f}")
# Ground truth statistics
target_mean = np.mean(target_flat)
target_std = np.std(target_flat)
stats_text.append(f"Ground Truth: μ={target_mean:.3f}, σ={target_std:.3f}")
# Coarse statistics if available
if coarse_inputs is not None:
coarse_mean = np.mean(coarse_flat)
coarse_std = np.std(coarse_flat)
stats_text.append(f"Coarse: μ={coarse_mean:.3f}, σ={coarse_std:.3f}")
# Calculate KL divergence between predictions and ground truth
hist_pred_safe = hist_pred + epsilon
hist_target_safe = hist_target + epsilon
# Normalize to probability distributions
hist_pred_safe = hist_pred_safe / np.sum(hist_pred_safe)
hist_target_safe = hist_target_safe / np.sum(hist_target_safe)
kl_divergence = np.sum(
hist_target_safe * np.log(hist_target_safe / hist_pred_safe)
)
# Add KL divergence to statistics
stats_text.append(f"KL Divergence: {kl_divergence:.4f}")
# Calculate correlation coefficient
correlation = np.corrcoef(pred_flat, target_flat)[0, 1]
stats_text.append(f"Correlation: {correlation:.4f}")
# Log statistics instead of plotting them
print(f"[PDF stats] {plot_name}")
print(f" Predictions: μ={pred_mean:.3f}, σ={pred_std:.3f}")
print(f" Ground Truth: μ={target_mean:.3f}, σ={target_std:.3f}")
if coarse_inputs is not None:
print(f" Coarse: μ={coarse_mean:.3f}, σ={coarse_std:.3f}")
print(f" KL Divergence: {kl_divergence:.4f}")
print(f" Correlation: {correlation:.4f}")
# ax.set_xlabel(f'{var_name}')
ax.set_xlabel(plot_name)
# Only show y-label for leftmost subplot
if i == 0:
# ax.set_ylabel('log₁₀(PDF)')
ax.set_ylabel(r"$\log_{10}(\mathrm{PDF})$")
# Add grid
ax.grid(True, alpha=0.3, linestyle="--")
# Add legend
ax.legend()
# Set y-limits for log plot (handle cases where log values might be very negative)
y_min = min(log_hist_pred.min(), log_hist_target.min())
if coarse_inputs is not None:
y_min = min(y_min, log_hist_coarse.min())
y_max = max(log_hist_pred.max(), log_hist_target.max())
if coarse_inputs is not None:
y_max = max(y_max, log_hist_coarse.max())
# Add small margin to y-limits
y_margin = 0.1 * (y_max - y_min) if y_max > y_min else 0.1
ax.set_ylim(y_min - y_margin, y_max + y_margin)
# Use scientific notation for large ranges
if data_range > 1000:
ax.ticklabel_format(style="sci", axis="x", scilimits=(0, 0))
if save_npz:
key = f"{var_name}__pdf__"
pdf_npz_data[key + "bin_centers"] = bin_centers
pdf_npz_data[key + "log_pred"] = log_hist_pred
pdf_npz_data[key + "log_truth"] = log_hist_target
pdf_npz_data[key + "mean_pred"] = pred_mean
pdf_npz_data[key + "std_pred"] = pred_std
pdf_npz_data[key + "mean_truth"] = target_mean
pdf_npz_data[key + "std_truth"] = target_std
pdf_npz_data[key + "kl"] = kl_divergence
pdf_npz_data[key + "corr"] = correlation
if coarse_inputs is not None:
pdf_npz_data[key + "log_coarse"] = log_hist_coarse
pdf_npz_data[key + "mean_coarse"] = coarse_mean
pdf_npz_data[key + "std_coarse"] = coarse_std
# Ensure save directory exists
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, filename)
if save_npz:
npz_path = os.path.splitext(save_path)[0] + ".npz"
np.savez_compressed(npz_path, **pdf_npz_data)
plt.savefig(save_path, bbox_inches="tight")
plt.close(fig)
return save_path
[docs]
def plot_power_spectra(
predictions, # Model predictions
targets, # Ground truth
dlat, # Grid spacing in latitude (degrees)
dlon, # Grid spacing in longitude (degrees)
coarse_inputs=None, # Coarse inputs for comparison (optional)
variable_names=None, # List of variable names
filename="power_spectra_physical.png",
save_dir="./results",
figsize_multiplier=4,
save_npz=False,
):
"""
Calculate and plot power spectra with proper physical wavenumbers.
Parameters
----------
predictions : torch.Tensor or np.array
Model predictions of shape [batch_size, num_variables, nh, nw]
targets : torch.Tensor or np.array
Ground truth of shape [batch_size, num_variables, nh, nw]
dlat : float
Grid spacing in latitude (degrees)
dlon : float
Grid spacing in longitude (degrees)
coarse_inputs : torch.Tensor or np.array, optional
Coarse inputs of shape [batch_size, num_variables, nh, nw]
variable_names : list of str, optional
Names of the variable names for subplot titles
filename : str, optional
Output filename
save_dir : str, optional
Directory to save the plot
figsize_multiplier : int, optional
Base size multiplier for subplots
save_npz : bool, optional
If True, saves the PDF diagnostics to a compressed .npz file.
Returns
-------
None
"""
# Convert to numpy if they're tensors
if hasattr(predictions, "detach"):
predictions = predictions.detach().cpu().numpy()
if hasattr(targets, "detach"):
targets = targets.detach().cpu().numpy()
if coarse_inputs is not None and hasattr(coarse_inputs, "detach"):
coarse_inputs = coarse_inputs.detach().cpu().numpy()
batch_size, num_vars, nh, nw = predictions.shape
# Default variable names if not provided
if variable_names is None:
variable_names = [f"Variable {i + 1}" for i in range(num_vars)]
# Calculate wavenumbers
# FFT frequencies are in cycles per grid spacing
fft_freq_lat = np.fft.fftfreq(nh, d=dlat) # cycles per degree in lat direction
fft_freq_lon = np.fft.fftfreq(nw, d=dlon) # cycles per degree in lon direction
# Shift frequencies so zero is at center
fft_freq_lat_shifted = np.fft.fftshift(fft_freq_lat)
fft_freq_lon_shifted = np.fft.fftshift(fft_freq_lon)
# Create 2D wavenumber grid
k_lat, k_lon = np.meshgrid(fft_freq_lon_shifted, fft_freq_lat_shifted)
# Calculate magnitude of wavenumber vector (in cycles/degree)
k_mag = np.sqrt(k_lat**2 + k_lon**2)
# Create bins for radial averaging
max_k = np.min([np.max(np.abs(fft_freq_lat)), np.max(np.abs(fft_freq_lon))])
k_bins = np.linspace(0, max_k, min(nh, nw) // 2)
k_centers = 0.5 * (k_bins[1:] + k_bins[:-1])
# Create figure
ncols = num_vars
nrows = 1 # 2 Two rows: one for 2D spectrum, one for 1D spectrum
fig, axes = plt.subplots(
nrows,
ncols,
figsize=(ncols * figsize_multiplier, nrows * figsize_multiplier),
squeeze=False,
) # nrows * figsize_multiplier
plt.subplots_adjust(
hspace=0.2, wspace=0.3, left=0.1, right=0.9, top=0.9, bottom=0.1
)
axes = axes.ravel()
for ax in axes:
ax.set_box_aspect(1)
if save_npz:
spectra_npz_data = {}
spectra_npz_data["__meta__dlat"] = dlat
spectra_npz_data["__meta__dlon"] = dlon
spectra_npz_data["__meta__variables"] = np.array(variable_names)
# Process each variable
for i, var_name in enumerate(variable_names):
if i >= num_vars:
continue
linestyles = mpltex.linestyle_generator(markers=[])
# plot_name = plot_variable_names[i]
# Initialize arrays for averaged PSDs
psd2d_pred_sum = np.zeros((nh, nw))
psd2d_target_sum = np.zeros((nh, nw))
if coarse_inputs is not None:
psd2d_coarse_sum = np.zeros((nh, nw))
# Process each sample in the batch
for b in range(batch_size):
# Predictions
# field_pred = predictions[b, i]
field_pred = PlotConfig.convert_units(var_name, predictions[b, i])
psd2d_pred = calculate_psd2d_simple(field_pred)
psd2d_pred_sum += psd2d_pred
# Targets
# field_target = targets[b, i]
field_target = PlotConfig.convert_units(var_name, targets[b, i])
psd2d_target = calculate_psd2d_simple(field_target)
psd2d_target_sum += psd2d_target
# Coarse inputs
if coarse_inputs is not None:
# field_coarse = coarse_inputs[b, i]
field_coarse = PlotConfig.convert_units(var_name, coarse_inputs[b, i])
psd2d_coarse = calculate_psd2d_simple(field_coarse)
psd2d_coarse_sum += psd2d_coarse
# Average over batch
psd2d_pred_avg = psd2d_pred_sum / batch_size
psd2d_target_avg = psd2d_target_sum / batch_size
if coarse_inputs is not None:
psd2d_coarse_avg = psd2d_coarse_sum / batch_size
# Calculate 1D radial spectra
psd1d_pred = radial_average_psd(psd2d_pred_avg, k_mag, k_bins)
psd1d_target = radial_average_psd(psd2d_target_avg, k_mag, k_bins)
if coarse_inputs is not None:
psd1d_coarse = radial_average_psd(psd2d_coarse_avg, k_mag, k_bins)
if save_npz:
key = f"{var_name}__spectra__"
spectra_npz_data[key + "k"] = k_centers
spectra_npz_data[key + "psd_pred"] = psd1d_pred
spectra_npz_data[key + "psd_truth"] = psd1d_target
if coarse_inputs is not None:
spectra_npz_data[key + "psd_coarse"] = psd1d_coarse
"""
# --- Plot 2D PSD (top row) ---
ax_top = axes[0, i] if num_vars > 1 else axes[0]
# Use k_lon and k_lat for the axes instead of lat/lon
k_lon_min, k_lon_max = fft_freq_lon_shifted[0], fft_freq_lon_shifted[-1]
k_lat_min, k_lat_max = fft_freq_lat_shifted[0], fft_freq_lat_shifted[-1]
im = ax_top.imshow(np.log10(psd2d_pred_avg + 1e-12),
cmap=cmap_white_jet,
aspect='auto',
origin='lower',
extent=[k_lon_min, k_lon_max, k_lat_min, k_lat_max])
#ax_top.set_title(f'{var_name}')
ax_top.set_title(plot_name)
# Only add y-axis label for leftmost column
if i == 0:
ax_top.set_ylabel(r'$\mathrm{k_{lat}}$ (cycles/°)')
else:
ax_top.set_ylabel('')
# Remove y-axis tick labels for non-leftmost columns
ax_top.tick_params(axis='y', labelleft=False)
# Always show x-axis label
ax_top.set_xlabel(r'$\mathrm{k_{lon}}$ (cycles/°)')
# Add grid for better readability
ax_top.grid(True, alpha=0.3, linestyle='--')
# Add colorbar for the last column only
if i == num_vars - 1:
cax = ax_top.inset_axes([1.05, 0, 0.05, 1]) # [x, y, w, h] relative to axes
cbar = plt.colorbar(im, cax=cax, orientation='vertical')
cbar.set_label('log₁₀(PSD)')
"""
# --- Plot 1D Radial Spectrum (bottom row) ---
# ax_bottom = axes[1, i] if num_vars > 1 else axes[1]
ax_bottom = axes[i]
# Plot all spectra
ax_bottom.loglog(k_centers, psd1d_pred, label="Pred", **next(linestyles))
ax_bottom.loglog(k_centers, psd1d_target, label="Truth", **next(linestyles))
if coarse_inputs is not None:
ax_bottom.loglog(
k_centers, psd1d_coarse, label="Coarse", **next(linestyles)
)
# Only add y-axis label for leftmost column
if i == 0:
ax_bottom.set_ylabel("PSD(k)")
else:
ax_bottom.set_ylabel("")
# Always show x-axis label
ax_bottom.set_xlabel("Wavenumber k [cycles/°]")
ax_bottom.legend()
ax_bottom.grid(True, alpha=0.3, which="both")
# Set reasonable axis limits
valid = (k_centers > 0) & (psd1d_target > 0)
if np.any(valid):
ax_bottom.set_xlim(k_centers[valid][0] * 0.8, k_centers[valid][-1] * 1.2)
# Find y-range
y_min = min(psd1d_pred[valid].min(), psd1d_target[valid].min())
y_max = max(psd1d_pred[valid].max(), psd1d_target[valid].max())
if coarse_inputs is not None:
y_min = min(y_min, psd1d_coarse[valid].min())
y_max = max(y_max, psd1d_coarse[valid].max())
ax_bottom.set_ylim(y_min * 0.5, y_max * 2.0)
# Save figure
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, filename)
if save_npz:
npz_path = os.path.splitext(save_path)[0] + ".npz"
np.savez_compressed(npz_path, **spectra_npz_data)
plt.savefig(save_path, bbox_inches="tight")
plt.close(fig)
return save_path
[docs]
def calculate_psd2d_simple(field):
"""
Simple 2D PSD calculation without preprocessing.
"""
fft = np.fft.fft2(field)
psd2d = np.abs(np.fft.fftshift(fft)) ** 2
return psd2d
[docs]
def radial_average_psd(psd2d, k_mag, k_bins):
"""
Radially average 2D PSD using wavenumber magnitude.
"""
# Flatten arrays
k_flat = k_mag.flatten()
psd_flat = psd2d.flatten()
# Use binned_statistic for radial averaging
psd1d, _, _ = stats.binned_statistic(
k_flat, psd_flat, statistic="mean", bins=k_bins
)
# Multiply by area of annulus (2πkΔk) to get proper spectral density
k_centers = 0.5 * (k_bins[1:] + k_bins[:-1])
delta_k = k_bins[1:] - k_bins[:-1]
area = 2 * np.pi * k_centers * delta_k
# Avoid division by zero
valid = area > 0
psd1d[valid] = psd1d[valid] * area[valid]
return psd1d
[docs]
def plot_qq_quantiles(
predictions, # Model predictions
targets, # Ground truth
coarse_inputs, # Coarse inputs
variable_names=None, # List of variable names
units=None, # List of units for each variable
quantiles=[0.90, 0.95, 0.975, 0.99, 0.995],
filename="qq_quantiles.png",
save_dir="./results",
figsize_multiplier=4,
save_npz=False,
):
"""
Create QQ-plats at different quantiles comparing model predictions and
coarse inputs against ground truth.
For each variable, plots quantiles of predictions and coarse inputs
against quantiles of ground truth with a 1:1 reference line.
Parameters
----------
predictions : torch.Tensor or np.array
Model predictions of shape [batch_size, num_variables, h, w]
targets : torch.Tensor or np.array
Ground truth of shape [batch_size, num_variables, h, w]
coarse_inputs : torch.Tensor or np.array
Coarse inputs of shape [batch_size, num_variables, h, w]
variable_names : list of str, optional
Names of the variables for subplot titles.
If None, uses ["VAR_0", "VAR_1", ...]
units : list of str, optional
Units for each variable for axis labels.
If None, uses empty strings.
quantiles : list of float, optional
Quantile values to plot (e.g., [0.90, 0.95, 0.975, 0.99, 0.995])
filename : str, optional
Output filename
save_dir : str, optional
Directory to save the plot
figsize_multiplier : int, optional
Base size multiplier for subplots
save_npz : bool, optional
If True, saves the PDF diagnostics to a compressed .npz file.
Returns
-------
save_path : str
Path to the saved figure
"""
# Convert tensors → numpy
if hasattr(predictions, "detach"):
predictions = predictions.detach().cpu().numpy()
if hasattr(targets, "detach"):
targets = targets.detach().cpu().numpy()
if hasattr(coarse_inputs, "detach"):
coarse_inputs = coarse_inputs.detach().cpu().numpy()
batch_size, num_vars, h, w = predictions.shape
# Default variable names if not provided
if variable_names is None:
variable_names = [f"VAR_{i}" for i in range(num_vars)]
plot_variable_names = [PlotConfig.get_plot_name(var) for var in variable_names]
# Default units if not provided
if units is None:
units = [""] * num_vars
# Figure setup
fig, axes = plt.subplots(
1,
num_vars,
figsize=(num_vars * figsize_multiplier, figsize_multiplier),
constrained_layout=True,
)
if num_vars > 1:
axes = axes.ravel()
# Handle single subplot case
else:
axes = np.array([axes])
for ax in axes:
ax.set_box_aspect(1)
if save_npz:
qq_npz_data = {}
qq_npz_data["__meta__variables"] = np.array(variable_names)
qq_npz_data["__meta__quantiles"] = np.array(quantiles)
for i, var_name in enumerate(variable_names):
linestyles = mpltex.linestyle_generator(lines=[])
ax = axes[i]
plot_name = plot_variable_names[i]
pred_vals = PlotConfig.convert_units(var_name, predictions[:, i])
target_vals = PlotConfig.convert_units(var_name, targets[:, i])
coarse_vals = PlotConfig.convert_units(var_name, coarse_inputs[:, i])
# Compute quantiles
qs_target = np.quantile(target_vals, quantiles)
qs_pred = np.quantile(pred_vals, quantiles)
qs_coarse = np.quantile(coarse_vals, quantiles)
if save_npz:
key = f"{var_name}__qq__"
qq_npz_data[key + "quantiles"] = np.array(quantiles)
qq_npz_data[key + "truth"] = qs_target
qq_npz_data[key + "pred"] = qs_pred
qq_npz_data[key + "coarse"] = qs_coarse
print(f"[QQ Quantiles] {plot_name}")
for q, qt, qp, qc in zip(quantiles, qs_target, qs_pred, qs_coarse):
print(f" q={q:.3f} | Truth={qt:.4f} | Pred={qp:.4f} | Coarse={qc:.4f} ")
# ---- Plot predicted quantiles ----
for q_idx, q in enumerate(quantiles):
ax.plot(
qs_target[q_idx],
qs_pred[q_idx],
label=f"{q * 100:.1f}%",
**next(linestyles),
)
# ---- Plot coarse quantiles ----
ax.plot(
qs_target,
qs_coarse,
c="black",
marker="s",
label="Coarse",
linestyle="None",
)
# ---- 1:1 reference line ----
# Calculate appropriate limits for this variable
min_val = min(qs_target.min(), qs_pred.min(), qs_coarse.min())
max_val = max(qs_target.max(), qs_pred.max(), qs_coarse.max())
margin = 0.0
plot_min = min_val - margin
plot_max = max_val + margin
ax.plot(
[plot_min, plot_max], [plot_min, plot_max], "r--", alpha=0.7, label="1:1"
)
ax.xaxis.set_major_locator(ticker.MaxNLocator(4))
ax.yaxis.set_major_locator(ticker.MaxNLocator(4))
# Labels and formatting
# ax.set_title(var_name)
ax.set_title(plot_name)
# Add unit to labels if provided
unit_str = f" ({units[i]})" if units[i] else ""
# Only add y-axis label for leftmost plot
if i == 0:
ax.set_ylabel(f"Predicted/Coarse quantiles{unit_str}")
ax.set_xlabel(f"True quantiles{unit_str}")
ax.grid(True, linestyle="--", alpha=0.3)
# Add legend only for first subplot
if i == 0:
ax.legend()
# Save figure
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, filename)
if save_npz:
npz_path = os.path.splitext(save_path)[0] + ".npz"
np.savez_compressed(npz_path, **qq_npz_data)
plt.savefig(save_path, bbox_inches="tight")
plt.close(fig)
return save_path
[docs]
def dry_frequency_map(array, threshold):
"""
Compute spatial dry pixels proportion maps. Value of each pixel corresponds to the frequency of dry weather for this pixel.
Parameters
----------
array : torch.Tensor or np.array
Model predictions of shape [batch_size, h, w]
threshold : float
threshold for precipitation (expressed in mm): under it, pixel is considered dry.
Returns
-------
np.ndarray(np.float64) of shape [h,w]
"""
# convert to numpy if tensor :
if hasattr(array, "detach"):
array = array.detach().cpu().numpy()
dry_array = (array < threshold).astype(np.float64)
dry_array_map = np.mean(dry_array, axis=0)
return dry_array_map
[docs]
def plot_dry_frequency_map(
predictions, # Model predictions precipitation (fine predicted)
targets, # Ground truth precipitation (fine true)
threshold, # threshold to define dry and wet (in mm)
lat_1d,
lon_1d,
filename="validation_dry_frequency_map.png",
save_dir=None,
figsize_multiplier=None, # Base size per subplot
):
"""
Plot spatial dry pixels proportion maps. Value of each pixel corresponds to the frequency of dry weather for this pixel.
Parameters
----------
predictions : torch.Tensor or np.array
Model predictions of shape [batch_size, h, w]
targets : torch.Tensor or np.array
Ground truth of shape [batch_size, h, w]
threshold : float
threshold for precipitation (expressed in mm): under it, pixel is considered dry.
lat_1d : array-like
1D array of latitude coordinates with shape [H].
lon_1d : array-like
1D array of longitude coordinates with shape [W].
filename : str, optional
Output filename for saving the plot.
save_dir : str, optional
Directory to save the plot.
figsize_multiplier : int, optional
Base size multiplier for subplots.
Returns
-------
None
"""
if save_dir is None:
save_dir = PlotConfig.DEFAULT_SAVE_DIR
if figsize_multiplier is None:
figsize_multiplier = PlotConfig.DEFAULT_FIGSIZE_MULTIPLIER
# Convert tensors to numpy
if hasattr(predictions, "detach"):
predictions = predictions.detach().cpu().numpy()
if hasattr(targets, "detach"):
targets = targets.detach().cpu().numpy()
if hasattr(lat_1d, "detach"):
lat_1d = lat_1d.detach().cpu().numpy()
if hasattr(lon_1d, "detach"):
lon_1d = lon_1d.detach().cpu().numpy()
lat_min, lat_max = lat_1d.min(), lat_1d.max()
lon_min, lon_max = lon_1d.min(), lon_1d.max()
_, h, w = targets.shape
lat_block = np.linspace(lat_max, lat_min, h)
lon_block = np.linspace(lon_min, lon_max, w)
lat, lon = np.meshgrid(lat_block, lon_block, indexing="ij")
lon_center = float((lon_min + lon_max) / 2)
cmap = PlotConfig.get_colormap(
"dry frequency"
) # need to define the comap in PlotConfig
# convert units :
predictions = PlotConfig.convert_units("precipitation", predictions)
targets = PlotConfig.convert_units("precipitation", targets)
dry_freq_pred_map = dry_frequency_map(predictions, threshold)
# dry_freq_pred = np.mean(dry_freq_pred_map)
dry_freq_targ_map = dry_frequency_map(targets, threshold)
# dry_freq_targ = np.mean(dry_freq_targ_map)
vmin = 0
vmax = 1
base_width_per_panel = 4.5
base_height_per_panel = 3.0
fig_width = base_width_per_panel
fig_height = 3 * base_height_per_panel
fig, axes = plt.subplots(
3,
figsize=(fig_width, fig_height),
subplot_kw={
"projection": ccrs.PlateCarree(central_longitude=lon_center)
}, # ccrs.Mercator(central_longitude=lon_center)
gridspec_kw={"wspace": 0.1},
)
fig.subplots_adjust(
top=0.9, bottom=0.1, left=0.1, right=0.9, wspace=0.1, hspace=0.1
)
im = axes[0].pcolormesh(
lon,
lat,
dry_freq_pred_map,
vmin=vmin,
vmax=vmax,
cmap=cmap,
transform=ccrs.PlateCarree(),
shading="auto",
)
axes[0].set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree())
axes[0].coastlines(linewidth=0.6)
axes[0].add_feature(
cfeature.BORDERS.with_scale("50m"),
linewidth=0.6,
linestyle="--",
edgecolor="black",
zorder=11,
)
axes[0].add_feature(
cfeature.LAKES.with_scale("50m"),
edgecolor="black",
facecolor="none",
linewidth=0.6,
zorder=9,
)
# ax.set_aspect("auto")
axes[0].set_xticks([])
axes[0].set_yticks([])
axes[0].set_title("Predicted")
im = axes[1].pcolormesh(
lon,
lat,
dry_freq_targ_map,
vmin=vmin,
vmax=vmax,
cmap=cmap,
transform=ccrs.PlateCarree(),
shading="auto",
)
axes[1].set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree())
axes[1].coastlines(linewidth=0.6)
axes[1].add_feature(
cfeature.BORDERS.with_scale("50m"),
linewidth=0.6,
linestyle="--",
edgecolor="black",
zorder=11,
)
axes[1].add_feature(
cfeature.LAKES.with_scale("50m"),
edgecolor="black",
facecolor="none",
linewidth=0.6,
zorder=9,
)
# ax.set_aspect("auto")
axes[1].set_xticks([])
axes[1].set_yticks([])
axes[1].set_title("Target")
pos0 = axes[0].get_position()
pos1 = axes[1].get_position()
bottom = pos1.y0
top = pos0.y1
height = top - bottom
cax1 = fig.add_axes([0.92, bottom, 0.03, height])
fig.colorbar(im, cax=cax1, label="frequency")
im = axes[2].pcolormesh(
lon,
lat,
dry_freq_pred_map - dry_freq_targ_map,
norm=mcolors.TwoSlopeNorm(vmin=-vmax, vcenter=0, vmax=vmax),
cmap="seismic",
transform=ccrs.PlateCarree(),
shading="auto",
)
axes[2].set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree())
axes[2].coastlines(linewidth=0.6)
axes[2].add_feature(
cfeature.BORDERS.with_scale("50m"),
linewidth=0.6,
linestyle="--",
edgecolor="black",
zorder=11,
)
axes[2].add_feature(
cfeature.LAKES.with_scale("50m"),
edgecolor="black",
facecolor="none",
linewidth=0.6,
zorder=9,
)
# ax.set_aspect("auto")
axes[2].set_xticks([])
axes[2].set_yticks([])
axes[2].set_title("Predicted frequency - Target frequency")
pos2 = axes[2].get_position()
cax2 = fig.add_axes([0.92, pos2.y0, 0.03, pos2.height])
fig.colorbar(im, cax=cax2, label="frequency")
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, filename)
plt.savefig(save_path, bbox_inches="tight")
plt.close(fig)
return save_path
[docs]
def calculate_pearsoncorr_nparray(arr1, arr2, axis=0):
"""
Calculate Pearson correlation between 2 N-dimensional numpy arrays.
Parameters:
-----------
arr1 : numpy.ndarray
First N-dimensional array
arr2 : numpy.ndarray
Second N-dimensional array (must have same shape as arr1)
axis : int or type of int, default=0
Axis or tuple of axes over which to compute correlation
Returns:
--------
numpy.ndarray
Pearson correlation coefficients. Output has N - len(axis) dimensions
(input shape with the specified axis/axes removed).
"""
if arr1.shape != arr2.shape:
raise ValueError(
f"Arrays must have the same shape. Got {arr1.shape} and {arr2.shape}"
)
# Center the data over axis/axes specified
arr1_centered = arr1 - arr1.mean(axis=axis, keepdims=True)
arr2_centered = arr2 - arr2.mean(axis=axis, keepdims=True)
# Compute correlation over axis/axes specified
numerator = (arr1_centered * arr2_centered).sum(axis=axis)
denominator = np.sqrt(
(arr1_centered**2).sum(axis=axis) * (arr2_centered**2).sum(axis=axis)
)
# Avoid division by zero (set as 0.0 instead of inf or nan)
correlations = np.divide(
numerator, denominator, out=np.zeros_like(numerator), where=denominator != 0
)
return correlations
[docs]
def plot_validation_mvcorr_space(
predictions, # Model predictions (fine predicted)
targets, # Ground truth (fine true)
coarse_inputs=None, # Coarse inputs for comparison (optional)
variable_names=None, # List of variable names
filename="validation_mvcorr_space.png",
save_dir="./results",
figsize_multiplier=4, # Base size per subplot
):
"""
Compute multivariate correlation over the space dimensions and plot as time-series,
comparing model predictions vs ground truth, for all combinations of variables.
Uses Pearson's correlation coefficient.
Parameters
----------
predictions : torch.Tensor or np.array
Model predictions of shape [batch_size, num_variables, h, w]
targets : torch.Tensor or np.array
Ground truth of shape [batch_size, num_variables, h, w]
coarse_inputs : torch.Tensor or np.array, optional
Coarse inputs of shape [batch_size, num_variables, h, w]
variable_names : list of str, optional
Names of the variables for subplot titles
filename : str, optional
Output filename
save_dir : str, optional
Directory to save the plot
figsize_multiplier : int, optional
Base size multiplier for subplots
Returns
-------
save_path : str
Path to the saved figure
"""
# Convert to numpy if they're tensors
if hasattr(predictions, "detach"):
predictions = predictions.detach().cpu().numpy()
if hasattr(targets, "detach"):
targets = targets.detach().cpu().numpy()
if coarse_inputs is not None and hasattr(coarse_inputs, "detach"):
coarse_inputs = coarse_inputs.detach().cpu().numpy()
batch_size, num_vars, h, w = predictions.shape
if num_vars < 2:
print("ERROR: need at least 2 variables but num_vars < 2")
return "0"
# Default variable names if not provided
if variable_names is None:
variable_names = [f"VAR_{i}" for i in range(num_vars)]
# Make list of tuples defining variable combinations
list_var_combos = []
for ii in range(num_vars - 1):
for jj in range(num_vars - 1 - ii):
list_var_combos.append((ii, ii + jj + 1))
# Calculate grid dimensions
ncols = 1
nrows = int(num_vars * (num_vars - 1) / 2) # no. distinct pairs of input variables
fwidth = 6 # longitude range
fheight = nrows * figsize_multiplier
# Set up figure
fig, axes = plt.subplots(
nrows, ncols, figsize=(fwidth, fheight), squeeze=False, sharex=True
)
axes = axes.flatten()
linestyles = mpltex.linestyle_generator(markers=[])
style_truth = next(linestyles)
style_pred = next(linestyles)
style_coarse = next(linestyles) if coarse_inputs is not None else None
var_name_combo_list = []
# Plot correlation timeseries for each combination of variables
# max_count = 0
for i, varComb in enumerate(list_var_combos):
var_name_combo = variable_names[varComb[0]] + " & " + variable_names[varComb[1]]
var_name_combo_list.append(var_name_combo)
# Compute Correlation
pred_corr = calculate_pearsoncorr_nparray(
predictions[:, varComb[0], :, :],
predictions[:, varComb[1], :, :],
axis=(1, 2),
)
target_corr = calculate_pearsoncorr_nparray(
targets[:, varComb[0], :, :], targets[:, varComb[1], :, :], axis=(1, 2)
)
if coarse_inputs is not None:
coarse_corr = calculate_pearsoncorr_nparray(
coarse_inputs[:, varComb[0], :, :],
coarse_inputs[:, varComb[1], :, :],
axis=(1, 2),
)
ax = axes[i]
time_index = range(batch_size)
ax.plot(time_index, target_corr, label="Truth", linewidth=1.0, **style_truth)
ax.plot(time_index, pred_corr, label="Prediction", linewidth=1.0, **style_pred)
if coarse_inputs is not None:
ax.plot(
time_index, coarse_corr, label="Coarse", linewidth=1.0, **style_coarse
)
ax.grid(True, alpha=0.3)
ax.set_ylabel(var_name_combo)
# ax.set_ylim(-1, 1)
ax.set_xlim(0, batch_size - 1)
if i == 0:
ax.legend()
axes[0].set_title("Spatial Pearson Correlation Over Time")
axes[-1].set_xlabel("Time Step")
# Ensure save directory exists
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, filename)
plt.savefig(save_path, bbox_inches="tight")
plt.close(fig)
return save_path
[docs]
def plot_validation_mvcorr(
predictions, # Model predictions (fine predicted)
targets, # Ground truth (fine true)
lat,
lon,
coarse_inputs=None, # Coarse inputs for comparison (optional)
variable_names=None, # List of variable names
filename="validation_mvcorr_time.png",
save_dir="./results",
figsize_multiplier=4, # Base size per subplot
):
"""
Compute multivariate correlation over the time dimension and plot as maps,
comparing model predictions vs ground truth, for all combinations of variables.
Uses Pearson's correlation coefficient.
Parameters
----------
predictions : torch.Tensor or np.array
Model predictions of shape [batch_size, num_variables, h, w]
targets : torch.Tensor or np.array
Ground truth of shape [batch_size, num_variables, h, w]
lat : array-like
2D array of latitude coordinates with shape [h, w].
lon : array-like
2D array of longitude coordinates with shape [h, w].
coarse_inputs : torch.Tensor or np.array, optional
Coarse inputs of shape [batch_size, num_variables, h, w]
variable_names : list of str, optional
Names of the variables for subplot titles
filename : str, optional
Output filename
save_dir : str, optional
Directory to save the plot
figsize_multiplier : int, optional
Base size multiplier for subplots
Returns
-------
save_path : str
Path to the saved figure
"""
if save_dir is None:
save_dir = PlotConfig.DEFAULT_SAVE_DIR
if figsize_multiplier is None:
figsize_multiplier = PlotConfig.DEFAULT_FIGSIZE_MULTIPLIER
# Convert to numpy if they're tensors
if hasattr(predictions, "detach"):
predictions = predictions.detach().cpu().numpy()
if hasattr(targets, "detach"):
targets = targets.detach().cpu().numpy()
if coarse_inputs is not None and hasattr(coarse_inputs, "detach"):
coarse_inputs = coarse_inputs.detach().cpu().numpy()
if hasattr(lat, "detach"):
lat = lat.detach().cpu().numpy()
if hasattr(lon, "detach"):
lon = lon.detach().cpu().numpy()
lat_min, lat_max = lat.min(), lat.max()
lon_min, lon_max = lon.min(), lon.max()
T, n_vars, h, w = predictions.shape
lat_block = np.linspace(lat_max, lat_min, h)
lon_block = np.linspace(lon_min, lon_max, w)
lat, lon = np.meshgrid(lat_block, lon_block, indexing="ij")
lon_center = float((lon_min + lon_max) / 2)
batch_size, num_vars, h, w = predictions.shape
if num_vars < 2:
print("ERROR: need at least 2 variables but num_vars < 2")
return "0"
# Default variable names if not provided
if variable_names is None:
variable_names = [f"VAR_{i}" for i in range(num_vars)]
# Make list of tuples defining variable combinations
list_var_combos = []
for ii in range(num_vars - 1):
for jj in range(num_vars - 1 - ii):
list_var_combos.append((ii, ii + jj + 1))
# Calculate grid dimensions
ncols = 2
if coarse_inputs is not None:
ncols = 3
nrows = int(num_vars * (num_vars - 1) / 2) # no. distinct pairs of input variables
base_width_per_panel = 4.5
base_height_per_panel = 3.0
fig_width = base_width_per_panel * ncols
fig_height = base_height_per_panel * nrows
spa_cor_out = np.zeros([nrows, ncols - 1])
spa_rmse_out = np.zeros([nrows, ncols - 1])
# Set up figure
fig, axes = plt.subplots(
nrows,
ncols,
figsize=(fig_width, fig_height),
subplot_kw={"projection": ccrs.PlateCarree(central_longitude=lon_center)},
squeeze=False,
gridspec_kw={"wspace": 0.1},
)
# Define geographic features
coastline = cfeature.COASTLINE.with_scale("50m")
borders = cfeature.BORDERS.with_scale("50m")
# lakes = cfeature.LAKES.with_scale('50m')
var_name_combo_list = []
# Plot each combination of variables
# max_count = 0
for i, varComb in enumerate(list_var_combos):
var_name_combo = variable_names[varComb[0]] + " & " + variable_names[varComb[1]]
var_name_combo_list.append(var_name_combo)
# Compute Correlation
pred_corr = calculate_pearsoncorr_nparray(
predictions[:, varComb[0], :, :], predictions[:, varComb[1], :, :], axis=0
)
target_corr = calculate_pearsoncorr_nparray(
targets[:, varComb[0], :, :], targets[:, varComb[1], :, :], axis=0
)
if coarse_inputs is not None:
coarse_corr = calculate_pearsoncorr_nparray(
coarse_inputs[:, varComb[0], :, :],
coarse_inputs[:, varComb[1], :, :],
axis=0,
)
spa_cor_out[i, 0] = np.corrcoef(
pred_corr.reshape(pred_corr.size), target_corr.reshape(target_corr.size)
)[0, 1]
if coarse_inputs is not None:
spa_cor_out[i, 1] = np.corrcoef(
coarse_corr.reshape(coarse_corr.size),
target_corr.reshape(target_corr.size),
)[0, 1]
spa_rmse_out[i, 0] = np.sqrt((np.square(pred_corr - target_corr)).mean())
if coarse_inputs is not None:
spa_rmse_out[i, 1] = np.sqrt((np.square(coarse_corr - target_corr)).mean())
# Col 0: Truth
ax_target = axes[i, 0]
ax_target.pcolormesh(
lon,
lat,
target_corr,
vmin=-1.0,
vmax=1.0,
cmap="RdBu",
transform=ccrs.PlateCarree(),
shading="auto",
)
ax_target.add_feature(coastline, linewidth=PlotConfig.COASTLINE_w)
ax_target.add_feature(
borders,
linewidth=PlotConfig.BORDER_w,
edgecolor="black",
linestyle=PlotConfig.BORDER_STYLE,
)
# ax_target.set_aspect("auto")
ax_target.set_extent(
[lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree()
)
# Col 1: Prediction
ax_pred = axes[i, 1]
im_pred = ax_pred.pcolormesh(
lon,
lat,
pred_corr,
vmin=-1.0,
vmax=1.0,
cmap="RdBu",
transform=ccrs.PlateCarree(),
shading="auto",
)
ax_pred.add_feature(coastline, linewidth=PlotConfig.COASTLINE_w)
ax_pred.add_feature(
borders,
linewidth=PlotConfig.BORDER_w,
edgecolor="black",
linestyle=PlotConfig.BORDER_STYLE,
)
# ax_pred.set_aspect("auto")
ax_pred.set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree())
if coarse_inputs is not None:
# Col 2: Coarse input
ax_coar = axes[i, 2]
ax_coar.pcolormesh(
lon,
lat,
coarse_corr,
vmin=-1.0,
vmax=1.0,
cmap="RdBu",
transform=ccrs.PlateCarree(),
shading="auto",
)
ax_coar.add_feature(coastline, linewidth=PlotConfig.COASTLINE_w)
ax_coar.add_feature(
borders,
linewidth=PlotConfig.BORDER_w,
edgecolor="black",
linestyle=PlotConfig.BORDER_STYLE,
)
# ax_coar.set_aspect("auto")
ax_coar.set_extent(
[lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree()
)
# Add col labels
col_labels = ["Truth", "Prediction"]
if coarse_inputs is not None:
col_labels = ["Truth", "Prediction", "Coarse"]
for col_idx, label in enumerate(col_labels):
axes[0, col_idx].set_title(label)
# Add row labels
for row_idx, label in enumerate(var_name_combo_list):
axes[row_idx, 0].text(
-0.1,
0.5,
label,
transform=axes[row_idx, 0].transAxes,
va="center",
ha="right",
rotation="vertical",
fontsize=12,
)
# Add colorbar
fig.subplots_adjust(top=0.9, bottom=0.1, left=0.1, right=0.9, wspace=0.1)
pos_top = axes[0, 0].get_position()
pos_bottom = axes[-1, 0].get_position()
bottom = pos_bottom.y0
top = pos_top.y1
height = top - bottom
cbar_ax = fig.add_axes([0.92, bottom, 0.015, height])
fig.colorbar(im_pred, cax=cbar_ax, label=r"Temporal Pearson Correlation")
# Ensure save directory exists
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, filename)
plt.savefig(save_path, bbox_inches="tight")
plt.close(fig)
"""
# _________________________________________
# Output summary map statistics as heatmaps
# Spatial Correlation and Spatial RMSE wrt target
# Setupt axis labels
xLabels=['Prediction']
if coarse_inputs is not None: xLabels=['Prediction','Coarse']
yLabels=var_name_combo_list
fig, (ax1,ax2) = plt.subplots(ncols=2, figsize=((ncols+2)*2,4))#, layout='constrained')
sns.heatmap(spa_cor_out, ax=ax1, cbar=False, linewidth=0.5, annot=True, fmt='.3f', xticklabels=xLabels, yticklabels=yLabels, vmin=0.0, vmax=1.0, cmap=plt.get_cmap('Reds'))
fig.colorbar(ax1.collections[0], ax=ax1, location="left", use_gridspec=False, pad=0.1, label="correlation")
ax1.tick_params(axis='y', pad=90, length=0)
ax1.tick_params(axis='x', length=0)
ax1.yaxis.set_label_position("left")
sns.heatmap(spa_rmse_out, ax=ax2, cbar=False, linewidth=0.5, annot=True, fmt='.3f', xticklabels=xLabels, yticklabels=[""]*ncols, vmin=0.0, vmax=0.3, cmap=plt.get_cmap('Reds_r'))
fig.colorbar(ax2.collections[0], ax=ax2, location="right", use_gridspec=False, pad=0.1, label="RMSE")
ax2.tick_params(rotation=0, length=0)
ax2.yaxis.set_label_position("right")
# Ensure save directory exists
os.makedirs(save_dir, exist_ok=True)
filenameCR='SpCorrRmse_'+filename
save_path = os.path.join(save_dir, filenameCR)
plt.savefig(save_path, bbox_inches='tight')
plt.close()
"""
return save_path
[docs]
def plot_temporal_series_comparison(
predictions, # Model predictions (fine predicted)
targets, # Ground truth (fine true)
coarse_inputs=None, # Coarse inputs for comparison (optional)
variable_names=None, # List of variable names
filename="validation_temp_series.png",
save_dir="./results",
figsize_multiplier=4,
):
"""
Plot spatially averaged temporal series for each variable.
Parameters
----------
predictions : torch.Tensor or np.array
Model predictions of shape [batch_size, num_variables, h, w]
targets : torch.Tensor or np.array
Ground truth of shape [batch_size, num_variables, h, w]
coarse_inputs : torch.Tensor or np.array, optional
Coarse inputs of shape [batch_size, num_variables, h, w]
variable_names : list of str, optional
Names of the variables for subplot titles
filename : str, optional
Output filename
save_dir : str, optional
Directory to save the plot
figsize_multiplier : int, optional
Base size multiplier for subplots
Returns
-------
save_path : str
Path to the saved figure
"""
if hasattr(predictions, "detach"):
predictions = predictions.detach().cpu().numpy()
if hasattr(targets, "detach"):
targets = targets.detach().cpu().numpy()
if coarse_inputs is not None and hasattr(coarse_inputs, "detach"):
coarse_inputs = coarse_inputs.detach().cpu().numpy()
if predictions.shape != targets.shape:
raise ValueError(f"Shape mismatch: {predictions.shape} vs {targets.shape}")
if coarse_inputs is not None and coarse_inputs.shape != targets.shape:
raise ValueError(
f"Coarse shape mismatch: {coarse_inputs.shape} vs {targets.shape}"
)
batch_size, num_vars, h, w = predictions.shape
# Default variable names if not provided
if variable_names is None:
variable_names = [f"VAR_{i}" for i in range(num_vars)]
if len(variable_names) != num_vars:
raise ValueError(
f"{len(variable_names)} variable names but num_vars={num_vars}"
)
fig = plt.figure(figsize=(6, figsize_multiplier * num_vars))
linestyles = mpltex.linestyle_generator(markers=[])
style_truth = next(linestyles)
style_pred = next(linestyles)
style_coarse = next(linestyles) if coarse_inputs is not None else None
# Loop over variables
for i, var in enumerate(variable_names):
ax = fig.add_subplot(num_vars, 1, i + 1)
pred_vals = PlotConfig.convert_units(var, predictions[:, i])
true_vals = PlotConfig.convert_units(var, targets[:, i])
# Spatial mean over H and W dimensions
s_pred = pred_vals.mean(axis=(1, 2))
s_true = true_vals.mean(axis=(1, 2))
# Temporal axis
time_index = range(batch_size)
ax.plot(time_index, s_true, label="Truth", linewidth=1.0, **style_truth)
ax.plot(time_index, s_pred, label="Prediction", linewidth=1.0, **style_pred)
if coarse_inputs is not None:
coarse_vals = PlotConfig.convert_units(var, coarse_inputs[:, i])
s_coarse = coarse_vals.mean(axis=(1, 2))
ax.plot(time_index, s_coarse, label="Coarse", linewidth=1.0, **style_coarse)
ax.set_title(var)
ax.grid(True, alpha=0.3)
ax.set_ylabel("Spatial mean")
if i == num_vars - 1:
ax.set_xlabel("Time index")
else:
ax.tick_params(labelbottom=False)
ax.legend()
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, filename)
plt.savefig(save_path, bbox_inches="tight")
plt.close(fig)
return save_path
[docs]
def ranks(
predictions, # Model predictions precipitation (fine predicted)
targets, # Ground truth precipitation (fine true)
):
"""
Compute ranks of predictions compared to targets.
Parameters
----------
predictions : torch.Tensor or np.array
Model predictions of shape [ensemble_size, batch_size, h, w]
targets : torch.Tensor or np.array
Targets of shape [batch_size, h, w]
Returns
-------
np.ndarray(np.float64) of shape [batch_size*h*w,]
"""
# convert to numpy if tensor :
if hasattr(predictions, "detach"):
predictions = predictions.detach().cpu().numpy()
if hasattr(targets, "detach"):
targets = targets.detach().cpu().numpy()
nb_ens, T, L, W = predictions.shape
predictions_ens = predictions.reshape(nb_ens, T * L * W)
targets = targets.reshape(1, T * L * W)
diff = predictions_ens - targets
mask_leq = (diff <= 0).astype(np.float32)
mask_l = (diff < 0).astype(np.float32)
mask = (mask_leq + mask_l) / 2
return np.sum(mask, axis=0)
[docs]
def plot_ranks(
predictions, # model predictions
targets, # ground truth
variable_names=None, # list of variable names
filename="ranks.png",
save_dir="./results",
figsize_multiplier=4,
):
"""
Create rank histograms of predictions compared to targets for each variable.
Parameters
----------
predictions : torch.Tensor or np.array
Model predictions of shape [ensemble_size, batch_size, num_variables, h, w]
targets : torch.Tensor or np.array
Ground truth of shape [batch_size, num_variables, h, w]
variable_names : list of str, optional
Names of the variables for subplot titles.
If None, uses ["VAR_0", "VAR_1", ...]
filename : str, optional
Output filename
save_dir : str, optional
Directory to save the plot
figsize_multiplier : int, optional
Base size multiplier for subplots
Returns
-------
save_path : str
Path to the saved figure
"""
# Convert tensors → numpy
if hasattr(predictions, "detach"):
predictions = predictions.detach().cpu().numpy()
if hasattr(targets, "detach"):
targets = targets.detach().cpu().numpy()
ensemble_size, batch_size, num_vars, h, w = predictions.shape
# Default variable names if not provided
if variable_names is None:
variable_names = [f"VAR_{i}" for i in range(num_vars)]
plot_variable_names = [PlotConfig.get_plot_name(var) for var in variable_names]
# Figure setup
fig, axes = plt.subplots(
1,
num_vars,
figsize=(num_vars * figsize_multiplier, figsize_multiplier),
constrained_layout=True,
)
if num_vars > 1:
axes = axes.ravel()
# Handle single subplot case
else:
axes = np.array([axes])
for ax in axes:
ax.set_box_aspect(1)
for i, var_name in enumerate(variable_names):
ax = axes[i]
plot_name = plot_variable_names[i]
ranks_predicted = ranks(
predictions=predictions[:, :, i, :, :],
targets=targets[:, i, :, :],
)
ax.hist(ranks_predicted, bins=np.arange(ensemble_size + 2), density=True)
ax.plot(
[0, ensemble_size + 1],
[1 / (ensemble_size + 1), 1 / (ensemble_size + 1)],
linestyle="--",
color="red",
)
ax.set_title(plot_name)
ax.set_xlabel("ranks")
ax.set_ylabel("frequency")
# Save figure
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, filename)
plt.savefig(save_path, bbox_inches="tight")
plt.close(fig)
return save_path
[docs]
def get_divergence(u_tensor, v_tensor, spacing):
"""
Compute the horizontal divergence of a windfield.
Parameters
----------
u_tensor : torch.Tensor or np.array, shape [...,h,w]
tensor that stores the zonal component of the windfield. Can have arbitrary number of dimensions, but the last two dimensions have to correspond to longitude and latitude.
u_tensor and v_tensor need to have the same shape.
v_tensor : torch.Tensor or np.array
tensor that stores the meridional component of the windfield. Can have arbitrary number of dimensions, but the last two dimensions have to correspond to longitude and latitude.
u_tensor and v_tensor need to have the same shape.
spacing : float
float that describes the resolution of the windfield. Used to compute the gradients.
Returns
-------
np.ndarray(np.float64) of same shape as u_tensor and v_tensor
"""
# convert to torch if needed
if isinstance(u_tensor, np.ndarray):
u_tensor = torch.from_numpy(u_tensor)
if isinstance(v_tensor, np.ndarray):
v_tensor = torch.from_numpy(v_tensor)
u_x = torch.gradient(u_tensor, spacing=spacing, dim=-2)[0]
v_y = torch.gradient(v_tensor, spacing=spacing, dim=-1)[0]
return (u_x + v_y).detach().cpu().numpy()
[docs]
def get_curl(u_tensor, v_tensor, spacing):
"""
Compute the curl of a windfield.
Parameters
----------
u_tensor : torch.Tensor or np.array, shape [...,h,w]
tensor that stores the zonal component of the windfield. Can have arbitrary number of dimensions, but the last two dimensions have to correspond to longitude and latitude.
u_tensor and v_tensor need to have the same shape.
v_tensor : torch.Tensor or np.array
tensor that stores the meridional component of the windfield. Can have arbitrary number of dimensions, but the last two dimensions have to correspond to longitude and latitude.
u_tensor and v_tensor need to have the same shape.
spacing : float
spatial resolution of the windfield. Used to compute the gradients.
Returns
-------
np.ndarray(np.float64) of same shape as u_tensor and v_tensor
"""
# convert to torch if needed
if isinstance(u_tensor, np.ndarray):
u_tensor = torch.from_numpy(u_tensor)
if isinstance(v_tensor, np.ndarray):
v_tensor = torch.from_numpy(v_tensor)
u_y = torch.gradient(u_tensor, spacing=spacing, dim=-1)[0]
v_x = torch.gradient(v_tensor, spacing=spacing, dim=-2)[0]
return (v_x - u_y).detach().cpu().numpy()
[docs]
def plot_mean_divergence_map(
u_prediction, # Model predictions precipitation (fine predicted)
v_prediction, # Model predictions precipitation (fine predicted)
u_target, # Ground truth precipitation (fine true)
v_target, # Ground truth precipitation (fine true)
spacing,
lat_1d,
lon_1d,
filename="mean_divergence.png",
save_dir=None,
figsize_multiplier=None, # Base size per subplot
):
"""
Plot spatial dry pixels proportion maps. Value of each pixel corresponds to the frequency of dry weather for this pixel.
Parameters
----------
u_prediction : torch.Tensor or np.array
Model predictions of shape [batch_size, h, w] for zonal component of wind
Last two dims have to correspond to longitude and latitude
u_prediction and v_prediction need to have the same shape
v_prediction : torch.Tensor or np.array
Model predictions of shape [batch_size, h, w] for meridional component of wind
Last two dims have to correspond to longitude and latitude
u_prediction and v_prediction need to have the same shape
u_target : torch.Tensor or np.array
Ground truth of shape [batch_size, h, w]
Last two dims have to correspond to longitude and latitude
u_target and v_target need to have the same shape
v_target : torch.Tensor or np.array
Ground truth of shape [batch_size, h, w]
Last two dims have to correspond to longitude and latitude
u_target and v_target need to have the same shape
spacing : float
spatial resolution of the windfield. Used to compute the gradients.
lat_1d : array-like
1D array of latitude coordinates with shape [H].
lon_1d : array-like
1D array of longitude coordinates with shape [W].
filename : str, optional
Output filename for saving the plot.
save_dir : str, optional
Directory to save the plot.
figsize_multiplier : int, optional
Base size multiplier for subplots.
Returns
-------
None
"""
if save_dir is None:
save_dir = PlotConfig.DEFAULT_SAVE_DIR
if figsize_multiplier is None:
figsize_multiplier = PlotConfig.DEFAULT_FIGSIZE_MULTIPLIER
lat_min, lat_max = lat_1d.min(), lat_1d.max()
lon_min, lon_max = lon_1d.min(), lon_1d.max()
_, h, w = u_target.shape
lat_block = np.linspace(lat_max, lat_min, h)
lon_block = np.linspace(lon_min, lon_max, w)
lat, lon = np.meshgrid(lat_block, lon_block, indexing="ij")
lon_center = float((lon_min + lon_max) / 2)
cmap = PlotConfig.get_colormap(
"divergence"
) # need to define the comap in PlotConfig
# convert units :
u_prediction = PlotConfig.convert_units("wind", u_prediction)
v_prediction = PlotConfig.convert_units("wind", v_prediction)
u_target = PlotConfig.convert_units("wind", u_target)
v_target = PlotConfig.convert_units("wind", v_target)
div_prediction = get_divergence(u_prediction, v_prediction, spacing)
div_target = get_divergence(u_target, v_target, spacing)
mean_div_prediction = np.mean(div_prediction, axis=0)
mean_div_target = np.mean(div_target, axis=0)
vmin = min(np.min(mean_div_prediction), np.min(mean_div_target))
vmax = max(np.max(mean_div_prediction), np.max(mean_div_target))
vmax = max(np.abs(vmax), np.abs(vmin))
norm = mcolors.TwoSlopeNorm(vmin=-vmax, vcenter=0, vmax=vmax)
base_width_per_panel = 4.5
base_height_per_panel = 3.0
fig_width = base_width_per_panel
fig_height = 3 * base_height_per_panel
fig, axes = plt.subplots(
3,
1,
figsize=(fig_width, fig_height),
subplot_kw={
"projection": ccrs.PlateCarree(central_longitude=lon_center)
}, # ccrs.Mercator(central_longitude=lon_center)
gridspec_kw={"wspace": 0.1},
)
fig.subplots_adjust(
top=0.9, bottom=0.1, left=0.1, right=0.9, wspace=0.1, hspace=0.1
)
im = axes[0].pcolormesh(
lon,
lat,
mean_div_target,
norm=norm,
cmap=cmap,
transform=ccrs.PlateCarree(),
shading="auto",
)
axes[0].set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree())
axes[0].coastlines(linewidth=0.6)
axes[0].add_feature(
cfeature.BORDERS.with_scale("50m"),
linewidth=0.6,
linestyle="--",
edgecolor="black",
zorder=11,
)
axes[0].add_feature(
cfeature.LAKES.with_scale("50m"),
edgecolor="black",
facecolor="none",
linewidth=0.6,
zorder=9,
)
# ax.set_aspect("auto")
axes[0].set_xticks([])
axes[0].set_yticks([])
axes[0].set_title("Target")
im = axes[1].pcolormesh(
lon,
lat,
mean_div_prediction,
norm=norm,
cmap=cmap,
transform=ccrs.PlateCarree(),
shading="auto",
)
axes[1].set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree())
axes[1].coastlines(linewidth=0.6)
axes[1].add_feature(
cfeature.BORDERS.with_scale("50m"),
linewidth=0.6,
linestyle="--",
edgecolor="black",
zorder=11,
)
axes[1].add_feature(
cfeature.LAKES.with_scale("50m"),
edgecolor="black",
facecolor="none",
linewidth=0.6,
zorder=9,
)
# ax.set_aspect("auto")
axes[1].set_xticks([])
axes[1].set_yticks([])
axes[1].set_title("Prediction")
# axes are vertically stacked
# Get positions of the top two axes
pos0 = axes[0].get_position()
pos1 = axes[1].get_position()
bottom = pos1.y0
top = pos0.y1
height = top - bottom
# Create the colorbar axis
cax1 = fig.add_axes([0.92, bottom, 0.02, height]) # [left, bottom, width, height]
# Add the colorbar
fig.colorbar(im, cax=cax1, orientation="vertical", label="divergence")
im = axes[2].pcolormesh(
lon,
lat,
mean_div_prediction - mean_div_target,
norm=norm,
cmap=cmap,
transform=ccrs.PlateCarree(),
shading="auto",
)
axes[2].set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree())
axes[2].coastlines(linewidth=0.6)
axes[2].add_feature(
cfeature.BORDERS.with_scale("50m"),
linewidth=0.6,
linestyle="--",
edgecolor="black",
zorder=11,
)
axes[2].add_feature(
cfeature.LAKES.with_scale("50m"),
edgecolor="black",
facecolor="none",
linewidth=0.6,
zorder=9,
)
# ax.set_aspect("auto")
axes[2].set_xticks([])
axes[2].set_yticks([])
axes[2].set_title("Predicted - Target")
# Get position of the bottom axis
pos2 = axes[2].get_position()
# Create colorbar axis
cax2 = fig.add_axes([0.92, pos2.y0, 0.02, pos2.height])
# Add the colorbar
fig.colorbar(im, cax=cax2, orientation="vertical", label="divergence error")
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, filename)
plt.savefig(save_path, bbox_inches="tight")
plt.close(fig)
return save_path
[docs]
def plot_mean_curl_map(
u_prediction, # Model predictions precipitation (fine predicted)
v_prediction, # Model predictions precipitation (fine predicted)
u_target, # Ground truth precipitation (fine true)
v_target, # Ground truth precipitation (fine true)
spacing,
lat_1d,
lon_1d,
filename="mean_curl.png",
save_dir=None,
figsize_multiplier=None, # Base size per subplot
):
"""
Plot spatial dry pixels proportion maps. Value of each pixel corresponds to the frequency of dry weather for this pixel.
Parameters
----------
u_prediction : torch.Tensor or np.array
Model predictions of shape [batch_size, h, w] for zonal component of wind
Last two dims have to correspond to longitude and latitude
u_prediction and v_prediction need to have the same shape
v_prediction : torch.Tensor or np.array
Model predictions of shape [batch_size, h, w] for meridional component of wind
Last two dims have to correspond to longitude and latitude
u_prediction and v_prediction need to have the same shape
u_target : torch.Tensor or np.array
Ground truth of shape [batch_size, h, w]
Last two dims have to correspond to longitude and latitude
u_target and v_target need to have the same shape
v_target : torch.Tensor or np.array
Ground truth of shape [batch_size, h, w]
Last two dims have to correspond to longitude and latitude
u_target and v_target need to have the same shape
spacing : float
spatial resolution of the windfield. Used to compute the gradients.
lat_1d : array-like
1D array of latitude coordinates with shape [H].
lon_1d : array-like
1D array of longitude coordinates with shape [W].
filename : str, optional
Output filename for saving the plot.
save_dir : str, optional
Directory to save the plot.
figsize_multiplier : int, optional
Base size multiplier for subplots.
Returns
-------
None
"""
if save_dir is None:
save_dir = PlotConfig.DEFAULT_SAVE_DIR
if figsize_multiplier is None:
figsize_multiplier = PlotConfig.DEFAULT_FIGSIZE_MULTIPLIER
lat_min, lat_max = lat_1d.min(), lat_1d.max()
lon_min, lon_max = lon_1d.min(), lon_1d.max()
_, h, w = u_target.shape
lat_block = np.linspace(lat_max, lat_min, h)
lon_block = np.linspace(lon_min, lon_max, w)
lat, lon = np.meshgrid(lat_block, lon_block, indexing="ij")
lon_center = float((lon_min + lon_max) / 2)
cmap = PlotConfig.get_colormap("curl") # need to define the comap in PlotConfig
# convert units :
u_prediction = PlotConfig.convert_units("wind", u_prediction)
v_prediction = PlotConfig.convert_units("wind", v_prediction)
u_target = PlotConfig.convert_units("wind", u_target)
v_target = PlotConfig.convert_units("wind", v_target)
curl_prediction = get_curl(u_prediction, v_prediction, spacing)
curl_target = get_curl(u_target, v_target, spacing)
mean_curl_prediction = np.mean(curl_prediction, axis=0)
mean_curl_target = np.mean(curl_target, axis=0)
vmin = min(np.min(mean_curl_prediction), np.min(mean_curl_target))
vmax = max(np.max(mean_curl_prediction), np.max(mean_curl_target))
vmax = max(np.abs(vmax), np.abs(vmin))
norm = mcolors.TwoSlopeNorm(vmin=-vmax, vcenter=0, vmax=vmax)
base_width_per_panel = 4.5
base_height_per_panel = 3.0
fig_width = base_width_per_panel
fig_height = 3 * base_height_per_panel
fig, axes = plt.subplots(
3,
1,
figsize=(fig_width, fig_height),
subplot_kw={
"projection": ccrs.PlateCarree(central_longitude=lon_center)
}, # ccrs.Mercator(central_longitude=lon_center)
gridspec_kw={"wspace": 0.1},
)
fig.subplots_adjust(
top=0.9, bottom=0.1, left=0.1, right=0.9, wspace=0.1, hspace=0.1
)
im = axes[0].pcolormesh(
lon,
lat,
mean_curl_target,
norm=norm,
cmap=cmap,
transform=ccrs.PlateCarree(),
shading="auto",
)
axes[0].set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree())
axes[0].coastlines(linewidth=0.6)
axes[0].add_feature(
cfeature.BORDERS.with_scale("50m"),
linewidth=0.6,
linestyle="--",
edgecolor="black",
zorder=11,
)
axes[0].add_feature(
cfeature.LAKES.with_scale("50m"),
edgecolor="black",
facecolor="none",
linewidth=0.6,
zorder=9,
)
# ax.set_aspect("auto")
axes[0].set_xticks([])
axes[0].set_yticks([])
axes[0].set_title("Target")
im = axes[1].pcolormesh(
lon,
lat,
mean_curl_prediction,
norm=norm,
cmap=cmap,
transform=ccrs.PlateCarree(),
shading="auto",
)
axes[1].set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree())
axes[1].coastlines(linewidth=0.6)
axes[1].add_feature(
cfeature.BORDERS.with_scale("50m"),
linewidth=0.6,
linestyle="--",
edgecolor="black",
zorder=11,
)
axes[1].add_feature(
cfeature.LAKES.with_scale("50m"),
edgecolor="black",
facecolor="none",
linewidth=0.6,
zorder=9,
)
# ax.set_aspect("auto")
axes[1].set_xticks([])
axes[1].set_yticks([])
axes[1].set_title("Prediction")
# Get positions of the top two axes
pos0 = axes[0].get_position()
pos1 = axes[1].get_position()
bottom = pos1.y0
top = pos0.y1
height = top - bottom
# Create the colorbar axis
cax1 = fig.add_axes([0.92, bottom, 0.02, height]) # [left, bottom, width, height]
# Add the colorbar
fig.colorbar(im, cax=cax1, orientation="vertical", label="curl")
im = axes[2].pcolormesh(
lon,
lat,
mean_curl_prediction - mean_curl_target,
norm=norm,
cmap=cmap,
transform=ccrs.PlateCarree(),
shading="auto",
)
axes[2].set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree())
axes[2].coastlines(linewidth=0.6)
axes[2].add_feature(
cfeature.BORDERS.with_scale("50m"),
linewidth=0.6,
linestyle="--",
edgecolor="black",
zorder=11,
)
axes[2].add_feature(
cfeature.LAKES.with_scale("50m"),
edgecolor="black",
facecolor="none",
linewidth=0.6,
zorder=9,
)
# ax.set_aspect("auto")
axes[2].set_xticks([])
axes[2].set_yticks([])
axes[2].set_title("Predicted - Target")
# Get position of the bottom axis
pos2 = axes[2].get_position()
# Create colorbar axis
cax2 = fig.add_axes([0.92, pos2.y0, 0.02, pos2.height])
# Add the colorbar
fig.colorbar(im, cax=cax2, orientation="vertical", label="curl error")
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, filename)
plt.savefig(save_path, bbox_inches="tight")
plt.close(fig)
return save_path