# Copyright 2026 IPSL / CNRS / Sorbonne University
# Authors: Kazem Ardaneh
#
# This work is licensed under the Creative Commons
# Attribution-NonCommercial-ShareAlike 4.0 International License.
# To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc-sa/4.0/
import torch
from torch.utils.data import Dataset
import numpy as np
import xarray as xr
from collections import defaultdict
import re
from typing import Dict, List, Tuple, Any
import random
[docs]
class DataPreprocessor(Dataset):
"""
Dataset class for preprocessing LSM (Land Surface Model) data.
This class handles loading and preprocessing of NetCDF files containing
climate data, with support for multiple years, spatial and temporal batching,
and various normalization techniques.
Parameters
----------
logger : object
Logger instance for logging messages.
dfs : List[str]
List of file paths to NetCDF files.
stime : int
Start time index.
tstep : int
Number of time steps per file.
tbatch : int
Temporal batch size.
norm_mapping : Dict, optional
Dictionary containing normalization statistics for each variable.
Default is empty dict.
normalization_type : Dict, optional
Dictionary specifying normalization type for each variable.
Default is empty dict.
Attributes
----------
logger : object
Logger instance.
stime : int
Start time index.
tstep : int
Time steps per file.
tbatch : int
Temporal batch size.
norm_mapping : Dict
Normalization statistics.
normalization_type : Dict
Normalization types per variable.
sbatch : int
Number of spatial batches.
years : List[int]
Sorted list of years in the dataset.
etime : int
End time index.
dfs : List[Tuple[int, int, str]]
List of (year, spatial_index, file_path) tuples.
time_blocks : np.ndarray
Shuffled time blocks.
min_dims : Dict[str, int]
Minimum dimensions across files.
cosz : List[str]
Cosine of solar zenith angle variable names.
lai : List[str]
Leaf area index variable names.
ssa : List[str]
Single scattering albedo variable names.
rs : List[str]
Surface reflectance variable names.
ov : List[str]
Output variable names.
Examples
--------
>>> from rtnn.logger import Logger
>>> logger = Logger()
>>> files = ["data_1995.nc", "data_1996.nc"]
>>> dataset = DataPreprocessor(
... logger=logger,
... dfs=files,
... stime=0,
... tstep=100,
... tbatch=24,
... norm_mapping={},
... normalization_type={}
... )
>>> len(dataset)
100
>>> features, targets = dataset[0]
>>> features.shape
torch.Size([schunk, feature_channels, seq_length])
>>> targets.shape
torch.Size([schunk, output_channels, seq_length])
"""
[docs]
def __init__(
self,
logger: Any,
dfs: List[str],
stime: int,
tbatch: int,
training: bool = True,
sblock_perc: float = 0.6,
norm_mapping: Dict = {},
normalization_type: Dict = {},
debug: bool = False,
) -> None:
"""
Initialize the DataPreprocessor.
Parameters
----------
logger : Any
Logger instance for logging messages.
dfs : List[str]
List of file paths to NetCDF files.
stime : int
Start time index.
tbatch : int
Temporal batch size.
training : bool, optional
If True, use 60% of spatial batches (data augmentation).
If False, use 100% of spatial batches (full evaluation).
norm_mapping : Dict, optional
Dictionary containing normalization statistics for each variable.
normalization_type : Dict, optional
Dictionary specifying normalization type for each variable.
debug : bool, optional
If True, print debug information.
"""
self.logger = logger
self.stime = stime
self.tbatch = tbatch
self.training = training
self.norm_mapping = norm_mapping
self.normalization_type = normalization_type
self.debug = debug
self.sblock_perc = sblock_perc
# Group files by year
self.train_sbatch_files_by_year = defaultdict(list)
for f in dfs:
match = re.search(r"_(\d{4})\.nc$", f)
if match:
year = int(match.group(1))
self.train_sbatch_files_by_year[year].append(f)
# Determine number of spatial batches
first_key = list(self.train_sbatch_files_by_year.keys())[0]
self.total_sbatch = len(self.train_sbatch_files_by_year[first_key])
# Set spatial batch size based on training mode
if self.training:
# Training: use 60% of spatial batches
self.sbatch = max(1, int(self.total_sbatch * self.sblock_perc))
# Initialize tracking for random spatial mapping
self.last_tindex = -1
self.current_spatial_mapping = None
else:
# Validation/Testing: use 100% of spatial batches
self.sbatch = self.total_sbatch
self.years = sorted(self.train_sbatch_files_by_year.keys())
self.year_to_index = {y: i for i, y in enumerate(self.years)}
# Create list of (year, spatial_index, path) for all files
self.dfs = [
(year, sindex, path)
for year in self.years
for sindex, path in enumerate(sorted(self.train_sbatch_files_by_year[year]))
]
# Find minimum dimensions across all files
self.min_dims = {
"time": np.inf,
"dim_1": np.inf,
"dim_2": np.inf,
"dim_3": np.inf,
"dim_4": np.inf,
}
for _, _, file_path in self.dfs:
ds = xr.open_dataset(file_path)
for dim in self.min_dims:
if dim in ds.sizes:
self.min_dims[dim] = min(self.min_dims[dim], ds.sizes[dim])
ds.close()
for dim, size in self.min_dims.items():
self.logger.info(f"Minimum {dim} across files: {size}")
self.tstep = self.min_dims["time"]
self.etime = self.tstep * len(self.years)
# Create and shuffle time blocks
self.time_blocks = np.arange((self.etime - self.stime) // self.tbatch)
# Define variable groups
self.cosz = ["coszang"] # Cosine of solar zenith angle
self.lai = ["laieff_collim", "laieff_isotrop"] # Leaf area index
self.ssa = ["leaf_ssa", "leaf_psd"] # Single scattering albedo
self.rs = ["rs_surface_emu"] # Surface reflectance
self.ov = [
"collim_alb",
"collim_tran",
"isotrop_alb",
"isotrop_tran",
] # Output variables
self.sindex_tracker = [] # Will store spatial indices
self.tindex_tracker = [] # Will store temporal indices
self.logger.info(f"Time range: {self.stime} ... {self.etime}")
self.logger.info(f"Time steps per file: {self.tstep}")
self.logger.info(f"Temporal batch size: {self.tbatch}")
self.logger.info(f"Spatial batche size: {self.sbatch}")
self.logger.info(f"Time blocks: {self.time_blocks}")
self.logger.info(f"Years: {self.years}")
self.logger.info(f"Year to index: {self.year_to_index}")
self.logger.info(
f"Variable groups: {self.cosz}, {self.lai}, {self.ssa}, {self.rs}, {self.ov}"
)
self.logger.info(
"The list of file info:\n"
+ "\n".join(f"{year}, {sindex}, {path}" for year, sindex, path in self.dfs)
)
random.seed(42) # Set a fixed seed for reproducibility
def _get_random_spatial_mapping(self) -> List[int]:
"""
Generate a random spatial mapping for training.
Returns
-------
List[int]
List of randomly selected processor ranks (size = self.sbatch).
"""
return random.sample(range(self.total_sbatch), self.sbatch)
[docs]
def normalize(self, data: np.ndarray, var_name: str) -> np.ndarray:
"""
Normalize data using the specified normalization method.
Parameters
----------
data : np.ndarray
Input data array to normalize.
var_name : str
Name of the variable for which to retrieve normalization statistics.
Returns
-------
np.ndarray
Normalized data array.
Raises
------
ValueError
If the normalization type is not supported.
Notes
-----
Supported normalization types:
- minmax: (x - min) / (max - min)
- standard: (x - mean) / std
- robust: (x - median) / IQR
- log1p_minmax: log1p(x) normalized
- log1p_standard: log1p(x) standardized
- log1p_robust: log1p(x) robust normalized
- sqrt_minmax: sqrt(x) normalized
- sqrt_standard: sqrt(x) standardized
- sqrt_robust: sqrt(x) robust normalized
"""
norm_type = self.normalization_type.get(var_name, "log1p_minmax")
stats = self.norm_mapping[var_name]
if self.debug:
self.logger.info(
f"Normalizing variable '{var_name}' using method '{norm_type}' with stats: {stats}"
)
if norm_type == "minmax":
vmin = stats["vmin"]
vmax = stats["vmax"]
return (data - vmin) / (vmax - vmin)
elif norm_type == "standard":
mean = stats["vmean"]
std = stats["vstd"]
return (data - mean) / std
elif norm_type == "robust":
median = stats["median"]
iqr = stats["iqr"]
return (data - median) / iqr
elif norm_type == "log1p_minmax":
data = np.log1p(data)
log_min = stats["log_min"]
log_max = stats["log_max"]
return (data - log_min) / (log_max - log_min)
elif norm_type == "log1p_standard":
data = np.log1p(data)
mean = stats["log_mean"]
std = stats["log_std"]
return (data - mean) / std
elif norm_type == "log1p_robust":
data = np.log1p(data)
median = stats["log_median"]
iqr = stats["log_iqr"]
return (data - median) / iqr
elif norm_type == "sqrt_minmax":
data = np.sqrt(np.clip(data, a_min=0, a_max=None))
sqrt_min = stats["sqrt_min"]
sqrt_max = stats["sqrt_max"]
return (data - sqrt_min) / (sqrt_max - sqrt_min)
elif norm_type == "sqrt_standard":
data = np.sqrt(np.clip(data, a_min=0, a_max=None))
mean = stats["sqrt_mean"]
std = stats["sqrt_std"]
return (data - mean) / std
elif norm_type == "sqrt_robust":
data = np.sqrt(np.clip(data, a_min=0, a_max=None))
median = stats["sqrt_median"]
iqr = stats["sqrt_iqr"]
return (data - median) / iqr
else:
raise ValueError(
f"Unsupported normalization type '{norm_type}' for variable '{var_name}'"
)
def __len__(self) -> int:
"""
Get the total number of samples in the dataset.
Returns
-------
int
Total number of samples (time blocks * spatial batches).
"""
return (self.etime - self.stime) // self.tbatch * self.sbatch
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Get a sample from the dataset.
Parameters
----------
index : int
Index of the sample to retrieve.
Returns
-------
Tuple[torch.Tensor, torch.Tensor]
A tuple containing:
- features: Input features tensor of shape (schunk, feature_channels, seq_length)
- targets: Target variables tensor of shape (schunk, output_channels, seq_length)
Notes
-----
The method loads data from the appropriate file based on the index,
applies normalization, and returns the processed features and targets.
"""
if index >= len(self):
raise IndexError(f"Index {index} out of range [0, {len(self)})")
# Calculate spatial and temporal indices
index_spatial_mapping = index % self.sbatch
tblock = index // self.sbatch
# Calculate which year this block belongs to
blocks_per_year = self.tstep // self.tbatch
if blocks_per_year <= 0:
raise ValueError(
f"Invalid blocks_per_year: {blocks_per_year}. "
f"tstep={self.tstep}, tbatch={self.tbatch}"
)
year_index = tblock // blocks_per_year
# Validate year_index
if year_index >= len(self.years):
raise IndexError(
f"Year index {year_index} out of range [0, {len(self.years)})"
)
local_tblock = tblock % blocks_per_year
# Calculate time index (with random offset for training)
tindex = local_tblock * self.tbatch + self.stime
# For training: regenerate spatial mapping when time index changes
if self.training:
if self.last_tindex != tindex:
self.current_spatial_mapping = self._get_random_spatial_mapping()
self.last_tindex = tindex
if self.debug:
self.logger.info(
f"New spatial mapping for tindex {tindex}: {self.current_spatial_mapping}"
)
# Map the spatial index to an actual processor rank
sindex = self.current_spatial_mapping[index_spatial_mapping]
else:
# For validation/testing: use direct mapping (sindex = index_spatial_mapping)
sindex = index_spatial_mapping
if self.training:
tindex += np.random.randint(self.tbatch)
self.tindex_tracker.append(tblock)
self.sindex_tracker.append(sindex)
# Get the file path
dfs_index = year_index * self.sbatch + sindex
_, _, path = self.dfs[dfs_index]
if self.debug:
self.logger.info("------------------- GET ITEM INFO -------------------")
self.logger.info(
f"\nTorch batch index: {index}\n"
f"Spatial index before mapping: {index_spatial_mapping}, and Spatial index after mapping: {sindex}\n"
f"Temporal block index: {tblock}\n"
f"Year index: {year_index}\n"
f"Local time block: {local_tblock}\n"
f"Time index: {tindex}\n"
f"Loading file: {path}"
)
# Open the dataset
self.df = xr.open_dataset(path)
# Get dimensions
sequence_length_dim = self.min_dims["dim_2"]
dim_1 = self.min_dims["dim_1"]
dim_3 = self.min_dims["dim_3"]
dim_4 = self.min_dims["dim_4"]
self.schunk = dim_1 * dim_3 * dim_4
# Initialize arrays for each variable group
npcosz = np.zeros([self.schunk, len(self.cosz), sequence_length_dim])
nplai = np.zeros([self.schunk, len(self.lai), sequence_length_dim])
npssa = np.zeros([self.schunk, len(self.ssa), sequence_length_dim])
npov = np.zeros([self.schunk, len(self.ov), sequence_length_dim])
nprs = np.zeros([self.schunk, len(self.rs), sequence_length_dim])
if self.debug:
self.logger.info(
f"Dimensions for processing:\n"
f" |- sequence_length_dim: {sequence_length_dim}\n"
f" |- dim_1: {dim_1}\n"
f" |- dim_3: {dim_3}\n"
f" |- dim_4: {dim_4}\n"
f" |- schunk (total spatial chunk size): {self.schunk}"
)
self.logger.info(
f"Initialized numpy arrays for variable groups with shapes:\n"
f" |- npcosz: {npcosz.shape}\n"
f" |- nplai: {nplai.shape}\n"
f" |- npssa: {npssa.shape}\n"
f" |- npov: {npov.shape}\n"
f" |- nprs: {nprs.shape}"
)
# Process cosz (cosine of solar zenith angle)
for variable_index, variable_name in enumerate(self.cosz):
da = self.df[variable_name]
temp = da.isel(time=tindex, dim_1=slice(0, dim_1)).values
temp = self.normalize(temp, variable_name)
temp = np.tile(temp, dim_3 * dim_4)
temp = np.tile(temp[:, np.newaxis], (1, sequence_length_dim))
npcosz[:, variable_index, :] = temp
tcosz = torch.tensor(npcosz, dtype=torch.float32)
# Process LAI (leaf area index)
for variable_index, variable_name in enumerate(self.lai):
da = self.df[variable_name]
temp = da.isel(time=tindex, dim_1=slice(0, dim_1)).values
temp = self.normalize(temp, variable_name)
temp = temp.transpose(0, 2, 1)
temp = temp.reshape(dim_3 * dim_1, sequence_length_dim)
temp = np.tile(temp, (dim_4, 1))
nplai[:, variable_index, :] = temp
tlai = torch.tensor(nplai, dtype=torch.float32)
# Process SSA (single scattering albedo)
for variable_index, variable_name in enumerate(self.ssa):
da = self.df[variable_name]
temp = da.isel(time=tindex).values
temp = self.normalize(temp, variable_name)
temp = temp.reshape(-1, 1)
temp = np.tile(temp, (dim_1, 1))
temp = np.tile(temp, (1, sequence_length_dim))
npssa[:, variable_index, :] = temp
tssa = torch.tensor(npssa, dtype=torch.float32)
# Process RS (surface reflectance)
for variable_index, variable_name in enumerate(self.rs):
da = self.df[variable_name]
temp = da.isel(time=tindex, dim_1=slice(0, dim_1)).values
temp = self.normalize(temp, variable_name)
temp = temp.reshape(-1, 1)
temp = np.tile(temp, (1, sequence_length_dim))
nprs[:, variable_index, :] = temp
trs = torch.tensor(nprs, dtype=torch.float32)
# Process output variables
for variable_index, variable_name in enumerate(self.ov):
da = self.df[variable_name]
temp = da.isel(time=tindex, dim_1=slice(0, dim_1)).values
temp = self.normalize(temp, variable_name)
temp = temp.transpose(0, 2, 3, 1)
temp = temp.reshape(-1, sequence_length_dim)
npov[:, variable_index, :] = temp
tov = torch.tensor(npov, dtype=torch.float32)
# Concatenate features
feature = torch.cat([tcosz, tlai, tssa, trs], dim=1)
return (feature, tov)