Skip to content

Commit

Permalink
read xarray in cache dir for caravan
Browse files Browse the repository at this point in the history
  • Loading branch information
OuyangWenyu committed Oct 19, 2023
1 parent f0b8657 commit c0dbab6
Showing 1 changed file with 12 additions and 15 deletions.
27 changes: 12 additions & 15 deletions hydrodataset/caravan.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import collections
import glob
import re
import warnings
from tqdm import tqdm
Expand Down Expand Up @@ -774,7 +775,7 @@ def _check_data(self, regions):
def read_attr_xrdataset(self, gage_id_lst=None, var_lst=None, **kwargs):
# Define the path to the attributes file
file_path = os.path.join(
self.data_source_description["ATTR_DIR"],
CACHE_DIR,
"caravan_attributes.nc",
)

Expand All @@ -791,22 +792,17 @@ def read_attr_xrdataset(self, gage_id_lst=None, var_lst=None, **kwargs):
return ds

def read_ts_xrdataset(self, gage_id_lst, t_range, var_lst, **kwargs):
# TS_dir is same as flow_dir/forcing_dir
base_path = self.data_source_description["FLOW_DIR"]
file_paths = []

# Generate list of file paths based on sites_id
for site in gage_id_lst:
# Split the site string to get the region and site_code
region, site_code = site.split("_")

file_path = os.path.join(base_path, region, f"{region}_{site_code}.nc")
file_paths.append(file_path)
file_paths = sorted(
glob.glob(os.path.join(CACHE_DIR, "*caravan*timeseries*.nc"))
)

# Open the dataset in a lazy manner using dask
parallel = kwargs.get("parallel", False)
combined_ds = xr.open_mfdataset(
file_paths, combine="nested", concat_dim="basin", parallel=parallel
file_paths,
combine="nested",
concat_dim="gauge_id",
parallel=parallel,
)

def extract_unit(variable_name, units_string):
Expand All @@ -829,6 +825,8 @@ def extract_unit(variable_name, units_string):
combined_ds = combined_ds[var_lst]
if t_range:
combined_ds = combined_ds.sel(date=slice(*t_range))
if gage_id_lst:
combined_ds = combined_ds.sel(gauge_id=gage_id_lst)

# some units are not recognized by pint_xarray, hence we manually set them
unit_mapping = {"W/m2": "watt / meter ** 2", "m3/m3": "meter^3/meter^3"}
Expand All @@ -840,9 +838,8 @@ def extract_unit(variable_name, units_string):
unit = unit_mapping.get(unit, unit)
combined_ds[var].attrs["units"] = unit

# Assign basin names as coordinate
combined_ds = combined_ds.assign_coords(basin=gage_id_lst)
combined_ds = combined_ds.rename({"date": "time"})
combined_ds = combined_ds.rename({"gauge_id": "basin"})
return combined_ds

@property
Expand Down

0 comments on commit c0dbab6

Please sign in to comment.