Skip to content

Commit 95cb644

Browse files
authored
Fix handling of named Avro schemas (#1928)
1 parent 790bc87 commit 95cb644

File tree

2 files changed

+56
-7
lines changed

2 files changed

+56
-7
lines changed

src/confluent_kafka/schema_registry/avro.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,17 @@
1818
import decimal
1919
import re
2020
from collections import defaultdict
21+
from copy import deepcopy
2122
from io import BytesIO
2223
from json import loads
2324
from struct import pack, unpack
2425
from typing import Dict, Union, Optional, Set, Callable
2526

26-
from fastavro import (parse_schema,
27-
schemaless_reader,
27+
from fastavro import (schemaless_reader,
2828
schemaless_writer,
29+
repository,
2930
validate)
31+
from fastavro.schema import load_schema
3032

3133
from . import (_MAGIC_BYTE,
3234
Schema,
@@ -104,7 +106,8 @@ def _resolve_named_schema(
104106
for ref in schema.references:
105107
referenced_schema = schema_registry_client.get_version(ref.subject, ref.version, True)
106108
ref_named_schemas = _resolve_named_schema(referenced_schema.schema, schema_registry_client)
107-
parsed_schema = parse_schema(loads(referenced_schema.schema.schema_str), named_schemas=ref_named_schemas)
109+
parsed_schema = parse_schema_with_repo(
110+
referenced_schema.schema.schema_str, named_schemas=ref_named_schemas)
108111
named_schemas.update(ref_named_schemas)
109112
named_schemas[ref.name] = parsed_schema
110113
return named_schemas
@@ -378,8 +381,8 @@ def _get_parsed_schema(self, schema: Schema) -> AvroSchema:
378381

379382
named_schemas = _resolve_named_schema(schema, self._registry)
380383
prepared_schema = _schema_loads(schema.schema_str)
381-
parsed_schema = parse_schema(
382-
loads(prepared_schema.schema_str), named_schemas=named_schemas, expand=True)
384+
parsed_schema = parse_schema_with_repo(
385+
prepared_schema.schema_str, named_schemas=named_schemas)
383386

384387
self._parsed_schemas.set(schema, parsed_schema)
385388
return parsed_schema
@@ -606,13 +609,28 @@ def _get_parsed_schema(self, schema: Schema) -> AvroSchema:
606609

607610
named_schemas = _resolve_named_schema(schema, self._registry)
608611
prepared_schema = _schema_loads(schema.schema_str)
609-
parsed_schema = parse_schema(
610-
loads(prepared_schema.schema_str), named_schemas=named_schemas, expand=True)
612+
parsed_schema = parse_schema_with_repo(
613+
prepared_schema.schema_str, named_schemas=named_schemas)
611614

612615
self._parsed_schemas.set(schema, parsed_schema)
613616
return parsed_schema
614617

615618

619+
class LocalSchemaRepository(repository.AbstractSchemaRepository):
620+
def __init__(self, schemas):
621+
self.schemas = schemas
622+
623+
def load(self, subject):
624+
return self.schemas.get(subject)
625+
626+
627+
def parse_schema_with_repo(schema_str: str, named_schemas: Dict[str, AvroSchema]) -> AvroSchema:
628+
copy = deepcopy(named_schemas)
629+
copy["$root"] = loads(schema_str)
630+
repo = LocalSchemaRepository(copy)
631+
return load_schema("$root", repo=repo)
632+
633+
616634
def transform(
617635
ctx: RuleContext, schema: AvroSchema, message: AvroMessage,
618636
field_transform: FieldTransform

tests/schema_registry/test_avro_serdes.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,37 @@ def test_avro_serialize_references():
216216
assert obj == obj2
217217

218218

219+
def test_avro_serialize_union():
220+
conf = {'url': _BASE_URL}
221+
client = SchemaRegistryClient.new_client(conf)
222+
ser_conf = {'auto.register.schemas': False, 'use.latest.version': True}
223+
224+
obj = {
225+
'First': {'stringField': 'hi'},
226+
'Second': {'stringField': 'hi'},
227+
}
228+
schema = ['null', {
229+
'type': 'record',
230+
'name': 'A',
231+
'namespace': 'test',
232+
'fields': [
233+
{'name': 'First', 'type': {'type': 'record', 'name': 'B', 'fields': [
234+
{'name': 'stringField', 'type': 'string'},
235+
]}},
236+
{'name': 'Second', 'type': 'B'}
237+
]
238+
}]
239+
client.register_schema(_SUBJECT, Schema(json.dumps(schema), 'AVRO'))
240+
241+
ser = AvroSerializer(client, schema_str=None, conf=ser_conf)
242+
ser_ctx = SerializationContext(_TOPIC, MessageField.VALUE)
243+
obj_bytes = ser(obj, ser_ctx)
244+
245+
deser = AvroDeserializer(client)
246+
obj2 = deser(obj_bytes, ser_ctx)
247+
assert obj == obj2
248+
249+
219250
def test_avro_serialize_union_with_references():
220251
conf = {'url': _BASE_URL}
221252
client = SchemaRegistryClient.new_client(conf)

0 commit comments

Comments
 (0)