Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix permission when accessing ML system indices #148

Merged
merged 1 commit into from
Feb 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.ml.common.transport.model.MLModelDeleteAction;
import org.opensearch.ml.common.transport.model.MLModelDeleteRequest;
import org.opensearch.tasks.Task;
Expand All @@ -45,19 +46,24 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete

DeleteRequest deleteRequest = new DeleteRequest(ML_MODEL_INDEX, modelId);

client.delete(deleteRequest, new ActionListener<DeleteResponse>() {
@Override
public void onResponse(DeleteResponse deleteResponse) {
log.info("Completed Delete Model Request, model id:{} deleted", modelId);
actionListener.onResponse(deleteResponse);
}

@Override
public void onFailure(Exception e) {
log.error("Failed to delete ML model " + modelId, e);
actionListener.onFailure(e);
}
});
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.delete(deleteRequest, new ActionListener<DeleteResponse>() {
@Override
public void onResponse(DeleteResponse deleteResponse) {
log.info("Completed Delete Model Request, model id:{} deleted", modelId);
actionListener.onResponse(deleteResponse);
}

@Override
public void onFailure(Exception e) {
log.error("Failed to delete ML model " + modelId, e);
actionListener.onFailure(e);
}
});
} catch (Exception e) {
log.error("Failed to delete ML model " + modelId, e);
actionListener.onFailure(e);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.NamedXContentRegistry;
import org.opensearch.common.xcontent.XContentParser;
import org.opensearch.index.IndexNotFoundException;
Expand Down Expand Up @@ -58,28 +59,34 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLMode
String modelId = mlModelGetRequest.getModelId();
GetRequest getRequest = new GetRequest(ML_MODEL_INDEX).id(modelId);

client.get(getRequest, ActionListener.wrap(r -> {
log.info("Completed Get Model Request, id:{}", modelId);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.get(getRequest, ActionListener.wrap(r -> {
log.info("Completed Get Model Request, id:{}", modelId);

if (r != null && r.isExists()) {
try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
MLModel mlModel = MLModel.parse(parser);
actionListener.onResponse(MLModelGetResponse.builder().mlModel(mlModel).build());
} catch (Exception e) {
log.error("Failed to parse ml model" + r.getId(), e);
if (r != null && r.isExists()) {
try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
MLModel mlModel = MLModel.parse(parser);
actionListener.onResponse(MLModelGetResponse.builder().mlModel(mlModel).build());
} catch (Exception e) {
log.error("Failed to parse ml model" + r.getId(), e);
actionListener.onFailure(e);
}
} else {
actionListener.onFailure(new MLResourceNotFoundException("Fail to find model " + modelId));
}
}, e -> {
if (e instanceof IndexNotFoundException) {
actionListener.onFailure(new MLResourceNotFoundException("Fail to find model " + modelId));
} else {
log.error("Failed to get ML model " + modelId, e);
actionListener.onFailure(e);
}
} else {
actionListener.onFailure(new MLResourceNotFoundException("Fail to find model " + modelId));
}
}, e -> {
if (e instanceof IndexNotFoundException) {
actionListener.onFailure(new MLResourceNotFoundException("Fail to find model " + modelId));
} else {
log.error("Failed to get ML model " + modelId, e);
actionListener.onFailure(e);
}
}));
}));
} catch (Exception e) {
log.error("Failed to get ML model " + modelId, e);
actionListener.onFailure(e);
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ protected void doExecute(Task task, SearchRequest request, ActionListener<Search
* @param <T> action listener response type
* @return wrapped action listener
*/
public static <T> ActionListener wrapRestActionListener(ActionListener<T> actionListener, String generalErrorMessage) {
public static <T> ActionListener<T> wrapRestActionListener(ActionListener<T> actionListener, String generalErrorMessage) {
return ActionListener.<T>wrap(r -> { actionListener.onResponse(r); }, e -> {
log.error("Wrap exception before sending back to user", e);
Throwable cause = Throwables.getRootCause(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.ml.common.transport.task.MLTaskDeleteAction;
import org.opensearch.ml.common.transport.task.MLTaskDeleteRequest;
import org.opensearch.tasks.Task;
Expand All @@ -42,18 +43,23 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete

DeleteRequest deleteRequest = new DeleteRequest(ML_TASK_INDEX, taskId);

client.delete(deleteRequest, new ActionListener<DeleteResponse>() {
@Override
public void onResponse(DeleteResponse deleteResponse) {
log.info("Completed Delete Task Request, task id:{} deleted", taskId);
actionListener.onResponse(deleteResponse);
}

@Override
public void onFailure(Exception e) {
log.error("Failed to delete ML Task " + taskId, e);
actionListener.onFailure(e);
}
});
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.delete(deleteRequest, new ActionListener<DeleteResponse>() {
@Override
public void onResponse(DeleteResponse deleteResponse) {
log.info("Completed Delete Task Request, task id:{} deleted", taskId);
actionListener.onResponse(deleteResponse);
}

@Override
public void onFailure(Exception e) {
log.error("Failed to delete ML Task " + taskId, e);
actionListener.onFailure(e);
}
});
} catch (Exception e) {
log.error("Failed to delete ML task " + taskId, e);
actionListener.onFailure(e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.NamedXContentRegistry;
import org.opensearch.common.xcontent.XContentParser;
import org.opensearch.index.IndexNotFoundException;
Expand Down Expand Up @@ -53,28 +54,34 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLTask
String taskId = mlTaskGetRequest.getTaskId();
GetRequest getRequest = new GetRequest(ML_TASK_INDEX).id(taskId);

client.get(getRequest, ActionListener.wrap(r -> {
log.info("Completed Get Task Request, id:{}", taskId);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.get(getRequest, ActionListener.wrap(r -> {
log.info("Completed Get Task Request, id:{}", taskId);

if (r != null && r.isExists()) {
try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
MLTask mlTask = MLTask.parse(parser);
actionListener.onResponse(MLTaskGetResponse.builder().mlTask(mlTask).build());
} catch (Exception e) {
log.error("Failed to parse ml task" + r.getId(), e);
if (r != null && r.isExists()) {
try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
MLTask mlTask = MLTask.parse(parser);
actionListener.onResponse(MLTaskGetResponse.builder().mlTask(mlTask).build());
} catch (Exception e) {
log.error("Failed to parse ml task" + r.getId(), e);
actionListener.onFailure(e);
}
} else {
actionListener.onFailure(new MLResourceNotFoundException("Fail to find task " + taskId));
}
}, e -> {
if (e instanceof IndexNotFoundException) {
actionListener.onFailure(new MLResourceNotFoundException("Fail to find task " + taskId));
} else {
log.error("Failed to get ML task " + taskId, e);
actionListener.onFailure(e);
}
} else {
actionListener.onFailure(new MLResourceNotFoundException("Fail to find task " + taskId));
}
}, e -> {
if (e instanceof IndexNotFoundException) {
actionListener.onFailure(new MLResourceNotFoundException("Fail to find task " + taskId));
} else {
log.error("Failed to get ML task " + taskId, e);
actionListener.onFailure(e);
}
}));
}));
} catch (Exception e) {
log.error("Failed to get ML task " + taskId, e);
actionListener.onFailure(e);
}

}
}
10 changes: 8 additions & 2 deletions plugin/src/main/java/org/opensearch/ml/task/MLTaskManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.opensearch.action.update.UpdateRequest;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.client.Client;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.ToXContent;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentFactory;
Expand Down Expand Up @@ -191,7 +192,10 @@ public void createMLTask(MLTask mlTask, ActionListener<IndexResponse> listener)
return;
}
IndexRequest request = new IndexRequest(ML_TASK_INDEX);
try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
try (
XContentBuilder builder = XContentFactory.jsonBuilder();
ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()
) {
request.source(mlTask.toXContent(builder, ToXContent.EMPTY_PARAMS)).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
client.index(request, listener);
} catch (Exception e) {
Expand Down Expand Up @@ -256,7 +260,9 @@ public void updateMLTask(
updatedContent.put(LAST_UPDATE_TIME_FIELD, Instant.now().toEpochMilli());
updateRequest.doc(updatedContent);
updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
ActionListener actionListener = semaphore == null ? listener : ActionListener.runAfter(listener, () -> semaphore.release());
ActionListener<UpdateResponse> actionListener = semaphore == null
? listener
: ActionListener.runAfter(listener, () -> semaphore.release());
client.update(updateRequest, actionListener);
}
}