Skip to content

Commit

Permalink
Add session setting support for specs
Browse files Browse the repository at this point in the history
This adds support for session settings for the various specs formats.
These session settings will be applied by the http/pg client before
the spec is executed.
  • Loading branch information
mkleen committed Aug 27, 2024
1 parent 9d07688 commit 09110ba
Show file tree
Hide file tree
Showing 9 changed files with 77 additions and 18 deletions.
4 changes: 3 additions & 1 deletion cr8/bench_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,13 @@ def from_dict(d):


class Spec:
def __init__(self, setup, teardown, queries=None, load_data=None, meta=None):
def __init__(self, setup, teardown, queries=None, load_data=None, meta=None, session_settings=None):
self.setup = setup
self.teardown = teardown
self.queries = queries
self.load_data = load_data
self.meta = meta or {}
self.session_settings = session_settings or {}

@staticmethod
def from_dict(d):
Expand All @@ -45,6 +46,7 @@ def from_dict(d):
meta=d.get('meta', {}),
queries=d.get('queries', []),
load_data=d.get('load_data', []),
session_settings=d.get('session_settings', {}),
)

@staticmethod
Expand Down
27 changes: 21 additions & 6 deletions cr8/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,18 +216,25 @@ def _verify_ssl_from_first(hosts):


class AsyncpgClient:
def __init__(self, hosts, pool_size=25):
def __init__(self, hosts, pool_size=25, session_settings=None):
self.dsn = _to_dsn(hosts)
self.pool_size = pool_size
self._pool = None
self.is_cratedb = True
self.session_settings = session_settings or {}

async def _get_pool(self):

async def set_session_settings(conn):
for setting, value in self.session_settings.items():
await conn.execute(f'set {setting}={value}')

if not self._pool:
self._pool = await asyncpg.create_pool(
self.dsn,
min_size=self.pool_size,
max_size=self.pool_size
max_size=self.pool_size,
setup=set_session_settings
)
return self._pool

Expand Down Expand Up @@ -308,7 +315,7 @@ def _append_sql(host):


class HttpClient:
def __init__(self, hosts, conn_pool_limit=25):
def __init__(self, hosts, conn_pool_limit=25, session_settings=None):
self.hosts = hosts
self.urls = itertools.cycle(list(map(_append_sql, hosts)))
self._connector_params = {
Expand All @@ -317,13 +324,21 @@ def __init__(self, hosts, conn_pool_limit=25):
}
self.__session = None
self.is_cratedb = True
self.session_settings = session_settings or {}

@property
async def _session(self):
session = self.__session
if session is None:
conn = aiohttp.TCPConnector(**self._connector_params)
self.__session = session = aiohttp.ClientSession(connector=conn)
for setting, value in self.session_settings.items():
payload = {'stmt': f'set {setting}={value}'}
await _exec(
session,
next(self.urls),
dumps(payload, cls=CrateJsonEncoder)
)
return session

async def execute(self, stmt, args=None):
Expand Down Expand Up @@ -372,10 +387,10 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self.close()


def client(hosts, concurrency=25):
def client(hosts, session_settings=None, concurrency=25):
hosts = hosts or 'localhost:4200'
if hosts.startswith('asyncpg://'):
if not asyncpg:
raise ValueError('Cannot use "asyncpg" scheme if asyncpg is not available')
return AsyncpgClient(hosts, pool_size=concurrency)
return HttpClient(_to_http_hosts(hosts), conn_pool_limit=concurrency)
return AsyncpgClient(hosts, pool_size=concurrency, session_settings=session_settings)
return HttpClient(_to_http_hosts(hosts), conn_pool_limit=concurrency, session_settings=session_settings)
4 changes: 2 additions & 2 deletions cr8/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ def _generate_statements(stmt, args, iterations, duration):


class Runner:
def __init__(self, hosts, concurrency, sample_mode):
def __init__(self, hosts, concurrency, sample_mode, session_settings=None):
self.concurrency = concurrency
self.client = client(hosts, concurrency=concurrency)
self.client = client(hosts, session_settings=session_settings, concurrency=concurrency)
self.sampler = get_sampler(sample_mode)

def warmup(self, stmt, num_warmup, concurrency=0, args=None):
Expand Down
16 changes: 10 additions & 6 deletions cr8/run_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,13 @@ def __init__(self,
result_hosts,
log,
fail_if,
sample_mode):
sample_mode,
session_settings):
self.benchmark_hosts = benchmark_hosts
self.sample_mode = sample_mode
self.session_settings = session_settings
self.spec_dir = spec_dir
self.client = clients.client(benchmark_hosts)
self.client = clients.client(benchmark_hosts, session_settings)
self.result_client = clients.client(result_hosts)
self.server_version_info = aio.run(self.client.get_server_version)
self.server_version = parse_version(self.server_version_info['number'])
Expand Down Expand Up @@ -204,7 +206,7 @@ def run_queries(self, queries: Iterable[dict], meta=None):
f' Concurrency: {concurrency}\n'
f' {mode_desc}: {duration or iterations}')
)
with Runner(self.benchmark_hosts, concurrency, self.sample_mode) as runner:
with Runner(self.benchmark_hosts, concurrency, self.sample_mode, self.session_settings) as runner:
if warmup > 0:
runner.warmup(stmt, warmup, concurrency, args)
timed_stats = runner.run(
Expand Down Expand Up @@ -242,15 +244,17 @@ def do_run_spec(spec,
action=None,
fail_if=None,
re_name=None):
spec_dir = os.path.dirname(spec)
spec = load_spec(spec)
with Executor(
spec_dir=os.path.dirname(spec),
spec_dir=spec_dir,
benchmark_hosts=benchmark_hosts,
result_hosts=result_hosts,
log=log,
fail_if=fail_if,
sample_mode=sample_mode
sample_mode=sample_mode,
session_settings=spec.session_settings
) as executor:
spec = load_spec(spec)
try:
if not action or 'setup' in action:
log.info('# Running setUp')
Expand Down
4 changes: 4 additions & 0 deletions specs/count_countries.json
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
}
]
},
"session_settings": {
"application_name": "my_app",
"timezone": "UTC"
},
"queries": [{
"iterations": 1000,
"statement": "select count(*) from countries"
Expand Down
4 changes: 2 additions & 2 deletions specs/sample.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

from itertools import count
from cr8.bench_spec import Spec, Instructions

Expand All @@ -21,4 +20,5 @@ def queries():
setup=Instructions(statements=["create table t (x int)"]),
teardown=Instructions(statements=["drop table t"]),
queries=queries(),
)
session_settings={'application_name': 'my_app', 'timezone': 'UTC'}
)
4 changes: 4 additions & 0 deletions specs/sample.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ statement_files = ["sql/create_countries.sql"]
target = "countries"
cmd = ['echo', '{"capital": "Demo"}']

[session_settings]
application_name = 'my_app'
timezone = 'UTC'

[[queries]]
name = "count countries" # Can be used to give the queries a name for easier analytics of the results
statement = "select count(*) from countries"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def parse(self, string, name='<string>'):
class SourceBuildTest(TestCase):

def test_build_from_branch(self):
self.assertIsNotNone(get_crate('4.1'))
self.assertIsNotNone(get_crate('5.8'))


def load_tests(loader, tests, ignore):
Expand Down
30 changes: 30 additions & 0 deletions tests/test_spec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import os
from unittest import TestCase
from doctest import DocTestSuite

from cr8.bench_spec import load_spec

from cr8 import engine


class SpecTest(TestCase):

def test_session_settings_from_spec(self):
spec = self.get_spec('sample.py')
self.assertEqual(spec.session_settings, {'application_name': 'my_app', 'timezone': 'UTC'})

def test_session_settings_from_toml(self):
spec = self.get_spec('sample.toml')
self.assertEqual(spec.session_settings, {'application_name': 'my_app', 'timezone': 'UTC'})

def test_session_settings_from_json(self):
spec = self.get_spec('count_countries.json')
self.assertEqual(spec.session_settings, {'application_name': 'my_app', 'timezone': 'UTC'})

def get_spec(self, name):
return load_spec(os.path.abspath(os.path.join(os.path.dirname(__file__), '../specs/', name)))


def load_tests(loader, tests, ignore):
tests.addTests(DocTestSuite(engine))
return tests

0 comments on commit 09110ba

Please sign in to comment.