Skip to content

Commit

Permalink
Use the active role in the session as the matching rule for the resou…
Browse files Browse the repository at this point in the history
…rce group (#18207)

Signed-off-by: HangyuanLiu <460660596@qq.com>
  • Loading branch information
HangyuanLiu authored Feb 23, 2023
1 parent 79992dc commit 5f6e4e0
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ private List<String> showClassifier(ResourceGroupClassifier classifier) {
return row;
}

public List<List<String>> showVisible(String user, String roleName, String ip) {
return classifiers.stream().filter(c -> c.isVisible(user, roleName, ip))
public List<List<String>> showVisible(String user, List<String> activeRoles, String ip) {
return classifiers.stream().filter(c -> c.isVisible(user, activeRoles, ip))
.map(this::showClassifier).collect(Collectors.toList());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,9 @@ public void write(DataOutput out) throws IOException {
Text.writeString(out, json);
}

public boolean isSatisfied(String user, String role, QueryType queryType, String sourceIp,
public boolean isSatisfied(String user, List<String> activeRoles, QueryType queryType, String sourceIp,
Set<Long> dbIds) {
if (!isVisible(user, role, sourceIp)) {
if (!isVisible(user, activeRoles, sourceIp)) {
return false;
}
if (CollectionUtils.isNotEmpty(queryTypes) && !this.queryTypes.contains(queryType)) {
Expand All @@ -136,11 +136,11 @@ public boolean isSatisfied(String user, String role, QueryType queryType, String
return true;
}

public boolean isVisible(String user, String role, String sourceIp) {
public boolean isVisible(String user, List<String> activeRoles, String sourceIp) {
if (this.user != null && !this.user.equals(user)) {
return false;
}
if (this.role != null && !this.role.equals(role)) {
if (this.role != null && !activeRoles.contains(role)) {
return false;
}
if (this.sourceIp != null && sourceIp != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package com.starrocks.catalog;

import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.gson.annotations.SerializedName;
import com.starrocks.common.AnalysisException;
import com.starrocks.common.DdlException;
Expand All @@ -26,6 +27,8 @@
import com.starrocks.persist.gson.GsonUtils;
import com.starrocks.privilege.PrivilegeBuiltinConstants;
import com.starrocks.privilege.PrivilegeException;
import com.starrocks.privilege.PrivilegeManager;
import com.starrocks.privilege.RolePrivilegeCollection;
import com.starrocks.qe.ConnectContext;
import com.starrocks.server.GlobalStateMgr;
import com.starrocks.sql.ast.AlterResourceGroupStmt;
Expand Down Expand Up @@ -154,34 +157,44 @@ private String getUnqualifiedUser(ConnectContext ctx) {
return userParts[userParts.length - 1];
}

private String getUnqualifiedRole(ConnectContext ctx) {
private List<String> getUnqualifiedRole(ConnectContext ctx) {
Preconditions.checkArgument(ctx != null);
String roleName = null;

if (GlobalStateMgr.getCurrentState().isUsingNewPrivilege()) {
try {
List<String> roleNameList = GlobalStateMgr.getCurrentState().getPrivilegeManager()
.getRoleNamesByUser(ctx.getCurrentUserIdentity());
roleNameList =
roleNameList.stream().filter(r -> !PrivilegeBuiltinConstants.BUILT_IN_ROLE_NAMES.contains(r))
.collect(Collectors.toList());
if (roleNameList.isEmpty()) {
return null;
} else {
return roleNameList.get(0);
PrivilegeManager manager = GlobalStateMgr.getCurrentState().getPrivilegeManager();
List<String> validRoles = new ArrayList<>();

Set<Long> activeRoles = ctx.getCurrentRoleIds();
if (activeRoles == null) {
activeRoles = manager.getRoleIdsByUser(ctx.getCurrentUserIdentity());
}

for (Long roleId : activeRoles) {
RolePrivilegeCollection rolePrivilegeCollection =
manager.getRolePrivilegeCollectionUnlocked(roleId, false);
if (rolePrivilegeCollection != null) {
validRoles.add(rolePrivilegeCollection.getName());
}
}

return validRoles.stream().filter(r -> !PrivilegeBuiltinConstants.BUILT_IN_ROLE_NAMES.contains(r))
.collect(Collectors.toList());
} catch (PrivilegeException e) {
LOG.info("getUnqualifiedRole failed for resource group, error message: " + e.getMessage());
return null;
}
}

String roleName = null;
String qualifiedRoleName = GlobalStateMgr.getCurrentState().getAuth()
.getRoleName(ctx.getCurrentUserIdentity());
if (qualifiedRoleName != null) {
//default_cluster:role
String[] roleParts = qualifiedRoleName.split(":");
roleName = roleParts[roleParts.length - 1];
}
return roleName;
return Lists.newArrayList(roleName);
}

public List<List<String>> showAllResourceGroups(ConnectContext ctx, Boolean isListAll) {
Expand All @@ -194,9 +207,9 @@ public List<List<String>> showAllResourceGroups(ConnectContext ctx, Boolean isLi
.flatMap(Collection::stream).collect(Collectors.toList());
} else {
String user = getUnqualifiedUser(ctx);
String role = getUnqualifiedRole(ctx);
List<String> activeRoles = getUnqualifiedRole(ctx);
String remoteIp = ctx.getRemoteIP();
return resourceGroupList.stream().map(rg -> rg.showVisible(user, role, remoteIp))
return resourceGroupList.stream().map(rg -> rg.showVisible(user, activeRoles, remoteIp))
.flatMap(Collection::stream).collect(Collectors.toList());
}
} finally {
Expand Down Expand Up @@ -502,7 +515,7 @@ public TWorkGroup chooseResourceGroupByID(long wgID) {
}

public TWorkGroup chooseResourceGroup(ConnectContext ctx, ResourceGroupClassifier.QueryType queryType, Set<Long> databases) {
String role = getUnqualifiedRole(ctx);
List<String> activeRoles = getUnqualifiedRole(ctx);

readLock();
try {
Expand All @@ -513,7 +526,7 @@ public TWorkGroup chooseResourceGroup(ConnectContext ctx, ResourceGroupClassifie
if (shortQueryResourceGroup != null) {
List<ResourceGroupClassifier> shortQueryClassifierList =
shortQueryResourceGroup.classifiers.stream().filter(
f -> f.isSatisfied(user, role, queryType, remoteIp, databases))
f -> f.isSatisfied(user, activeRoles, queryType, remoteIp, databases))
.sorted(Comparator.comparingDouble(ResourceGroupClassifier::weight))
.collect(Collectors.toList());
if (!shortQueryClassifierList.isEmpty()) {
Expand All @@ -522,7 +535,7 @@ public TWorkGroup chooseResourceGroup(ConnectContext ctx, ResourceGroupClassifie
}

List<ResourceGroupClassifier> classifierList =
classifierMap.values().stream().filter(f -> f.isSatisfied(user, role, queryType, remoteIp, databases))
classifierMap.values().stream().filter(f -> f.isSatisfied(user, activeRoles, queryType, remoteIp, databases))
.sorted(Comparator.comparingDouble(ResourceGroupClassifier::weight))
.collect(Collectors.toList());
if (classifierList.isEmpty()) {
Expand Down

0 comments on commit 5f6e4e0

Please sign in to comment.