Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Search Task API and Refactor search actions and handlers #149

Merged
merged 1 commit into from
Feb 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package org.opensearch.ml.common.transport.task;

import org.opensearch.action.ActionType;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.ml.common.transport.model.MLModelSearchAction;

public class MLTaskSearchAction extends ActionType<SearchResponse> {
// External Action which used for public facing RestAPIs.
public static final String NAME = "cluster:admin/opensearch/ml/tasks/search";
public static final MLTaskSearchAction INSTANCE = new MLTaskSearchAction();

private MLTaskSearchAction() {
super(NAME, SearchResponse::new);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,25 @@

package org.opensearch.ml.action.handler;

import static org.opensearch.rest.RestStatus.BAD_REQUEST;
import static org.opensearch.rest.RestStatus.INTERNAL_SERVER_ERROR;

import lombok.extern.log4j.Log4j2;

import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.ActionListener;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.Client;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.NamedXContentRegistry;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.indices.InvalidIndexNameException;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
import org.opensearch.rest.RestStatus;

import com.google.common.base.Throwables;

/**
* Handle general get and search request in ml common.
Expand All @@ -22,4 +37,63 @@ public MLSearchHandler(Client client, NamedXContentRegistry xContentRegistry) {
this.client = client;
this.xContentRegistry = xContentRegistry;
}

public void search(SearchRequest request, ActionListener<SearchResponse> actionListener) {
ActionListener<SearchResponse> listener = wrapRestActionListener(actionListener, "Fail to search");
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.search(request, listener);
} catch (Exception e) {
log.error(e);
listener.onFailure(e);
}
}

/**
* Wrap action listener to avoid return verbose error message and wrong 500 error to user.
* Suggestion for exception handling in ML common:
* 1. If the error is caused by wrong input, throw IllegalArgumentException exception.
* 2. For other errors, please use MLException or its subclass, or use
* OpenSearchStatusException.
*
* TODO: tune this function for wrapped exception, return root exception error message
*
* @param actionListener action listener
* @param generalErrorMessage general error message
* @param <T> action listener response type
* @return wrapped action listener
*/
public static <T> ActionListener<T> wrapRestActionListener(ActionListener<T> actionListener, String generalErrorMessage) {
return ActionListener.<T>wrap(r -> { actionListener.onResponse(r); }, e -> {
log.error("Wrap exception before sending back to user", e);
Throwable cause = Throwables.getRootCause(e);
if (isProperExceptionToReturn(e)) {
actionListener.onFailure(e);
} else if (isProperExceptionToReturn(cause)) {
actionListener.onFailure((Exception) cause);
} else {
RestStatus status = isBadRequest(e) ? BAD_REQUEST : INTERNAL_SERVER_ERROR;
String errorMessage = generalErrorMessage;
if (isBadRequest(e) || e instanceof MLException) {
errorMessage = e.getMessage();
} else if (cause != null && (isBadRequest(cause) || cause instanceof MLException)) {
errorMessage = cause.getMessage();
}
actionListener.onFailure(new OpenSearchStatusException(errorMessage, status));
}
});
}

public static boolean isProperExceptionToReturn(Throwable e) {
if (e == null) {
return false;
}
return e instanceof OpenSearchStatusException || e instanceof IndexNotFoundException || e instanceof InvalidIndexNameException;
}

public static boolean isBadRequest(Throwable e) {
if (e == null) {
return false;
}
return e instanceof IllegalArgumentException || e instanceof MLResourceNotFoundException;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,98 +5,31 @@

package org.opensearch.ml.action.models;

import static org.opensearch.rest.RestStatus.BAD_REQUEST;
import static org.opensearch.rest.RestStatus.INTERNAL_SERVER_ERROR;

import lombok.extern.log4j.Log4j2;

import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.ActionListener;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.indices.InvalidIndexNameException;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
import org.opensearch.ml.action.handler.MLSearchHandler;
import org.opensearch.ml.common.transport.model.MLModelSearchAction;
import org.opensearch.rest.RestStatus;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

import com.google.common.base.Throwables;

@Log4j2
public class SearchModelTransportAction extends HandledTransportAction<SearchRequest, SearchResponse> {
Client client;
private MLSearchHandler mlSearchHandler;

@Inject
public SearchModelTransportAction(TransportService transportService, ActionFilters actionFilters, Client client) {
public SearchModelTransportAction(TransportService transportService, ActionFilters actionFilters, MLSearchHandler mlSearchHandler) {
super(MLModelSearchAction.NAME, transportService, actionFilters, SearchRequest::new);
this.client = client;
this.mlSearchHandler = mlSearchHandler;
}

@Override
protected void doExecute(Task task, SearchRequest request, ActionListener<SearchResponse> actionListener) {
ActionListener<SearchResponse> listener = wrapRestActionListener(actionListener, "Fail to search");
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.search(request, listener);
} catch (Exception e) {
log.error(e);
listener.onFailure(e);
}
}

/**
* Wrap action listener to avoid return verbose error message and wrong 500 error to user.
* Suggestion for exception handling in ML common:
* 1. If the error is caused by wrong input, throw IllegalArgumentException exception.
* 2. For other errors, please use MLException or its subclass, or use
* OpenSearchStatusException.
*
* TODO: tune this function for wrapped exception, return root exception error message
*
* @param actionListener action listener
* @param generalErrorMessage general error message
* @param <T> action listener response type
* @return wrapped action listener
*/
public static <T> ActionListener<T> wrapRestActionListener(ActionListener<T> actionListener, String generalErrorMessage) {
return ActionListener.<T>wrap(r -> { actionListener.onResponse(r); }, e -> {
log.error("Wrap exception before sending back to user", e);
Throwable cause = Throwables.getRootCause(e);
if (isProperExceptionToReturn(e)) {
actionListener.onFailure(e);
} else if (isProperExceptionToReturn(cause)) {
actionListener.onFailure((Exception) cause);
} else {
RestStatus status = isBadRequest(e) ? BAD_REQUEST : INTERNAL_SERVER_ERROR;
String errorMessage = generalErrorMessage;
if (isBadRequest(e) || e instanceof MLException) {
errorMessage = e.getMessage();
} else if (cause != null && (isBadRequest(cause) || cause instanceof MLException)) {
errorMessage = cause.getMessage();
}
actionListener.onFailure(new OpenSearchStatusException(errorMessage, status));
}
});
}

public static boolean isProperExceptionToReturn(Throwable e) {
if (e == null) {
return false;
}
return e instanceof OpenSearchStatusException || e instanceof IndexNotFoundException || e instanceof InvalidIndexNameException;
}

public static boolean isBadRequest(Throwable e) {
if (e == null) {
return false;
}
return e instanceof IllegalArgumentException || e instanceof MLResourceNotFoundException;
mlSearchHandler.search(request, actionListener);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.action.models;

import org.opensearch.action.ActionListener;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.common.inject.Inject;
import org.opensearch.ml.action.handler.MLSearchHandler;
import org.opensearch.ml.common.transport.task.MLTaskSearchAction;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

public class SearchTaskTransportAction extends HandledTransportAction<SearchRequest, SearchResponse> {
private MLSearchHandler mlSearchHandler;

@Inject
public SearchTaskTransportAction(TransportService transportService, ActionFilters actionFilters, MLSearchHandler mlSearchHandler) {
super(MLTaskSearchAction.NAME, transportService, actionFilters, SearchRequest::new);
this.mlSearchHandler = mlSearchHandler;
}

@Override
protected void doExecute(Task task, SearchRequest request, ActionListener<SearchResponse> actionListener) {
mlSearchHandler.search(request, actionListener);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.opensearch.ml.action.models.DeleteModelTransportAction;
import org.opensearch.ml.action.models.GetModelTransportAction;
import org.opensearch.ml.action.models.SearchModelTransportAction;
import org.opensearch.ml.action.models.SearchTaskTransportAction;
import org.opensearch.ml.action.prediction.TransportPredictionTaskAction;
import org.opensearch.ml.action.stats.MLStatsNodesAction;
import org.opensearch.ml.action.stats.MLStatsNodesTransportAction;
Expand All @@ -55,6 +56,7 @@
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.task.MLTaskDeleteAction;
import org.opensearch.ml.common.transport.task.MLTaskGetAction;
import org.opensearch.ml.common.transport.task.MLTaskSearchAction;
import org.opensearch.ml.common.transport.training.MLTrainingTaskAction;
import org.opensearch.ml.common.transport.trainpredict.MLTrainAndPredictionTaskAction;
import org.opensearch.ml.engine.MLEngineClassLoader;
Expand Down Expand Up @@ -128,7 +130,8 @@ public Setting<Boolean> legacySetting() {
new ActionHandler<>(MLModelDeleteAction.INSTANCE, DeleteModelTransportAction.class),
new ActionHandler<>(MLModelSearchAction.INSTANCE, SearchModelTransportAction.class),
new ActionHandler<>(MLTaskGetAction.INSTANCE, GetTaskTransportAction.class),
new ActionHandler<>(MLTaskDeleteAction.INSTANCE, DeleteTaskTransportAction.class)
new ActionHandler<>(MLTaskDeleteAction.INSTANCE, DeleteTaskTransportAction.class),
new ActionHandler<>(MLTaskSearchAction.INSTANCE, SearchTaskTransportAction.class)
);
}

Expand Down Expand Up @@ -250,6 +253,7 @@ public List<RestHandler> getRestHandlers(
RestMLSearchModelAction restMLSearchModelAction = new RestMLSearchModelAction();
RestMLGetTaskAction restMLGetTaskAction = new RestMLGetTaskAction();
RestMLDeleteTaskAction restMLDeleteTaskAction = new RestMLDeleteTaskAction();
RestMLSearchTaskAction restMLSearchTaskAction = new RestMLSearchTaskAction();

return ImmutableList
.of(
Expand All @@ -262,7 +266,8 @@ public List<RestHandler> getRestHandlers(
restMLDeleteModelAction,
restMLSearchModelAction,
restMLGetTaskAction,
restMLDeleteTaskAction
restMLDeleteTaskAction,
restMLSearchTaskAction
);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package org.opensearch.ml.rest;

import static org.opensearch.common.xcontent.ToXContent.EMPTY_PARAMS;
import static org.opensearch.ml.utils.RestActionUtils.getSourceContext;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

import org.opensearch.action.ActionType;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.node.NodeClient;
import org.opensearch.common.xcontent.ToXContentObject;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestChannel;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.RestResponse;
import org.opensearch.rest.RestStatus;
import org.opensearch.rest.action.RestResponseListener;
import org.opensearch.search.builder.SearchSourceBuilder;

public abstract class AbstractMLSearchAction<T extends ToXContentObject> extends BaseRestHandler {

protected final List<String> urlPaths;
protected final String index;
protected final Class<T> clazz;
protected final ActionType<SearchResponse> actionType;

public AbstractMLSearchAction(List<String> urlPaths, String index, Class<T> clazz, ActionType<SearchResponse> actionType) {
this.urlPaths = urlPaths;
this.index = index;
this.clazz = clazz;
this.actionType = actionType;
}

@Override
protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
searchSourceBuilder.parseXContent(request.contentOrSourceParamParser());
searchSourceBuilder.fetchSource(getSourceContext(request));
searchSourceBuilder.seqNoAndPrimaryTerm(true).version(true);
SearchRequest searchRequest = new SearchRequest().source(searchSourceBuilder).indices(index);
return channel -> client.execute(actionType, searchRequest, search(channel));
}

protected RestResponseListener<SearchResponse> search(RestChannel channel) {
return new RestResponseListener<SearchResponse>(channel) {
@Override
public RestResponse buildResponse(SearchResponse response) throws Exception {
if (response.isTimedOut()) {
return new BytesRestResponse(RestStatus.REQUEST_TIMEOUT, response.toString());
}
return new BytesRestResponse(RestStatus.OK, response.toXContent(channel.newBuilder(), EMPTY_PARAMS));
}
};
}

@Override
public List<Route> routes() {
List<Route> routes = new ArrayList<>();
for (String path : urlPaths) {
routes.add(new Route(RestRequest.Method.POST, path));
routes.add(new Route(RestRequest.Method.GET, path));
}
return routes;
}
}
Loading