Skip to content

refac: more control on the output of dicts from Model.as_dict() and experiment.as_dict() #18

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 29 additions & 33 deletions floatcsep/experiment.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import datetime
import json
import logging
import os
import shutil
import warnings
from os.path import join, abspath, relpath, dirname, isfile, split, exists
from typing import Union, List, Dict, Callable, Sequence
from typing import Union, List, Dict, Sequence

import csep
import numpy
import yaml
from cartopy import crs as ccrs
Expand All @@ -23,13 +21,11 @@
from floatcsep.repository import ResultsRepository, CatalogRepository
from floatcsep.utils import (
NoAliasLoader,
parse_csep_func,
read_time_cfg,
read_region_cfg,
Task,
TaskGraph,
timewindow2str,
str2timewindow,
magnitude_vs_time,
parse_nested_dicts,
)
Expand Down Expand Up @@ -184,7 +180,7 @@ def __init__(
self.postproc_config = postproc_config if postproc_config else {}
self.default_test_kwargs = default_test_kwargs

self.catalog_repo.set_catalog(catalog, self.time_config, self.region_config)
self.catalog_repo.set_main_catalog(catalog, self.time_config, self.region_config)

self.models = self.set_models(
models or kwargs.get("model_config"), kwargs.get("order", None)
Expand Down Expand Up @@ -717,7 +713,7 @@ def make_repr(self):

# Dropping region to results folder if it is a file
region_path = self.region_config.get("path", False)
if region_path:
if isinstance(region_path, str):
if isfile(region_path) and region_path:
new_path = join(self.registry.rundir, self.region_config["path"])
shutil.copy2(region_path, new_path)
Expand All @@ -726,10 +722,10 @@ def make_repr(self):

# Dropping catalog to results folder
target_cat = join(
self.registry.workdir, self.registry.rundir, split(self.catalog_repo._catpath)[-1]
self.registry.workdir, self.registry.rundir, split(self.catalog_repo.cat_path)[-1]
)
if not exists(target_cat):
shutil.copy2(self.registry.abs(self.catalog_repo._catpath), target_cat)
shutil.copy2(self.registry.abs(self.catalog_repo.cat_path), target_cat)
self._catpath = self.registry.rel(target_cat)

relative_path = os.path.relpath(
Expand All @@ -738,41 +734,41 @@ def make_repr(self):
self.registry.workdir = relative_path
self.to_yml(repr_config, extended=True)

def as_dict(
self,
exclude: Sequence = (
"magnitudes",
"depths",
"timewindows",
"filetree",
"task_graph",
"tasks",
"models",
"tests",
"results_repo",
"catalog_repo",
),
extended: bool = False,
) -> dict:
def as_dict(self, extra: Sequence = (), extended=False) -> dict:
"""
Converts an Experiment instance into a dictionary.

Args:
exclude (tuple, list): Attributes, or attribute keys, to ignore
extended (bool): Verbose representation of pycsep objects
extra: additional instance attribute to include in the dictionary.
extended: Include explicit parameters

Returns:
A dictionary with serialized instance's attributes, which are
floatCSEP readable
"""

listwalk = [(i, j) for i, j in self.__dict__.items() if not i.startswith("_") and j]
listwalk.insert(6, ("catalog", self.catalog_repo._catpath))

dictwalk = {i: j for i, j in listwalk}
dictwalk["path"] = dictwalk.pop("registry").workdir
dict_walk = {
"name": self.name,
"config_file": self.config_file,
"path": self.registry.workdir,
"run_dir": self.registry.rundir,
"time_config": {
i: j
for i, j in self.time_config.items()
if (i not in ("timewindows",) or extended)
},
"region_config": {
i: j
for i, j in self.region_config.items()
if (i not in ("magnitudes", "depths") or extended)
},
"catalog": self.catalog_repo.cat_path,
"models": [i.as_dict() for i in self.models],
"tests": [i.as_dict() for i in self.tests],
}
dict_walk.update(extra)

return parse_nested_dicts(dictwalk, excluded=exclude, extended=extended)
return parse_nested_dicts(dict_walk)

def to_yml(self, filename: str, **kwargs) -> None:
"""
Expand Down
7 changes: 3 additions & 4 deletions floatcsep/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
import os
from abc import ABC, abstractmethod
from datetime import datetime
from typing import List, Callable, Union, Mapping, Sequence
from typing import List, Callable, Union, Sequence

import git
import numpy
from csep.core.forecasts import GriddedForecast, CatalogForecast

from floatcsep.accessors import from_zenodo, from_git
Expand Down Expand Up @@ -133,10 +132,10 @@ def as_dict(self, excluded=("name", "repository", "workdir")):
(i, j) for i, j in sorted(self.__dict__.items()) if not i.startswith("_") and j
]

dict_walk = {i: j for i, j in list_walk}
dict_walk = {i: j for i, j in list_walk if i not in excluded}
dict_walk["path"] = dict_walk.pop("registry").path

return {self.name: parse_nested_dicts(dict_walk, excluded=excluded)}
return {self.name: parse_nested_dicts(dict_walk)}

@classmethod
def from_dict(cls, record: dict, **kwargs):
Expand Down
25 changes: 13 additions & 12 deletions floatcsep/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@ class ResultsRepository:

def __init__(self, registry: ExperimentRegistry):
self.registry = registry
self.a = 1

def _load_result(
self,
Expand Down Expand Up @@ -236,6 +235,8 @@ def default(self, obj):
class CatalogRepository:

def __init__(self, registry: ExperimentRegistry):
self.cat_path = None
self._catalog = None
self.registry = registry
self.time_config = {}
self.region_config = {}
Expand Down Expand Up @@ -270,7 +271,7 @@ def __getattr__(self, item: str) -> object:
def as_dict(self):
return

def set_catalog(
def set_main_catalog(
self, catalog: Union[str, Callable, CSEPCatalog], time_config: dict, region_config: dict
):
"""
Expand All @@ -291,11 +292,11 @@ def catalog(self) -> CSEPCatalog:
Returns a CSEP catalog loaded from the given query function or a stored file if it
exists.
"""
cat_path = self.registry.abs(self._catpath)
cat_path = self.registry.abs(self.cat_path)

if callable(self._catalog):
if isfile(self._catpath):
return CSEPCatalog.load_json(self._catpath)
if isfile(self.cat_path):
return CSEPCatalog.load_json(self.cat_path)
bounds = {
"start_time": min([item for sublist in self.timewindows for item in sublist]),
"end_time": max([item for sublist in self.timewindows for item in sublist]),
Expand All @@ -318,7 +319,7 @@ def catalog(self) -> CSEPCatalog:
if self.region:
catalog.filter_spatial(region=self.region, in_place=True)
catalog.region = None
catalog.write_json(self._catpath)
catalog.write_json(self.cat_path)

return catalog

Expand All @@ -333,19 +334,19 @@ def catalog(self, cat: Union[Callable, CSEPCatalog, str]) -> None:

if cat is None:
self._catalog = None
self._catpath = None
self.cat_path = None

elif isfile(self.registry.abs(cat)):
log.info(f"\tCatalog: '{cat}'")
self._catalog = self.registry.rel(cat)
self._catpath = self.registry.rel(cat)
self.cat_path = self.registry.rel(cat)

else:
# catalog can be a function
self._catalog = parse_csep_func(cat)
self._catpath = self.registry.abs("catalog.json")
if isfile(self._catpath):
log.info(f"\tCatalog: stored " f"'{self._catpath}' " f"from '{cat}'")
self.cat_path = self.registry.abs("catalog.json")
if isfile(self.cat_path):
log.info(f"\tCatalog: stored " f"'{self.cat_path}' " f"from '{cat}'")
else:
log.info(f"\tCatalog: '{cat}'")

Expand All @@ -363,7 +364,7 @@ def get_test_cat(self, tstring: str = None) -> CSEPCatalog:
else:
start = self.start_date
end = self.end_date
print(self.catalog)

sub_cat = self.catalog.filter(
[
f"origin_time < {end.timestamp() * 1000}",
Expand Down
12 changes: 3 additions & 9 deletions floatcsep/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,11 +407,9 @@ def timewindows_td(
# return timewindows


def parse_nested_dicts(
nested_dict: dict, excluded: Sequence = (), extended: bool = False
) -> dict:
def parse_nested_dicts(nested_dict: dict) -> dict:
"""
Parses nested dictionaries to flatten them
Parses nested dictionaries to return appropriate parsing on each element
"""

def _get_value(x):
Expand All @@ -435,11 +433,7 @@ def _get_value(x):
def iter_attr(val):
# recursive iter through nested dicts/lists
if isinstance(val, Mapping):
return {
item: iter_attr(val_)
for item, val_ in val.items()
if ((item not in excluded) and val_) or extended
}
return {item: iter_attr(val_) for item, val_ in val.items()}
elif isinstance(val, Sequence) and not isinstance(val, str):
return [iter_attr(i) for i in val]
else:
Expand Down
5 changes: 4 additions & 1 deletion tests/unit/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ def test_to_dict(self):
"name": "test",
"path": os.getcwd(),
"run_dir": "results",
"config_file": None,
"models": [],
"tests": [],
"time_config": {
"exp_class": "ti",
"start_date": datetime(2020, 1, 1),
Expand Down Expand Up @@ -109,7 +112,7 @@ def test_to_yml(self):
self.assertEqualExperiment(exp_a, exp_b)

file_ = tempfile.mkstemp()[1]
exp_a.to_yml(file_, extended=True)
exp_a.to_yml(file_)
exp_c = Experiment.from_yml(file_)
self.assertEqualExperiment(exp_a, exp_c)

Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_repositories.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,10 +204,10 @@ def test_set_catalog(self, mock_isfile):
# Mock the registry's rel method to return the same path for simplicity
self.mock_registry.rel.return_value = "catalog_path"

self.catalog_repo.set_catalog("catalog_path", {}, {})
self.catalog_repo.set_main_catalog("catalog_path", {}, {})

# Check if _catpath is set correctly
self.assertEqual(self.catalog_repo._catpath, "catalog_path")
self.assertEqual(self.catalog_repo.cat_path, "catalog_path")

# Check if _catalog is set correctly
self.assertEqual(self.catalog_repo._catalog, "catalog_path")
Expand Down
Loading