Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor dmod.modeldata serialization and deserialization types #257

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
6b73d55
format modeldata's setup.py
aaraney Jan 26, 2023
e770a4d
add missing dep, gitpython to modeldata's setup.py
aaraney Jan 26, 2023
982f7bc
add pydantic dep
aaraney Jan 26, 2023
fb9d8cb
refactor SubsetDefinition
aaraney Jan 26, 2023
d7dbdc2
refactor Partition
aaraney Jan 26, 2023
3bfcb61
refactor PartitionConfig
aaraney Jan 26, 2023
0b7ab68
add kwargs argument to SubsetDefinition initilizer
aaraney Jan 26, 2023
2db82c2
refactor HydrofabricSubset
aaraney Jan 26, 2023
076c5da
refactor SimpleHydrofabricSubset
aaraney Jan 26, 2023
cd32278
move HydrofabricSubset initializer
aaraney Jan 26, 2023
986cd9f
add private attr class level declarations to SimpleHydrofabricSubset
aaraney Jan 26, 2023
ab1f2b0
update Tuple type hints to be variadic
aaraney Jan 26, 2023
8956e40
fix reference to non-existent field
aaraney Feb 7, 2023
20a6fa0
add missing PrivateAttr field to Partition
aaraney Feb 7, 2023
50abfe7
typo in Partition's fields alias map
aaraney Feb 7, 2023
63cabe9
serialize Partition fields as list and fix it's intializer fn to allo…
aaraney Feb 7, 2023
2d3fa59
add Partition unit tests
aaraney Feb 7, 2023
c9738d8
fix invalid reference
aaraney Feb 7, 2023
fcc488f
add PartionConfig unittests. format with black
aaraney Feb 7, 2023
e45a603
add PartionConfig field_serializer
aaraney Feb 7, 2023
93d0474
override PartitionConfig dict method.
aaraney Feb 7, 2023
f4fe15b
replace private reference to Dataset field name
aaraney Feb 10, 2023
8973a89
replace private reference to Dataset field name in test
aaraney Feb 10, 2023
e29c64b
no longer use a double wrapped classmethod property. is not supported…
aaraney Feb 10, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -560,8 +560,8 @@ def reload(self, reload_from: str, serialized_item: Optional[str] = None) -> Dat
response_obj.release_conn()

# If we can safely infer it, make sure the "type" key is set in cases when it is missing
if len(self.supported_dataset_types) == 1 and Dataset._KEY_TYPE not in response_data:
response_data[Dataset._KEY_TYPE] = list(self.supported_dataset_types)[0].name
if len(self.supported_dataset_types) == 1 and "type" not in response_data:
response_data["type"] = list(self.supported_dataset_types)[0].name

dataset = Dataset.factory_init_from_deserialized_json(response_data)
dataset.manager = self
Expand Down
296 changes: 138 additions & 158 deletions python/lib/modeldata/dmod/modeldata/hydrofabric/partition.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from numbers import Number
from typing import Collection, Dict, FrozenSet, List, Union
from typing import Collection, FrozenSet, List, Optional, TYPE_CHECKING, Union
from pydantic import Field, PrivateAttr, validator
from dmod.core.serializable import Serializable

if TYPE_CHECKING:
from pydantic.typing import AbstractSetIntStr, DictStrAny, MappingIntStrAny


class Partition(Serializable):
"""
Expand All @@ -13,56 +16,84 @@ class Partition(Serializable):
in the context of the related hydrofabric.
"""

__slots__ = ["_catchment_ids", "_hash_val", "_nexus_ids", "_partition_id", "_remote_downstream_nexus_ids",
"_remote_upstream_nexus_ids"]
partition_id: int
catchment_ids: FrozenSet[str]
nexus_ids: FrozenSet[str]
"""
Note that, at the time this is committed, partition ids should always be integers. This is so they can easily
correspond to MPI ranks. However, because of how the expected
"""
remote_upstream_nexus_ids: FrozenSet[str] = Field(default_factory=frozenset)
remote_downstream_nexus_ids: FrozenSet[str] = Field(default_factory=frozenset)

_hash_val: Optional[int] = PrivateAttr(None)

class Config:
fields = {
"catchment_ids": {"alias": "cat-ids"},
"partition_id": {"alias": "id"},
"nexus_ids": {"alias": "nex-ids"},
"remote_upstream_nexus_ids": {"alias": "remote-up"},
"remote_downstream_nexus_ids": {"alias": "remote-down"},
}

def _serialize_frozenset(value: FrozenSet[str]) -> List[str]:
return list(value)

_KEY_CATCHMENT_IDS = 'cat-ids'
_KEY_PARTITION_ID = 'id'
# Note that these need to be included in the JSON, but initially aren't actually used at the JSON level
_KEY_NEXUS_IDS = 'nex-ids'
_KEY_REMOTE_UPSTREAM_NEXUS_IDS = 'remote-up'
_KEY_REMOTE_DOWNSTREAM_NEXUS_IDS = 'remote-down'
field_serializers = {
"catchment_ids": _serialize_frozenset,
"nexus_ids": _serialize_frozenset,
"remote_upstream_nexus_ids": _serialize_frozenset,
"remote_downstream_nexus_ids": _serialize_frozenset,
}

@classmethod
def factory_init_from_deserialized_json(cls, json_obj: dict):
try:
# TODO: later these may be required, but for now, keep optional
if cls._KEY_REMOTE_UPSTREAM_NEXUS_IDS in json_obj:
remote_up = json_obj[cls._KEY_REMOTE_UPSTREAM_NEXUS_IDS]
else:
remote_up = []
if cls._KEY_REMOTE_DOWNSTREAM_NEXUS_IDS in json_obj:
remote_down = json_obj[cls._KEY_REMOTE_UPSTREAM_NEXUS_IDS]
else:
remote_down = []
return Partition(catchment_ids=json_obj[cls._KEY_CATCHMENT_IDS], nexus_ids=json_obj[cls._KEY_NEXUS_IDS],
remote_up_nexuses=remote_up, remote_down_nexuses=remote_down,
partition_id=int(json_obj[cls._KEY_PARTITION_ID]))
except:
return None

def __init__(self, partition_id: int, catchment_ids: Collection[str], nexus_ids: Collection[str],
remote_up_nexuses: Collection[str] = tuple(), remote_down_nexuses: Collection[str] = tuple()):
self._partition_id = partition_id
self._catchment_ids = frozenset(catchment_ids)
self._nexus_ids = frozenset(nexus_ids)
self._remote_upstream_nexus_ids = frozenset(remote_up_nexuses)
self._remote_downstream_nexus_ids = frozenset(remote_down_nexuses)

self._hash_val = None

def __eq__(self, other):
def __init__(
self,
# required, but for backwards compatibility, None
partition_id: int = None,
catchment_ids: Collection[str] = None,
nexus_ids: Collection[str] = None,
# non-required fields
remote_up_nexuses: Collection[str] = None,
remote_down_nexuses: Collection[str] = None,
**data
):
# if data exists, assume fields specified using their alias; no backwards compatibility.
if data:
super().__init__(**data)
return


if remote_up_nexuses is None or remote_down_nexuses is None:
super().__init__(
partition_id=partition_id,
catchment_ids=catchment_ids,
nexus_ids=nexus_ids,
**data
)
return

super().__init__(
partition_id=partition_id,
catchment_ids=catchment_ids,
nexus_ids=nexus_ids,
remote_upstream_nexus_ids=remote_up_nexuses,
remote_downstream_nexus_ids=remote_down_nexuses
)


def __eq__(self, other: object):
if not isinstance(other, self.__class__) or other.partition_id != self.partition_id:
return False
else:
return other.__hash__() == self.__hash__()

def __lt__(self, other):
def __lt__(self, other: "Partition"):
# Go first by id, so this is clearly true
if self._partition_id < other._partition_id:
if self.partition_id < other.partition_id:
return True
# Again, going by id first, having greater id is also clear
elif self._partition_id > other._partition_id:
elif self.partition_id > other.partition_id:
return False
# Also can't be (strictly) less-than AND equal-to
elif self == other:
Expand All @@ -79,116 +110,37 @@ def __hash__(self):
self._hash_val = hash(','.join(cat_id_list))
return self._hash_val

@property
def catchment_ids(self) -> FrozenSet[str]:
"""
Get the frozen set of ids for all catchments in this partition.

Returns
-------
Set[str]
The frozen set of string ids for all catchments in this partition.
"""
return self._catchment_ids

@property
def nexus_ids(self) -> FrozenSet[str]:
"""
Get the frozen set of ids for all nexuses in this partition.

Returns
-------
Set[str]
The frozen set of string ids for all nexuses in this partition.
"""
return self._nexus_ids

@property
def partition_id(self) -> int:
"""
Get the id of this partition.

Note that, at the time this is committed, partition ids should always be integers. This is so they can easily
correspond to MPI ranks. However, because of how the expected

Returns
-------
str
The id of this partition, as a string.
"""
return self._partition_id

@property
def remote_downstream_nexus_ids(self) -> FrozenSet[str]:
"""
Get the frozen set of ids for all remote downstream nexuses in this partition.

Returns
-------
Set[str]
The frozen set of string ids for all remote downstream nexuses in this partition.
"""
return self._remote_downstream_nexus_ids

@property
def remote_upstream_nexus_ids(self) -> FrozenSet[str]:
"""
Get the frozen set of ids for all remote upstream nexuses in this partition.

Returns
-------
Set[str]
The frozen set of string ids for all remote upstream nexuses in this partition.
"""
return self._remote_upstream_nexus_ids

def to_dict(self) -> Dict[str, Union[str, Number, dict, list]]:
"""
Get the instance represented as a dict (i.e., a JSON-like object).

Note that, as described in the main docstring for the class, there are extra keys in the dict/JSON currently
that don't correspond to any attributes of the instance. This is for consistency with other tools.

Returns
-------
dict
The instance as a dict
"""
return {
self._KEY_PARTITION_ID: str(self.partition_id),
self._KEY_CATCHMENT_IDS: list(self.catchment_ids),
self._KEY_NEXUS_IDS: list(self.nexus_ids),
self._KEY_REMOTE_UPSTREAM_NEXUS_IDS: list(self.remote_upstream_nexus_ids),
self._KEY_REMOTE_DOWNSTREAM_NEXUS_IDS: list(self.remote_downstream_nexus_ids)
}


class PartitionConfig(Serializable):
"""
A type to easily encapsulate the JSON object that is output from the NextGen partitioner.
"""

_KEY_PARTITIONS = 'partitions'
partitions: FrozenSet[Partition]

@classmethod
def factory_init_from_deserialized_json(cls, json_obj: dict):
try:
return PartitionConfig([Partition.factory_init_from_deserialized_json(serial_p) for serial_p in json_obj[cls._KEY_PARTITIONS]])
except:
return None
@validator("partitions")
def _sort_partitions(cls, value: FrozenSet[Partition]) -> FrozenSet[Partition]:
return frozenset(sorted(value))

class Config:
def _serialize_frozenset(value: FrozenSet[Partition]) -> List[Partition]:
return list(value)

field_serializers = {
"partitions": _serialize_frozenset
}

@classmethod
def get_serial_property_key_partitions(cls) -> str:
return cls._KEY_PARTITIONS
return "partitions"

def __init__(self, partitions: Collection[Partition]):
self._partitions = frozenset(partitions)
def __init__(self, partitions: Collection[Partition], **data):
super().__init__(partitions=partitions, **data)

def __eq__(self, other):
def __eq__(self, other: object):
if not isinstance(other, PartitionConfig):
return False
other_partitions_dict = dict()
for other_p in other._partitions:
for other_p in other.partitions:
other_partitions_dict[other_p.partition_id] = other_p

other_pids = set([p2.partition_id for p2 in other.partitions])
Expand All @@ -197,7 +149,7 @@ def __eq__(self, other):
return False
return True

def __hash__(self):
def __hash__(self) -> int:
"""
Get the unique hash for this instance.

Expand All @@ -206,22 +158,50 @@ def __hash__(self):

Returns
-------

"""
#
return hash(','.join([str(p.__hash__()) for p in sorted(self._partitions)]))

@property
def partitions(self) -> List[Partition]:
"""
Get the (sorted) list of partitions for this config.

Returns
-------
List[Partition]
The (sorted) list of partitions for this config.
"""
return sorted(self._partitions)

def to_dict(self) -> Dict[str, Union[str, Number, dict, list]]:
return {self._KEY_PARTITIONS: [p.to_dict() for p in self.partitions]}
int
Hash of instance
"""
return hash(",".join([str(p.__hash__()) for p in sorted(self.partitions)]))

def dict(
self,
*,
include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None,
exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None,
by_alias: bool = False,
skip_defaults: Optional[bool] = None,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False
) -> "DictStrAny":
# reasons why dict is overridden here:
# pydantic will serialize from inner types outward, serializing each type as a dictionary,
# list, or primitive and replacing its previous type with the new "serialized" type.
# Consequently, this means hashable container types like tuples and frozensets that contain
# values that "serialize" to a non-hashable type (non-primitive, in this case) will raise a
# `TypeError: unhashable type: 'dict'`. In the case of PartitionConfig,
# FronzenSet[Partition] "serializes" inner Partition types as dictionaries which are not
# hashable. To get around this, we will momentarily swap the `partitions` field for a
# non-hashable container type, serialize using `.dict()`, and swap back in the original
# `partitions` container.

# 1. take a reference to partitions: FrozenSet[Partition]
partitions = self.partitions

# 2. cast and set partitions to a list, a non-hashable container type
self.partitions = list(partitions)

# 3. serialize
serial = super().dict(
include=include,
exclude=exclude,
by_alias=by_alias,
skip_defaults=skip_defaults,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)

# 4. replace partitions with its hashable representation
self.partitions = partitions
return serial
Loading