Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions src/anemoi/datasets/create/sources/xarray_support/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,21 @@ def __init__(self, owner: Any, selection: Any) -> None:
# Copy the metadata from the owner
self._md = owner._metadata.copy()

aliases = {}
for coord_name, coord_value in self.selection.coords.items():
if is_scalar(coord_value):
# Extract the single value from the scalar dimension
# and store it in the metadata
coordinate = owner.by_name[coord_name]
self._md[coord_name] = coordinate.normalise(extract_single_value(coord_value))
normalised = coordinate.normalise(extract_single_value(coord_value))
self._md[coord_name] = normalised
for alias in coordinate.mars_names:
aliases[alias] = normalised

# Add metadata aliases (e.g. levelist == level) only if they are not already present
for alias, value in aliases.items():
if alias not in self._md:
self._md[alias] = value

# By now, the only dimensions should be latitude and longitude
self._shape = tuple(list(self.selection.shape)[-2:])
Expand Down Expand Up @@ -180,7 +189,7 @@ def forecast_reference_time(self) -> datetime.datetime:

def __repr__(self) -> str:
"""Return a string representation of the field."""
return repr(self._metadata)
return f"XArrayField({self._metadata})"

def _values(self, dtype: type | None = None) -> Any:
"""Return the values of the selection.
Expand Down
123 changes: 8 additions & 115 deletions src/anemoi/datasets/create/sources/xarray_support/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,87 +21,6 @@
LOG = logging.getLogger(__name__)


class _MDMapping:
"""A class to handle metadata mapping for variables.

Attributes
----------
variable : Any
The variable to map.
time : Any
The time associated with the variable.
mapping : Dict[str, str]
A dictionary mapping keys to variable names.
"""

def __init__(self, variable: Any) -> None:
"""Initialize the _MDMapping class.

Parameters
----------
variable : Any
The variable to map.
"""
self.variable = variable
self.time = variable.time
self.mapping = dict()
# Aliases

def _from_user(self, key: str) -> str:
"""Get the internal key corresponding to a user-provided key.

Parameters
----------
key : str
The user-provided key.

Returns
-------
str
The internal key corresponding to the user-provided key.
"""
return self.mapping.get(key, key)

def from_user(self, kwargs: dict[str, Any]) -> dict[str, Any]:
"""Convert user-provided keys to internal keys.

Parameters
----------
kwargs : Dict[str, Any]
A dictionary of user-provided keys and values.

Returns
-------
Dict[str, Any]
A dictionary with internal keys and original values.
"""
return {self._from_user(k): v for k, v in kwargs.items()}

def __repr__(self) -> str:
"""Return a string representation of the _MDMapping object.

Returns
-------
str
String representation of the _MDMapping object.
"""
return f"MDMapping({self.mapping})"

def fill_time_metadata(self, field: Any, md: dict[str, Any]) -> None:
"""Fill the time metadata for a field.

Parameters
----------
field : Any
The field to fill metadata for.
md : Dict[str, Any]
The metadata dictionary to update.
"""
valid_datetime = self.variable.time.fill_time_metadata(field._md, md)
if valid_datetime is not None:
md["valid_datetime"] = as_datetime(valid_datetime).isoformat()


class XArrayMetadata(RawMetadata):
"""A class to handle metadata for XArray fields.

Expand All @@ -127,10 +46,16 @@ def __init__(self, field: Any) -> None:
field : Any
The field to extract metadata from.
"""
from .field import XArrayField

assert isinstance(field, XArrayField), type(field)
self._field = field
md = field._md.copy()
self._mapping = _MDMapping(field.owner)
self._mapping.fill_time_metadata(field, md)

valid_datetime = field.owner.time.fill_time_metadata(field._md, md)
if valid_datetime is not None:
md["valid_datetime"] = as_datetime(valid_datetime).isoformat()

super().__init__(md)

@cached_property
Expand Down Expand Up @@ -190,38 +115,6 @@ def _valid_datetime(self) -> datetime.datetime | None:
"""
return self._get("valid_datetime")

def get(self, key: str, astype: type | None = None, **kwargs: Any) -> Any:
"""Get a metadata value by key.

Parameters
----------
key : str
The key to get the value for.
astype : Optional[type]
The type to cast the value to.
**kwargs : Any
Additional keyword arguments.

Returns
-------
Any
The value for the specified key, optionally cast to the specified type.
"""

if key == "levelist":
# Special case for levelist, for compatibility with GRIB
if key not in self._d and "level" in self._d:
key = "level"

if key in self._d:
if astype is not None:
return astype(self._d[key])
return self._d[key]

key = self._mapping._from_user(key)

return super().get(key, astype=astype, **kwargs)


class XArrayFieldGeography(Geography):
"""A class to handle geography information for XArray fields.
Expand Down
Loading