Skip to content

Commit

Permalink
Richer printing for some artifact objects (flyteorg#2624)
Browse files Browse the repository at this point in the history
Signed-off-by: Yee Hing Tong <wild-endeavor@users.noreply.github.com>
Signed-off-by: mao3267 <chenvincent610@gmail.com>
  • Loading branch information
wild-endeavor authored and mao3267 committed Aug 2, 2024
1 parent e6606ff commit 20c8250
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 4 deletions.
43 changes: 42 additions & 1 deletion flytekit/core/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,24 @@ def __init__(
self.reference_artifact: Optional[Artifact] = None
self.granularity = granularity

def __rich_repr__(self):
if self.value:
if isinstance(self.value, art_id.LabelValue):
if self.value.HasField("time_value"):
yield "Time Partition", str(self.value.time_value.ToDatetime())
elif self.value.HasField("input_binding"):
yield "Time Partition (bound to)", self.value.input_binding.var
else:
yield "Time Partition", "unspecified"
else:
yield "Time Partition", "unspecified"

def _repr_html_(self):
"""
Jupyter notebook rendering.
"""
return "".join([str(x) for x in self.__rich_repr__()])

def __add__(self, other: timedelta) -> TimePartition:
tp = TimePartition(self.value, op=Op.PLUS, other=other, granularity=self.granularity)
tp.reference_artifact = self.reference_artifact
Expand All @@ -293,6 +311,15 @@ def __init__(self, value: Optional[art_id.LabelValue], name: str):
self.value = value
self.reference_artifact: Optional[Artifact] = None

def __rich_repr__(self):
yield self.name, self.value

def _repr_html_(self):
"""
Jupyter notebook rendering.
"""
return "".join([f"{x[0]}: {x[1]}" for x in self.__rich_repr__()])


class Partitions(object):
def __init__(self, partitions: Optional[typing.Mapping[str, Union[str, art_id.InputBindingData, Partition]]]):
Expand All @@ -307,6 +334,19 @@ def __init__(self, partitions: Optional[typing.Mapping[str, Union[str, art_id.In
self._partitions[k] = Partition(art_id.LabelValue(static_value=v), name=k)
self.reference_artifact: Optional[Artifact] = None

def __rich_repr__(self):
if self.partitions:
ps = [str(next(v.__rich_repr__())) for k, v in self.partitions.items()]
yield "Partitions", ", ".join(ps)
else:
yield ""

def _repr_html_(self):
"""
Jupyter notebook rendering.
"""
return ", ".join([str(x) for x in self.__rich_repr__()])

@property
def partitions(self) -> Optional[typing.Dict[str, Partition]]:
return self._partitions
Expand Down Expand Up @@ -562,7 +602,8 @@ def embed_as_query(
op: Optional[Op] = None,
) -> art_id.ArtifactQuery:
"""
This should only be called in the context of a Trigger
This should only be called in the context of a Trigger. The type of query this returns is different from the
query() function. This type of query is used to reference the triggering artifact, rather than running a query.
:param partition: Can embed a time partition
:param bind_to_time_partition: Set to true if you want to bind to a time partition
:param expr: Only valid if there's a time partition.
Expand Down
2 changes: 0 additions & 2 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2031,8 +2031,6 @@ class LiteralsResolver(collections.UserDict):
LiteralsResolver is a helper class meant primarily for use with the FlyteRemote experience or any other situation
where you might be working with LiteralMaps. This object allows the caller to specify the Python type that should
correspond to an element of the map.
TODO: Consider inheriting from collections.UserDict instead of manually having the _native_values cache
"""

def __init__(
Expand Down
28 changes: 27 additions & 1 deletion flytekit/models/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,12 +283,38 @@ def __init__(
self._enum_type = enum_type
self._union_type = union_type
self._structured_dataset_type = structured_dataset_type
self._metadata = metadata
self._structure = structure
self._structured_dataset_type = structured_dataset_type
self._metadata = metadata
self._annotation = annotation

def __rich_repr__(self):
if self.simple:
yield "Simple"
elif self.schema:
yield "Schema"
elif self.collection_type:
sub = next(self.collection_type.__rich_repr__())
yield f"List[{sub}]"
elif self.map_value_type:
sub = next(self.map_value_type.__rich_repr__())
yield f"Dict[str, {sub}]"
elif self.blob:
if self.blob.dimensionality == _types_pb2.BlobType.BlobDimensionality.SINGLE:
yield "File"
elif self.blob.dimensionality == _types_pb2.BlobType.BlobDimensionality.MULTIPART:
yield "Directory"
else:
yield "Unknown Blob Type"
elif self.enum_type:
yield "Enum"
elif self.union_type:
yield "Union"
elif self.structured_dataset_type:
yield f"StructuredDataset(format={self.structured_dataset_type.format})"
else:
yield "Unknown Type"

@property
def simple(self) -> SimpleType:
return self._simple
Expand Down
21 changes: 21 additions & 0 deletions tests/flytekit/unit/core/test_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,27 @@ def test_tp_math():
assert tp2 is not tp


def test_tp_printing():
d = datetime.datetime(2063, 4, 5, 0, 0)
pt = Timestamp()
pt.FromDatetime(d)
tp = TimePartition(value=art_id.LabelValue(time_value=pt), granularity=Granularity.HOUR)
txt = "".join([str(x) for x in tp.__rich_repr__()])
# should show something like ('Time Partition', '2063-04-05 00:00:00')
# just check that we don't accidentally fail to evaluate a generator
assert "generator" not in txt


def test_partition_printing():
a1_b = Artifact(name="my_data", partition_keys=["b"])
spec = a1_b(b="my_b_value")
ps = spec.partitions
txt = "".join([str(x) for x in ps.__rich_repr__()])
# should look something like ('Partitions', '(\'b\', static_value: "my_b_value"\n)')
# just check that we don't accidentally fail to evaluate a generator
assert "generator" not in txt


def test_lims():
# test an artifact with 11 partition keys
with pytest.raises(ValueError):
Expand Down

0 comments on commit 20c8250

Please sign in to comment.