Skip to content

Commit

Permalink
Fix/311: Add enum and map type for avro schema (#324)
Browse files Browse the repository at this point in the history
* Add enum and map type for avro schema

* Add use of enum via avroType config

---------

Co-authored-by: aniket-kapdule <aniket.kapdule@deliveryhero.com>
Co-authored-by: jochen <jochen.christ@innoq.com>
  • Loading branch information
3 people authored Jul 18, 2024
1 parent bb9c8aa commit f817893
Show file tree
Hide file tree
Showing 8 changed files with 280 additions and 18 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added
- Add support for dbt manifest file (#104)
- Adds support for referencing fields within a definition (#322)
- Fix import of pyspark for type-checking when pyspark isn't required as a module (#312)
- Adds support for referencing fields within a definition (#322)
- Add `map` and `enum` type for Avro schema import (#311)

### Fixed
- Fix import of pyspark for type-checking when pyspark isn't required as a module (#312)- `datacontract import --format spark`: Import from Spark tables (#326)
- Fix an issue where specifying `glue_table` as parameter did not filter the tables and instead returned all tables from `source` database (#333)


## [0.10.9] - 2024-07-03

### Added
Expand Down
2 changes: 1 addition & 1 deletion datacontract/export/avro_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def to_avro_type(field: Field, field_name: str) -> str | dict:
if field.config["avroLogicalType"] in ["time-millis", "date"]:
return {"type": "int", "logicalType": field.config["avroLogicalType"]}
if "avroType" in field.config:
return field.config["avroLogicalType"]
return field.config["avroType"]

if field.type is None:
return "null"
Expand Down
150 changes: 142 additions & 8 deletions datacontract/imports/avro_importer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Dict, List

import avro.schema

from datacontract.imports.importer import Importer
Expand All @@ -6,13 +8,39 @@


class AvroImporter(Importer):
"""Class to import Avro Schema file"""

def import_source(
self, data_contract_specification: DataContractSpecification, source: str, import_args: dict
) -> dict:
) -> DataContractSpecification:
"""
Import Avro schema from a source file.
Args:
data_contract_specification: The data contract specification to update.
source: The path to the Avro schema file.
import_args: Additional import arguments.
Returns:
The updated data contract specification.
"""
return import_avro(data_contract_specification, source)


def import_avro(data_contract_specification: DataContractSpecification, source: str) -> DataContractSpecification:
"""
Import an Avro schema from a file and update the data contract specification.
Args:
data_contract_specification: The data contract specification to update.
source: The path to the Avro schema file.
Returns:
DataContractSpecification: The updated data contract specification.
Raises:
DataContractException: If there's an error parsing the Avro schema.
"""
if data_contract_specification.models is None:
data_contract_specification.models = {}

Expand Down Expand Up @@ -45,7 +73,14 @@ def import_avro(data_contract_specification: DataContractSpecification, source:
return data_contract_specification


def handle_config_avro_custom_properties(field, imported_field):
def handle_config_avro_custom_properties(field: avro.schema.Field, imported_field: Field) -> None:
"""
Handle custom Avro properties and add them to the imported field's config.
Args:
field: The Avro field.
imported_field: The imported field to update.
"""
if field.get_prop("logicalType") is not None:
if imported_field.config is None:
imported_field.config = {}
Expand All @@ -57,7 +92,16 @@ def handle_config_avro_custom_properties(field, imported_field):
imported_field.config["avroDefault"] = field.default


def import_record_fields(record_fields):
def import_record_fields(record_fields: List[avro.schema.Field]) -> Dict[str, Field]:
"""
Import Avro record fields and convert them to data contract fields.
Args:
record_fields: List of Avro record fields.
Returns:
A dictionary of imported fields.
"""
imported_fields = {}
for field in record_fields:
imported_field = Field()
Expand All @@ -83,6 +127,15 @@ def import_record_fields(record_fields):
elif field.type.type == "array":
imported_field.type = "array"
imported_field.items = import_avro_array_items(field.type)
elif field.type.type == "map":
imported_field.type = "map"
imported_field.values = import_avro_map_values(field.type)
elif field.type.type == "enum":
imported_field.type = "string"
imported_field.enum = field.type.symbols
if not imported_field.config:
imported_field.config = {}
imported_field.config["avroType"] = "enum"
else: # primitive type
imported_field.type = map_type_from_avro(field.type.type)

Expand All @@ -91,7 +144,16 @@ def import_record_fields(record_fields):
return imported_fields


def import_avro_array_items(array_schema):
def import_avro_array_items(array_schema: avro.schema.ArraySchema) -> Field:
"""
Import Avro array items and convert them to a data contract field.
Args:
array_schema: The Avro array schema.
Returns:
Field: The imported field representing the array items.
"""
items = Field()
for prop in array_schema.other_props:
items.__setattr__(prop, array_schema.other_props[prop])
Expand All @@ -108,7 +170,45 @@ def import_avro_array_items(array_schema):
return items


def import_type_of_optional_field(field):
def import_avro_map_values(map_schema: avro.schema.MapSchema) -> Field:
"""
Import Avro map values and convert them to a data contract field.
Args:
map_schema: The Avro map schema.
Returns:
Field: The imported field representing the map values.
"""
values = Field()
for prop in map_schema.other_props:
values.__setattr__(prop, map_schema.other_props[prop])

if map_schema.values.type == "record":
values.type = "object"
values.fields = import_record_fields(map_schema.values.fields)
elif map_schema.values.type == "array":
values.type = "array"
values.items = import_avro_array_items(map_schema.values)
else: # primitive type
values.type = map_type_from_avro(map_schema.values.type)

return values


def import_type_of_optional_field(field: avro.schema.Field) -> str:
"""
Determine the type of optional field in an Avro union.
Args:
field: The Avro field with a union type.
Returns:
str: The mapped type of the non-null field in the union.
Raises:
DataContractException: If no non-null type is found in the union.
"""
for field_type in field.type.schemas:
if field_type.type != "null":
return map_type_from_avro(field_type.type)
Expand All @@ -121,21 +221,51 @@ def import_type_of_optional_field(field):
)


def get_record_from_union_field(field):
def get_record_from_union_field(field: avro.schema.Field) -> avro.schema.RecordSchema | None:
"""
Get the record schema from a union field.
Args:
field: The Avro field with a union type.
Returns:
The record schema if found, None otherwise.
"""
for field_type in field.type.schemas:
if field_type.type == "record":
return field_type
return None


def get_array_from_union_field(field):
def get_array_from_union_field(field: avro.schema.Field) -> avro.schema.ArraySchema | None:
"""
Get the array schema from a union field.
Args:
field: The Avro field with a union type.
Returns:
The array schema if found, None otherwise.
"""
for field_type in field.type.schemas:
if field_type.type == "array":
return field_type
return None


def map_type_from_avro(avro_type_str: str):
def map_type_from_avro(avro_type_str: str) -> str:
"""
Map Avro type strings to data contract type strings.
Args:
avro_type_str (str): The Avro type string.
Returns:
str: The corresponding data contract type string.
Raises:
DataContractException: If the Avro type is unsupported.
"""
# TODO: ambiguous mapping in the export
if avro_type_str == "null":
return "null"
Expand All @@ -155,6 +285,10 @@ def map_type_from_avro(avro_type_str: str):
return "record"
elif avro_type_str == "array":
return "array"
elif avro_type_str == "map":
return "map"
elif avro_type_str == "enum":
return "string"
else:
raise DataContractException(
type="schema",
Expand Down
28 changes: 24 additions & 4 deletions datacontract/lint/schema.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,37 @@
import json
import os
from typing import Dict, Any

import requests

from datacontract.model.exceptions import DataContractException


def fetch_schema(location: str = None):
def fetch_schema(location: str = None) -> Dict[str, Any]:
"""
Fetch and return a JSON schema from a given location.
This function retrieves a JSON schema either from a URL or a local file path.
If no location is provided, it defaults to the DataContract schema URL.
Args:
location: The URL or file path of the schema.
Returns:
The JSON schema as a dictionary.
Raises:
DataContractException: If the specified local file does not exist.
requests.RequestException: If there's an error fetching the schema from a URL.
json.JSONDecodeError: If there's an error decoding the JSON schema.
"""
if location is None:
location = "https://datacontract.com/datacontract.schema.json"

if location.startswith("http://") or location.startswith("https://"):
response = requests.get(location)
return response.json()
schema = response.json()
else:
if not os.path.exists(location):
raise DataContractException(
Expand All @@ -23,5 +42,6 @@ def fetch_schema(location: str = None):
result="error",
)
with open(location, "r") as file:
file_content = file.read()
return json.loads(file_content)
schema = json.load(file)

return schema
2 changes: 2 additions & 0 deletions datacontract/model/data_contract_specification.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ class Field(pyd.BaseModel):
links: Dict[str, str] = {}
fields: Dict[str, "Field"] = {}
items: "Field" = None
keys: "Field" = None
values: "Field" = None
precision: int = None
scale: int = None
example: str = None
Expand Down
30 changes: 29 additions & 1 deletion tests/fixtures/avro/data/orders.avsc
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,38 @@
"name": "address",
"type": "record"
}
},
{
"name": "status",
"doc": "order status",
"type": {
"type": "enum",
"name": "Status",
"symbols": ["PLACED", "SHIPPED", "DELIVERED", "CANCELLED"]
}
},
{
"name": "metadata",
"doc": "Additional metadata about the order",
"type": {
"type": "map",
"values": {
"type": "record",
"name": "MetadataValue",
"fields": [
{"name": "value", "type": "string"},
{"name": "type", "type": {"type": "enum", "name": "MetadataType", "symbols": ["STRING", "LONG", "DOUBLE"]}},
{"name": "timestamp", "type": "long"},
{"name": "source", "type": "string"}
]
},
"default": {}
}
}

],
"name": "orders",
"doc": "My Model",
"type": "record",
"namespace": "com.sample.schema"
}
}
Loading

0 comments on commit f817893

Please sign in to comment.