Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Richer printing for some artifact objects #2624

Merged
merged 1 commit into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 42 additions & 1 deletion flytekit/core/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,24 @@ def __init__(
self.reference_artifact: Optional[Artifact] = None
self.granularity = granularity

def __rich_repr__(self):
thomasjpfan marked this conversation as resolved.
Show resolved Hide resolved
if self.value:
if isinstance(self.value, art_id.LabelValue):
if self.value.HasField("time_value"):
yield "Time Partition", str(self.value.time_value.ToDatetime())
elif self.value.HasField("input_binding"):
yield "Time Partition (bound to)", self.value.input_binding.var
else:
yield "Time Partition", "unspecified"
else:
yield "Time Partition", "unspecified"

def _repr_html_(self):
"""
Jupyter notebook rendering.
"""
return "".join([str(x) for x in self.__rich_repr__()])

def __add__(self, other: timedelta) -> TimePartition:
tp = TimePartition(self.value, op=Op.PLUS, other=other, granularity=self.granularity)
tp.reference_artifact = self.reference_artifact
Expand All @@ -293,6 +311,15 @@ def __init__(self, value: Optional[art_id.LabelValue], name: str):
self.value = value
self.reference_artifact: Optional[Artifact] = None

def __rich_repr__(self):
yield self.name, self.value

def _repr_html_(self):
"""
Jupyter notebook rendering.
"""
return "".join([f"{x[0]}: {x[1]}" for x in self.__rich_repr__()])


class Partitions(object):
def __init__(self, partitions: Optional[typing.Mapping[str, Union[str, art_id.InputBindingData, Partition]]]):
Expand All @@ -307,6 +334,19 @@ def __init__(self, partitions: Optional[typing.Mapping[str, Union[str, art_id.In
self._partitions[k] = Partition(art_id.LabelValue(static_value=v), name=k)
self.reference_artifact: Optional[Artifact] = None

def __rich_repr__(self):
if self.partitions:
ps = [str(next(v.__rich_repr__())) for k, v in self.partitions.items()]
yield "Partitions", ", ".join(ps)
else:
yield ""

def _repr_html_(self):
"""
Jupyter notebook rendering.
"""
return ", ".join([str(x) for x in self.__rich_repr__()])

@property
def partitions(self) -> Optional[typing.Dict[str, Partition]]:
return self._partitions
Expand Down Expand Up @@ -562,7 +602,8 @@ def embed_as_query(
op: Optional[Op] = None,
) -> art_id.ArtifactQuery:
"""
This should only be called in the context of a Trigger
This should only be called in the context of a Trigger. The type of query this returns is different from the
query() function. This type of query is used to reference the triggering artifact, rather than running a query.
:param partition: Can embed a time partition
:param bind_to_time_partition: Set to true if you want to bind to a time partition
:param expr: Only valid if there's a time partition.
Expand Down
2 changes: 0 additions & 2 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2023,8 +2023,6 @@ class LiteralsResolver(collections.UserDict):
LiteralsResolver is a helper class meant primarily for use with the FlyteRemote experience or any other situation
where you might be working with LiteralMaps. This object allows the caller to specify the Python type that should
correspond to an element of the map.

TODO: Consider inheriting from collections.UserDict instead of manually having the _native_values cache
"""

def __init__(
Expand Down
28 changes: 27 additions & 1 deletion flytekit/models/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,12 +283,38 @@ def __init__(
self._enum_type = enum_type
self._union_type = union_type
self._structured_dataset_type = structured_dataset_type
self._metadata = metadata
self._structure = structure
self._structured_dataset_type = structured_dataset_type
self._metadata = metadata
self._annotation = annotation

def __rich_repr__(self):
if self.simple:
yield "Simple"
elif self.schema:
yield "Schema"
elif self.collection_type:
sub = next(self.collection_type.__rich_repr__())
yield f"List[{sub}]"
elif self.map_value_type:
sub = next(self.map_value_type.__rich_repr__())
yield f"Dict[str, {sub}]"
elif self.blob:
if self.blob.dimensionality == _types_pb2.BlobType.BlobDimensionality.SINGLE:
yield "File"
elif self.blob.dimensionality == _types_pb2.BlobType.BlobDimensionality.MULTIPART:
yield "Directory"
else:
yield "Unknown Blob Type"
elif self.enum_type:
yield "Enum"
elif self.union_type:
yield "Union"
elif self.structured_dataset_type:
yield f"StructuredDataset(format={self.structured_dataset_type.format})"
else:
yield "Unknown Type"

@property
def simple(self) -> SimpleType:
return self._simple
Expand Down
21 changes: 21 additions & 0 deletions tests/flytekit/unit/core/test_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,27 @@ def test_tp_math():
assert tp2 is not tp


def test_tp_printing():
d = datetime.datetime(2063, 4, 5, 0, 0)
pt = Timestamp()
pt.FromDatetime(d)
tp = TimePartition(value=art_id.LabelValue(time_value=pt), granularity=Granularity.HOUR)
txt = "".join([str(x) for x in tp.__rich_repr__()])
# should show something like ('Time Partition', '2063-04-05 00:00:00')
# just check that we don't accidentally fail to evaluate a generator
assert "generator" not in txt


def test_partition_printing():
a1_b = Artifact(name="my_data", partition_keys=["b"])
spec = a1_b(b="my_b_value")
ps = spec.partitions
txt = "".join([str(x) for x in ps.__rich_repr__()])
# should look something like ('Partitions', '(\'b\', static_value: "my_b_value"\n)')
# just check that we don't accidentally fail to evaluate a generator
assert "generator" not in txt


def test_lims():
# test an artifact with 11 partition keys
with pytest.raises(ValueError):
Expand Down
Loading