# Copyright 2026 IPSL / CNRS / Sorbonne University
# Authors: Kazem Ardaneh, Kishanthan Kingston, Pierre Chapel
#
# This work is licensed under the Creative Commons
# Attribution-NonCommercial-ShareAlike 4.0 International License.
# To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc-sa/4.0/
import torch
import numpy as np
import xarray as xr
import pandas as pd
from datetime import datetime
from IPSL_AID.utils import EasyDict
from torch.utils.data import Dataset
import torchvision
import glob
import os
import json
[docs]
def stats(ds, logger, input_dir, norm_mapping=dict()):
"""
Load normalization statistics and compute coordinate metadata for a NetCDF dataset.
This function loads normalization statistics from a JSON file if available.
If no statistics file is found, it falls back to predefined constants for
fine and coarse variables only.
Parameters
----------
ds : xarray.Dataset
NetCDF dataset to process.
logger : logging.Logger
Logger instance for logging messages and statistics.
input_dir : str
Directory containing a statistics.json file with precomputed
normalization statistics.
norm_mapping : dict, optional
Dictionary to store computed statistics. If empty, will be populated.
Default is empty dict.
Returns
-------
norm_mapping : dict
Dictionary mapping variable names to their computed statistics.
For coordinates: min, max, mean, std.
For data variables: min, max, mean, std, q1, q3, iqr, median.
steps : EasyDict
Dictionary containing coordinate step sizes and lengths.
Notes
-----
- If statistics.json is found, all statistics are loaded as-is.
- If not found, fallback constants are used for fine and coarse variables only.
"""
# stats_file_path = os.path.join(output_dir, "statistics.json")
logger.info("Starting statistics computation...")
steps = EasyDict()
# ds = ds.drop_vars(["number", "expver"], errors="ignore")
stats_loaded = False
# Load stats from JSON if available
if input_dir is not None:
stats_path = os.path.join(input_dir, "statistics.json")
if os.path.isfile(stats_path):
logger.info(f"Loading normalization statistics from {stats_path}")
with open(stats_path, "r") as f:
raw_stats = json.load(f)
# Load all statistics from JSON into norm_mapping
for key, values in raw_stats.items():
norm_mapping[key] = EasyDict(values)
stats_loaded = True
# Use manual stats
if not stats_loaded:
logger.info("No statistics.json found, using manual constants")
RAW_CONSTANTS = {
"VAR_2T_fine": {"vmean": 286.3158874511719, "vstd": 12.632543563842773},
"VAR_10U_fine": {"vmean": 0.36626559495925903, "vstd": 3.4527664184570312},
"VAR_10V_fine": {"vmean": -0.05726669356226921, "vstd": 3.699695110321045},
"VAR_TP_fine": {
"vmean": 9.736002539284527e-05,
"vstd": 0.0004453345318324864,
},
"VAR_D2M_fine": {"vmean": 280.1187744140625, "vstd": 12.384744644165039},
"VAR_SSTK_fine": {"vmean": 294.448974609375, "vstd": 7.419081211090088},
"VAR_SKT_fine": {"vmean": 286.8110656738281, "vstd": 13.723491668701172},
"VAR_ST_fine": {"vmean": 286.93658447265625, "vstd": 13.66016674041748},
"VAR_TCWV_fine": {"vmean": 19.793996810913086, "vstd": 13.685070037841797},
"VAR_2T_coarse": {"vmean": 286.2744445800781, "vstd": 12.475735664367676},
"VAR_10U_coarse": {"vmean": 0.3791687488555908, "vstd": 3.261948585510254},
"VAR_10V_coarse": {
"vmean": -0.05645933374762535,
"vstd": 3.4837522506713867,
},
"VAR_TP_coarse": {
"vmean": 9.832954674493521e-05,
"vstd": 0.0003440176951698959,
},
"VAR_D2M_coarse": {"vmean": 280.09478759765625, "vstd": 12.215740203857422},
"VAR_SSTK_coarse": {"vmean": 292.5169982910156, "vstd": 8.795150756835938},
"VAR_SKT_coarse": {"vmean": 286.7804260253906, "vstd": 13.49194049835205},
"VAR_ST_coarse": {"vmean": 286.9049377441406, "vstd": 13.402721405029297},
"VAR_TCWV_coarse": {"vmean": 19.75491714477539, "vstd": 13.40636920928955},
}
for key, values in RAW_CONSTANTS.items():
norm_mapping[key] = EasyDict()
norm_mapping[key].vmean = values["vmean"]
norm_mapping[key].vstd = values["vstd"]
logger.info("------------------------------------------")
for cname in ds.coords:
cdata = ds[cname].values
steps[cname] = len(cdata)
steps[f"d_{cname}"] = abs(cdata[1] - cdata[0])
# Skip multi-dimensional coordinates (rare but possible)
if cdata.ndim != 1:
logger.warning(f"[COORD] Skipping '{cname}' because it is not 1-D")
continue
# Skip datetime unless desired
if np.issubdtype(cdata.dtype, np.datetime64):
logger.info(f"[COORD] Skipping time coordinate '{cname}' (datetime64)")
continue
vmin = float(np.min(cdata))
vmax = float(np.max(cdata))
vmean = float(np.mean(cdata))
vstd = float(np.std(cdata))
# q1 = float(np.percentile(cdata, 25))
# q3 = float(np.percentile(cdata, 75))
# iqr = q3 - q1 if q3 != q1 else 1.0
# median = float(np.median(cdata))
norm_mapping[cname] = EasyDict()
norm_mapping[cname].vmin = vmin
norm_mapping[cname].vmax = vmax
norm_mapping[cname].vmean = vmean
norm_mapping[cname].vstd = vstd
logger.info("------ Coordinate / Dimension Sizes ------")
for key, value in steps.items():
logger.info(f" └── {key}: {value}")
logger.info("------------------------------------------")
return norm_mapping, steps
[docs]
def coarse_down_up(fine_filtered, fine_batch, input_shape=(16, 32), axis=0):
"""
Downscale and then upscale fine-resolution data to compute coarse approximation.
This function performs a downscaling-upscaling operation to create a coarse
resolution approximation of fine data. This is commonly used in
multi-scale analysis, image processing, and super-resolution tasks.
Parameters
----------
fine_filtered : torch.Tensor or np.ndarray
Fine-resolution filtered data. Can be of shape (C, Hf, Wf) for multi-channel
data or (Hf, Wf) for single-channel data. Where C is number of channels,
Hf is fine height, and Wf is fine width.
fine_batch : torch.Tensor or np.ndarray
Fine-resolution target data. Must have same spatial dimensions as
`fine_filtered`. Shape: (C, Hf, Wf) or (Hf, Wf).
input_shape : tuple of int, optional
Target shape (Hc, Wc) for the coarse-resolution data after downscaling.
Default is (16, 32).
axis : int, optional
Axis along which to insert batch dimension if the input lacks one.
Default is 0.
Returns
-------
coarse_up : torch.Tensor
Upscaled coarse approximation of the fine data. Same shape as input
`fine_filtered` without batch dimension.
Notes
-----
- The function ensures that the input tensors have a batch dimension by
adding one if missing.
- Uses bilinear interpolation for both downscaling and upscaling operations.
- The antialias parameter is set to True for better quality resampling.
- Useful for creating multi-scale representations in image processing and
computer vision tasks.
"""
# Ensure batch dimension
if isinstance(fine_filtered, np.ndarray):
fine_filtered = torch.from_numpy(fine_filtered)
if isinstance(fine_batch, np.ndarray):
fine_batch = torch.from_numpy(fine_batch)
if fine_filtered.dim() == 3:
fine_filtered = fine_filtered.unsqueeze(axis) # (1, C, H, W)
if fine_batch.dim() == 3:
fine_batch = fine_batch.unsqueeze(axis)
# Downscale to coarse resolution
coarsen_transform = torchvision.transforms.Resize(
input_shape,
interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
antialias=True,
)
out_shape = (fine_filtered.shape[-2], fine_filtered.shape[-1])
interp_transform = torchvision.transforms.Resize(
out_shape,
interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
antialias=True,
)
coarse_up = interp_transform(coarsen_transform(fine_filtered))
# Remove batch dimension
coarse_up = coarse_up.squeeze(0)
return coarse_up
[docs]
def gaussian_filter(
image, dW, dH, cutoff_W_phys, cutoff_H_phys, epsilon=0.01, margin=8
):
"""
Apply a Gaussian low-pass filter with controlled attenuation at the cutoff frequency.
This function performs a 2D Fourier transform of the input field, applies a Gaussian
weighting in the frequency domain, and inversely transforms it back to the spatial
domain. Unlike the standard Gaussian filter, this version defines the Gaussian width
such that the response amplitude reaches a specified attenuation factor (epsilon)
at the cutoff frequency. Padding with reflection is used to minimize edge artifacts.
Parameters
----------
image : np.array of shape (H, W)
Input 2D field to be filtered (temperature, wind component).
dW : float
Grid spacing in degrees of longitude.
dH : float
Grid spacing in degrees of latitude.
cutoff_W_phys : float
Longitudinal cutoff frequency in cycles per degree. Frequencies higher than
this threshold are attenuated according to the Gaussian response.
cutoff_H_phys : float
Latitudinal cutoff frequency in cycles per degree. Frequencies higher than
this threshold are attenuated according to the Gaussian response.
epsilon : float, optional
Desired amplitude response at the cutoff frequency (default: 0.01).
Lower values produce sharper attenuation and stronger filtering.
margin : int, optional
Number of pixels to pad on each side using reflection (default: 8).
This reduces edge effects in the Fourier transform.
Returns
-------
filtered : ndarray of shape (H, W)
Real-valued filtered field after inverse Fourier transform and margin cropping.
Notes
-----
- The Gaussian width parameters are computed such that:
exp(-0.5 * (f_cutoff / sigma)^2) = epsilon,
leading to sigma = f_cutoff / sqrt(-2 * log(epsilon)).
- Padding the input with reflective boundaries minimizes spectral leakage
and discontinuities at image edges.
- The output field is cropped back to its original size after filtering.
- This formulation provides more explicit control over filter sharpness
than the standard Gaussian low-pass implementation.
"""
# H = number of latitude grid points, W = number of longitude grid points
H, W = image.shape
# Add reflective padding around the input image to reduce FFT edge artifacts
img_pad = np.pad(
image, margin, mode="reflect"
) # margin defines how many pixels to pad on each side
# 2D Fast Fourier Transform (FFT) of the padded image
fft = np.fft.fft2(img_pad)
fH = np.fft.fftfreq(img_pad.shape[1], d=dH) # cycles per degree latitude
fW = np.fft.fftfreq(img_pad.shape[0], d=dW) # cycles per degree longitude
# 2D meshgrid of longitude (fW) and latitude (fH) frequencies
FW, FH = np.meshgrid(fW, fH, indexing="ij") # shape (H, W)
assert FW.shape == fft.shape and FH.shape == fft.shape, (
f"Frequency grid shape mismatch: "
f"FW.shape={FW.shape}, FH.shape={FH.shape}, but fft.shape={fft.shape}. "
"Ensure meshgrid construction matches the 2D FFT dimensions."
)
# Sigma (-0.5 * (f_cutoff / sigma)^2 ) = epsilon
sigma_W = cutoff_W_phys / np.sqrt(-2 * np.log(epsilon))
sigma_H = cutoff_H_phys / np.sqrt(-2 * np.log(epsilon))
H_filter = np.exp(-0.5 * ((FW / sigma_W) ** 2 + (FH / sigma_H) ** 2))
fft_filtered = fft * H_filter
img_filt_pad = np.fft.ifft2(fft_filtered).real
# The padded margins are discarded to ensure the output has the same dimensions as the input image
filtered = img_filt_pad[margin:-margin, margin:-margin]
return filtered
[docs]
class DataPreprocessor(Dataset):
"""
Dataset class for preprocessing weather and climate data for machine learning.
This class handles loading, preprocessing, and sampling of multi-year NetCDF
weather data with support for multi-scale processing, normalization, and
spatial-temporal sampling strategies.
Parameters
----------
years : list of int
Years of data to include.
loaded_dfs : xarray.Dataset
Pre-loaded dataset containing the weather variables.
constants_file_path : str
Path to NetCDF file containing constant variables (e.g., topography).
varnames_list : list of str
List of variable names to extract from the dataset.
units_list : list of str
Units for each variable in varnames_list.
in_shape : tuple of int, optional
Target shape (height, width) for coarse resolution. Default is (16, 32).
batch_size_lat : int, optional
Height of spatial batch in grid points. Default is 144.
batch_size_lon : int, optional
Width of spatial batch in grid points. Default is 144.
steps : EasyDict, optional
Dictionary containing grid dimension information. Should include:
- latitude/lat: number of latitude points
- longitude/lon: number of longitude points
- time: number of time steps
- d_latitude/d_lat: latitude spacing
- d_longitude/d_lon: longitude spacing
tbatch : int, optional
Number of time batches to sample. Default is 1.
sbatch : int, optional
Number of spatial batches to sample. Default is 8.
debug : bool, optional
Enable debug logging. Default is True.
mode : str, optional
Operation mode: "train" or "validation". Default is "train".
run_type : str, optional
Run type: "train", "validation", or "inference". Default is "train".
dynamic_covariates : list of str, optional
List of dynamic covariate variable names. Default is None.
dynamic_covariates_dir : str, optional
Directory containing dynamic covariate files. Default is None.
time_normalization : str, optional
Method for time normalization: "linear" or "cos_sin". Default is "linear".
norm_mapping : dict, optional
Dictionary containing normalization statistics for variables.
index_mapping : dict, optional
Dictionary mapping variable names to indices in the data array.
normalization_type : dict, optional
Dictionary specifying normalization type per variable.
constant_variables : list of str, optional
List of constant variable names to load. Default is None.
epsilon : float, optional
Small value for numerical stability in filtering. Default is 0.02.
margin : int, optional
Margin for filtering operations. Default is 8.
dtype : tuple, optional
Data types for torch and numpy (torch_dtype, np_dtype).
Default is (torch.float32, np.float32).
apply_filter : bool, optional
Whether to apply Gaussian filtering for multi-scale processing.
Default is False.
logger : logging.Logger, optional
Logger instance for logging messages. Default is None.
Attributes
----------
const_vars : np.ndarray or None
Array of constant variables with shape (n_constants, H, W).
time : xarray.DataArray
Time coordinate from dataset.
year : xarray.DataArray
Year component of time.
month : xarray.DataArray
Month component of time.
day : xarray.DataArray
Day component of time.
hour : xarray.DataArray
Hour component of time.
year_norm : torch.Tensor
Normalized year values.
doy_norm : torch.Tensor or None
Normalized day-of-year values (linear mode).
hour_norm : torch.Tensor or None
Normalized hour values (linear mode).
doy_sin, doy_cos : torch.Tensor or None
Sine and cosine of day-of-year (cos_sin mode).
hour_sin, hour_cos : torch.Tensor or None
Sine and cosine of hour (cos_sin mode).
time_batchs : np.ndarray
Array of time indices for current epoch.
eval_slices : list of tuple or None
List of spatial slices for evaluation mode.
random_centers : list of tuple or None
List of random spatial centers for training mode.
center_tracker : list
Tracks spatial centers for debugging.
tindex_tracker : list
Tracks temporal indices for debugging.
Methods
-------
new_epoch()
Reset time batches and random centers for new training epoch.
sample_time_steps_by_doy()
Sample time steps based on day-of-year (DOY) for multi-year continuity.
sample_random_time_indices()
Randomly sample time indices for training.
load_dynamic_covariates()
Load dynamic covariate data (not fully implemented).
generate_random_batch_centers(n_batches)
Generate random spatial centers for batch sampling.
generate_evaluation_slices()
Generate deterministic spatial slices for evaluation.
extract_batch(data, ilat, ilon)
Extract spatial batch centered at (ilat, ilon) with cyclic longitude.
filter_batch(fine_patch, fine_block)
Apply Gaussian low-pass filtering for multi-scale processing.
normalize(data, stats, norm_type, var_name=None, data_type=None)
Normalize data using specified statistics and method.
normalize_time(tindex)
Return normalized time features for given time index.
__len__()
Return total number of samples.
__getitem__(index)
Get a single sample with appropriate spatial-temporal sampling.
Notes
-----
- Supports both random (training) and deterministic (validation) sampling.
- Handles cyclic longitude wrapping for global datasets.
- Provides multi-scale processing through downscaling/upscaling.
- Includes time normalization with linear or trigonometric encoding.
- Can incorporate constant variables (e.g., topography, land-sea mask).
"""
[docs]
def __init__(
self,
years,
loaded_dfs,
constants_file_path,
varnames_list,
units_list,
in_shape=(80, 128),
batch_size_lat=144,
batch_size_lon=144,
steps=dict(),
tbatch=1,
sbatch=8,
debug=True,
mode="train",
run_type="train",
dynamic_covariates=None,
dynamic_covariates_dir=None,
time_normalization="linear",
norm_mapping=None,
index_mapping=None,
normalization_type=None,
constant_variables=None,
epsilon=0.02,
margin=8,
dtype=(torch.float32, np.float32),
apply_filter=False,
region_center=None, # (lat_value, lon_value)
region_size=None,
logger=None,
):
"""
Initialize the DataPreprocessor.
Parameters
----------
years : list of int
Years of data to include.
loaded_dfs : xarray.Dataset
Pre-loaded dataset containing the weather variables.
constants_file_path : str
Path to NetCDF file containing constant variables.
varnames_list : list of str
List of variable names to extract.
units_list : list of str
Units for each variable.
in_shape : tuple of int, optional
Target shape for coarse resolution.
batch_size_lat : int, optional
Height of spatial batch.
batch_size_lon : int, optional
Width of spatial batch.
steps : EasyDict, optional
Grid dimension information.
tbatch : int, optional
Number of time batches.
sbatch : int, optional
Number of spatial batches.
debug : bool, optional
Enable debug logging.
mode : str, optional
Operation mode.
run_type : str, optional
Run type.
dynamic_covariates : list of str, optional
Dynamic covariate variable names.
dynamic_covariates_dir : str, optional
Directory for dynamic covariates.
time_normalization : str, optional
Time normalization method.
norm_mapping : dict, optional
Normalization statistics.
index_mapping : dict, optional
Variable to index mapping.
normalization_type : dict, optional
Normalization type per variable.
constant_variables : list of str, optional
Constant variable names.
epsilon : float, optional
Numerical stability value.
margin : int, optional
Filter margin.
dtype : tuple, optional
Data types for torch and numpy.
apply_filter : bool, optional
Apply Gaussian filtering.
region_center : tuple of float or None
Fixed geographic center (lat, lon) for spatial sampling.
logger : logging.Logger, optional
Logger instance.
"""
self.constants_file_path = constants_file_path
self.constant_variables = constant_variables
self.years = years
self.varnames_list = varnames_list
self.units_list = units_list
self.in_shape = in_shape
self.batch_size_lat = batch_size_lat
self.batch_size_lon = batch_size_lon
if hasattr(steps, "latitude"):
steps.latitude = steps.latitude
steps.d_latitude = steps.d_latitude
elif hasattr(steps, "lat"):
steps.latitude = steps.lat
steps.d_latitude = steps.d_lat
else:
assert False, (
f"Missing required latitude coordinate ('latitude' or 'lat'). "
f"Available keys: {list(vars(steps).keys())}"
)
if hasattr(steps, "longitude"):
steps.longitude = steps.longitude
steps.d_longitude = steps.d_longitude
elif hasattr(steps, "lon"):
steps.longitude = steps.lon
steps.d_longitude = steps.d_lon
else:
assert False, (
f"Missing required longitude coordinate ('longitude' or 'lon'). "
f"Available keys: {list(vars(steps).keys())}"
)
assert hasattr(steps, "time"), (
f"Missing required 'time' coordinate. "
f"Available keys: {list(vars(steps).keys())}"
)
self.H = steps.latitude
self.dH = steps.d_latitude
self.W = steps.longitude
self.dW = steps.d_longitude
self.region_center = region_center
self.region_size = region_size
self.tbatch = tbatch
self.sbatch = sbatch
self.stime = 0
self.debug = debug
self.mode = mode
self.run_type = run_type
self.dynamic_covariates = dynamic_covariates
self.dynamic_covariates_dir = dynamic_covariates_dir
self.time_normalization = time_normalization
self.norm_mapping = norm_mapping
self.index_mapping = index_mapping
self.normalization_type = normalization_type
self.epsilon = epsilon
self.margin = margin
self.torch_dtype = dtype[0]
self.np_dtype = dtype[1]
self.apply_filter = apply_filter
self.logger = logger
if self.apply_filter:
self.logger.info(f"Fine filtering enabled: {self.apply_filter}")
# Validate batch sizes
if self.batch_size_lat > self.H:
raise ValueError(
f"batch height {self.batch_size_lat} exceeds latitude dimension {self.H}"
)
if self.batch_size_lon > self.W:
raise ValueError(
f"batch width {self.batch_size_lon} exceeds longitude dimension {self.W}"
)
assert self.logger is not None, "Make sure the logger is set"
self.logger.info(f"Spatial dimensions: {self.H} x {self.W}")
self.logger.info(f"batch size: {self.batch_size_lat} x {self.batch_size_lon}")
if self.constant_variables is not None and self.constants_file_path is not None:
self.logger.info(f"Opening constant variables file: {constants_file_path}")
# Open file
ds_const = xr.open_dataset(self.constants_file_path).load()
# Get the first time step (since there's only one) and drop time dimension
ds_const = ds_const.isel(time=0)
# Initialize constant variables tensor
self.const_vars = np.zeros(
(len(self.constant_variables), self.H, self.W) # [channels, lat, lon]
)
for i, const_varname in enumerate(self.constant_variables):
const_var = ds_const[const_varname]
# Normalize if needed
if const_varname != "lsm":
self.logger.info(f"Normalizing {const_varname}")
# Use xarray's weighted method (similar to your original code)
weighted_var = const_var.weighted(
np.cos(np.radians(ds_const.latitude))
)
mean_var = weighted_var.mean().values
std_var = weighted_var.std().values
self.logger.info(
f"{const_varname} - Weighted Mean: {mean_var:.4f}, Weighted Std: {std_var:.4f}"
)
# Normalize
self.const_vars[i] = ((const_var - mean_var) / std_var).values
else:
self.const_vars[i] = const_var.values
self.logger.info(f"Loaded constant variables: {self.constant_variables}")
self.logger.info(f"Constant variables shape: {self.const_vars.shape}")
# Close the dataset
ds_const.close()
else:
self.const_vars = None
self.logger.info("No constant variables provided")
"""
# Load dynamic covariates if specified
self.dynamic_covariate_data = None
if self.dynamic_covariates:
self.load_dynamic_covariates()
"""
# Cache for loaded data
self.loaded_dfs = loaded_dfs
self.etime = len(self.loaded_dfs["time"])
# ----------------------------------------------------------
# 1. Extract time components from xarray
# ----------------------------------------------------------
self.time = self.loaded_dfs.time
self.year = self.time.dt.year
self.month = self.time.dt.month
self.day = self.time.dt.day
self.hour = self.time.dt.hour
# ----------------------------------------------------------
# 2. Normalized year
# ----------------------------------------------------------
year_np = ((self.year.to_numpy() - 1940) / 100).astype(self.np_dtype)
self.year_norm = torch.from_numpy(year_np).to(self.torch_dtype)
# ----------------------------------------------------------
# 3. DOY & hour normalization
# ----------------------------------------------------------
if self.time_normalization == "linear":
# Approximate DOY: (month-1)*30 + (day-1)
doy_np = (
((self.month - 1.0) * 30 + (self.day - 1.0))
.to_numpy()
.astype(self.np_dtype)
)
hour_np = (self.hour.to_numpy() / 24.0).astype(self.np_dtype)
self.doy_norm = torch.from_numpy((doy_np / 360).astype(self.np_dtype)).to(
self.torch_dtype
)
self.hour_norm = torch.from_numpy(hour_np).to(self.torch_dtype)
elif self.time_normalization == "cos_sin":
date = pd.to_datetime(dict(year=self.year, month=self.month, day=self.day))
doy = (date - datetime(2000, 1, 1)).dt.days
self.doy_np = doy.to_numpy()
self.unique_doys = np.unique(self.doy_np)
self.doy_to_indices = {
d: np.where(self.doy_np == d)[0] for d in self.unique_doys
}
doy_np = doy.to_numpy().astype(self.np_dtype)
hour_np = self.hour.to_numpy().astype(self.np_dtype)
self.doy_sin = torch.sin(
2 * np.pi * torch.from_numpy(doy_np).to(self.torch_dtype) / 365.25
)
self.doy_cos = torch.cos(
2 * np.pi * torch.from_numpy(doy_np).to(self.torch_dtype) / 365.25
)
self.hour_sin = torch.sin(
2 * np.pi * torch.from_numpy(hour_np).to(self.torch_dtype) / 24.0
)
self.hour_cos = torch.cos(
2 * np.pi * torch.from_numpy(hour_np).to(self.torch_dtype) / 24.0
)
else:
raise ValueError("time_normalization must be 'linear' or 'cos_sin'")
if self.mode == "validation":
self.eval_slices = self.generate_evaluation_slices()
# To Do: a key to add if all sbatch to taken or not
self.sbatch = len(
self.eval_slices
) # Set sbatch to exactly match number of slices
self.logger.info(f"Evaluation mode: Generated {self.sbatch} spatial slices")
# To Do: set time batches for validation (if inference is active or not)
if self.run_type == "inference":
self.time_batchs = np.arange(self.stime, self.etime, dtype=int)
elif (
self.run_type in ["inference_regional", "train_regional"]
): # it can happen that we have mode = validation and run_type = train_regional.
# Regional inference mode:
# Instead of tiling the full spatial domain (as in global validation),
# we extract a single spatial window centered on a user-defined
# geographical location (latitude, longitude).
# Ensure that a target region center is provided
assert (
self.region_center is not None
), "region_center must be provided for inference_regional mode"
assert (
self.region_size is not None
), "region_size must be provided when using inference_regional mode"
lat_val, lon_val = self.region_center
lat_idx, lon_idx = self.get_center_indices_from_latlon(lat_val, lon_val)
region_size_lat, region_size_lon = self.region_size
assert region_size_lat % self.batch_size_lat == 0
assert region_size_lon % self.batch_size_lon == 0
self.eval_slices = self.generate_region_slices(
lat_idx, lon_idx, region_size_lat, region_size_lon
)
# Set sbatch to exactly match number of slices
self.sbatch = len(self.eval_slices)
if self.run_type == "train_regional":
self.time_batchs = np.linspace(
self.etime // 3, self.etime * 2 // 3, self.tbatch, dtype=int
)
else:
self.time_batchs = np.arange(self.stime, self.etime, dtype=int)
self.logger.info(
f"Inference region mode activated at lat={lat_val}, lon={lon_val}"
)
else:
self.time_batchs = np.linspace(
self.etime // 3, self.etime * 2 // 3, self.tbatch, dtype=int
)
self.logger.info(f"Validation time batches: {self.time_batchs}")
else:
# Training mode
# To Do: set time batches for training (full or partial sampling)
if self.run_type == "train_regional":
# Regional train mode:
# Instead of randomly selecting windows (as in train),
# we extract a single spatial window centered on a user-defined
# geographical location (latitude, longitude).
# Ensure that a target region center is provided
assert (
self.region_center is not None
), "region_center must be provided for train_regional mode"
assert (
self.region_size is not None
), "region_size must be provided when using train_regional mode"
lat_val, lon_val = self.region_center
lat_idx, lon_idx = self.get_center_indices_from_latlon(lat_val, lon_val)
region_size_lat, region_size_lon = self.region_size
assert region_size_lat % self.batch_size_lat == 0
assert region_size_lon % self.batch_size_lon == 0
self.train_slices = self.generate_region_slices(
lat_idx, lon_idx, region_size_lat, region_size_lon
)
# Set sbatch to exactly match number of slices
self.sbatch = len(self.train_slices)
self.time_batchs = np.arange(self.stime, self.etime, dtype=int)
self.logger.info(
f"Train region mode activated at lat={lat_val}, lon={lon_val}"
)
else: # train global
self.time_batchs = np.arange(self.stime, self.etime, dtype=int)
self.logger.info(
f"Training (global) mode: sbatch={self.sbatch}, tbatch={self.tbatch}"
)
self.new_epoch() # Initialize time batches for training
self.center_tracker = [] # Will store spatial indices
self.tindex_tracker = [] # Will store temporal indices
[docs]
def new_epoch(self):
"""
Prepare for a new training epoch by generating new time batches.
This method is called at the start of each training epoch to
refresh the temporal and spatial sampling.
"""
# The partial sampling by DOY is currently disabled
# self.sample_time_steps_by_doy()
# Also regenerate random centers for the new epoch
self.random_centers = [None] * self.sbatch
self.last_tbatch_index = -1
[docs]
def sample_time_steps_by_doy(self):
"""
Sample time steps based on day-of-year (DOY) for multi-year continuity.
This method selects unique DOYs from the available multi-year data
and picks one random time index for each DOY.
Raises
------
ValueError
If requested tbatch exceeds number of unique DOYs.
"""
n = self.tbatch
if n > len(self.unique_doys):
raise ValueError(
f"Requested tbatch={n}, but only {len(self.unique_doys)} unique days available."
)
# Select n unique DOYs from MULTI-YEAR continuous DOY sequence
selected_doys = np.random.choice(self.unique_doys, size=n, replace=False)
# For each DOY pick one random index where that DOY occurs
self.time_batchs = np.array(
[np.random.choice(self.doy_to_indices[d]) for d in selected_doys], dtype=int
)
if self.debug:
self.logger.info(
f"[DOY-batch] Selected DOYs: {selected_doys.tolist()} → indices: {self.time_batchs.tolist()}"
)
[docs]
def sample_random_time_indices(self):
"""
Generate random time indices for training.
This method samples random time indices uniformly across the
available time range.
"""
# Only called from new_epoch which already checks mode == "train"
self.time_batchs = np.random.randint(
self.stime, self.etime - 1, size=self.tbatch
)
if self.debug:
self.logger.info(f"Generated new training time batches: {self.time_batchs}")
[docs]
def load_dynamic_covariates(self):
"""Load dynamic covariates data."""
self.dynamic_covariate_data = {}
for covariate in self.dynamic_covariates:
covariate_files = []
for year in self.years:
file_pattern = (
f"{self.dynamic_covariates_dir}/samples_{year}_{covariate}.nc"
)
matching_files = glob.glob(file_pattern)
if matching_files:
covariate_files.extend(matching_files)
if covariate_files:
# Load and concatenate covariate data
datasets = [xr.open_dataset(f) for f in covariate_files]
combined_ds = xr.concat(datasets, dim="time")
self.dynamic_covariate_data[covariate] = combined_ds
# Close individual datasets
for ds in datasets:
ds.close()
[docs]
def get_center_indices_from_latlon(self, lat_value, lon_value):
"""
Convert geographic coordinates (latitude, longitude) to nearest grid indices.
Parameters
----------
lat_value : float
Latitude in degrees.
lon_value : float
Longitude in degrees.
Returns
-------
lat_idx : int
Index of the closest latitude grid point.
lon_idx : int
Index of the closest longitude grid point.
Notes
-----
- The dataset is defined on a discrete latitude–longitude grid.
- Since spatial extraction operates on grid indices, the requested
physical coordinates are mapped to the nearest available grid point.
- This ensures consistency between user-defined locations and
internal batch extraction logic.
"""
# Retrieve latitude and longitude arrays from the dataset
lat_array = self.loaded_dfs.latitude.values
lon_array = self.loaded_dfs.longitude.values
# Find the index of the grid point closest to the requested lat/lon
lat_idx = np.abs(lat_array - lat_value).argmin()
lon_idx = np.abs(lon_array - lon_value).argmin()
return lat_idx, lon_idx
[docs]
def generate_random_batch_centers(self, n_batches):
"""
Generate random (latitude, longitude) centers for batch sampling.
Parameters
----------
n_batches : int
Number of random centers to generate.
Returns
-------
centers : list of tuple
List of (lat_center, lon_center) tuples.
Notes
-----
- Latitude centers avoid poles to ensure full batch extraction.
- Longitude centers can be any value due to cyclic wrapping.
"""
centers = []
half_lat = self.batch_size_lat // 2
try:
for _ in range(n_batches):
# Latitude: avoid poles (non-cyclic)
lat_center = np.random.randint(half_lat, self.H - half_lat)
# Longitude: any (cyclic)
lon_center = np.random.randint(0, self.W)
centers.append((lat_center, lon_center))
if self.debug:
self.logger.info(
f" [RandomBlockSampler]Generated {n_batches} random centers: {centers}"
)
return centers
except Exception as e:
self.logger.exception(
f"[RandomBlockSampler] Error while generating random centers: {e}"
)
raise
[docs]
def generate_evaluation_slices(self):
"""
Generate deterministic spatial slices for evaluation mode.
Returns
-------
slices : list of tuple
List of (lat_start, lat_end, lon_start, lon_end) tuples defining
non-overlapping spatial blocks covering the entire domain.
"""
n_blocks_lat = self.H // self.batch_size_lat
n_blocks_lon = self.W // self.batch_size_lon
# Create grid of block indices
lat_idx, lon_idx = np.mgrid[0:n_blocks_lat, 0:n_blocks_lon]
# Calculate slice boundaries
lat_starts = (lat_idx * self.batch_size_lat).ravel()
lon_starts = (lon_idx * self.batch_size_lon).ravel()
lat_ends = lat_starts + self.batch_size_lat
lon_ends = lon_starts + self.batch_size_lon
# Create slices list
slices = list(zip(lat_starts, lat_ends, lon_starts, lon_ends))
self.logger.info(
f"Generated {len(slices)} evaluation blocks "
f"({n_blocks_lat} x {n_blocks_lon} grid)"
)
return slices
[docs]
def generate_region_slices(
self, lat_center, lon_center, region_size_lat, region_size_lon
):
"""
Generate deterministic spatial slices for regional inference/train.
The slices cover a block centered on (lat_center, lon_center).
The region is divided into non-overlapping blocks of size
(batch_size_lat, batch_size_lon) used for model inference.
Parameters
----------
lat_center : int
Latitude index of the region center in the global grid.
lon_center : int
Longitude index of the region center in the global grid.
region_size_lat : int
Height of the region (in grid points).
region_size_lon : int
Width of the region (in grid points).
Returns
-------
slices : list of tuple
List of (lat_start, lat_end, lon_start, lon_end) tuples defining
non-overlapping spatial blocks covering the selected region.
Notes
-----
The latitude start index is clamped to ensure the region remains within
the global latitude bounds. As a result, if the requested region is too
close to the poles, the extracted region may be shifted and may no longer
be centered exactly on (lat_center, lon_center). Longitude wrapping
is handled later during patch extraction (in extract_batch).
"""
n_blocks_lat = region_size_lat // self.batch_size_lat
n_blocks_lon = region_size_lon // self.batch_size_lon
# Compute the top-left corner (lat0, lon0) of the region
# The region is centered around (lat_center, lon_center)
lat0 = lat_center - region_size_lat // 2
lon0 = lon_center - region_size_lon // 2
# Latitude is NOT cyclic, so we must ensure indices stay within [0, H]
lat0 = max(0, lat0)
lat0 = min(lat0, self.H - region_size_lat)
slices = []
for i in range(n_blocks_lat):
for j in range(n_blocks_lon):
# Compute latitude boundaries of the current block
lat_start = lat0 + i * self.batch_size_lat
lat_end = lat_start + self.batch_size_lat
# Compute longitude boundaries of the current block
# Longitude is allowed to go out of bound
lon_start = lon0 + j * self.batch_size_lon
lon_end = lon_start + self.batch_size_lon
slices.append((lat_start, lat_end, lon_start, lon_end))
return slices
[docs]
def build_fine_coarse_blocks(self, npfeatures_full, lat_center, lon_center):
"""
Build fine, optionally filtered, and coarse-resolution spatial blocks
centered at a given location.
Parameters
----------
npfeatures_full : np.ndarray
Full-domain input features with shape (C, H, W).
lat_center : int
Latitude center index.
lon_center : int
Longitude center index.
Returns
-------
fine_block : np.ndarray
Extracted fine-resolution block of shape (C, batch_size_lat, batch_size_lon).
fine_filtered_block : np.ndarray or None
Filtered fine-resolution block of shape (C, batch_size_lat, batch_size_lon).
Returns None if apply_filter is False.
coarse_block : np.ndarray
Coarse-resolution approximation of the fine block (after downscaling
and upscaling).
fine_indices : tuple
Tuple of (lat_start, lat_end, lon_start, lon_end) indices defining the
spatial region of the extracted blocks.
Raises
------
AssertionError
If spatial indices mismatch between fine, filtered, and coarse blocks,
or if their shapes are inconsistent.
Notes
-----
- Spatial extraction, ensuring consistent handling of cyclic longitude
and non-cyclic latitude.
- The coarse block is generated from the full domain (not locally),
ensuring global consistency of the low-resolution representation.
- All returned blocks share identical spatial indices and shapes,
making them directly comparable for residual learning.
"""
fine_block, fine_indices = self.extract_batch(
npfeatures_full, lat_center, lon_center
)
if self.apply_filter:
fine_filtered_full = self.filter_batch(npfeatures_full, fine_block)
fine_filtered_block, filtered_indices = self.extract_batch(
fine_filtered_full, lat_center, lon_center
)
assert filtered_indices == fine_indices, (
f"Indices mismatch after filtering:\n"
f" original indices: {fine_indices}\n"
f" filtered indices: {filtered_indices}"
)
# Apply coarsening to the full domain of the filtered HR field
coarse_full = coarse_down_up(
fine_filtered_full, npfeatures_full, input_shape=self.in_shape
)
else:
fine_filtered_block = None
# Apply coarsening directly to the full domain of the raw HR field
coarse_full = coarse_down_up(
npfeatures_full, npfeatures_full, input_shape=self.in_shape
)
coarse_block, coarse_indices = self.extract_batch(
coarse_full, lat_center, lon_center
)
assert coarse_indices == fine_indices, (
f"Indices mismatch after coarsening:\n"
f" original indices: {fine_indices}\n"
f" coarse indices: {coarse_indices}"
)
assert coarse_block.shape == fine_block.shape, (
f"Shape mismatch between fine and coarse blocks:\n"
f" fine shape: {fine_block.shape}\n"
f" coarse shape: {coarse_block.shape}"
)
return fine_block, fine_filtered_block, coarse_block, fine_indices
[docs]
def filter_batch(self, fine_patch, fine_block):
"""
Apply Gaussian low-pass filtering for multi-scale processing.
Parameters
----------
fine_patch : np.ndarray
Fine-resolution data of shape (C, H, W).
fine_block : np.ndarray
Reference block used to determine scaling factors.
Returns
-------
fine_filtered : np.ndarray
Filtered data of shape (C, H, W).
Notes
-----
- Filters high-frequency components beyond the coarse grid's Nyquist.
- Uses Gaussian filtering in the frequency domain.
- Processes each channel independently.
"""
# Output container
fine_filtered = np.zeros_like(fine_patch)
# ----------------------------------------------------------
# Determine scaling between fine and coarse grids
# ----------------------------------------------------------
scale_factor_H = fine_block.shape[-2] / self.in_shape[0]
scale_factor_W = fine_block.shape[-1] / self.in_shape[1]
# ----------------------------------------------------------
# Coarse-grid physical spacing and Nyquist limits
# ----------------------------------------------------------
dW_coarse = self.dW * scale_factor_W
dH_coarse = self.dH * scale_factor_H
f_nW_coarse = 1.0 / (2.0 * dW_coarse)
f_nH_coarse = 1.0 / (2.0 * dH_coarse)
# ----------------------------------------------------------
# Apply filtering per channel
# ----------------------------------------------------------
C = fine_patch.shape[0]
for c in range(C):
try:
img = fine_patch[c]
filt = gaussian_filter(
img,
self.dW,
self.dH,
f_nW_coarse,
f_nH_coarse,
self.epsilon,
self.margin,
)
# Save the filtered channel
fine_filtered[c] = filt
except Exception as e:
self.logger.exception(f"Filtering failed at channel {c}", e)
raise
return fine_filtered
[docs]
def normalize(self, data, stats, norm_type, var_name=None, data_type=None):
"""
Normalize data using specified statistics and method.
Parameters
----------
data : torch.Tensor
Input data to normalize.
stats : object
Statistics object with attributes: vmin, vmax, vmean, vstd,
median, iqr, q1, q3.
norm_type : str
Normalization type: "minmax", "minmax_11", "standard", "robust",
"log1p_minmax", "log1p_standard".
var_name : str, optional
Variable name for logging.
data_type : str, optional
Data type description for logging.
Returns
-------
torch.Tensor
Normalized data.
Raises
------
ValueError
If norm_type is not supported.
"""
# Create context string for logging
context = ""
if var_name is not None:
context = f" for {var_name}"
if data_type is not None:
context += f" ({data_type})"
if self.debug:
self.logger.info(
f"Normalizing{context} with type '{norm_type}'\n"
f" └── Normalization stats:\n"
f" └── vmin: {getattr(stats, 'vmin', None)}\n"
f" └── vmax: {getattr(stats, 'vmax', None)}\n"
f" └── vmean: {getattr(stats, 'vmean', None)}\n"
f" └── vstd: {getattr(stats, 'vstd', None)}\n"
f" └── median: {getattr(stats, 'median', None)}\n"
f" └── iqr: {getattr(stats, 'iqr', None)}\n"
f" └── q1: {getattr(stats, 'q1', None)}\n"
f" └── q3: {getattr(stats, 'q3', None)}"
)
# ------------------ MIN-MAX ------------------
if norm_type == "minmax":
vmin = torch.tensor(stats.vmin, dtype=data.dtype, device=data.device)
vmax = torch.tensor(stats.vmax, dtype=data.dtype, device=data.device)
denom = vmax - vmin
if denom == 0:
self.logger.warning("vmax == vmin, returning zeros.")
return torch.zeros_like(data)
return (data - vmin) / denom
# ------------------ MIN-MAX [-1, 1] -----------------
elif norm_type == "minmax_11":
vmin = torch.tensor(stats.vmin, dtype=data.dtype, device=data.device)
vmax = torch.tensor(stats.vmax, dtype=data.dtype, device=data.device)
denom = vmax - vmin
if denom == 0:
self.logger.warning("vmax == vmin, returning zeros.")
return torch.zeros_like(data)
return 2 * (data - vmin) / denom - 1
# ------------------ STANDARD -----------------
elif norm_type == "standard":
mean = torch.tensor(stats.vmean, dtype=data.dtype, device=data.device)
std = torch.tensor(stats.vstd, dtype=data.dtype, device=data.device)
if std == 0:
self.logger.warning("vstd == 0, returning zeros.")
return torch.zeros_like(data)
return (data - mean) / std
# ------------------ ROBUST -------------------
elif norm_type == "robust":
median = torch.tensor(stats.median, dtype=data.dtype, device=data.device)
iqr = torch.tensor(stats.iqr, dtype=data.dtype, device=data.device)
if iqr == 0:
self.logger.warning("iqr == 0, returning zeros.")
return torch.zeros_like(data)
return (data - median) / iqr
# ------------------ LOG1P + MIN-MAX ------------------
elif norm_type == "log1p_minmax":
data = torch.log1p(data)
log_min = torch.tensor(stats.vmin, dtype=data.dtype, device=data.device)
log_max = torch.tensor(stats.vmax, dtype=data.dtype, device=data.device)
denom = log_max - log_min
if denom == 0:
self.logger.warning("log_max == log_min, returning zeros.")
return torch.zeros_like(data)
return (data - log_min) / denom
# ------------------ LOG1P + STANDARD ------------------
elif norm_type == "log1p_standard":
data = torch.log1p(data)
mean = torch.tensor(stats.vmean, dtype=data.dtype, device=data.device)
std = torch.tensor(stats.vstd, dtype=data.dtype, device=data.device)
if std == 0:
self.logger.warning("log_std == 0, returning zeros.")
return torch.zeros_like(data)
return (data - mean) / std
else:
self.logger.error(f"Unsupported norm_type '{norm_type}'")
raise ValueError(f"Unsupported norm_type '{norm_type}'")
[docs]
def normalize_time(self, tindex):
"""
Return normalized time features for given time index.
Parameters
----------
tindex : int
Time index.
Returns
-------
dict
Dictionary of normalized time features.
Notes
-----
Features depend on time_normalization setting:
- "linear": year_norm, doy_norm, hour_norm
- "cos_sin": year_norm, doy_sin, doy_cos, hour_sin, hour_cos
"""
out = {"year_norm": self.year_norm[tindex]}
if self.time_normalization == "linear":
out["doy_norm"] = self.doy_norm[tindex]
out["hour_norm"] = self.hour_norm[tindex]
elif self.time_normalization == "cos_sin":
out["doy_sin"] = self.doy_sin[tindex]
out["doy_cos"] = self.doy_cos[tindex]
out["hour_sin"] = self.hour_sin[tindex]
out["hour_cos"] = self.hour_cos[tindex]
else:
raise NotImplementedError(
f"Time normalization '{self.time_normalization}' not implemented!"
)
return out
[docs]
def __len__(self):
"""
Return total number of samples in the dataset.
Returns
-------
int
Total samples = number of time batches × number of spatial batches.
"""
return len(self.time_batchs) * self.sbatch
[docs]
def __getitem__(self, index):
"""
Get a single data sample.
Parameters
----------
index : int
Index of the sample to retrieve.
Returns
-------
sample : dict
Dictionary containing:
- inputs: model input features including normalized coarse data,
normalized coordinates, and constant variables.
- targets: normalized residuals.
- fine: original fine-resolution data.
- coarse: coarse approximation before normalization.
- coordinates: latitude and longitude coordinates for the batch.
- time features: normalized temporal features.
Raises
------
IndexError
If index is out of bounds.
AssertionError
If shape mismatches or other consistency checks fail.
Notes
-----
- In training mode: random spatial and temporal sampling.
- In validation mode: deterministic spatial slicing and fixed
or sequential temporal sampling.
- Applies multi-scale processing if apply_filter is True.
- Normalizes data according to provided statistics.
"""
if self.debug:
self.logger.info("------------------- GET ITEM INFO -------------------")
# Calculate spatial and temporal indices
sindex = index % self.sbatch
tbatch_index = index // self.sbatch
if tbatch_index >= len(self.time_batchs):
raise IndexError(f"Time batch index {tbatch_index} out of range")
tindex = self.time_batchs[tbatch_index]
self.tindex_tracker.append(tindex)
if self.debug:
self.logger.info(
f"\nTorch batch index: {index}\n"
f"Time block index: {tbatch_index}\n"
f"Final time index (tindex): {tindex}\n"
f"Spatial batch index (sindex): {sindex}\n"
)
# Load data
full_data_org = self.loaded_dfs.isel(time=tindex)
lat = full_data_org.latitude.values.copy()
lon = full_data_org.longitude.values.copy()
# Normalize to range [-1, 1] for better neural network input stability
lat_norm = 2 * ((lat - lat.min()) / (lat.max() - lat.min())) - 1
lon_norm = 2 * ((lon - lon.min()) / (lon.max() - lon.min())) - 1
# 2D meshgrids of normalized latitude and longitude (shape: H x W)
lat_grid, lon_grid = np.meshgrid(lat_norm, lon_norm, indexing="ij")
sample = self.normalize_time(tindex)
# Determine spatial sampling based on mode
if self.mode == "validation":
# Ensure evaluation slices are available
assert (
hasattr(self, "eval_slices") and self.eval_slices is not None
), "eval_slices not initialized for validation mode"
assert len(self.eval_slices) > 0, "eval_slices is empty for validation mode"
# Deterministic spatial sampling for evaluation.
# eval_slices defines the spatial regions to evaluate and is later used
# to reconstruct the full domain.
lat_start, lat_end, lon_start, lon_end = self.eval_slices[sindex]
lat_center = lat_start + self.batch_size_lat // 2
lon_center = lon_start + self.batch_size_lon // 2
self.center_tracker.append((lat_center, lon_center))
if (
self.run_type == "inference_regional"
or self.run_type == "train_regional"
):
# we can have mode = validation and run_type = train_regional at the same time.
# Extract coordinates using the same spatial extraction logic as training
# (cyclic longitude handling, boundary alignment).
lat_batch, lat_indices = self.extract_batch(
lat_grid, lat_center, lon_center
)
lon_batch, lon_indices = self.extract_batch(
lon_grid, lat_center, lon_center
)
assert lat_indices == lon_indices, (
f"Indices mismatch for same center (lat_center={lat_center}, lon_center={lon_center}):\n"
f" lat_indices: {lat_indices}\n"
f" lon_indices: {lon_indices}"
)
lat_start, lat_end, lon_start, lon_end = lat_indices
npfeatures_full = np.zeros([len(self.varnames_list), self.H, self.W])
for var_name in self.varnames_list:
iv = self.index_mapping[var_name]
npfeatures_full[iv, :, :] = full_data_org[var_name].values
fine_block, fine_filtered_block, coarse, fine_indices = (
self.build_fine_coarse_blocks(
npfeatures_full, lat_center, lon_center
)
)
else: # validation, global inference
# Extract normalized coordinates for the batch
lat_batch = lat_grid[lat_start:lat_end, lon_start:lon_end]
lon_batch = lon_grid[lat_start:lat_end, lon_start:lon_end]
# Extract data for all variables into the WHOLE spatial domain first (same as training)
npfeatures_full = np.zeros([len(self.varnames_list), self.H, self.W])
for i, var_name in enumerate(self.varnames_list):
npfeatures_full[i, :, :] = full_data_org[var_name].values
# Extract the target batch for scaling determination
fine_block = npfeatures_full[:, lat_start:lat_end, lon_start:lon_end]
if self.apply_filter:
# Spatial filtering on the full domain before coarsening
fine_filtered_full = self.filter_batch(npfeatures_full, fine_block)
assert fine_filtered_full.shape == npfeatures_full.shape, (
f"Mismatch in shapes: fine_filtered has shape {fine_filtered_full.shape} "
f"but npfeatures has shape {npfeatures_full.shape}."
)
# Now extract the batch from filtered data
fine_filtered_block = fine_filtered_full[
:, lat_start:lat_end, lon_start:lon_end
]
assert fine_filtered_block.shape == fine_block.shape, (
f"Mismatch in shapes: fine_filtered_block has shape {fine_filtered_block.shape} "
f"but fine_block has shape {fine_block.shape}."
)
# Apply coarsening to the full domain of the filtered HR field
coarse_full = coarse_down_up(
fine_filtered_full, npfeatures_full, input_shape=self.in_shape
)
else:
# Apply coarsening directly to the full domain of the raw HR field
coarse_full = coarse_down_up(
npfeatures_full, npfeatures_full, input_shape=self.in_shape
)
coarse_block = coarse_full[:, lat_start:lat_end, lon_start:lon_end]
coarse = coarse_block
if self.debug:
self.logger.info(
f" Evaluation mode: slice lat[{lat_start}:{lat_end}], lon[{lon_start}:{lon_end}]\n"
f" Full domain shape: {npfeatures_full.shape}\n"
f" Batch shape: {fine_block.shape}"
)
elif self.run_type == "train_regional":
# Ensure evaluation slices are available
assert (
hasattr(self, "train_slices") and self.train_slices is not None
), "train_slices not initialized for train_regional mode"
assert (
len(self.train_slices) > 0
), "train_slices is empty for train_regional mode"
# Deterministic spatial sampling for regional training.
# train_slices defines the spatial regions to train on and is later used
# to reconstruct the full domain.
lat_start, lat_end, lon_start, lon_end = self.train_slices[sindex]
lat_center = lat_start + self.batch_size_lat // 2
lon_center = lon_start + self.batch_size_lon // 2
self.center_tracker.append((lat_center, lon_center))
# Extract coordinates using the same spatial extraction logic as training
# (cyclic longitude handling, boundary alignment).
lat_batch, lat_indices = self.extract_batch(
lat_grid, lat_center, lon_center
)
lon_batch, lon_indices = self.extract_batch(
lon_grid, lat_center, lon_center
)
assert lat_indices == lon_indices, (
f"Indices mismatch for same center (lat_center={lat_center}, lon_center={lon_center}):\n"
f" lat_indices: {lat_indices}\n"
f" lon_indices: {lon_indices}"
)
lat_start, lat_end, lon_start, lon_end = lat_indices
npfeatures_full = np.zeros([len(self.varnames_list), self.H, self.W])
for var_name in self.varnames_list:
iv = self.index_mapping[var_name]
npfeatures_full[iv, :, :] = full_data_org[var_name].values
fine_block, fine_filtered_block, coarse, fine_indices = (
self.build_fine_coarse_blocks(npfeatures_full, lat_center, lon_center)
)
assert fine_indices == lat_indices, (
f"Indices mismatch between data and coordinate extractions:\n"
f" lat/lon indices: {lat_indices}\n"
f" data indices: {fine_indices}"
)
else: # train (global)
# Random spatial sampling for training
if self.last_tbatch_index != tbatch_index:
self.random_centers = self.generate_random_batch_centers(self.sbatch)
self.last_tbatch_index = tbatch_index
assert (
self.random_centers[sindex] is not None
), f"Random center at index {sindex} has not been generated yet."
lat_center, lon_center = self.random_centers[sindex]
self.center_tracker.append((lat_center, lon_center))
if self.debug:
self.logger.info(f" Training mode: center ({lat_center},{lon_center})")
lat_batch, lat_indices = self.extract_batch(
lat_grid, lat_center, lon_center
)
lat_start, lat_end, lon_start, lon_end = lat_indices
lon_batch, lon_indices = self.extract_batch(
lon_grid, lat_center, lon_center
)
assert lat_indices == lon_indices, (
f"Indices mismatch for same center (lat_center={lat_center}, lon_center={lon_center}):\n"
f" lat_indices: {lat_indices}\n"
f" lon_indices: {lon_indices}"
)
# Extract data for all variables into a NumPy array (full domain)
npfeatures_full = np.zeros([len(self.varnames_list), self.H, self.W])
for var_name in self.varnames_list:
iv = self.index_mapping[var_name]
npfeatures_full[iv, :, :] = full_data_org[var_name].values
fine_block, fine_filtered_block, coarse, fine_indices = (
self.build_fine_coarse_blocks(npfeatures_full, lat_center, lon_center)
)
assert fine_indices == lat_indices, (
f"Indices mismatch between data and coordinate extractions:\n"
f" lat/lon indices: {lat_indices}\n"
f" data indices: {fine_indices}"
)
# Convert fine data to torch tensor and initialize normalized container
fine_block = torch.from_numpy(fine_block).to(self.torch_dtype)
fine_block_norm = torch.zeros_like(fine_block)
# Ensure coarse are torch tensors with correct dtype
if isinstance(coarse, np.ndarray):
coarse = torch.from_numpy(coarse)
elif not torch.is_tensor(coarse):
raise TypeError(f"Unexpected type for coarse: {type(coarse)}")
coarse = coarse.to(self.torch_dtype)
coarse_norm = torch.zeros_like(coarse)
# If filtering is enabled, also prepare the filtered fine field
if self.apply_filter:
fine_filtered_block = torch.from_numpy(fine_filtered_block).to(
self.torch_dtype
)
# Normalize fine and coarse fields using variable-specific statistics.
# Statistics are read from the JSON file and are independent for fine and coarse data.
# For log-based normalizations, the stored statistics correspond to log1p(fine)/log1p(coarse).
for var_name in self.varnames_list:
iv = self.index_mapping[var_name]
norm_type = self.normalization_type[var_name]
# Select appropriate statistics depending on the normalization type
if norm_type.startswith("log1p"):
stats_fine = self.norm_mapping[f"{var_name}_fine_log"]
stats_coarse = self.norm_mapping[f"{var_name}_coarse_log"]
else:
stats_fine = self.norm_mapping[f"{var_name}_fine"]
stats_coarse = self.norm_mapping[f"{var_name}_coarse"]
# Normalize coarse field
coarse_norm[iv] = self.normalize(
coarse[iv],
stats_coarse,
norm_type,
var_name=var_name,
data_type="coarse",
)
# Normalize filtered fine field if applicable
if self.apply_filter:
fine_block_norm[iv] = self.normalize(
fine_filtered_block[iv],
stats_fine,
norm_type,
var_name=var_name,
data_type="fine",
)
# Normalize fine field
else:
fine_block_norm[iv] = self.normalize(
fine_block[iv],
stats_fine,
norm_type,
var_name=var_name,
data_type="fine",
)
# residual is defined in normalized space
residual = fine_block_norm - coarse_norm
expected_shape = (
len(self.varnames_list),
self.batch_size_lat,
self.batch_size_lon,
)
assert (
coarse.shape == expected_shape
), f"coarse shape {coarse.shape} != expected {expected_shape}"
assert (
residual.shape == expected_shape
), f"residual shape {residual.shape} != expected {expected_shape}"
# Ensure spatial batches are torch tensors with correct dtype
lat_batch = (
torch.from_numpy(lat_batch)
if isinstance(lat_batch, np.ndarray)
else lat_batch
)
lon_batch = (
torch.from_numpy(lon_batch)
if isinstance(lon_batch, np.ndarray)
else lon_batch
)
lat_batch = lat_batch.to(self.torch_dtype)
lon_batch = lon_batch.to(self.torch_dtype)
# Ensure residual are torch tensors with correct dtype
residual = (
torch.from_numpy(residual) if isinstance(residual, np.ndarray) else residual
)
residual = residual.to(self.torch_dtype)
# Ensure coarse_norm/fine_norm are torch tensors with correct dtype
coarse_norm = (
torch.from_numpy(coarse_norm)
if isinstance(coarse_norm, np.ndarray)
else coarse_norm
)
fine_block_norm = (
torch.from_numpy(fine_block_norm)
if isinstance(fine_block_norm, np.ndarray)
else fine_block_norm
)
coarse_norm = coarse_norm.to(self.torch_dtype)
fine_block_norm = fine_block_norm.to(self.torch_dtype)
feature = torch.cat(
[coarse_norm, lat_batch.unsqueeze(0), lon_batch.unsqueeze(0)], dim=0
)
if self.debug:
self.logger.info(f" Feature composition before constants: {feature.shape}")
if self.constant_variables is not None:
assert self.const_vars is not None, (
f"Constant variables {self.constant_variables} were specified "
f"but const_vars could not be loaded. Please check the file path and variable names."
)
if self.mode == "validation":
if (
self.run_type == "inference_regional"
or self.run_type == "train_regional"
):
# we can have mode validation and run_type train_regional at the same time.
# For inference_regional, use extract_batch
const_batch, const_indices = self.extract_batch(
self.const_vars, lat_center, lon_center
)
assert const_indices == lat_indices, (
f"Indices mismatch for constant variables:\n"
f" coordinate indices: {lat_indices}\n"
f" constant var indices: {const_indices}"
)
else: # validation, inference_global
# For evaluation, use direct slicing
const_batch = self.const_vars[
:, lat_start:lat_end, lon_start:lon_end
]
else:
# For training (global or regional), use extract_batch
const_batch, const_indices = self.extract_batch(
self.const_vars, lat_center, lon_center
)
assert const_indices == lat_indices, (
f"Indices mismatch for constant variables:\n"
f" coordinate indices: {lat_indices}\n"
f" constant var indices: {const_indices}"
)
const_batch = (
torch.from_numpy(const_batch)
if isinstance(const_batch, np.ndarray)
else const_batch
)
const_batch = const_batch.to(self.torch_dtype)
if self.debug:
self.logger.info(f" Constant batch shape: {const_batch.shape}")
feature = torch.cat([feature, const_batch], dim=0)
if self.debug:
self.logger.info(
f" Feature shape after adding constants: {feature.shape}"
)
if self.run_type in ["inference_regional", "train_regional"]:
# In regional inference / regional train mode, we extract a spatial window centered
# on a user-defined (lat_center, lon_center)
# Therefore, longitude requires special handling to avoid
# discontinuities when crossing the dateline
lat_vals = lat[lat_start:lat_end]
# Recompute the same shift used inside extract_batch
shift = self.W // 2 - lon_center
lon_rolled = np.roll(lon, shift=shift)
# Extract the regional longitude window
lon_vals = lon_rolled[lon_start:lon_end]
# Ensure longitude continuity when the window crosses the dateline
if np.any(np.diff(lon_vals) < -180):
lon_vals = np.rad2deg(np.unwrap(np.deg2rad(lon_vals)))
else:
lat_vals = lat[lat_start:lat_end]
lon_vals = lon[lon_start:lon_end]
sample.update(
{
"inputs": feature, # model inputs (coarse_norm + coordinates + constants)
"targets": residual, # residual in normalized space (model target)
"fine": fine_block, # fine-resolution physical data (for diagnostics)
"coarse": coarse, # coarse physical data (baselinefor diagnostics)
"corrdinates": {
"lat": torch.from_numpy(lat_vals).to(self.torch_dtype),
"lon": torch.from_numpy(lon_vals).to(self.torch_dtype),
},
}
)
if self.debug:
self.logger.info("------------------------------------------------------")
return sample