Skip to content

Fixed fastavro not being used if schemas aren't manually specified… #601

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions confluent_kafka/avro/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from confluent_kafka import Producer, Consumer
from confluent_kafka.avro.error import ClientError
from confluent_kafka.avro.load import load, loads # noqa
from confluent_kafka.avro.load import load, loads, loads_fast # noqa
from confluent_kafka.avro.cached_schema_registry_client import CachedSchemaRegistryClient
from confluent_kafka.avro.serializer import (SerializerError, # noqa
KeySerializerError,
Expand Down Expand Up @@ -101,10 +101,12 @@ class AvroConsumer(Consumer):
and the standard Kafka client configuration (``bootstrap.servers`` et.al)
:param schema reader_key_schema: a reader schema for the message key
:param schema reader_value_schema: a reader schema for the message value
:param bool decode_key: optionally disable key decoding (i.e. only decode values)
:raises ValueError: For invalid configurations
"""

def __init__(self, config, schema_registry=None, reader_key_schema=None, reader_value_schema=None):
def __init__(self, config, schema_registry=None, reader_key_schema=None,
reader_value_schema=None, decode_key=True):

sr_conf = {key.replace("schema.registry.", ""): value
for key, value in config.items() if key.startswith("schema.registry")}
Expand All @@ -125,6 +127,8 @@ def __init__(self, config, schema_registry=None, reader_key_schema=None, reader_
super(AvroConsumer, self).__init__(ap_conf)
self._serializer = MessageSerializer(schema_registry, reader_key_schema, reader_value_schema)

self._decode_key = decode_key

def poll(self, timeout=None):
"""
This is an overriden method from confluent_kafka.Consumer class. This handles message
Expand All @@ -145,7 +149,7 @@ def poll(self, timeout=None):
if message.value() is not None:
decoded_value = self._serializer.decode_message(message.value(), is_key=False)
message.set_value(decoded_value)
if message.key() is not None:
if self._decode_key and message.key() is not None:
decoded_key = self._serializer.decode_message(message.key(), is_key=True)
message.set_key(decoded_key)
except SerializerError as e:
Expand Down
39 changes: 30 additions & 9 deletions confluent_kafka/avro/cached_schema_registry_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from requests import Session, utils

from .error import ClientError
from . import loads
from . import loads, loads_fast

# Python 2 considers int an instance of str
try:
Expand Down Expand Up @@ -97,6 +97,8 @@ def __init__(self, url, max_schemas_per_subject=1000, ca_location=None, cert_loc
self.subject_to_schema_ids = defaultdict(dict)
# id => avro_schema
self.id_to_schema = defaultdict(dict)
# id => fastavro schema
self._id_to_fast_schema = dict()
# subj => { schema => version }
self.subject_to_schema_versions = defaultdict(dict)

Expand Down Expand Up @@ -175,13 +177,18 @@ def _add_to_cache(cache, subject, schema, value):
sub_cache = cache[subject]
sub_cache[schema] = value

def _cache_schema(self, schema, schema_id, subject=None, version=None):
def _cache_schema(self, schema, schema_id, subject=None, version=None, schema_str=None):
# don't overwrite anything
if schema_id in self.id_to_schema:
schema = self.id_to_schema[schema_id]
else:
self.id_to_schema[schema_id] = schema

if schema_str is not None and schema_id not in self._id_to_fast_schema:
fast_schema = loads_fast(schema_str)
if fast_schema is not None:
self._id_to_fast_schema[schema_id] = fast_schema

if subject:
self._add_to_cache(self.subject_to_schema_ids,
subject, schema, schema_id)
Expand Down Expand Up @@ -213,7 +220,8 @@ def register(self, subject, avro_schema):
url = '/'.join([self.url, 'subjects', subject, 'versions'])
# body is { schema : json_string }

body = {'schema': json.dumps(avro_schema.to_json())}
schema_str = json.dumps(avro_schema.to_json())
body = {'schema': schema_str}
result, code = self._send_request(url, method='POST', body=body)
if (code == 401 or code == 403):
raise ClientError("Unauthorized access. Error code:" + str(code))
Expand All @@ -226,7 +234,7 @@ def register(self, subject, avro_schema):
# result is a dict
schema_id = result['id']
# cache it
self._cache_schema(avro_schema, schema_id, subject)
self._cache_schema(avro_schema, schema_id, subject, schema_str=schema_str)
return schema_id

def delete_subject(self, subject):
Expand Down Expand Up @@ -272,12 +280,22 @@ def get_by_id(self, schema_id):
try:
result = loads(schema_str)
# cache it
self._cache_schema(result, schema_id)
self._cache_schema(result, schema_id, schema_str=schema_str)
return result
except ClientError as e:
# bad schema - should not happen
raise ClientError("Received bad schema (id %s) from registry: %s" % (schema_id, e))

def get_by_id_fast(self, schema_id):
"""
Retrieves a fastavro-parsed schema by id, if possible
:param int schema_id: int value
:returns: A fastavro schema, or None
"""
if self.get_by_id(schema_id) is not None and schema_id in self._id_to_fast_schema:
return self._id_to_fast_schema[schema_id]
return None

def get_latest_schema(self, subject):
"""
GET /subjects/(string: subject)/versions/(versionId: version)
Expand Down Expand Up @@ -306,16 +324,18 @@ def get_latest_schema(self, subject):
return (None, None, None)
schema_id = result['id']
version = result['version']
schema_str = None
if schema_id in self.id_to_schema:
schema = self.id_to_schema[schema_id]
else:
try:
schema = loads(result['schema'])
schema_str = result['schema']
schema = loads(schema_str)
except ClientError:
# bad schema - should not happen
raise

self._cache_schema(schema, schema_id, subject, version)
self._cache_schema(schema, schema_id, subject, version, schema_str)
return (schema_id, schema, version)

def get_version(self, subject, avro_schema):
Expand All @@ -336,7 +356,8 @@ def get_version(self, subject, avro_schema):
return version

url = '/'.join([self.url, 'subjects', subject])
body = {'schema': json.dumps(avro_schema.to_json())}
schema_str = json.dumps(avro_schema.to_json())
body = {'schema': schema_str}

result, code = self._send_request(url, method='POST', body=body)
if code == 404:
Expand All @@ -347,7 +368,7 @@ def get_version(self, subject, avro_schema):
return None
schema_id = result['id']
version = result['version']
self._cache_schema(avro_schema, schema_id, subject, version)
self._cache_schema(avro_schema, schema_id, subject, version, schema_str)
return version

def test_compatibility(self, subject, avro_schema, version='latest'):
Expand Down
18 changes: 18 additions & 0 deletions confluent_kafka/avro/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,21 @@ def _hash_func(self):

except ImportError:
schema = None


HAS_FAST = False
try:
import fastavro
import json
HAS_FAST = True
except ImportError:
pass


def loads_fast(schema_str):
if HAS_FAST:
try:
return fastavro.parse_schema(json.loads(schema_str))
except Exception:
pass
return None
10 changes: 7 additions & 3 deletions confluent_kafka/avro/serializer/message_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,13 @@ def _get_decoder_func(self, schema_id, payload, is_key=False):
if HAS_FAST:
# try to use fast avro
try:
writer_schema = writer_schema_obj.to_json()
reader_schema = reader_schema_obj.to_json()
schemaless_reader(payload, writer_schema)
writer_schema = self.registry_client.get_by_id_fast(schema_id)
if writer_schema is None:
writer_schema = writer_schema_obj.to_json()
reader_schema = None
if reader_schema_obj is not None:
reader_schema = reader_schema_obj.to_json()
schemaless_reader(payload, writer_schema, reader_schema)

# If we reach this point, this means we have fastavro and it can
# do this deserialization. Rewind since this method just determines
Expand Down