Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,30 +29,18 @@ public class MLIndexInsightGetRequest extends ActionRequest {
String indexName;
MLIndexInsightType targetIndexInsight;
String tenantId;
String cmkRoleArn;
String assumeRoleArn;

public MLIndexInsightGetRequest(
String indexName,
MLIndexInsightType targetIndexInsight,
String tenantId,
String cmkRoleArn,
String assumeRoleArn
) {

public MLIndexInsightGetRequest(String indexName, MLIndexInsightType targetIndexInsight, String tenantId) {
this.indexName = indexName;
this.targetIndexInsight = targetIndexInsight;
this.tenantId = tenantId;
this.cmkRoleArn = cmkRoleArn;
this.assumeRoleArn = assumeRoleArn;
}

public MLIndexInsightGetRequest(StreamInput in) throws IOException {
super(in);
this.indexName = in.readString();
this.targetIndexInsight = MLIndexInsightType.fromString(in.readString());
this.tenantId = in.readOptionalString();
this.cmkRoleArn = in.readOptionalString();
this.assumeRoleArn = in.readOptionalString();
}

@Override
Expand All @@ -61,8 +49,6 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeString(this.indexName);
out.writeString(this.targetIndexInsight.name());
out.writeOptionalString(tenantId);
out.writeOptionalString(cmkRoleArn);
out.writeOptionalString(assumeRoleArn);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public class MLIndexInsightGetRequestTests {
public void constructor() {
indexName = "test-abc";
mlIndexInsightType = FIELD_DESCRIPTION;
MLIndexInsightGetRequest mlConfigGetRequest = new MLIndexInsightGetRequest(indexName, mlIndexInsightType, tenantId, null, null);
MLIndexInsightGetRequest mlConfigGetRequest = new MLIndexInsightGetRequest(indexName, mlIndexInsightType, tenantId);
assertEquals(mlConfigGetRequest.getIndexName(), indexName);
assertEquals(mlConfigGetRequest.getTargetIndexInsight(), mlIndexInsightType);
}
Expand All @@ -37,13 +37,7 @@ public void constructor() {
public void writeTo() throws IOException {
indexName = "test-abc";
mlIndexInsightType = FIELD_DESCRIPTION;
MLIndexInsightGetRequest mlIndexInsightGetRequest = new MLIndexInsightGetRequest(
indexName,
mlIndexInsightType,
tenantId,
null,
null
);
MLIndexInsightGetRequest mlIndexInsightGetRequest = new MLIndexInsightGetRequest(indexName, mlIndexInsightType, tenantId);
BytesStreamOutput output = new BytesStreamOutput();
mlIndexInsightGetRequest.writeTo(output);

Expand All @@ -60,13 +54,7 @@ public void writeTo_WithTenantId() throws IOException {
indexName = "test-abc";
mlIndexInsightType = FIELD_DESCRIPTION;
String tenantId = "demo_id";
MLIndexInsightGetRequest mlIndexInsightGetRequest = new MLIndexInsightGetRequest(
indexName,
mlIndexInsightType,
tenantId,
null,
null
);
MLIndexInsightGetRequest mlIndexInsightGetRequest = new MLIndexInsightGetRequest(indexName, mlIndexInsightType, tenantId);
BytesStreamOutput output = new BytesStreamOutput();
mlIndexInsightGetRequest.writeTo(output);

Expand All @@ -84,13 +72,7 @@ public void writeTo_WithTenantId() throws IOException {
public void validate_Success() {
indexName = "not-null";
mlIndexInsightType = FIELD_DESCRIPTION;
MLIndexInsightGetRequest mlIndexInsightGetRequest = new MLIndexInsightGetRequest(
indexName,
mlIndexInsightType,
tenantId,
null,
null
);
MLIndexInsightGetRequest mlIndexInsightGetRequest = new MLIndexInsightGetRequest(indexName, mlIndexInsightType, tenantId);

assertEquals(null, mlIndexInsightGetRequest.validate());
}
Expand All @@ -99,13 +81,7 @@ public void validate_Success() {
public void validate_Failure_index() {
indexName = null;
mlIndexInsightType = FIELD_DESCRIPTION;
MLIndexInsightGetRequest mlIndexInsightGetRequest = new MLIndexInsightGetRequest(
indexName,
mlIndexInsightType,
tenantId,
null,
null
);
MLIndexInsightGetRequest mlIndexInsightGetRequest = new MLIndexInsightGetRequest(indexName, mlIndexInsightType, tenantId);
assertEquals(null, mlIndexInsightGetRequest.getIndexName());

ActionRequestValidationException exception = addValidationError("Index insight's target index can't be null", null);
Expand All @@ -116,13 +92,7 @@ public void validate_Failure_index() {
public void validate_Failure_type() {
indexName = "not-null";
mlIndexInsightType = null;
MLIndexInsightGetRequest mlIndexInsightGetRequest = new MLIndexInsightGetRequest(
indexName,
mlIndexInsightType,
tenantId,
null,
null
);
MLIndexInsightGetRequest mlIndexInsightGetRequest = new MLIndexInsightGetRequest(indexName, mlIndexInsightType, tenantId);
assertEquals(null, mlIndexInsightGetRequest.getTargetIndexInsight());

ActionRequestValidationException exception = addValidationError("Index insight's target type can't be null", null);
Expand All @@ -133,27 +103,15 @@ public void validate_Failure_type() {
public void fromActionRequest_Success() throws IOException {
indexName = "test-abc";
mlIndexInsightType = FIELD_DESCRIPTION;
MLIndexInsightGetRequest mlIndexInsightGetRequest = new MLIndexInsightGetRequest(
indexName,
mlIndexInsightType,
tenantId,
null,
null
);
MLIndexInsightGetRequest mlIndexInsightGetRequest = new MLIndexInsightGetRequest(indexName, mlIndexInsightType, tenantId);
assertEquals(mlIndexInsightGetRequest.fromActionRequest(mlIndexInsightGetRequest), mlIndexInsightGetRequest);
}

@Test
public void fromActionRequest_Success_fromActionRequest() throws IOException {
indexName = "test-abc";
mlIndexInsightType = FIELD_DESCRIPTION;
MLIndexInsightGetRequest mlIndexInsightGetRequest = new MLIndexInsightGetRequest(
indexName,
mlIndexInsightType,
tenantId,
null,
null
);
MLIndexInsightGetRequest mlIndexInsightGetRequest = new MLIndexInsightGetRequest(indexName, mlIndexInsightType, tenantId);

ActionRequest actionRequest = new ActionRequest() {
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,7 @@ private void executePlanningLoop(
MLAgentExecutor.MESSAGE_HISTORY_LIMIT,
allParams.getOrDefault(EXECUTOR_MESSAGE_HISTORY_LIMIT, DEFAULT_EXECUTOR_MESSAGE_HISTORY_LIMIT)
);

if (allParams.containsKey(MEMORY_CONTAINER_ID_FIELD)) {
reactParams.put(MEMORY_CONTAINER_ID_FIELD, allParams.get(MEMORY_CONTAINER_ID_FIELD));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@

import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;
import static org.opensearch.ml.common.CommonValue.TOOL_INPUT_SCHEMA_FIELD;
import static org.opensearch.ml.common.input.Constants.CMK_ASSUME_ROLE_FIELD;
import static org.opensearch.ml.common.input.Constants.CMK_ROLE_FIELD;

import java.util.Map;

Expand Down Expand Up @@ -82,9 +80,7 @@ public <T> void run(Map<String, String> originalParameters, ActionListener<T> li
MLIndexInsightGetRequest mlIndexInsightGetRequest = new MLIndexInsightGetRequest(
indexName,
taskType,
parameters.getOrDefault(TENANT_ID_FIELD, null),
parameters.getOrDefault(CMK_ROLE_FIELD, null),
parameters.getOrDefault(CMK_ASSUME_ROLE_FIELD, null)
parameters.getOrDefault(TENANT_ID_FIELD, null)
);
client.execute(MLIndexInsightGetAction.INSTANCE, mlIndexInsightGetRequest, ActionListener.wrap(r -> {
IndexInsight indexInsight = r.getIndexInsight();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.ml.action.IndexInsight;

import static org.opensearch.ml.common.input.Constants.CMK_ASSUME_ROLE_FIELD;
import static org.opensearch.ml.common.input.Constants.CMK_ROLE_FIELD;
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_INDEX_INSIGHT_FEATURE_ENABLED;

import java.time.Instant;
Expand All @@ -13,6 +15,7 @@
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.MLIndex;
Expand Down Expand Up @@ -185,31 +188,16 @@ private void returnCombinedResult(
}

IndexInsightTask createTask(MLIndexInsightGetRequest request) {
ThreadContext threadContext = client.threadPool().getThreadContext();
String cmkRoleArn = threadContext.getHeader(CMK_ROLE_FIELD);
String cmkAssumeRoleArn = threadContext.getHeader(CMK_ASSUME_ROLE_FIELD);
switch (request.getTargetIndexInsight()) {
case STATISTICAL_DATA:
return new StatisticalDataTask(
request.getIndexName(),
client,
sdkClient,
request.getCmkRoleArn(),
request.getAssumeRoleArn()
);
return new StatisticalDataTask(request.getIndexName(), client, sdkClient, cmkRoleArn, cmkAssumeRoleArn);
case FIELD_DESCRIPTION:
return new FieldDescriptionTask(
request.getIndexName(),
client,
sdkClient,
request.getCmkRoleArn(),
request.getAssumeRoleArn()
);
return new FieldDescriptionTask(request.getIndexName(), client, sdkClient, cmkRoleArn, cmkAssumeRoleArn);
case LOG_RELATED_INDEX_CHECK:
return new LogRelatedIndexCheckTask(
request.getIndexName(),
client,
sdkClient,
request.getCmkRoleArn(),
request.getAssumeRoleArn()
);
return new LogRelatedIndexCheckTask(request.getIndexName(), client, sdkClient, cmkRoleArn, cmkAssumeRoleArn);
default:
throw new IllegalArgumentException("Unsupported task type: " + request.getTargetIndexInsight());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import static org.opensearch.ml.utils.RestActionUtils.getAlgorithm;
import static org.opensearch.ml.utils.RestActionUtils.hasMcpHeaders;
import static org.opensearch.ml.utils.RestActionUtils.isAsync;
import static org.opensearch.ml.utils.RestActionUtils.putCMKRelatedRoleFromHeaders;
import static org.opensearch.ml.utils.RestActionUtils.putMcpRequestHeaders;
import static org.opensearch.ml.utils.TenantAwareHelper.getTenantID;

Expand Down Expand Up @@ -162,6 +163,8 @@ MLExecuteTaskRequest getRequest(RestRequest request, NodeClient client) throws I
);
}
putMcpRequestHeaders(request, client);
putCMKRelatedRoleFromHeaders(request, client);

}
} else if (uri.startsWith(ML_BASE_URI + "/tools/")) {
if (!mlFeatureEnabledSetting.isToolExecuteEnabled()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,11 @@
package org.opensearch.ml.rest;

import static org.opensearch.ml.common.indexInsight.MLIndexInsightType.STATISTICAL_DATA;
import static org.opensearch.ml.common.input.Constants.CMK_ASSUME_ROLE_FIELD;
import static org.opensearch.ml.common.input.Constants.CMK_ROLE_FIELD;
import static org.opensearch.ml.common.utils.ToolUtils.getAttributeFromHeader;
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;
import static org.opensearch.ml.utils.MLExceptionUtils.AGENT_FRAMEWORK_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_INDEX_ID;
import static org.opensearch.ml.utils.RestActionUtils.getParameterId;
import static org.opensearch.ml.utils.RestActionUtils.putCMKRelatedRoleFromHeaders;
import static org.opensearch.ml.utils.TenantAwareHelper.getTenantID;

import java.io.IOException;
Expand All @@ -31,6 +29,9 @@
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;

import lombok.extern.log4j.Log4j2;

@Log4j2
public class RestMLGetIndexInsightAction extends BaseRestHandler {
private static final String ML_GET_INDEX_INSIGHT_ACTION = "ml_get_index_insight_action";

Expand Down Expand Up @@ -58,6 +59,7 @@ public List<Route> routes() {

@Override
protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException {
putCMKRelatedRoleFromHeaders(restRequest, client);
MLIndexInsightGetRequest mlIndexInsightGetRequest = getRequest(restRequest);
return channel -> client.execute(MLIndexInsightGetAction.INSTANCE, mlIndexInsightGetRequest, new RestToXContentListener<>(channel));
}
Expand All @@ -73,9 +75,7 @@ MLIndexInsightGetRequest getRequest(RestRequest request) throws IOException {
if (insightType == null) {
insightType = STATISTICAL_DATA.name();
}
String cmkRoleArn = getAttributeFromHeader(CMK_ROLE_FIELD, request);
String assumeRoleArn = getAttributeFromHeader(CMK_ASSUME_ROLE_FIELD, request);
MLIndexInsightType type = MLIndexInsightType.fromString(insightType);
return new MLIndexInsightGetRequest(indexName, type, tenantId, cmkRoleArn, assumeRoleArn);
return new MLIndexInsightGetRequest(indexName, type, tenantId);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

import static org.opensearch.ml.common.MLModel.MODEL_CONTENT_FIELD;
import static org.opensearch.ml.common.MLModel.OLD_MODEL_CONTENT_FIELD;
import static org.opensearch.ml.common.input.Constants.CMK_ASSUME_ROLE_FIELD;
import static org.opensearch.ml.common.input.Constants.CMK_ROLE_FIELD;
import static org.opensearch.ml.common.utils.ToolUtils.getAttributeFromHeader;

import java.security.AccessController;
import java.security.PrivilegedActionException;
Expand Down Expand Up @@ -354,6 +357,25 @@ public static boolean hasMcpHeaders(RestRequest request) {
|| request.header(CommonValue.MCP_HEADER_OPENSEARCH_URL) != null;
}

/**
* Extracts CMK related role from request headers from the REST request and puts them in ThreadContext.
*
* @param request RestRequest containing the CMK headers
* @param client Client to access ThreadContext
*/
public static void putCMKRelatedRoleFromHeaders(RestRequest request, Client client) {
ThreadContext threadContext = client.threadPool().getThreadContext();
String cmkRoleArn = getAttributeFromHeader(CMK_ROLE_FIELD, request);
String assumeRoleArn = getAttributeFromHeader(CMK_ASSUME_ROLE_FIELD, request);
if (cmkRoleArn != null && !cmkRoleArn.isEmpty()) {
threadContext.putHeader(CMK_ROLE_FIELD, cmkRoleArn);
}
if (assumeRoleArn != null && !assumeRoleArn.isEmpty()) {
threadContext.putHeader(CMK_ASSUME_ROLE_FIELD, assumeRoleArn);
}

}

/**
* Extracts MCP (Model Context Protocol) request headers from the REST request and puts them in ThreadContext.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,17 +216,17 @@ public void testGetIndexInsight_FailToAccess() {

@Test
public void testCreateTask() {
MLIndexInsightGetRequest statisticalRequest = new MLIndexInsightGetRequest("test_index", STATISTICAL_DATA, null, null, null);
MLIndexInsightGetRequest statisticalRequest = new MLIndexInsightGetRequest("test_index", STATISTICAL_DATA, null);
IndexInsightTask statisticalTask = getIndexInsightTransportAction.createTask(statisticalRequest);
assertNotNull(statisticalTask);
assertTrue(statisticalTask instanceof StatisticalDataTask);

MLIndexInsightGetRequest fieldRequest = new MLIndexInsightGetRequest("test_index", FIELD_DESCRIPTION, null, null, null);
MLIndexInsightGetRequest fieldRequest = new MLIndexInsightGetRequest("test_index", FIELD_DESCRIPTION, null);
IndexInsightTask fieldTask = getIndexInsightTransportAction.createTask(fieldRequest);
assertNotNull(fieldTask);
assertTrue(fieldTask instanceof FieldDescriptionTask);

MLIndexInsightGetRequest logRequest = new MLIndexInsightGetRequest("test_index", LOG_RELATED_INDEX_CHECK, null, null, null);
MLIndexInsightGetRequest logRequest = new MLIndexInsightGetRequest("test_index", LOG_RELATED_INDEX_CHECK, null);
IndexInsightTask logTask = getIndexInsightTransportAction.createTask(logRequest);
assertNotNull(logTask);
assertTrue(logTask instanceof LogRelatedIndexCheckTask);
Expand Down
Loading