Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 73 additions & 61 deletions smoketests/tests/replication.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,10 @@
from .. import COMPOSE_FILE, Smoketest, requires_docker, spacetime
from ..docker import DockerManager

import re
import time
from typing import Callable
import unittest

def get_int(text):
digits = re.search(r'\d+', text)
if digits is None:
raise Exception("no numbers found in string")
return int(digits.group())

def retry(func: Callable, max_retries: int = 3, retry_delay: int = 2):
"""Retry a function on failure with delay."""
for attempt in range(1, max_retries + 1):
Expand All @@ -25,6 +18,21 @@ def retry(func: Callable, max_retries: int = 3, retry_delay: int = 2):
print("Max retries reached. Skipping the exception.")
return False

def parse_sql_result(res: str) -> list[dict]:
"""Parse tabular output from an SQL query into a list of dicts."""
lines = res.splitlines()
headers = lines[0].split('|') if '|' in lines[0] else [lines[0]]
headers = [header.strip() for header in headers]
rows = []
for row in lines[2:]:
cols = [col.strip() for col in row.split('|')]
rows.append(dict(zip(headers, cols)))
return rows

def int_vals(rows: list[dict]) -> list[dict]:
"""For all dicts in list, cast all values in dict to int."""
return [{k: int(v) for k, v in row.items()} for row in rows]

class Cluster:
"""Manages leader-related operations and state for SpaceTime database cluster."""

Expand All @@ -35,56 +43,47 @@ def __init__(self, docker_manager, smoketest: Smoketest):
# Ensure all containers are up.
self.docker.compose("up", "-d")

def read_controldb(self, sql):
"""Helper method to read from control database."""
return self.test.spacetime("sql", "spacetime-control", sql)
def sql(self, sql: str) -> list[dict]:
"""Query the test database."""
res = self.test.sql(sql)
return parse_sql_result(str(res))

def read_controldb(self, sql: str) -> list[dict]:
"""Query the control database."""
res = self.test.spacetime("sql", "spacetime-control", sql)
return parse_sql_result(str(res))

def get_db_id(self):
"""Query database ID."""
sql = f"select id from database where database_identity=0x{self.test.database_identity}"
db_id_tb = self.read_controldb(sql)
return get_int(db_id_tb)

res = self.read_controldb(sql)
return int(res[0]['id'])

def get_all_replicas(self):
"""Get all replica nodes in the cluster."""
database_id = self.get_db_id()
sql = f"select id, node_id from replica where database_id={database_id}"
replica_tb = self.read_controldb(sql)
replicas = []
for line in str(replica_tb).splitlines()[2:]:
replica_id, node_id = line.split('|')
replicas.append({
'replica_id': int(replica_id),
'node_id': int(node_id)
})
return replicas
return int_vals(self.read_controldb(sql))

def get_leader_info(self):
"""Get current leader's node information including ID, hostname, and container ID."""

database_id = self.get_db_id()
# Query leader replica ID
sql = f"select leader from replication_state where database_id={database_id}"
leader_tb = self.read_controldb(sql)
leader_id = get_int(leader_tb)

# Query leader node ID
sql = f"select node_id from replica where id={leader_id}"
leader_node_tb = self.read_controldb(sql)
leader_node_id = get_int(leader_node_tb)

# Query leader hostname
sql = f"select network_addr from node_v2 where id={leader_node_id}"
leader_host_tb = str(self.read_controldb(sql))
lines = leader_host_tb.splitlines()
sql = f""" \
select node_v2.id, node_v2.network_addr from node_v2 \
join replica on replica.node_id=node_v2.id \
join replication_state on replication_state.leader=replica.id \
where replication_state.database_id={database_id} \
"""
rows = self.read_controldb(sql)
if not rows:
raise Exception("Could not find current leader's node")

leader_node_id = int(rows[0]['id'])
hostname = ""
if len(lines) == 3: # actual row starts from 3rd line
leader_row = lines[2]
if "(some =" in leader_row:
address = leader_row.split('"')[1]
hostname = address.split(':')[0]
if "(some =" in rows[0]['network_addr']:
address = rows[0]['network_addr'].split('"')[1]
hostname = address.split(':')[0]

# Find container ID
container_id = ""
Expand Down Expand Up @@ -114,15 +113,16 @@ def wait_for_leader_change(self, previous_leader_node, max_attempts=10, delay=2)
time.sleep(delay)
return None

def ensure_leader_health(self, id, wait_time=2):
def ensure_leader_health(self, id):
"""Verify leader is healthy by inserting a row."""
if wait_time:
time.sleep(wait_time)

retry(lambda: self.test.call("start", id, 1))
add_table = str(self.test.sql(f"SELECT id FROM counter where id={id}"))
if str(id) not in add_table:
rows = self.sql(f"select id from counter where id={id}")
if len(rows) < 1 or int(rows[0]['id']) != id:
raise ValueError(f"Could not find {id} in counter table")
# Wait for at least one tick to ensure buffers are flushed.
# TODO: Replace with confirmed read.
time.sleep(0.6)


def fail_leader(self, action='kill'):
Expand Down Expand Up @@ -247,31 +247,42 @@ def start(self, id: int, count: int):
"""Send a message to the database."""
retry(lambda: self.call("start", id, count))

def collect_counter_rows(self):
return int_vals(self.cluster.sql("select * from counter"))


class LeaderElection(ReplicationTest):
def test_leader_election_in_loop(self):
"""This test fails a leader, wait for new leader to be elected and verify if commits replicated to new leader"""
iterations = 5
row_ids = [101 + i for i in range(iterations * 2)]
for (first_id, second_id) in zip(row_ids[::2], row_ids[1::2]):
cur_leader = self.cluster.wait_for_leader_change(None)
print(f"ensure leader health {first_id}")
self.cluster.ensure_leader_health(first_id)

print("killing current leader: {}", cur_leader)
print(f"killing current leader: {cur_leader}")
container_id = self.cluster.fail_leader()

self.assertIsNotNone(container_id)

next_leader = self.cluster.wait_for_leader_change(cur_leader)
self.assertNotEqual(cur_leader, next_leader)
# this check if leader election happened
print(f"ensure_leader_health {second_id}")
self.cluster.ensure_leader_health(second_id)
# restart the old leader, so that we can maintain quorum for next iteration
print(f"reconnect leader {container_id}")
self.cluster.restore_leader(container_id, 'start')

# verify if all past rows are present in new leader
for row_id in row_ids:
table = self.sql(f"SELECT * FROM counter WHERE id = {row_id}")
self.assertIn(f"{row_id}", str(table))
# Ensure we have a current leader
last_row_id = row_ids[-1] + 1
self.cluster.ensure_leader_health(row_ids[-1] + 1)
row_ids.append(last_row_id)

# Verify that all inserted rows are present
stored_row_ids = [row['id'] for row in self.collect_counter_rows()]
self.assertEqual(set(stored_row_ids), set(row_ids))

class LeaderDisconnect(ReplicationTest):
def test_leader_c_disconnect_in_loop(self):
Expand Down Expand Up @@ -300,12 +311,15 @@ def test_leader_c_disconnect_in_loop(self):
# restart the old leader, so that we can maintain quorum for next iteration
print(f"reconnect leader {container_id}")
self.cluster.restore_leader(container_id, 'connect')
time.sleep(1)

# verify if all past rows are present in new leader
for row_id in row_ids:
table = self.sql(f"SELECT * FROM counter WHERE id = {row_id}")
self.assertIn(f"{row_id}", str(table))
# Ensure we have a current leader
last_row_id = row_ids[-1] + 1
self.cluster.ensure_leader_health(last_row_id)
row_ids.append(last_row_id)

# Verify that all inserted rows are present
stored_row_ids = [row['id'] for row in self.collect_counter_rows()]
self.assertEqual(set(stored_row_ids), set(row_ids))


@unittest.skip("drain_node not yet supported")
Expand Down Expand Up @@ -342,18 +356,16 @@ def test_prefer_leader(self):
if replica['node_id'] != cur_leader_node_id:
prefer_replica = replica
break
prefer_replica_id = prefer_replica['replica_id']
prefer_replica_id = prefer_replica['id']
self.spacetime("call", "spacetime-control", "prefer_leader", f"{prefer_replica_id}")

next_leader_node_id = self.cluster.wait_for_leader_change(cur_leader_node_id)
self.cluster.ensure_leader_health(402)
self.assertEqual(prefer_replica['node_id'], next_leader_node_id)


# verify if all past rows are present in new leader
for row_id in [401, 402]:
table = self.sql(f"SELECT * FROM counter WHERE id = {row_id}")
self.assertIn(f"{row_id}", str(table))
stored_row_ids = [row['id'] for row in self.collect_counter_rows()]
self.assertEqual(set(stored_row_ids), set([401, 402]))


class ManyTransactions(ReplicationTest):
Expand Down
Loading