Skip to content

Commit

Permalink
Adding Principal name in resource group selector criteria
Browse files Browse the repository at this point in the history
Adding principal name in the resource group selector criteria.
  • Loading branch information
swapsmagic committed Sep 6, 2023
1 parent 0419278 commit 9c7c230
Show file tree
Hide file tree
Showing 13 changed files with 123 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
import javax.annotation.PreDestroy;
import javax.inject.Inject;

import java.security.Principal;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.Executor;
Expand Down Expand Up @@ -299,7 +300,8 @@ private <C> void createQueryInternal(QueryId queryId, String slug, int retryCoun
sessionContext.getResourceEstimates(),
queryType.map(Enum::name),
Optional.ofNullable(sessionContext.getClientInfo()),
Optional.ofNullable(sessionContext.getSchema())));
Optional.ofNullable(sessionContext.getSchema()),
sessionContext.getIdentity().getPrincipal().map(Principal::getName)));

// apply system default session properties (does not override user set properties)
session = sessionPropertyDefaults.newSessionWithDefaultProperties(session, queryType.map(Enum::name), Optional.of(selectionContext.getResourceGroupId()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ protected List<ResourceGroupSelector> buildSelectors(ManagerSpec managerSpec)
spec.getQueryType(),
spec.getClientInfoRegex(),
spec.getSchema(),
spec.getPrincipalRegex(),
spec.getGroup()));
}
return selectors.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ public class SelectorSpec
private final Optional<List<String>> clientTags;
private final Optional<SelectorResourceEstimate> selectorResourceEstimate;
private final Optional<Pattern> clientInfoRegex;

private final Optional<String> schema;
private final Optional<Pattern> principalRegex;
private final ResourceGroupIdTemplate group;

@JsonCreator
Expand All @@ -45,6 +45,7 @@ public SelectorSpec(
@JsonProperty("selectorResourceEstimate") Optional<SelectorResourceEstimate> selectorResourceEstimate,
@JsonProperty("clientInfo") Optional<Pattern> clientInfoRegex,
@JsonProperty("schema") Optional<String> schema,
@JsonProperty("principal") Optional<Pattern> principal,
@JsonProperty("group") ResourceGroupIdTemplate group)
{
this.userRegex = requireNonNull(userRegex, "userRegex is null");
Expand All @@ -55,6 +56,7 @@ public SelectorSpec(
this.group = requireNonNull(group, "group is null");
this.clientInfoRegex = requireNonNull(clientInfoRegex, "clientInfoRegex is null");
this.schema = requireNonNull(schema, "schema is null");
this.principalRegex = requireNonNull(principal, "principal is null");
}

public Optional<Pattern> getUserRegex()
Expand Down Expand Up @@ -97,6 +99,11 @@ public Optional<String> getSchema()
return schema;
}

public Optional<Pattern> getPrincipalRegex()
{
return principalRegex;
}

@Override
public boolean equals(Object other)
{
Expand All @@ -116,7 +123,8 @@ public boolean equals(Object other)
clientTags.equals(that.clientTags) &&
clientInfoRegex.map(Pattern::pattern).equals(that.clientInfoRegex.map(Pattern::pattern)) &&
clientInfoRegex.map(Pattern::flags).equals(that.clientInfoRegex.map(Pattern::flags)) &&
schema.equals(that.schema);
schema.equals(that.schema) &&
principalRegex.equals(that.principalRegex);
}

@Override
Expand All @@ -129,7 +137,9 @@ public int hashCode()
sourceRegex.map(Pattern::pattern),
sourceRegex.map(Pattern::flags),
queryType,
clientTags);
clientTags,
principalRegex.map(Pattern::pattern),
principalRegex.map(Pattern::flags));
}

@Override
Expand All @@ -144,6 +154,8 @@ public String toString()
.add("queryType", queryType)
.add("clientTags", clientTags)
.add("clientInfoRegex", clientInfoRegex)
.add("principalRegex", principalRegex)
.add("principalFlags", principalRegex.map(Pattern::flags))
.toString();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ public class StaticSelector
private final Optional<Pattern> clientInfoRegex;

private final Optional<String> schema;
private final Optional<Pattern> principalRegex;
private final ResourceGroupIdTemplate group;
private final Set<String> variableNames;

Expand All @@ -63,6 +64,7 @@ public StaticSelector(
Optional<String> queryType,
Optional<Pattern> clientInfoRegex,
Optional<String> schema,
Optional<Pattern> principalRegex,
ResourceGroupIdTemplate group)
{
this.userRegex = requireNonNull(userRegex, "userRegex is null");
Expand All @@ -73,6 +75,7 @@ public StaticSelector(
this.queryType = requireNonNull(queryType, "queryType is null");
this.clientInfoRegex = requireNonNull(clientInfoRegex, "clientInfoRegex is null");
this.schema = requireNonNull(schema, "schema is null");
this.principalRegex = requireNonNull(principalRegex, "principalRegex is null");
this.group = requireNonNull(group, "group is null");

HashSet<String> variableNames = new HashSet<>(ImmutableList.of(USER_VARIABLE, SOURCE_VARIABLE, SCHEMA_VARIABLE));
Expand Down Expand Up @@ -106,6 +109,16 @@ public Optional<SelectionContext<VariableMap>> match(SelectionCriteria criteria)

addVariableValues(sourceRegex.get(), source, variables);
}

if (principalRegex.isPresent()) {
String principal = criteria.getPrincipal().orElse("");
if (!principalRegex.get().matcher(principal).matches()) {
return Optional.empty();
}

addVariableValues(principalRegex.get(), principal, variables);
}

if (!clientTags.isEmpty() && !criteria.getTags().containsAll(clientTags)) {
return Optional.empty();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ public synchronized ManagerSpec getManagerSpec()
selectorRecord.getSelectorResourceEstimate(),
selectorRecord.getClientInfoRegex(),
selectorRecord.getSchema(),
selectorRecord.getPrincipalRegex(),
resourceGroupIdTemplateMap.get(selectorRecord.getResourceGroupId()))
).collect(toList());
ManagerSpec managerSpec = new ManagerSpec(rootGroups, selectors, getCpuQuotaPeriodFromDb());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public interface ResourceGroupsDao
@UseRowMapper(ResourceGroupSpecBuilder.Mapper.class)
List<ResourceGroupSpecBuilder> getResourceGroups(@Bind("environment") String environment);

@SqlQuery("SELECT S.resource_group_id, S.priority, S.user_regex, S.source_regex, S.query_type, S.client_tags, S.selector_resource_estimate, S.client_info_regex, S.schema\n" +
@SqlQuery("SELECT S.resource_group_id, S.priority, S.user_regex, S.source_regex, S.query_type, S.client_tags, S.selector_resource_estimate, S.client_info_regex, S.schema, S.principal_regex\n" +
"FROM selectors S\n" +
"JOIN resource_groups R ON (S.resource_group_id = R.resource_group_id)\n" +
"WHERE R.environment = :environment\n" +
Expand All @@ -82,6 +82,7 @@ public interface ResourceGroupsDao
" selector_resource_estimate VARCHAR(1024),\n" +
" client_info_regex VARCHAR(1024),\n" +
" schema VARCHAR(1024),\n" +
" principal_regex VARCHAR(1024),\n" +
" id BIGINT NOT NULL AUTO_INCREMENT,\n" +
" PRIMARY KEY (id),\n" +
" FOREIGN KEY (resource_group_id) REFERENCES resource_groups (resource_group_id)\n" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ public class SelectorRecord
private final Optional<List<String>> clientTags;
private final Optional<SelectorResourceEstimate> selectorResourceEstimate;
private final Optional<Pattern> clientInfoRegex;

private final Optional<String> schema;
private final Optional<Pattern> principalRegex;

public SelectorRecord(
long resourceGroupId,
Expand All @@ -51,7 +51,8 @@ public SelectorRecord(
Optional<List<String>> clientTags,
Optional<SelectorResourceEstimate> selectorResourceEstimate,
Optional<Pattern> clientInfoRegex,
Optional<String> schema)
Optional<String> schema,
Optional<Pattern> principalRegex)
{
this.resourceGroupId = resourceGroupId;
this.priority = priority;
Expand All @@ -62,6 +63,7 @@ public SelectorRecord(
this.selectorResourceEstimate = requireNonNull(selectorResourceEstimate, "selectorResourceEstimate is null");
this.clientInfoRegex = requireNonNull(clientInfoRegex, "clientInfoRegex is null");
this.schema = requireNonNull(schema, "schema is null");
this.principalRegex = requireNonNull(principalRegex, "principalRegex is null");
}

public long getResourceGroupId()
Expand Down Expand Up @@ -109,6 +111,11 @@ public Optional<String> getSchema()
return schema;
}

public Optional<Pattern> getPrincipalRegex()
{
return principalRegex;
}

public static class Mapper
implements RowMapper<SelectorRecord>
{
Expand All @@ -128,7 +135,8 @@ public SelectorRecord map(ResultSet resultSet, StatementContext context)
Optional.ofNullable(resultSet.getString("client_tags")).map(LIST_STRING_CODEC::fromJson),
Optional.ofNullable(resultSet.getString("selector_resource_estimate")).map(SELECTOR_RESOURCE_ESTIMATE_JSON_CODEC::fromJson),
Optional.ofNullable(resultSet.getString("client_info_regex")).map(Pattern::compile),
Optional.ofNullable(resultSet.getString("schema")));
Optional.ofNullable(resultSet.getString("schema")),
Optional.ofNullable(resultSet.getString("principal_regex")).map(Pattern::compile));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,14 @@ public void testQueryTypeConfiguration() throws IOException
{
FileResourceGroupConfigurationManager manager = parse("resource_groups_config_query_type.json");
List<ResourceGroupSelector> selectors = manager.getSelectors();
assertMatch(selectors, new SelectionCriteria(true, "test_user", Optional.empty(), ImmutableSet.of(), EMPTY_RESOURCE_ESTIMATES, Optional.of("select"), Optional.empty(), Optional.empty()), "global.select");
assertMatch(selectors, new SelectionCriteria(true, "test_user", Optional.empty(), ImmutableSet.of(), EMPTY_RESOURCE_ESTIMATES, Optional.of("explain"), Optional.empty(), Optional.empty()), "global.explain");
assertMatch(selectors, new SelectionCriteria(true, "test_user", Optional.empty(), ImmutableSet.of(), EMPTY_RESOURCE_ESTIMATES, Optional.of("insert"), Optional.empty(), Optional.empty()), "global.insert");
assertMatch(selectors, new SelectionCriteria(true, "test_user", Optional.empty(), ImmutableSet.of(), EMPTY_RESOURCE_ESTIMATES, Optional.of("delete"), Optional.empty(), Optional.empty()), "global.delete");
assertMatch(selectors, new SelectionCriteria(true, "test_user", Optional.empty(), ImmutableSet.of(), EMPTY_RESOURCE_ESTIMATES, Optional.of("describe"), Optional.empty(), Optional.empty()), "global.describe");
assertMatch(selectors, new SelectionCriteria(true, "test_user", Optional.empty(), ImmutableSet.of(), EMPTY_RESOURCE_ESTIMATES, Optional.of("data_definition"), Optional.empty(), Optional.empty()), "global.data_definition");
assertMatch(selectors, new SelectionCriteria(true, "test_user", Optional.empty(), ImmutableSet.of(), EMPTY_RESOURCE_ESTIMATES, Optional.of("sth_else"), Optional.empty(), Optional.empty()), "global.other");
assertMatch(selectors, new SelectionCriteria(true, "test_user", Optional.empty(), ImmutableSet.of(), EMPTY_RESOURCE_ESTIMATES, Optional.of("sth_else"), Optional.of("client2_34"), Optional.empty()), "global.other-2");
assertMatch(selectors, new SelectionCriteria(true, "test_user", Optional.empty(), ImmutableSet.of(), EMPTY_RESOURCE_ESTIMATES, Optional.of("select"), Optional.empty(), Optional.empty(), Optional.empty()), "global.select");
assertMatch(selectors, new SelectionCriteria(true, "test_user", Optional.empty(), ImmutableSet.of(), EMPTY_RESOURCE_ESTIMATES, Optional.of("explain"), Optional.empty(), Optional.empty(), Optional.empty()), "global.explain");
assertMatch(selectors, new SelectionCriteria(true, "test_user", Optional.empty(), ImmutableSet.of(), EMPTY_RESOURCE_ESTIMATES, Optional.of("insert"), Optional.empty(), Optional.empty(), Optional.empty()), "global.insert");
assertMatch(selectors, new SelectionCriteria(true, "test_user", Optional.empty(), ImmutableSet.of(), EMPTY_RESOURCE_ESTIMATES, Optional.of("delete"), Optional.empty(), Optional.empty(), Optional.empty()), "global.delete");
assertMatch(selectors, new SelectionCriteria(true, "test_user", Optional.empty(), ImmutableSet.of(), EMPTY_RESOURCE_ESTIMATES, Optional.of("describe"), Optional.empty(), Optional.empty(), Optional.empty()), "global.describe");
assertMatch(selectors, new SelectionCriteria(true, "test_user", Optional.empty(), ImmutableSet.of(), EMPTY_RESOURCE_ESTIMATES, Optional.of("data_definition"), Optional.empty(), Optional.empty(), Optional.empty()), "global.data_definition");
assertMatch(selectors, new SelectionCriteria(true, "test_user", Optional.empty(), ImmutableSet.of(), EMPTY_RESOURCE_ESTIMATES, Optional.of("sth_else"), Optional.empty(), Optional.empty(), Optional.empty()), "global.other");
assertMatch(selectors, new SelectionCriteria(true, "test_user", Optional.empty(), ImmutableSet.of(), EMPTY_RESOURCE_ESTIMATES, Optional.of("sth_else"), Optional.of("client2_34"), Optional.empty(), Optional.empty()), "global.other-2");
}

@Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "Selector specifies an invalid query type: invalid_query_type")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ public void testExtraction()
ResourceGroupId expected = new ResourceGroupId(new ResourceGroupId(new ResourceGroupId(new ResourceGroupId("test"), "pipeline"), "job_testpipeline_user:user"), "user");

Pattern sourcePattern = Pattern.compile("scheduler.important.(?<pipeline>[^\\[]*).*");
StaticSelector selector = new StaticSelector(Optional.empty(), Optional.of(sourcePattern), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), template);
SelectionCriteria context = new SelectionCriteria(true, "user", Optional.of("scheduler.important.testpipeline[5]"), ImmutableSet.of(), EMPTY_RESOURCE_ESTIMATES, Optional.empty(), Optional.empty(), Optional.empty());
StaticSelector selector = new StaticSelector(Optional.empty(), Optional.of(sourcePattern), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), template);
SelectionCriteria context = new SelectionCriteria(true, "user", Optional.of("scheduler.important.testpipeline[5]"), ImmutableSet.of(), EMPTY_RESOURCE_ESTIMATES, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty());

assertEquals(selector.match(context).map(SelectionContext::getResourceGroupId), Optional.of(expected));
}
Expand All @@ -78,8 +78,8 @@ public void testNoSource()
ResourceGroupId expected = new ResourceGroupId(new ResourceGroupId(new ResourceGroupId(new ResourceGroupId("test"), "pipeline"), "testpipeline"), "_s");

Pattern userPattern = Pattern.compile("scheduler.important.(?<pipeline>[^\\[]*).*");
StaticSelector selector = new StaticSelector(Optional.of(userPattern), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), template);
SelectionCriteria context = new SelectionCriteria(true, "scheduler.important.testpipeline[5]", Optional.empty(), ImmutableSet.of(), EMPTY_RESOURCE_ESTIMATES, Optional.empty(), Optional.empty(), Optional.empty());
StaticSelector selector = new StaticSelector(Optional.of(userPattern), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), template);
SelectionCriteria context = new SelectionCriteria(true, "scheduler.important.testpipeline[5]", Optional.empty(), ImmutableSet.of(), EMPTY_RESOURCE_ESTIMATES, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty());

assertEquals(selector.match(context).map(SelectionContext::getResourceGroupId), Optional.of(expected));
}
Expand All @@ -89,8 +89,8 @@ public void testNoMatch()
{
ResourceGroupIdTemplate template = new ResourceGroupIdTemplate("test.pipeline.${pipeline}.${USER}");
Pattern sourcePattern = Pattern.compile("scheduler.important.(?<pipeline>[^\\[]*).*");
StaticSelector selector = new StaticSelector(Optional.empty(), Optional.of(sourcePattern), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), template);
SelectionCriteria context = new SelectionCriteria(true, "user", Optional.of("scheduler.testpipeline[5]"), ImmutableSet.of(), EMPTY_RESOURCE_ESTIMATES, Optional.empty(), Optional.empty(), Optional.empty());
StaticSelector selector = new StaticSelector(Optional.empty(), Optional.of(sourcePattern), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), template);
SelectionCriteria context = new SelectionCriteria(true, "user", Optional.of("scheduler.testpipeline[5]"), ImmutableSet.of(), EMPTY_RESOURCE_ESTIMATES, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty());

assertFalse(selector.match(context).isPresent());
}
Expand All @@ -100,8 +100,8 @@ public void testUnresolvedVariableLoadTime()
{
ResourceGroupIdTemplate template = new ResourceGroupIdTemplate("test.pipeline.${pipeline}.${user}");
Pattern sourcePattern = Pattern.compile("scheduler.important.(?<pipeline>[^\\[]*).*");
StaticSelector selector = new StaticSelector(Optional.empty(), Optional.of(sourcePattern), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), template);
SelectionCriteria context = new SelectionCriteria(true, "user", Optional.of("scheduler.important.testpipeline[5]"), ImmutableSet.of(), EMPTY_RESOURCE_ESTIMATES, Optional.empty(), Optional.empty(), Optional.empty());
StaticSelector selector = new StaticSelector(Optional.empty(), Optional.of(sourcePattern), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), template);
SelectionCriteria context = new SelectionCriteria(true, "user", Optional.of("scheduler.important.testpipeline[5]"), ImmutableSet.of(), EMPTY_RESOURCE_ESTIMATES, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty());
selector.match(context);
}

Expand All @@ -110,8 +110,8 @@ public void testUnresolvedVariableRunTime()
{
ResourceGroupIdTemplate template = new ResourceGroupIdTemplate("test.pipeline.${pipeline}.${USER}");
Pattern sourcePattern = Pattern.compile("scheduler.important.(testpipeline\\[|(?<pipeline>[^\\[]*)).*");
StaticSelector selector = new StaticSelector(Optional.empty(), Optional.of(sourcePattern), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), template);
SelectionCriteria context = new SelectionCriteria(true, "user", Optional.of("scheduler.important.testpipeline[5]"), ImmutableSet.of(), EMPTY_RESOURCE_ESTIMATES, Optional.empty(), Optional.empty(), Optional.empty());
StaticSelector selector = new StaticSelector(Optional.empty(), Optional.of(sourcePattern), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), template);
SelectionCriteria context = new SelectionCriteria(true, "user", Optional.of("scheduler.important.testpipeline[5]"), ImmutableSet.of(), EMPTY_RESOURCE_ESTIMATES, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty());
selector.match(context);
}
}
Loading

0 comments on commit 9c7c230

Please sign in to comment.