Skip to content

Commit

Permalink
refactor random_permutations for simplicity
Browse files Browse the repository at this point in the history
alpha-reorder imports

fix broken import
  • Loading branch information
ankona committed Jun 13, 2023
1 parent 395ffb0 commit b544e68
Showing 1 changed file with 7 additions and 18 deletions.
25 changes: 7 additions & 18 deletions smartsim/entity/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,21 +56,10 @@ def step_values(
def random_permutations(
param_names: t.List[str], param_values: t.List[t.List[str]], n_models: int = 0
) -> t.List[t.Dict[str, str]]:
# first, check if we've requested more values than possible.
perms = list(product(*param_values))
if n_models >= len(perms):
return create_all_permutations(param_names, param_values)
else:
permutations: t.List[t.Dict[str, str]] = []
permutation_strings = set()
while len(permutations) < n_models:
model_dict = dict(
zip(
param_names,
map(lambda x: x[random.randint(0, len(x) - 1)], param_values),
)
)
if str(model_dict) not in permutation_strings:
permutation_strings.add(str(model_dict))
permutations.append(model_dict)
return permutations
permutations = create_all_permutations(param_names, param_values)

# sample from available permutations if n_models is specified
if n_models and n_models < len(permutations):
permutations = random.sample(permutations, n_models)

return permutations

0 comments on commit b544e68

Please sign in to comment.