Skip to content

Implement DB Associations API and a little refactoring. #304

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jun 25, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
wip
  • Loading branch information
tazend committed Jun 10, 2023
commit 258c747affaf7fc84f7a5f100bcd81a222b0d51a
10 changes: 10 additions & 0 deletions pyslurm/db/assoc.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ from pyslurm.db.tres cimport (
merge_tres_str,
tres_ids_to_names,
TrackableResources,
TrackableResourceLimits,
)
from pyslurm.db.connection cimport Connection
from pyslurm.utils cimport cstr
Expand All @@ -70,6 +71,15 @@ cdef class Association:
QualitiesOfService qos_data
TrackableResources tres_data

cdef public:
group_tres
group_tres_mins
group_tres_run_mins
max_tres_mins_per_job
max_tres_run_mins_per_user
max_tres_per_job
max_tres_per_node

@staticmethod
cdef Association from_ptr(slurmdb_assoc_rec_t *in_ptr)

102 changes: 33 additions & 69 deletions pyslurm/db/assoc.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ cdef class Associations(dict):
assoc = Association.from_ptr(<slurmdb_assoc_rec_t*>assoc_ptr.data)
assoc.qos_data = qos_data
assoc.tres_data = tres_data
assoc._parse_tres()
assoc_dict[assoc.id] = assoc

return assoc_dict
Expand Down Expand Up @@ -194,6 +195,23 @@ cdef class Association:
wrap.ptr = in_ptr
return wrap

def _parse_tres(self):
cdef TrackableResources tres = self.tres_data
self.group_tres = TrackableResourceLimits.from_ids(
self.ptr.grp_tres, tres)
self.group_tres_mins = TrackableResourceLimits.from_ids(
self.ptr.grp_tres_mins, tres)
self.group_tres_run_mins = TrackableResourceLimits.from_ids(
self.ptr.grp_tres_mins, tres)
self.max_tres_mins_per_job = TrackableResourceLimits.from_ids(
self.ptr.max_tres_mins_pj, tres)
self.max_tres_run_mins_per_user = TrackableResourceLimits.from_ids(
self.ptr.max_tres_run_mins, tres)
self.max_tres_per_job = TrackableResourceLimits.from_ids(
self.ptr.max_tres_pj, tres)
self.max_tres_per_node = TrackableResourceLimits.from_ids(
self.ptr.max_tres_pn, tres)

def as_dict(self):
"""Database Association information formatted as a dictionary.

Expand All @@ -204,19 +222,20 @@ cdef class Association:

def _validate_tres(self):
self.tres_data = TrackableResources.load(name_is_key=False)
self.group_tres = tres_names_to_ids(self.group_tres, self.tres_data)
self.group_tres_mins = tres_names_to_ids(
self.group_tres_mins, self.tres_data)
self.group_tres_run_mins = tres_names_to_ids(
self.group_tres_run_mins, self.tres_data)
self.max_tres_mins_per_job = tres_names_to_ids(
self.max_tres_mins_per_job, self.tres_data)
self.max_tres_run_mins_per_user = tres_names_to_ids(
self.max_tres_run_mins_per_user, self.tres_data)
self.max_tres_per_job = tres_names_to_ids(
self.max_tres_per_job, self.tres_data)
self.max_tres_per_node = tres_names_to_ids(
self.max_tres_per_node, self.tres_data)
cstr.from_dict(&self.ptr.grp_tres,
self.group_tres._validate(self.tres_data))
cstr.from_dict(&self.ptr.grp_tres_mins,
self.group_tres_mins._validate(self.tres_data))
cstr.from_dict(&self.ptr.grp_tres_run_mins,
self.group_tres_run_mins._validate(self.tres_data))
cstr.from_dict(&self.ptr.max_tres_mins_pj,
self.max_tres_mins_per_job._validate(self.tres_data))
cstr.from_dict(&self.ptr.max_tres_run_mins,
self.max_tres_run_mins_per_user._validate(self.tres_data))
cstr.from_dict(&self.ptr.max_tres_pj,
self.max_tres_per_job._validate(self.tres_data))
cstr.from_dict(&self.ptr.max_tres_pn,
self.max_tres_per_node._validate(self.tres_data))

@staticmethod
def load(name):
Expand Down Expand Up @@ -274,30 +293,6 @@ cdef class Association:
def group_submit_jobs(self, val):
self.ptr.grp_submit_jobs = u32(val, zero_is_noval=False)

@property
def group_tres(self):
return tres_ids_to_names(self.ptr.grp_tres, self.tres_data)

@group_tres.setter
def group_tres(self, val):
cstr.from_dict(&self.ptr.grp_tres, val)

@property
def group_tres_mins(self):
return tres_ids_to_names(self.ptr.grp_tres_mins, self.tres_data)

@group_tres_mins.setter
def group_tres_mins(self, val):
cstr.from_dict(&self.ptr.grp_tres_mins, val)

@property
def group_tres_run_mins(self):
return tres_ids_to_names(self.ptr.grp_tres_run_mins, self.tres_data)

@group_tres_run_mins.setter
def group_tres_run_mins(self, val):
cstr.from_dict(&self.ptr.grp_tres_run_mins, val)

@property
def group_wall_time(self):
return u32_parse(self.ptr.grp_wall, zero_is_noval=False)
Expand Down Expand Up @@ -346,38 +341,6 @@ cdef class Association:
def max_submit_jobs(self, val):
self.ptr.max_submit_jobs = u32(val, zero_is_noval=False)

@property
def max_tres_mins_per_job(self):
return tres_ids_to_names(self.ptr.max_tres_mins_pj, self.tres_data)

@max_tres_mins_per_job.setter
def max_tres_mins_per_job(self, val):
cstr.from_dict(&self.ptr.max_tres_mins_pj, val)

@property
def max_tres_run_mins_per_user(self):
return tres_ids_to_names(self.ptr.max_tres_run_mins, self.tres_data)

@max_tres_run_mins_per_user.setter
def max_tres_run_mins_per_user(self, val):
cstr.from_dict(&self.ptr.max_tres_run_mins, val)

@property
def max_tres_per_job(self):
return tres_ids_to_names(self.ptr.max_tres_pj, self.tres_data)

@max_tres_per_job.setter
def max_tres_per_job(self, val):
cstr.from_dict(&self.ptr.max_tres_pj, val)

@property
def max_tres_per_node(self):
return tres_ids_to_names(self.ptr.max_tres_pn, self.tres_data)

@max_tres_per_node.setter
def max_tres_per_node(self, val):
cstr.from_dict(&self.ptr.max_tres_pn, val)

@property
def max_wall_time_per_job(self):
return u32_parse(self.ptr.max_wall_pj, zero_is_noval=False)
Expand Down Expand Up @@ -424,6 +387,7 @@ cdef class Association:

@qos.setter
def qos(self, val):
# TODO: must be ids
make_char_list(&self.ptr.qos_list, val)

@property
Expand Down
18 changes: 18 additions & 0 deletions pyslurm/db/tres.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,24 @@ cdef merge_tres_str(char **tres_str, typ, val)
cdef tres_ids_to_names(char *tres_str, TrackableResources tres_data)


cdef class TrackableResourceLimits:

cdef public:
cpu
mem
energy
node
billing
fs
vmem
pages
gres
license

@staticmethod
cdef from_ids(char *tres_id_str, TrackableResources tres_data)


cdef class TrackableResourceFilter:
cdef slurmdb_tres_cond_t *ptr

Expand Down
95 changes: 86 additions & 9 deletions pyslurm/db/tres.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,85 @@
from pyslurm.utils.uint import *
from pyslurm.constants import UNLIMITED
from pyslurm.core.error import RPCError
from pyslurm.utils.helpers import instance_to_dict
from pyslurm.utils import cstr
import json


TRES_TYPE_DELIM = "/"


cdef class TrackableResourceLimits:

def __init__(self, **kwargs):
self.fs = {}
self.gres = {}
self.license = {}

for k, v in kwargs.items():
if TRES_TYPE_DELIM in k:
typ, name = self._unflatten_tres(k)
cur_val = getattr(self, typ)

if not isinstance(cur_val, dict):
raise ValueError(f"TRES Type {typ} cannot have a name "
f"({name}). Invalid Value: {typ}/{name}")

cur_val.update({name : int(v)})
setattr(self, typ, cur_val)
else:
setattr(self, k, v)

@staticmethod
cdef from_ids(char *tres_id_str, TrackableResources tres_data):
tres_list = tres_ids_to_names(tres_id_str, tres_data)
if not tres_list:
return None

cdef TrackableResourceLimits out = TrackableResourceLimits()

for tres in tres_list:
typ, name, cnt = tres
cur_val = getattr(out, typ, slurm.NO_VAL64)
if cur_val != slurm.NO_VAL64:
if isinstance(cur_val, dict):
cur_val.update({name : cnt})
setattr(out, typ, cur_val)
else:
setattr(out, typ, cnt)

return out

def _validate(self, TrackableResources tres_data):
id_dict = tres_names_to_ids(self.as_dict(flatten_limits=True),
tres_data)
return id_dict

def _unflatten_tres(self, type_and_name):
typ, name = type_and_name.split(TRES_TYPE_DELIM, 1)
return typ, name

def _flatten_tres(self, typ, vals):
cdef dict out = {}
for name, cnt in vals.items():
out[f"{typ}{TRES_TYPE_DELIM}{name}"] = cnt

return out

def as_dict(self, flatten_limits=False):
cdef dict inst_dict = instance_to_dict(self)

if flatten_limits:
vals = inst_dict.pop("fs")
inst_dict.update(self._flatten_tres("fs", vals))

vals = inst_dict.pop("license")
inst_dict.update(self._flatten_tres("license", vals))

vals = inst_dict.pop("gres")
inst_dict.update(self._flatten_tres("gres", vals))

return inst_dict


cdef class TrackableResourceFilter:
Expand Down Expand Up @@ -162,7 +241,7 @@ cdef class TrackableResource:
def type_and_name(self):
type_and_name = self.type
if self.name:
type_and_name = f"{type_and_name}/{self.name}"
type_and_name = f"{type_and_name}{TRES_TYPE_DELIM}{self.name}"

return type_and_name

Expand Down Expand Up @@ -206,24 +285,22 @@ cdef merge_tres_str(char **tres_str, typ, val):

cdef tres_ids_to_names(char *tres_str, TrackableResources tres_data):
if not tres_str:
return {}
return None

cdef:
dict tdict = cstr.to_dict(tres_str)
dict out = {}
list out = []

if not tres_data:
return tdict
return None

for tid, cnt in tdict.items():
if isinstance(tid, str) and tid.isdigit():
_tid = int(tid)
if _tid in tres_data:
out[tres_data[_tid].type_and_name] = cnt
continue

# If we can't find the TRES ID in our data, return it raw.
out[tid] = cnt
out.append(
(tres_data[_tid].type, tres_data[_tid].name, int(cnt))
)

return out

Expand Down