Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cherry-pick][branch-2.3][BugFix] Support authentication for StarRocks external table (#11871) #12012

Merged
merged 8 commits into from
Oct 10, 2022
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
import com.starrocks.service.FrontendOptions;
import com.starrocks.sql.ast.AstVisitor;
import com.starrocks.sql.ast.QueryStatement;
import com.starrocks.thrift.TAuthenticateParams;
import com.starrocks.thrift.TUniqueId;
import com.starrocks.transaction.TransactionState;
import com.starrocks.transaction.TransactionState.LoadJobSourceType;
Expand Down Expand Up @@ -323,13 +324,19 @@ public void analyze(Analyzer analyzer) throws UserException {
if (targetTable instanceof ExternalOlapTable) {
LoadJobSourceType sourceType = LoadJobSourceType.INSERT_STREAMING;
ExternalOlapTable externalTable = (ExternalOlapTable) targetTable;
TAuthenticateParams authenticateParams = new TAuthenticateParams();
authenticateParams.setUser(externalTable.getSourceTableUser());
authenticateParams.setPasswd(externalTable.getSourceTablePassword());
authenticateParams.setHost(ConnectContext.get().getRemoteIP());
authenticateParams.setDb_name(externalTable.getSourceTableDbName());
authenticateParams.setTable_names(Lists.newArrayList(externalTable.getSourceTableName()));
transactionId = GlobalStateMgr.getCurrentGlobalTransactionMgr()
.beginRemoteTransaction(externalTable.getSourceTableDbId(),
Lists.newArrayList(externalTable.getSourceTableId()), label,
externalTable.getSourceTableHost(),
externalTable.getSourceTablePort(),
new TxnCoordinator(TxnSourceType.FE, FrontendOptions.getLocalHostAddress()),
sourceType, timeoutSecond);
sourceType, timeoutSecond, authenticateParams);
} else if (targetTable instanceof OlapTable) {
LoadJobSourceType sourceType = LoadJobSourceType.INSERT_STREAMING;
MetricRepo.COUNTER_LOAD_ADD.increase(1L);
Expand Down
7 changes: 7 additions & 0 deletions fe/fe-core/src/main/java/com/starrocks/common/Config.java
Original file line number Diff line number Diff line change
Expand Up @@ -1184,6 +1184,13 @@ public class Config extends ConfigBase {
@ConfField
public static boolean enable_auth_check = true;

/**
* If set to false, auth check for StarRocks external table will be disabled. The check
* only happens on the target cluster.
*/
@ConfField(mutable = true)
public static boolean enable_starrocks_external_table_auth_check = true;

/**
* ldap server host for authentication_ldap_simple
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import com.starrocks.catalog.ExternalOlapTable;
import com.starrocks.common.Config;
import com.starrocks.rpc.FrontendServiceProxy;
import com.starrocks.thrift.TAuthenticateParams;
import com.starrocks.thrift.TGetTableMetaRequest;
import com.starrocks.thrift.TGetTableMetaResponse;
import com.starrocks.thrift.TNetworkAddress;
Expand All @@ -25,10 +24,6 @@ public void syncTable(ExternalOlapTable table) {
TGetTableMetaRequest request = new TGetTableMetaRequest();
request.setDb_name(table.getSourceTableDbName());
request.setTable_name(table.getSourceTableName());
TAuthenticateParams authInfo = new TAuthenticateParams();
authInfo.setUser(table.getSourceTableUser());
authInfo.setPasswd(table.getSourceTablePassword());
request.setAuth_info(authInfo);
try {
TGetTableMetaResponse response = FrontendServiceProxy.call(addr,
Config.thrift_rpc_timeout_ms,
Expand Down
33 changes: 32 additions & 1 deletion fe/fe-core/src/main/java/com/starrocks/http/BaseAction.java
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,16 @@ public String toString() {
sb.append(", password: ").append(password).append(", cluster: ").append(cluster);
return sb.toString();
}

public static ActionAuthorizationInfo of(
String fullUserName, String password, String remoteIp, String cluster) {
ActionAuthorizationInfo authInfo = new ActionAuthorizationInfo();
authInfo.fullUserName = fullUserName;
authInfo.remoteIp = remoteIp;
authInfo.password = password;
authInfo.cluster = cluster;
return authInfo;
}
}

protected void checkGlobalAuth(UserIdentity currentUser, PrivPredicate predicate) throws UnauthorizedException {
Expand All @@ -296,7 +306,7 @@ protected void checkTblAuth(UserIdentity currentUser, String db, String tbl, Pri
}

// return currentUserIdentity from StarRocks auth
protected UserIdentity checkPassword(ActionAuthorizationInfo authInfo)
public static UserIdentity checkPassword(ActionAuthorizationInfo authInfo)
throws UnauthorizedException {
List<UserIdentity> currentUser = Lists.newArrayList();
if (!GlobalStateMgr.getCurrentState().getAuth().checkPlainPassword(authInfo.fullUserName,
Expand Down Expand Up @@ -368,6 +378,27 @@ private boolean parseAuthInfo(BaseRequest request, ActionAuthorizationInfo authI
return true;
}

// Refer to {@link #parseAuthInfo(BaseRequest, ActionAuthorizationInfo)}
public static ActionAuthorizationInfo parseAuthInfo(String fullUserName, String password, String host) {
ActionAuthorizationInfo authInfo = new ActionAuthorizationInfo();
final String[] elements = fullUserName.split("@");
if (elements.length < 2) {
authInfo.fullUserName = ClusterNamespace.getFullName(SystemInfoService.DEFAULT_CLUSTER, fullUserName);
authInfo.cluster = SystemInfoService.DEFAULT_CLUSTER;
} else if (elements.length == 2) {
authInfo.fullUserName = ClusterNamespace.getFullName(elements[1], elements[0]);
authInfo.cluster = elements[1];
} else {
authInfo.fullUserName = fullUserName;
}
authInfo.password = password;
authInfo.remoteIp = host;

LOG.debug("Parse result for the input [{} {} {}]: {}", fullUserName, password, host, authInfo);

return authInfo;
}

protected int checkIntParam(String strParam) {
return Integer.parseInt(strParam);
}
Expand Down
10 changes: 9 additions & 1 deletion fe/fe-core/src/main/java/com/starrocks/qe/StmtExecutor.java
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@
import com.starrocks.statistic.Constants;
import com.starrocks.statistic.StatisticExecutor;
import com.starrocks.task.LoadEtlTask;
import com.starrocks.thrift.TAuthenticateParams;
import com.starrocks.thrift.TDescriptorTable;
import com.starrocks.thrift.TQueryOptions;
import com.starrocks.thrift.TQueryType;
Expand Down Expand Up @@ -1105,6 +1106,12 @@ public void handleDMLStmt(ExecPlan execPlan, DmlStmt stmt) throws Exception {
long transactionId = -1;
if (targetTable instanceof ExternalOlapTable) {
ExternalOlapTable externalTable = (ExternalOlapTable) targetTable;
TAuthenticateParams authenticateParams = new TAuthenticateParams();
authenticateParams.setUser(externalTable.getSourceTableUser());
authenticateParams.setPasswd(externalTable.getSourceTablePassword());
authenticateParams.setHost(context.getRemoteIP());
authenticateParams.setDb_name(externalTable.getSourceTableDbName());
authenticateParams.setTable_names(Lists.newArrayList(externalTable.getSourceTableName()));
transactionId =
GlobalStateMgr.getCurrentGlobalTransactionMgr()
.beginRemoteTransaction(externalTable.getSourceTableDbId(),
Expand All @@ -1114,7 +1121,8 @@ public void handleDMLStmt(ExecPlan execPlan, DmlStmt stmt) throws Exception {
new TransactionState.TxnCoordinator(TransactionState.TxnSourceType.FE,
FrontendOptions.getLocalHostAddress()),
sourceType,
ConnectContext.get().getSessionVariable().getQueryTimeoutS());
ConnectContext.get().getSessionVariable().getQueryTimeoutS(),
authenticateParams);
} else {
transactionId = GlobalStateMgr.getCurrentGlobalTransactionMgr().beginTransaction(
database.getId(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@
import com.starrocks.common.ThriftServerEventProcessor;
import com.starrocks.common.UserException;
import com.starrocks.common.util.DebugUtil;
import com.starrocks.http.BaseAction;
import com.starrocks.http.UnauthorizedException;
import com.starrocks.load.loadv2.ManualLoadTxnCommitAttachment;
import com.starrocks.load.routineload.RLTaskTxnCommitAttachment;
import com.starrocks.master.MasterImpl;
Expand Down Expand Up @@ -82,6 +84,7 @@
import com.starrocks.thrift.FrontendServiceVersion;
import com.starrocks.thrift.TAbortRemoteTxnRequest;
import com.starrocks.thrift.TAbortRemoteTxnResponse;
import com.starrocks.thrift.TAuthenticateParams;
import com.starrocks.thrift.TBeginRemoteTxnRequest;
import com.starrocks.thrift.TBeginRemoteTxnResponse;
import com.starrocks.thrift.TColumnDef;
Expand Down Expand Up @@ -1212,13 +1215,84 @@ private String getClientAddrAsString() {
return addr == null ? "unknown" : addr.hostname;
}

// Authenticate a FrontendServiceImpl#beginRemoteTxn RPC for StarRocks external table.
// The beginRemoteTxn is sent by the source cluster, and received by the target cluster.
// The target cluster should do authentication using the TAuthenticateParams. This method
// will check whether the user has an authorization, and whether the user has a
// PrivPredicate.LOAD on the given tables. The implementation is similar with that
// of stream load, and you can refer to RestBaseAction#execute and LoadAction#executeWithoutPassword
// to know more about the related part.
static TStatus checkPasswordAndLoadPrivilege(TAuthenticateParams authParams) {
if (authParams == null) {
LOG.debug("received null TAuthenticateParams");
return new TStatus(TStatusCode.OK);
}

LOG.debug("Receive TAuthenticateParams [user: {}, host: {}, db: {}, tables: {}]",
authParams.user, authParams.getHost(), authParams.getDb_name(), authParams.getTable_names());
if (!Config.enable_starrocks_external_table_auth_check) {
LOG.debug("enable_starrocks_external_table_auth_check is disabled, " +
"and skip to check authorization and privilege for {}", authParams);
return new TStatus(TStatusCode.OK);
}

String configHintMsg = "Set the configuration 'enable_starrocks_external_table_auth_check' to 'false' on the" +
" target cluster if you don't want to check the authorization and privilege.";

// 1. check user and password
BaseAction.ActionAuthorizationInfo authInfo;
UserIdentity userIdentity;
try {
authInfo = BaseAction.parseAuthInfo(authParams.getUser(), authParams.getPasswd(), authParams.getHost());
userIdentity = BaseAction.checkPassword(authInfo);
} catch (Exception e) {
LOG.warn("Failed to check TAuthenticateParams [user: {}, host: {}, db: {}, tables: {}]",
authParams.user, authParams.getHost(), authParams.getDb_name(), authParams.getTable_names(), e);
TStatus status = new TStatus(TStatusCode.NOT_AUTHORIZED);
status.setError_msgs(Lists.newArrayList(e.getMessage(), "Please check that your user or password " +
"is correct", configHintMsg));
return status;
}

// 2. check privilege
try {
String clusterName = authInfo.cluster;
if (Strings.isNullOrEmpty(clusterName)) {
throw new DdlException("No cluster selected");
}
String fullDbName = ClusterNamespace.getFullName(clusterName, authParams.getDb_name());
for (String tableName : authParams.getTable_names()) {
if (!GlobalStateMgr.getCurrentState().getAuth().checkTblPriv(
userIdentity, fullDbName, tableName, PrivPredicate.LOAD)) {
String errMsg = String.format("Access denied; user '%s'@'%s' need (at least one of) the " +
"privilege(s) in [%s] for table '%s' in database '%s'", userIdentity.getQualifiedUser(),
userIdentity.getHost(), PrivPredicate.LOAD.getPrivs().toString().trim(), tableName, fullDbName);
throw new UnauthorizedException(errMsg);
}
}
return new TStatus(TStatusCode.OK);
} catch (Exception e) {
LOG.warn("Failed to check TAuthenticateParams [user: {}, host: {}, db: {}, tables: {}]",
authParams.user, authParams.getHost(), authParams.getDb_name(), authParams.getTable_names(), e);
TStatus status = new TStatus(TStatusCode.NOT_AUTHORIZED);
status.setError_msgs(Lists.newArrayList(e.getMessage(), configHintMsg));
return status;
}
}

@Override
public TGetTableMetaResponse getTableMeta(TGetTableMetaRequest request) throws TException {
return masterImpl.getTableMeta(request);
}

@Override
public TBeginRemoteTxnResponse beginRemoteTxn(TBeginRemoteTxnRequest request) throws TException {
TStatus status = checkPasswordAndLoadPrivilege(request.getAuth_info());
if (status.getStatus_code() != TStatusCode.OK) {
TBeginRemoteTxnResponse response = new TBeginRemoteTxnResponse();
response.setStatus(status);
return response;
}
return masterImpl.beginRemoteTxn(request);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import com.starrocks.server.GlobalStateMgr;
import com.starrocks.thrift.TAbortRemoteTxnRequest;
import com.starrocks.thrift.TAbortRemoteTxnResponse;
import com.starrocks.thrift.TAuthenticateParams;
import com.starrocks.thrift.TBeginRemoteTxnRequest;
import com.starrocks.thrift.TBeginRemoteTxnResponse;
import com.starrocks.thrift.TCommitRemoteTxnRequest;
Expand Down Expand Up @@ -113,7 +114,8 @@ public void removeDatabaseTransactionMgr(Long dbId) {
// begin transaction in remote StarRocks cluster
public long beginRemoteTransaction(long dbId, List<Long> tableIds, String label,
String host, int port, TxnCoordinator coordinator,
LoadJobSourceType sourceType, long timeoutSecond)
LoadJobSourceType sourceType, long timeoutSecond,
TAuthenticateParams authenticateParams)
throws AnalysisException, BeginTransactionException {
if (Config.disable_load_job) {
throw new AnalysisException("disable_load_job is set to true, all load jobs are prevented");
Expand All @@ -137,6 +139,7 @@ public long beginRemoteTransaction(long dbId, List<Long> tableIds, String label,
request.setLabel(label);
request.setSource_type(sourceType.ordinal());
request.setTimeout_second(timeoutSecond);
request.setAuth_info(authenticateParams);
TBeginRemoteTxnResponse response;
try {
response = FrontendServiceProxy.call(addr,
Expand All @@ -150,7 +153,7 @@ public long beginRemoteTransaction(long dbId, List<Long> tableIds, String label,
if (response.status.getStatus_code() != TStatusCode.OK) {
String errStr;
if (response.status.getError_msgs() != null) {
errStr = String.join(",", response.status.getError_msgs());
errStr = String.join(". ", response.status.getError_msgs());
} else {
errStr = "";
}
Expand Down
29 changes: 29 additions & 0 deletions fe/fe-core/src/test/java/com/starrocks/http/BasicActionTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// This file is licensed under the Elastic License 2.0. Copyright 2021-present, StarRocks Limited.

package com.starrocks.http;

import org.junit.Test;

import static org.junit.Assert.assertEquals;

public class BasicActionTest {

@Test
public void testParseAuthInfo() {
BaseAction.ActionAuthorizationInfo authInfo =
BaseAction.parseAuthInfo("abc", "123", "127.0.0.1");
verifyAuthInfo(BaseAction.ActionAuthorizationInfo.of(
"default_cluster:abc", "123", "127.0.0.1", "default_cluster"), authInfo);

authInfo = BaseAction.parseAuthInfo("test@cluster_id", "", "192.168.19.10");
verifyAuthInfo(BaseAction.ActionAuthorizationInfo.of(
"cluster_id:test", "", "192.168.19.10", "cluster_id"), authInfo);
}

private void verifyAuthInfo(BaseAction.ActionAuthorizationInfo expect, BaseAction.ActionAuthorizationInfo actual) {
assertEquals(expect.fullUserName, actual.fullUserName);
assertEquals(expect.remoteIp, actual.remoteIp);
assertEquals(expect.password, actual.password);
assertEquals(expect.cluster, actual.cluster);
}
}
Loading