Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions kernelci/config/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def __init__(
priority_max=None,
queue_timeout=None,
notify=None,
max_queue_depth=None,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -107,6 +108,7 @@ def _set_priority_value(value, default):
self._priority_max = _set_priority_value(priority_max, self._priority)
self._notify = notify or {}
self._queue_timeout = queue_timeout
self._max_queue_depth = max_queue_depth if max_queue_depth is not None else 50

@property
def url(self):
Expand Down Expand Up @@ -142,6 +144,11 @@ def notify(self):
"""Callback parameters for the `notify` part of the jobs"""
return self._notify.copy()

@property
def max_queue_depth(self):
"""Maximum queue depth per device type before skipping job submissions"""
return self._max_queue_depth

@classmethod
def _get_yaml_attributes(cls):
attrs = super()._get_yaml_attributes()
Expand All @@ -153,6 +160,7 @@ def _get_yaml_attributes(cls):
"queue_timeout",
"url",
"notify",
"max_queue_depth",
}
)
return attrs
Expand Down
85 changes: 85 additions & 0 deletions kernelci/runtime/lava.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,91 @@ def _connect(self):
}
return rest_api

def _get_response(self, url, params=None):
resp = self._server.session.get(url, params=params, timeout=30)
resp.raise_for_status()
return resp.json()

def _get_all(self, url, params=None):
resp = self._get_response(url, params=params)
results = resp.get('results', [])
next_url = resp.get('next')
while next_url:
resp = self._get_response(next_url)
results.extend(resp.get('results', []))
next_url = resp.get('next')
return results

def get_devicetype_job_count(self, device_types):
"""Get queued job counts per requested device type.

*device_types* can be a device type name or list of device type names.
This queries /jobs/?state=Submitted&requested_device_type=<type> and
reads the 'count' field from the paginated DRF response.
"""
if self._server.url is None:
raise ValueError("LAVA server URL is not configured")

single_type = isinstance(device_types, str)
if single_type:
requested_types = [device_types]
else:
requested_types = list(device_types or [])
if not requested_types:
return 0 if single_type else {}

jobs_url = urljoin(self._server.url, 'jobs/')
counts = {}
for dev_type in requested_types:
params = {
'state': 'Submitted',
'requested_device_type': dev_type,
}
resp = self._get_response(jobs_url, params=params)
counts[dev_type] = resp.get('count', 0)

if single_type:
return counts.get(requested_types[0], 0)
return counts

def get_device_names_by_type(self, device_type, online_only=False):
"""Get device names for a given LAVA device type.

*device_type* can be a string or list of device type names.
*online_only* filters devices with health == 'Good' when available.
Use this with get_devicetype_job_count() to gate submissions when the
queue per device type exceeds a threshold.
"""
if self._server.url is None:
raise ValueError("LAVA server URL is not configured")

single_type = isinstance(device_type, str)
if single_type:
device_types = [device_type]
else:
device_types = list(device_type or [])
if not device_types:
return [] if single_type else {}

devices_url = urljoin(self._server.url, 'devices/')
result = {}
for dev_type in device_types:
params = {'device_type': dev_type}
devices = self._get_all(devices_url, params=params)
names = []
for device in devices:
if device.get('device_type') != dev_type:
continue
if online_only and device.get('health') not in (None, 'Good'):
continue
hostname = device.get('hostname') or device.get('name')
if hostname:
names.append(hostname)
result[dev_type] = names
if single_type:
return result.get(device_types[0], [])
return result

def _submit(self, job):
if self._server.url is None:
return self._store_job_in_external_storage(job)
Expand Down
2 changes: 2 additions & 0 deletions tests/configs/runtimes.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ runtimes:
priority_max: 100
queue_timeout:
days: 1
max_queue_depth: 50
notify: {}
filters:
- blocklist:
Expand All @@ -46,6 +47,7 @@ runtimes:
priority_max: 45
queue_timeout:
hours: 1
max_queue_depth: 50
notify:
callback:
token: some-token-name
Expand Down
130 changes: 130 additions & 0 deletions tests/test_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,53 @@
# implementation.
# pylint: disable=protected-access

import types

import kernelci.config
import kernelci.runtime


class _FakeResponse:
def __init__(self, payload, status_code=200, text=""):
"""Initialize a fake response with payload and status."""
self._payload = payload
self.status_code = status_code
self.text = text

def raise_for_status(self):
"""Raise an error when the response indicates failure."""
if self.status_code >= 400:
raise RuntimeError(f"HTTP {self.status_code}")

def json(self):
"""Return the preloaded JSON payload."""
return self._payload


class _FakeSession:
def __init__(self, get_handler=None, post_handler=None):
"""Initialize a fake session with optional handlers."""
self._get_handler = get_handler
self._post_handler = post_handler
self.calls = []

def get(self, url, params=None, timeout=30): # pylint: disable=unused-argument
"""Invoke the GET handler and return a fake response."""
if not self._get_handler:
raise AssertionError("GET handler not set")
self.calls.append((url, params))
return _FakeResponse(self._get_handler(url, params))

def post( # pylint: disable=unused-argument
self, url, json=None, allow_redirects=False, timeout=30
):
"""Invoke the POST handler and return its response."""
if not self._post_handler:
raise AssertionError("POST handler not set")
self.calls.append((url, json))
return self._post_handler(url, json)


def test_runtimes_init():
"""Test that all the runtimes can be initialised (offline)"""
config = kernelci.config.load('tests/configs/runtimes.yaml')
Expand Down Expand Up @@ -66,3 +109,90 @@ def test_lava_priority_scale():
spec_priority = int(priority)
print(f"* {plan_name:12s} {lab_priority:3d} {spec_priority:3d}")
assert lab_priority == spec_priority


def test_lava_get_devicetype_job_count():
"""Test queued job count via jobs API with state=Submitted."""
config = kernelci.config.load('tests/configs/lava-runtimes.yaml')
runtime_config = config['runtimes']['lab-min-12-max-40-new-runtime']
lab = kernelci.runtime.get_runtime(runtime_config)

def handler(url, params):
assert url.endswith('jobs/')
assert params.get('state') == 'Submitted'
dev_type = params.get('requested_device_type')
if dev_type == 'type-a':
return {'count': 61, 'next': None, 'previous': None, 'results': []}
if dev_type == 'type-b':
return {'count': 40, 'next': None, 'previous': None, 'results': []}
raise AssertionError(f"Unexpected request: {url} {params}")

lab._server = types.SimpleNamespace(
url='http://lava/api/v0.2/',
session=_FakeSession(get_handler=handler),
)

counts = lab.get_devicetype_job_count(['type-a', 'type-b'])
assert counts == {'type-a': 61, 'type-b': 40}


def test_lava_get_device_names_by_type():
"""Test device name lookups with type filtering and health checks."""
config = kernelci.config.load('tests/configs/lava-runtimes.yaml')
runtime_config = config['runtimes']['lab-min-12-max-40-new-runtime']
lab = kernelci.runtime.get_runtime(runtime_config)

def handler(url, params):
if url.endswith('devices/'):
dev_type = params.get('device_type')
if dev_type == 'type-a':
return {
'results': [
{'device_type': 'type-a', 'hostname': 'dev-1', 'health': 'Good'},
{'device_type': 'type-a', 'hostname': 'dev-2', 'health': 'Bad'},
{'device_type': 'type-b', 'hostname': 'dev-x', 'health': 'Good'},
],
'next': None,
}
if dev_type == 'type-b':
return {
'results': [
{'device_type': 'type-b', 'hostname': 'dev-3', 'health': 'Good'},
],
'next': None,
}
raise AssertionError(f"Unexpected request: {url} {params}")

lab._server = types.SimpleNamespace(
url='http://lava/api/v0.2/',
session=_FakeSession(get_handler=handler),
)

names = lab.get_device_names_by_type('type-a', online_only=True)
assert names == ['dev-1']

names_by_type = lab.get_device_names_by_type(['type-a', 'type-b'])
assert names_by_type == {'type-a': ['dev-1', 'dev-2'], 'type-b': ['dev-3']}


def test_lava_submit_rest():
"""Test LAVA REST submission builds a job with expected payload."""
config = kernelci.config.load('tests/configs/lava-runtimes.yaml')
runtime_config = config['runtimes']['lab-min-12-max-40-new-runtime']
lab = kernelci.runtime.get_runtime(runtime_config)

captured = {}

def post_handler(url, payload):
assert url.endswith('jobs/')
captured['json'] = payload
return _FakeResponse({'job_ids': [123]})

lab._server = types.SimpleNamespace(
url='http://lava/api/v0.2/',
session=_FakeSession(post_handler=post_handler),
)

job_id = lab._submit("jobdef")
assert job_id == 123
assert captured['json']['definition'] == "jobdef"