Skip to content

replay log on initialize #56

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/processor/test_base_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_initialiseBaseProcessor():
mock_task = mock.Mock()
mock_task.application_id = 'test_id'
mock_task_id = TaskId('test_group', 0)
mock_context = ProcessorContext(mock_task_id, mock_task, None, None, {})
mock_context = ProcessorContext(mock_task_id, mock_task, None, {}, None)
bp = wks_processor.BaseProcessor()
bp.initialise('my-name', mock_context)

Expand Down
2 changes: 1 addition & 1 deletion tests/processor/test_sink_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_sinkProcessorProcess():
mock_task = mock.Mock()
mock_task.application_id = 'test_id'
mock_task_id = TaskId('test_group', 0)
processor_context = wks_processor.ProcessorContext(mock_task_id, mock_task, None, None, {})
processor_context = wks_processor.ProcessorContext(mock_task_id, mock_task, None, {}, None)
processor_context.record_collector = mock.MagicMock()

sink = wks_processor.SinkProcessor('topic1')
Expand Down
100 changes: 100 additions & 0 deletions tests/state/test_change_logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from collections import deque
from typing import Iterator, Tuple

import pytest

from winton_kafka_streams.processor.serialization.serdes import IntegerSerde, StringSerde
from winton_kafka_streams.state.in_memory.in_memory_state_store import InMemoryStateStore
from winton_kafka_streams.state.logging.change_logging_state_store import ChangeLoggingStateStore
from winton_kafka_streams.state.logging.store_change_logger import StoreChangeLogger


class MockChangeLogger(StoreChangeLogger):
def __init__(self):
super(MockChangeLogger, self).__init__()
self.change_log = deque()

def log_change(self, key: bytes, value: bytes) -> None:
self.change_log.append((key, value))

def __iter__(self) -> Iterator[Tuple[bytes, bytes]]:
return self.change_log.__iter__()


def _get_store():
inner_store = InMemoryStateStore('teststore', StringSerde(), IntegerSerde(), False)
store = ChangeLoggingStateStore('teststore', StringSerde(), IntegerSerde(), False, inner_store)
store._get_change_logger = lambda context: MockChangeLogger()
store.initialize(None, None)
return store


def test_change_store_is_dict():
store = _get_store()
kv_store = store.get_key_value_store()

kv_store['a'] = 1
assert kv_store['a'] == 1

kv_store['a'] = 2
assert kv_store['a'] == 2

del kv_store['a']
assert kv_store.get('a') is None
with pytest.raises(KeyError):
_ = kv_store['a']


def test_change_log_is_written_to():
store = _get_store()
kv_store = store.get_key_value_store()

kv_store['a'] = 12
assert len(store.change_logger.change_log) == 1
assert store.change_logger.change_log[0] == (b'a', b'\x0c\0\0\0')

del kv_store['a']
assert len(store.change_logger.change_log) == 2
assert store.change_logger.change_log[1] == (b'a', b'')


def test_can_replay_log():
store = _get_store()
kv_store = store.get_key_value_store()

kv_store['a'] = 12
kv_store['b'] = 123
del kv_store['a']

keys = []
values = []

for k, v in store.change_logger:
keys.append(k)
values.append(v)

assert keys == [b'a', b'b', b'a']
assert values == [b'\x0c\0\0\0', b'\x7b\0\0\0', b'']


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am I missing something_ Where exactly does it replay into the actual store from the changelog?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ChangeLoggingStateStore.initialize?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah. Hooking that up with stream task is another PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A StreamTask calles StateStore.initialize(context, root) as part of its init(). The if the StateStore is a ChangeLoggingStateStore then it will replay the changelog into it's inner_state_store. See change_logging_state_store.py:25

def test_rebuild_state_from_log():
store = _get_store()
kv_store = store.get_key_value_store()

kv_store['a'] = 12
kv_store['b'] = 123
del kv_store['a']
kv_store['c'] = 1234

log = store.change_logger

# reattach previous changelog and run initialize()
store = _get_store()
kv_store = store.get_key_value_store()
store._get_change_logger = lambda context: log
store.initialize(None, None)

with pytest.raises(KeyError):
_ = kv_store['a']
assert kv_store['b'] == 123
assert kv_store['c'] == 1234
2 changes: 1 addition & 1 deletion tests/state/test_in_memory_key_value_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from winton_kafka_streams.state.in_memory.in_memory_state_store import InMemoryStateStore


def test_inMemoryKeyValueStore():
def test_in_memory_key_value_store():
store = InMemoryStateStore('teststore', BytesSerde(), BytesSerde(), False)
kv_store = store.get_key_value_store()

Expand Down
3 changes: 1 addition & 2 deletions winton_kafka_streams/processor/_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,9 @@ class Context:

"""

def __init__(self, _state_record_collector, _state_stores):
def __init__(self, _state_stores):
self.current_node = None
self.current_record = None
self.state_record_collector = _state_record_collector
self._state_stores = _state_stores

def send(self, topic, key, obj):
Expand Down
5 changes: 1 addition & 4 deletions winton_kafka_streams/processor/_stream_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from confluent_kafka import TopicPartition
from confluent_kafka.cimpl import KafkaException, KafkaError

from winton_kafka_streams.processor.serialization.serdes import BytesSerde
from ..errors._kafka_error_codes import _get_invalid_producer_epoch_code
from ._punctuation_queue import PunctuationQueue
from ._record_collector import RecordCollector
Expand Down Expand Up @@ -69,11 +68,9 @@ def __init__(self, _task_id, _application_id, _partitions, _topology_builder, _c
self.value_serde.configure(self.config, False)

self.record_collector = RecordCollector(self.producer, self.key_serde, self.value_serde)
self.state_record_collector = RecordCollector(self.producer, BytesSerde(), BytesSerde())

self.queue = queue.Queue()
self.context = ProcessorContext(self.task_id, self, self.record_collector,
self.state_record_collector, self.state_stores)
self.context = ProcessorContext(self.task_id, self, self.record_collector, self.state_stores, self.config)

self.punctuation_queue = PunctuationQueue(self.punctuate)
# TODO: use the configured timestamp extractor.
Expand Down
5 changes: 3 additions & 2 deletions winton_kafka_streams/processor/processor_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@ class ProcessorContext(_context.Context):
values to downstream processors.

"""
def __init__(self, _task_id, _task, _record_collector, _state_record_collector, _state_stores):
def __init__(self, _task_id, _task, _record_collector, _state_stores, _config):

super().__init__(_state_record_collector, _state_stores)
super().__init__(_state_stores)

self.application_id = _task.application_id
self.task_id = _task_id
self.task = _task
self.record_collector = _record_collector
self.config = _config

def commit(self):
"""
Expand Down
17 changes: 13 additions & 4 deletions winton_kafka_streams/state/logging/change_logging_state_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from winton_kafka_streams.processor.serialization import Serde
from ..key_value_state_store import KeyValueStateStore
from ..state_store import StateStore
from .store_change_logger import StoreChangeLogger
from .store_change_logger import StoreChangeLogger, StoreChangeLoggerImpl

KT = TypeVar('KT') # Key type.
VT = TypeVar('VT') # Value type.
Expand All @@ -14,12 +14,21 @@ def __init__(self, name: str, key_serde: Serde[KT], value_serde: Serde[VT], log
inner_state_store: StateStore[KT, VT]) -> None:
super().__init__(name, key_serde, value_serde, logging_enabled)
self.inner_state_store = inner_state_store
self.change_logger = None
self.change_logger: StoreChangeLogger = None

def _get_change_logger(self, context) -> StoreChangeLogger:
return StoreChangeLoggerImpl(self.inner_state_store.name, context)

def initialize(self, context, root):
self.inner_state_store.initialize(context, root)
self.change_logger = StoreChangeLogger(self.inner_state_store.name, context)
# TODO rebuild state into inner here
self.change_logger = self._get_change_logger(context)
for k, v in self.change_logger:
deserialized_key = self.deserialize_key(k)
inner_kv_store = self.inner_state_store.get_key_value_store()
if v == b'':
del inner_kv_store[deserialized_key]
else:
inner_kv_store[deserialized_key] = self.deserialize_value(v)

def get_key_value_store(self) -> KeyValueStateStore[KT, VT]:
parent = self
Expand Down
42 changes: 40 additions & 2 deletions winton_kafka_streams/state/logging/store_change_logger.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,48 @@
class StoreChangeLogger:
from abc import abstractmethod
from typing import Iterator, Iterable, Tuple

from confluent_kafka.cimpl import TopicPartition, OFFSET_BEGINNING, KafkaError

from winton_kafka_streams.processor.serialization.serdes import BytesSerde
from winton_kafka_streams.kafka_client_supplier import KafkaClientSupplier
from winton_kafka_streams.processor._record_collector import RecordCollector


class StoreChangeLogger(Iterable[Tuple[bytes, bytes]]):
@abstractmethod
def log_change(self, key: bytes, value: bytes) -> None:
pass

@abstractmethod
def __iter__(self) -> Iterator[Tuple[bytes, bytes]]:
pass


class StoreChangeLoggerImpl(StoreChangeLogger):
def __init__(self, store_name, context) -> None:
self.topic = f'{context.application_id}-{store_name}-changelog'
self.context = context
self.partition = context.task_id.partition
self.record_collector = context.state_record_collector
self.client_supplier = KafkaClientSupplier(self.context.config)
self.record_collector = RecordCollector(self.client_supplier.producer(), BytesSerde(), BytesSerde())

def log_change(self, key: bytes, value: bytes) -> None:
if self.record_collector:
self.record_collector.send(self.topic, key, value, self.context.timestamp, partition=self.partition)

def __iter__(self) -> Iterator[Tuple[bytes, bytes]]:
consumer = self.client_supplier.consumer()
partition = TopicPartition(self.topic, self.partition, OFFSET_BEGINNING)
consumer.assign([partition])

class TopicIterator(Iterator[Tuple[bytes, bytes]]):
def __next__(self) -> Tuple[bytes, bytes]:
msg = consumer.poll(1.0)
if msg.error():
if msg.error().code() == KafkaError._PARTITION_EOF:
raise StopIteration()
if msg is None:
raise StopIteration()
return msg.key(), msg.value()

return TopicIterator()