Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
Fokko committed Jul 6, 2023
1 parent 5701bf2 commit 6d730a0
Show file tree
Hide file tree
Showing 17 changed files with 279 additions and 214 deletions.
3 changes: 1 addition & 2 deletions python/pyiceberg/avro/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
from pyiceberg.avro.decoder import BinaryDecoder
from pyiceberg.typedef import StructProtocol
from pyiceberg.types import StructType
from pyiceberg.utils.singleton import Singleton


def _skip_map_array(decoder: BinaryDecoder, skip_entry: Callable[[], None]) -> None:
Expand Down Expand Up @@ -82,7 +81,7 @@ def _skip_map_array(decoder: BinaryDecoder, skip_entry: Callable[[], None]) -> N
block_count = decoder.read_int()


class Reader(Singleton):
class Reader:
@abstractmethod
def read(self, decoder: BinaryDecoder) -> Any:
...
Expand Down
17 changes: 14 additions & 3 deletions python/pyiceberg/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,23 @@
# under the License.
from functools import cached_property
from typing import (
Annotated,
Any,
Dict,
List,
Optional,
Tuple,
)

from pydantic import Field, SerializeAsAny
from pydantic import (
BeforeValidator,
Field,
PlainSerializer,
WithJsonSchema,
)

from pyiceberg.schema import Schema
from pyiceberg.transforms import Transform
from pyiceberg.transforms import Transform, _deserialize_transform
from pyiceberg.typedef import IcebergBaseModel
from pyiceberg.types import NestedField, StructType

Expand All @@ -46,7 +52,12 @@ class PartitionField(IcebergBaseModel):

source_id: int = Field(alias="source-id")
field_id: int = Field(alias="field-id")
transform: Transform[Any, Any] = Field()
transform: Annotated[
Transform,
BeforeValidator(_deserialize_transform),
PlainSerializer(lambda c: str(c), return_type=str),
WithJsonSchema({"type": "string"}, mode="serialization"),
] = Field()
name: str = Field()

def __init__(
Expand Down
8 changes: 6 additions & 2 deletions python/pyiceberg/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from pyiceberg.io import InputFile, InputStream, OutputFile
from pyiceberg.table.metadata import TableMetadata, TableMetadataUtil
from pyiceberg.typedef import IcebergBaseModel, IcebergRootModel

GZIP = "gzip"

Expand Down Expand Up @@ -86,12 +87,15 @@ def table_metadata(
encoding (default "utf-8"): The byte encoder to use for the reader.
compression: Optional compression method
"""

class VO(IcebergRootModel):
root: TableMetadata
with compression.stream_decompressor(byte_stream) as byte_stream:
reader = codecs.getreader(encoding)
json_bytes = reader(byte_stream)
metadata = json.load(json_bytes)
return VO.model_validate_json(json_bytes.read())


return TableMetadataUtil.parse_obj(metadata)


class FromInputFile:
Expand Down
32 changes: 19 additions & 13 deletions python/pyiceberg/table/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
List,
Literal,
Optional,
Union,
Union, Annotated,
)

from pydantic import Field, model_validator
Expand Down Expand Up @@ -80,7 +80,7 @@ def check_partition_specs(values: "TableMetadataV2") -> "TableMetadataV2":
if spec.spec_id == default_spec_id:
return values

raise ValidationError(f"default-spec-id {default_spec_id} can't be found")
# raise ValidationError(f"default-spec-id {default_spec_id} can't be found")


def check_sort_orders(values: "TableMetadataV2") -> "TableMetadataV2":
Expand Down Expand Up @@ -150,7 +150,7 @@ def construct_refs(cls, data: Dict[str, Any]) -> Dict[str, Any]:
default_spec_id: int = Field(alias="default-spec-id", default=INITIAL_SPEC_ID)
"""ID of the “current” spec that writers should use by default."""

last_partition_id: Optional[int] = Field(alias="last-partition-id")
last_partition_id: Optional[int] = Field(alias="last-partition-id", default=None)
"""An integer; the highest assigned partition field ID across all
partition specs for the table. This is used to ensure partition fields
are always assigned an unused ID when evolving specs."""
Expand Down Expand Up @@ -202,6 +202,8 @@ def construct_refs(cls, data: Dict[str, Any]) -> Dict[str, Any]:
current-snapshot-id even if the refs map is null."""




class TableMetadataV1(TableMetadataCommonFields, IcebergBaseModel):
"""Represents version 1 of the Table Metadata.
Expand All @@ -216,7 +218,7 @@ class TableMetadataV1(TableMetadataCommonFields, IcebergBaseModel):
# because bumping the version should be an explicit operation that is up
# to the owner of the table.

@model_validator(mode="after")
@model_validator(mode="before")
def set_v2_compatible_defaults(cls, data: Dict[str, Any]) -> Dict[str, Any]:
"""Sets default values to be compatible with the format v2.
Expand All @@ -232,7 +234,7 @@ def set_v2_compatible_defaults(cls, data: Dict[str, Any]) -> Dict[str, Any]:

return data

@model_validator(mode="after")
@model_validator(mode="before")
def construct_schemas(cls, data: Dict[str, Any]) -> Dict[str, Any]:
"""Converts the schema into schemas.
Expand All @@ -247,13 +249,11 @@ def construct_schemas(cls, data: Dict[str, Any]) -> Dict[str, Any]:
The TableMetadata with the schemas set, if not provided.
"""
if not data.get("schemas"):
schema = data["schema_"]
schema = data["schema"]
data["schemas"] = [schema]
else:
check_schemas(data)
return data

@model_validator(mode="after")
@model_validator(mode="before")
def construct_partition_specs(cls, data: Dict[str, Any]) -> Dict[str, Any]:
"""Converts the partition_spec into partition_specs.
Expand All @@ -268,10 +268,13 @@ def construct_partition_specs(cls, data: Dict[str, Any]) -> Dict[str, Any]:
The TableMetadata with the partition_specs set, if not provided.
"""
if not data.get(PARTITION_SPECS):
fields = data[PARTITION_SPEC]
migrated_spec = PartitionSpec(*fields)
data[PARTITION_SPECS] = [migrated_spec]
data[DEFAULT_SPEC_ID] = migrated_spec.spec_id
if PARTITION_SPEC in data:
fields = data[PARTITION_SPEC]
migrated_spec = PartitionSpec(*fields)
data[PARTITION_SPECS] = [migrated_spec]
data[DEFAULT_SPEC_ID] = migrated_spec.spec_id
else:
data[PARTITION_SPECS] = []
else:
check_partition_specs(data)

Expand Down Expand Up @@ -398,3 +401,6 @@ def parse_obj(data: Dict[str, Any]) -> TableMetadata:
return TableMetadataV2(**data)
else:
raise ValidationError(f"Unknown format version: {format_version}")


TableMetadata = Annotated[Union[TableMetadataV1, TableMetadataV2], Field(discriminator='format_version')]
16 changes: 8 additions & 8 deletions python/pyiceberg/table/snapshots.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,14 @@ class Summary(IcebergBaseModel):
like snapshot expiration, to skip processing certain snapshots.
"""

root: Dict[str, Union[str, Operation]]
operation: Operation = Field()
_additional_properties: Dict[str, str] = PrivateAttr()

def __init__(
self, operation: Optional[Operation] = None, root: Optional[Dict[str, Union[str, Operation]]] = None, **data: Any
self, operation: Optional[Operation] = None, **data: Any
) -> None:
super().__init__(root={"operation": operation, **data} if not root else root)
self._additional_properties = {
k: v for k, v in self.root.items() if k != OPERATION # type: ignore # We know that they are all string, and we don't want to check
}
super().__init__(operation=operation, **data)
self._additional_properties = data

@property
def operation(self) -> Operation:
Expand All @@ -94,8 +92,10 @@ class Snapshot(IcebergBaseModel):
parent_snapshot_id: Optional[int] = Field(alias="parent-snapshot-id", default=None)
sequence_number: Optional[int] = Field(alias="sequence-number", default=None)
timestamp_ms: int = Field(alias="timestamp-ms")
manifest_list: Optional[str] = Field(alias="manifest-list", description="Location of the snapshot's manifest list file")
summary: Optional[Summary] = Field()
manifest_list: Optional[str] = Field(
alias="manifest-list", description="Location of the snapshot's manifest list file", default=None
)
summary: Optional[Summary] = Field(default=None)
schema_id: Optional[int] = Field(alias="schema-id", default=None)

def __str__(self) -> str:
Expand Down
50 changes: 42 additions & 8 deletions python/pyiceberg/table/sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,28 @@
# under the License.
# pylint: disable=keyword-arg-before-vararg
from enum import Enum
from typing import Any, Dict, List

from pydantic import Field, model_validator
from typing import (
Annotated,
Any,
Callable,
Dict,
List,
Optional,
Union,
)

from pydantic import (
BeforeValidator,
Field,
PlainSerializer,
WithJsonSchema,
model_validator,
)

from pyiceberg.schema import Schema
from pyiceberg.transforms import IdentityTransform, Transform, Transforms
from pyiceberg.transforms import IdentityTransform, Transform, _deserialize_transform
from pyiceberg.typedef import IcebergBaseModel
from pyiceberg.types import IcebergType


class SortDirection(Enum):
Expand Down Expand Up @@ -62,19 +77,38 @@ class SortField(IcebergBaseModel):
null_order (NullOrder): Null order that describes the order of null values when sorted. Can only be either nulls-first or nulls-last.
"""

def __init__(self, *args, **data):
super().__init__(*args, **data)
def __init__(
self,
source_id: Optional[int] = None,
transform: Optional[Union[Transform[Any, Any], Callable[[IcebergType], Transform[Any, Any]]]] = None,
direction: Optional[SortDirection] = None,
null_order: Optional[NullOrder] = None,
**data: Any,
):
if source_id is not None:
data["source-id"] = source_id
if transform is not None:
data["transform"] = transform
if direction is not None:
data["direction"] = direction
if null_order is not None:
data["null-order"] = null_order
super().__init__(**data)

@model_validator(mode="before")
def set_null_order(cls, values: Dict[str, Any]) -> Dict[str, Any]:
values["direction"] = values["direction"] if values.get("direction") else SortDirection.ASC
if not values.get("null-order"):
values["null-order"] = NullOrder.NULLS_FIRST if values["direction"] == SortDirection.ASC else NullOrder.NULLS_LAST
values["transform"] = Transform.validate(values["transform"])
return values

source_id: int = Field(alias="source-id")
transform: Transforms = Field()
transform: Annotated[
Transform,
BeforeValidator(_deserialize_transform),
PlainSerializer(lambda c: str(c), return_type=str),
WithJsonSchema({"type": "string"}, mode="serialization"),
] = Field()
direction: SortDirection = Field()
null_order: NullOrder = Field(alias="null-order")

Expand Down
Loading

0 comments on commit 6d730a0

Please sign in to comment.