diff --git a/CHANGELOG.md b/CHANGELOG.md index deaa2a8a..bde3385b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,12 +9,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Add support for dbt manifest file (#104) -- Adds support for referencing fields within a definition (#322) +- Fix import of pyspark for type-checking when pyspark isn't required as a module (#312) +- Adds support for referencing fields within a definition (#322) +- Add `map` and `enum` type for Avro schema import (#311) ### Fixed - Fix import of pyspark for type-checking when pyspark isn't required as a module (#312)- `datacontract import --format spark`: Import from Spark tables (#326) - Fix an issue where specifying `glue_table` as parameter did not filter the tables and instead returned all tables from `source` database (#333) + ## [0.10.9] - 2024-07-03 ### Added diff --git a/datacontract/export/avro_converter.py b/datacontract/export/avro_converter.py index 7c253d54..3715f622 100644 --- a/datacontract/export/avro_converter.py +++ b/datacontract/export/avro_converter.py @@ -65,7 +65,7 @@ def to_avro_type(field: Field, field_name: str) -> str | dict: if field.config["avroLogicalType"] in ["time-millis", "date"]: return {"type": "int", "logicalType": field.config["avroLogicalType"]} if "avroType" in field.config: - return field.config["avroLogicalType"] + return field.config["avroType"] if field.type is None: return "null" diff --git a/datacontract/imports/avro_importer.py b/datacontract/imports/avro_importer.py index a40d5353..eff8a88c 100644 --- a/datacontract/imports/avro_importer.py +++ b/datacontract/imports/avro_importer.py @@ -1,3 +1,5 @@ +from typing import Dict, List + import avro.schema from datacontract.imports.importer import Importer @@ -6,13 +8,39 @@ class AvroImporter(Importer): + """Class to import Avro Schema file""" + def import_source( self, data_contract_specification: DataContractSpecification, source: str, import_args: dict - ) -> dict: + ) -> DataContractSpecification: + """ + Import Avro schema from a source file. + + Args: + data_contract_specification: The data contract specification to update. + source: The path to the Avro schema file. + import_args: Additional import arguments. + + Returns: + The updated data contract specification. + """ return import_avro(data_contract_specification, source) def import_avro(data_contract_specification: DataContractSpecification, source: str) -> DataContractSpecification: + """ + Import an Avro schema from a file and update the data contract specification. + + Args: + data_contract_specification: The data contract specification to update. + source: The path to the Avro schema file. + + Returns: + DataContractSpecification: The updated data contract specification. + + Raises: + DataContractException: If there's an error parsing the Avro schema. + """ if data_contract_specification.models is None: data_contract_specification.models = {} @@ -45,7 +73,14 @@ def import_avro(data_contract_specification: DataContractSpecification, source: return data_contract_specification -def handle_config_avro_custom_properties(field, imported_field): +def handle_config_avro_custom_properties(field: avro.schema.Field, imported_field: Field) -> None: + """ + Handle custom Avro properties and add them to the imported field's config. + + Args: + field: The Avro field. + imported_field: The imported field to update. + """ if field.get_prop("logicalType") is not None: if imported_field.config is None: imported_field.config = {} @@ -57,7 +92,16 @@ def handle_config_avro_custom_properties(field, imported_field): imported_field.config["avroDefault"] = field.default -def import_record_fields(record_fields): +def import_record_fields(record_fields: List[avro.schema.Field]) -> Dict[str, Field]: + """ + Import Avro record fields and convert them to data contract fields. + + Args: + record_fields: List of Avro record fields. + + Returns: + A dictionary of imported fields. + """ imported_fields = {} for field in record_fields: imported_field = Field() @@ -83,6 +127,15 @@ def import_record_fields(record_fields): elif field.type.type == "array": imported_field.type = "array" imported_field.items = import_avro_array_items(field.type) + elif field.type.type == "map": + imported_field.type = "map" + imported_field.values = import_avro_map_values(field.type) + elif field.type.type == "enum": + imported_field.type = "string" + imported_field.enum = field.type.symbols + if not imported_field.config: + imported_field.config = {} + imported_field.config["avroType"] = "enum" else: # primitive type imported_field.type = map_type_from_avro(field.type.type) @@ -91,7 +144,16 @@ def import_record_fields(record_fields): return imported_fields -def import_avro_array_items(array_schema): +def import_avro_array_items(array_schema: avro.schema.ArraySchema) -> Field: + """ + Import Avro array items and convert them to a data contract field. + + Args: + array_schema: The Avro array schema. + + Returns: + Field: The imported field representing the array items. + """ items = Field() for prop in array_schema.other_props: items.__setattr__(prop, array_schema.other_props[prop]) @@ -108,7 +170,45 @@ def import_avro_array_items(array_schema): return items -def import_type_of_optional_field(field): +def import_avro_map_values(map_schema: avro.schema.MapSchema) -> Field: + """ + Import Avro map values and convert them to a data contract field. + + Args: + map_schema: The Avro map schema. + + Returns: + Field: The imported field representing the map values. + """ + values = Field() + for prop in map_schema.other_props: + values.__setattr__(prop, map_schema.other_props[prop]) + + if map_schema.values.type == "record": + values.type = "object" + values.fields = import_record_fields(map_schema.values.fields) + elif map_schema.values.type == "array": + values.type = "array" + values.items = import_avro_array_items(map_schema.values) + else: # primitive type + values.type = map_type_from_avro(map_schema.values.type) + + return values + + +def import_type_of_optional_field(field: avro.schema.Field) -> str: + """ + Determine the type of optional field in an Avro union. + + Args: + field: The Avro field with a union type. + + Returns: + str: The mapped type of the non-null field in the union. + + Raises: + DataContractException: If no non-null type is found in the union. + """ for field_type in field.type.schemas: if field_type.type != "null": return map_type_from_avro(field_type.type) @@ -121,21 +221,51 @@ def import_type_of_optional_field(field): ) -def get_record_from_union_field(field): +def get_record_from_union_field(field: avro.schema.Field) -> avro.schema.RecordSchema | None: + """ + Get the record schema from a union field. + + Args: + field: The Avro field with a union type. + + Returns: + The record schema if found, None otherwise. + """ for field_type in field.type.schemas: if field_type.type == "record": return field_type return None -def get_array_from_union_field(field): +def get_array_from_union_field(field: avro.schema.Field) -> avro.schema.ArraySchema | None: + """ + Get the array schema from a union field. + + Args: + field: The Avro field with a union type. + + Returns: + The array schema if found, None otherwise. + """ for field_type in field.type.schemas: if field_type.type == "array": return field_type return None -def map_type_from_avro(avro_type_str: str): +def map_type_from_avro(avro_type_str: str) -> str: + """ + Map Avro type strings to data contract type strings. + + Args: + avro_type_str (str): The Avro type string. + + Returns: + str: The corresponding data contract type string. + + Raises: + DataContractException: If the Avro type is unsupported. + """ # TODO: ambiguous mapping in the export if avro_type_str == "null": return "null" @@ -155,6 +285,10 @@ def map_type_from_avro(avro_type_str: str): return "record" elif avro_type_str == "array": return "array" + elif avro_type_str == "map": + return "map" + elif avro_type_str == "enum": + return "string" else: raise DataContractException( type="schema", diff --git a/datacontract/lint/schema.py b/datacontract/lint/schema.py index ab64fc7b..a73258e4 100644 --- a/datacontract/lint/schema.py +++ b/datacontract/lint/schema.py @@ -1,18 +1,37 @@ import json import os +from typing import Dict, Any import requests from datacontract.model.exceptions import DataContractException -def fetch_schema(location: str = None): +def fetch_schema(location: str = None) -> Dict[str, Any]: + """ + Fetch and return a JSON schema from a given location. + + This function retrieves a JSON schema either from a URL or a local file path. + If no location is provided, it defaults to the DataContract schema URL. + + Args: + location: The URL or file path of the schema. + + Returns: + The JSON schema as a dictionary. + + Raises: + DataContractException: If the specified local file does not exist. + requests.RequestException: If there's an error fetching the schema from a URL. + json.JSONDecodeError: If there's an error decoding the JSON schema. + + """ if location is None: location = "https://datacontract.com/datacontract.schema.json" if location.startswith("http://") or location.startswith("https://"): response = requests.get(location) - return response.json() + schema = response.json() else: if not os.path.exists(location): raise DataContractException( @@ -23,5 +42,6 @@ def fetch_schema(location: str = None): result="error", ) with open(location, "r") as file: - file_content = file.read() - return json.loads(file_content) + schema = json.load(file) + + return schema diff --git a/datacontract/model/data_contract_specification.py b/datacontract/model/data_contract_specification.py index 90c9b7ee..d7144060 100644 --- a/datacontract/model/data_contract_specification.py +++ b/datacontract/model/data_contract_specification.py @@ -108,6 +108,8 @@ class Field(pyd.BaseModel): links: Dict[str, str] = {} fields: Dict[str, "Field"] = {} items: "Field" = None + keys: "Field" = None + values: "Field" = None precision: int = None scale: int = None example: str = None diff --git a/tests/fixtures/avro/data/orders.avsc b/tests/fixtures/avro/data/orders.avsc index fdfc3527..ac074864 100644 --- a/tests/fixtures/avro/data/orders.avsc +++ b/tests/fixtures/avro/data/orders.avsc @@ -58,10 +58,38 @@ "name": "address", "type": "record" } + }, + { + "name": "status", + "doc": "order status", + "type": { + "type": "enum", + "name": "Status", + "symbols": ["PLACED", "SHIPPED", "DELIVERED", "CANCELLED"] + } + }, + { + "name": "metadata", + "doc": "Additional metadata about the order", + "type": { + "type": "map", + "values": { + "type": "record", + "name": "MetadataValue", + "fields": [ + {"name": "value", "type": "string"}, + {"name": "type", "type": {"type": "enum", "name": "MetadataType", "symbols": ["STRING", "LONG", "DOUBLE"]}}, + {"name": "timestamp", "type": "long"}, + {"name": "source", "type": "string"} + ] + }, + "default": {} + } } + ], "name": "orders", "doc": "My Model", "type": "record", "namespace": "com.sample.schema" - } \ No newline at end of file +} diff --git a/tests/test_schema.py b/tests/test_export_complex_data_contract.py similarity index 57% rename from tests/test_schema.py rename to tests/test_export_complex_data_contract.py index b3bc338f..a6ed9306 100644 --- a/tests/test_schema.py +++ b/tests/test_export_complex_data_contract.py @@ -5,13 +5,13 @@ logging.basicConfig(level=logging.INFO, force=True) -def test_schema(): +def test_export_complex_data_contract(): """ - A schema in a data contract would not do anything, but should also raise no errors. + Use a complex data contract, and smoke test that it can be exported to various formats without exception. """ data_contract = DataContract( data_contract_str=""" -dataContractSpecification: 0.9.2 +dataContractSpecification: 0.9.3 id: urn:datacontract:checkout:orders-latest info: title: Orders Latest @@ -31,6 +31,43 @@ def test_schema(): order_total: type: integer description: Total amount of the order in the smallest monetary unit (e.g., cents). + status: + type: string + required: true + description: order status + enum: + - PLACED + - SHIPPED + - DELIVERED + - CANCELLED + config: + avroType: enum + metadata: + type: map + required: true + description: Additional metadata about the order + values: + type: object + fields: + value: + type: string + required: true + type: + type: string + required: true + enum: + - STRING + - LONG + - DOUBLE + config: + avroType: enum + timestamp: + type: long + required: true + source: + type: string + required: true + default: {} line_items: type: object fields: @@ -47,6 +84,7 @@ def test_schema(): data_contract.lint() data_contract.test() + data_contract.export(export_format="avro", model="orders") data_contract.export(export_format="odcs") data_contract.export(export_format="dbt") data_contract.export(export_format="dbt-sources") diff --git a/tests/test_import_avro.py b/tests/test_import_avro.py index 30fd0234..33c84113 100644 --- a/tests/test_import_avro.py +++ b/tests/test_import_avro.py @@ -76,6 +76,43 @@ def test_import_avro_schema(): zipcode: type: long required: true + status: + type: string + required: true + description: order status + enum: + - PLACED + - SHIPPED + - DELIVERED + - CANCELLED + config: + avroType: enum + metadata: + type: map + required: true + description: Additional metadata about the order + values: + type: object + fields: + value: + type: string + required: true + type: + type: string + required: true + enum: + - STRING + - LONG + - DOUBLE + config: + avroType: enum + timestamp: + type: long + required: true + source: + type: string + required: true + default: {} """ print("Result:\n", result.to_yaml()) assert yaml.safe_load(result.to_yaml()) == yaml.safe_load(expected)