Skip to content

Add ability to set subject name strategies #764

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 14 commits into from
Closed
44 changes: 43 additions & 1 deletion confluent_kafka/avro/cached_schema_registry_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@
import warnings
from collections import defaultdict

from avro.schema import RecordSchema
from requests import Session, utils

from .error import ClientError
from .error import ClientError, SubjectNameStrategyError
from . import loads

# Python 2 considers int an instance of str
Expand All @@ -39,6 +40,31 @@
VALID_METHODS = ['GET', 'POST', 'PUT', 'DELETE']
VALID_AUTH_PROVIDERS = ['URL', 'USER_INFO', 'SASL_INHERIT']


def topic_name_strategy(topic, schema, is_key):
return topic + ("-key" if is_key else "-value")


def record_name_strategy(topic, schema, is_key):
if isinstance(schema, RecordSchema) and schema.fullname:
return schema.fullname

if is_key:
raise SubjectNameStrategyError("the message key must have a name")
else:
raise SubjectNameStrategyError("the message value must have a name")


def topic_record_name_strategy(topic, schema, is_key):
return topic + "-" + record_name_strategy(topic, schema, is_key)


SUBJECT_NAME_STRATEGIES = {
'TopicNameStrategy': topic_name_strategy,
'RecordNameStrategy': record_name_strategy,
'TopicRecordNameStrategy': topic_record_name_strategy,
}

# Common accept header sent
ACCEPT_HDR = "application/vnd.schemaregistry.v1+json, application/vnd.schemaregistry+json, application/json"
log = logging.getLogger(__name__)
Expand Down Expand Up @@ -109,6 +135,22 @@ def __init__(self, url, max_schemas_per_subject=1000, ca_location=None, cert_loc
s.auth = self._configure_basic_auth(self.url, conf)
self.url = utils.urldefragauth(self.url)

key_subject_name_strategy = conf.pop(
'key.subject.name.strategy',
'TopicNameStrategy'
)
if key_subject_name_strategy not in SUBJECT_NAME_STRATEGIES:
raise ValueError("Invalid Key Subject Name Strategy")
self.key_subject_name_strategy_func = SUBJECT_NAME_STRATEGIES[key_subject_name_strategy] # noqa

value_subject_name_strategy = conf.pop(
'value.subject.name.strategy',
'TopicNameStrategy'
)
if value_subject_name_strategy not in SUBJECT_NAME_STRATEGIES:
raise ValueError("Invalid Value Subject Name Strategy")
self.value_subject_name_strategy_func = SUBJECT_NAME_STRATEGIES[value_subject_name_strategy] # noqa

self._session = s

self.auto_register_schemas = conf.pop("auto.register.schemas", True)
Expand Down
11 changes: 11 additions & 0 deletions confluent_kafka/avro/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,14 @@ def __repr__(self):

def __str__(self):
return self.message


class SubjectNameStrategyError(Exception):
def __init__(self, message):
self.message = message

def __repr__(self):
return "SubjectNameStrategyError(error={error})".format(error=self.message) # noqa

def __str__(self):
return self.message
14 changes: 11 additions & 3 deletions confluent_kafka/avro/serializer/message_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,13 @@ def _get_encoder_func(self, writer_schema):
writer = avro.io.DatumWriter(writer_schema)
return lambda record, fp: writer.write(record, avro.io.BinaryEncoder(fp))

def _get_subject(self, topic, schema, is_key):
if is_key:
subject = self.registry_client.key_subject_name_strategy_func(topic, schema, is_key) # noqa
else:
subject = self.registry_client.value_subject_name_strategy_func(topic, schema, is_key) # noqa
return subject

def encode_record_with_schema(self, topic, schema, record, is_key=False):
"""
Given a parsed avro schema, encode a record for the given topic. The
Expand All @@ -100,14 +107,15 @@ def encode_record_with_schema(self, topic, schema, record, is_key=False):
"""
serialize_err = KeySerializerError if is_key else ValueSerializerError

subject_suffix = ('-key' if is_key else '-value')
# get the latest schema for the subject
subject = topic + subject_suffix
subject = self._get_subject(topic, schema, is_key)

if self.registry_client.auto_register_schemas:
# register it
schema_id = self.registry_client.register(subject, schema)
else:
# get the latest schema for the subject
schema_id = self.registry_client.check_registration(subject, schema)

if not schema_id:
message = "Unable to retrieve schema id for subject %s" % (subject)
raise serialize_err(message)
Expand Down
3 changes: 3 additions & 0 deletions tests/avro/mock_schema_registry_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#

from confluent_kafka.avro import ClientError
from confluent_kafka.avro.cached_schema_registry_client import topic_name_strategy


class MockSchemaRegistryClient(object):
Expand All @@ -45,6 +46,8 @@ def __init__(self, max_schemas_per_subject=1000):
self.next_id = 1
self.schema_to_id = {}

self.key_subject_name_strategy_func = topic_name_strategy
self.value_subject_name_strategy_func = topic_name_strategy
self.auto_register_schemas = True

def _get_next_id(self, schema):
Expand Down
90 changes: 89 additions & 1 deletion tests/avro/test_cached_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,13 @@

from tests.avro import mock_registry
from tests.avro import data_gen
from confluent_kafka.avro.cached_schema_registry_client import CachedSchemaRegistryClient
from confluent_kafka.avro.cached_schema_registry_client import (
CachedSchemaRegistryClient,

topic_name_strategy,
record_name_strategy,
topic_record_name_strategy,
)
from confluent_kafka import avro


Expand Down Expand Up @@ -233,3 +239,85 @@ def test_invalid_conf(self):
'invalid.conf': 1,
'invalid.conf2': 2
})

def test_default_key_subject_name_strategy(self):
client = CachedSchemaRegistryClient({
'url': 'https://user_url:secret_url@127.0.0.1:65534',
})

expected = topic_name_strategy

self.assertEqual(expected, client.key_subject_name_strategy_func)

def test_invalid_key_subject_name_strategy(self):
with self.assertRaises(ValueError):
CachedSchemaRegistryClient({
'url': 'https://user_url:secret_url@127.0.0.1:65534',
'key.subject.name.strategy': "InvalidNameStrategy",
})

def test_key_subject_name_strategies(self):
client = CachedSchemaRegistryClient({
'url': 'https://user_url:secret_url@127.0.0.1:65534',
'key.subject.name.strategy': "TopicNameStrategy",
})

expected = topic_name_strategy
self.assertEqual(expected, client.key_subject_name_strategy_func)

client = CachedSchemaRegistryClient({
'url': 'https://user_url:secret_url@127.0.0.1:65534',
'key.subject.name.strategy': "RecordNameStrategy",
})

expected = record_name_strategy
self.assertEqual(expected, client.key_subject_name_strategy_func)

client = CachedSchemaRegistryClient({
'url': 'https://user_url:secret_url@127.0.0.1:65534',
'key.subject.name.strategy': "TopicRecordNameStrategy",
})

expected = topic_record_name_strategy
self.assertEqual(expected, client.key_subject_name_strategy_func)

def test_default_value_subject_name_strategy(self):
client = CachedSchemaRegistryClient({
'url': 'https://user_url:secret_url@127.0.0.1:65534',
})

expected = topic_name_strategy

self.assertEqual(expected, client.value_subject_name_strategy_func)

def test_invalid_value_subject_name_strategy(self):
with self.assertRaises(ValueError):
CachedSchemaRegistryClient({
'url': 'https://user_url:secret_url@127.0.0.1:65534',
'value.subject.name.strategy': "InvalidNameStrategy",
})

def test_value_subject_name_strategies(self):
client = CachedSchemaRegistryClient({
'url': 'https://user_url:secret_url@127.0.0.1:65534',
'value.subject.name.strategy': "TopicNameStrategy",
})

expected = topic_name_strategy
self.assertEqual(expected, client.value_subject_name_strategy_func)

client = CachedSchemaRegistryClient({
'url': 'https://user_url:secret_url@127.0.0.1:65534',
'value.subject.name.strategy': "RecordNameStrategy",
})

expected = record_name_strategy
self.assertEqual(expected, client.value_subject_name_strategy_func)

client = CachedSchemaRegistryClient({
'url': 'https://user_url:secret_url@127.0.0.1:65534',
'value.subject.name.strategy': "TopicRecordNameStrategy",
})

expected = topic_record_name_strategy
self.assertEqual(expected, client.value_subject_name_strategy_func)
59 changes: 59 additions & 0 deletions tests/avro/test_message_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@
from confluent_kafka.avro.serializer.message_serializer import MessageSerializer
from tests.avro.mock_schema_registry_client import MockSchemaRegistryClient
from confluent_kafka import avro
from confluent_kafka.avro.cached_schema_registry_client import (
topic_name_strategy,
record_name_strategy,
topic_record_name_strategy,
)


class TestMessageSerializer(unittest.TestCase):
Expand Down Expand Up @@ -80,5 +85,59 @@ def test_decode_none(self):

self.assertIsNone(self.ms.decode_message(None))

def test__get_subject_for_key_with_topic_name_strategy(self):
basic = avro.loads(data_gen.BASIC_SCHEMA)
topic = "topic"
self.ms.registry_client.key_subject_name_strategy_func = topic_name_strategy # noqa
subject = self.ms._get_subject(topic=topic, schema=basic, is_key=True)

expected = "topic-key"
self.assertEqual(expected, subject)

def test__get_subject_for_key_with_record_name_strategy(self):
basic = avro.loads(data_gen.BASIC_SCHEMA)
topic = "topic"
self.ms.registry_client.key_subject_name_strategy_func = record_name_strategy # noqa
subject = self.ms._get_subject(topic=topic, schema=basic, is_key=True)

expected = "python.test.basic.basic"
self.assertEqual(expected, subject)

def test__get_subject_for_key_with_topic_record_name_strategy(self):
basic = avro.loads(data_gen.BASIC_SCHEMA)
topic = "topic"
self.ms.registry_client.key_subject_name_strategy_func = topic_record_name_strategy # noqa
subject = self.ms._get_subject(topic=topic, schema=basic, is_key=True)

expected = "topic-python.test.basic.basic"
self.assertEqual(expected, subject)

def test__get_subject_for_value_with_topic_name_strategy(self):
basic = avro.loads(data_gen.BASIC_SCHEMA)
topic = "topic"
self.ms.registry_client.value_subject_name_strategy_func = topic_name_strategy # noqa
subject = self.ms._get_subject(topic=topic, schema=basic, is_key=False)

expected = "topic-value"
self.assertEqual(expected, subject)

def test__get_subject_for_value_with_record_name_strategy(self):
basic = avro.loads(data_gen.BASIC_SCHEMA)
topic = "topic"
self.ms.registry_client.value_subject_name_strategy_func = record_name_strategy # noqa
subject = self.ms._get_subject(topic=topic, schema=basic, is_key=False)

expected = "python.test.basic.basic"
self.assertEqual(expected, subject)

def test__get_subject_for_value_with_topic_record_name_strategy(self):
basic = avro.loads(data_gen.BASIC_SCHEMA)
topic = "topic"
self.ms.registry_client.value_subject_name_strategy_func = topic_record_name_strategy # noqa
subject = self.ms._get_subject(topic=topic, schema=basic, is_key=False)

expected = "topic-python.test.basic.basic"
self.assertEqual(expected, subject)

def hash_func(self):
return hash(str(self))
46 changes: 46 additions & 0 deletions tests/avro/test_subject_name_strategies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import os.path
import unittest

from confluent_kafka import avro
from confluent_kafka.avro.cached_schema_registry_client import (
topic_name_strategy,
record_name_strategy,
topic_record_name_strategy,
)
from confluent_kafka.avro.error import SubjectNameStrategyError

from tests.avro import data_gen

avsc_dir = os.path.dirname(os.path.realpath(__file__))


class TestSubjectNameStrategies(unittest.TestCase):
def setUp(self):
self.schema = avro.loads(data_gen.ADVANCED_SCHEMA)

def test_topic_name_strategy(self):
subject = topic_name_strategy("topic", self.schema, False)
expected = "topic-value"

self.assertEqual(expected, subject)

def test_record_name_strategy(self):
subject = record_name_strategy("topic", self.schema, False)
expected = self.schema.fullname

self.assertEqual(expected, subject)

def test_topic_record_name_strategy(self):
subject = topic_record_name_strategy("topic", self.schema, False)
expected = "topic-%s" % self.schema.fullname

self.assertEqual(expected, subject)

def test_should_raise_exception_for_schema_without_name(self):
schema = avro.load(os.path.join(avsc_dir, "primitive_string.avsc"))

with self.assertRaises(SubjectNameStrategyError, msg="the message key must have a name"): # noqa
record_name_strategy("topic", schema, is_key=True)

with self.assertRaises(SubjectNameStrategyError, msg="the message value must have a name"): # noqa
record_name_strategy("topic", schema, is_key=False)