Skip to content

Commit

Permalink
Add Search Task API and Refactor search actions and handlers (#149)
Browse files Browse the repository at this point in the history
Signed-off-by: Xun Zhang <xunzh@amazon.com>
  • Loading branch information
Zhangxunmt authored Feb 23, 2022
1 parent 69195fd commit afa4788
Show file tree
Hide file tree
Showing 9 changed files with 251 additions and 137 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package org.opensearch.ml.common.transport.task;

import org.opensearch.action.ActionType;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.ml.common.transport.model.MLModelSearchAction;

public class MLTaskSearchAction extends ActionType<SearchResponse> {
// External Action which used for public facing RestAPIs.
public static final String NAME = "cluster:admin/opensearch/ml/tasks/search";
public static final MLTaskSearchAction INSTANCE = new MLTaskSearchAction();

private MLTaskSearchAction() {
super(NAME, SearchResponse::new);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,25 @@

package org.opensearch.ml.action.handler;

import static org.opensearch.rest.RestStatus.BAD_REQUEST;
import static org.opensearch.rest.RestStatus.INTERNAL_SERVER_ERROR;

import lombok.extern.log4j.Log4j2;

import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.ActionListener;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.Client;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.NamedXContentRegistry;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.indices.InvalidIndexNameException;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
import org.opensearch.rest.RestStatus;

import com.google.common.base.Throwables;

/**
* Handle general get and search request in ml common.
Expand All @@ -22,4 +37,63 @@ public MLSearchHandler(Client client, NamedXContentRegistry xContentRegistry) {
this.client = client;
this.xContentRegistry = xContentRegistry;
}

public void search(SearchRequest request, ActionListener<SearchResponse> actionListener) {
ActionListener<SearchResponse> listener = wrapRestActionListener(actionListener, "Fail to search");
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.search(request, listener);
} catch (Exception e) {
log.error(e);
listener.onFailure(e);
}
}

/**
* Wrap action listener to avoid return verbose error message and wrong 500 error to user.
* Suggestion for exception handling in ML common:
* 1. If the error is caused by wrong input, throw IllegalArgumentException exception.
* 2. For other errors, please use MLException or its subclass, or use
* OpenSearchStatusException.
*
* TODO: tune this function for wrapped exception, return root exception error message
*
* @param actionListener action listener
* @param generalErrorMessage general error message
* @param <T> action listener response type
* @return wrapped action listener
*/
public static <T> ActionListener<T> wrapRestActionListener(ActionListener<T> actionListener, String generalErrorMessage) {
return ActionListener.<T>wrap(r -> { actionListener.onResponse(r); }, e -> {
log.error("Wrap exception before sending back to user", e);
Throwable cause = Throwables.getRootCause(e);
if (isProperExceptionToReturn(e)) {
actionListener.onFailure(e);
} else if (isProperExceptionToReturn(cause)) {
actionListener.onFailure((Exception) cause);
} else {
RestStatus status = isBadRequest(e) ? BAD_REQUEST : INTERNAL_SERVER_ERROR;
String errorMessage = generalErrorMessage;
if (isBadRequest(e) || e instanceof MLException) {
errorMessage = e.getMessage();
} else if (cause != null && (isBadRequest(cause) || cause instanceof MLException)) {
errorMessage = cause.getMessage();
}
actionListener.onFailure(new OpenSearchStatusException(errorMessage, status));
}
});
}

public static boolean isProperExceptionToReturn(Throwable e) {
if (e == null) {
return false;
}
return e instanceof OpenSearchStatusException || e instanceof IndexNotFoundException || e instanceof InvalidIndexNameException;
}

public static boolean isBadRequest(Throwable e) {
if (e == null) {
return false;
}
return e instanceof IllegalArgumentException || e instanceof MLResourceNotFoundException;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,98 +5,31 @@

package org.opensearch.ml.action.models;

import static org.opensearch.rest.RestStatus.BAD_REQUEST;
import static org.opensearch.rest.RestStatus.INTERNAL_SERVER_ERROR;

import lombok.extern.log4j.Log4j2;

import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.ActionListener;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.indices.InvalidIndexNameException;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
import org.opensearch.ml.action.handler.MLSearchHandler;
import org.opensearch.ml.common.transport.model.MLModelSearchAction;
import org.opensearch.rest.RestStatus;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

import com.google.common.base.Throwables;

@Log4j2
public class SearchModelTransportAction extends HandledTransportAction<SearchRequest, SearchResponse> {
Client client;
private MLSearchHandler mlSearchHandler;

@Inject
public SearchModelTransportAction(TransportService transportService, ActionFilters actionFilters, Client client) {
public SearchModelTransportAction(TransportService transportService, ActionFilters actionFilters, MLSearchHandler mlSearchHandler) {
super(MLModelSearchAction.NAME, transportService, actionFilters, SearchRequest::new);
this.client = client;
this.mlSearchHandler = mlSearchHandler;
}

@Override
protected void doExecute(Task task, SearchRequest request, ActionListener<SearchResponse> actionListener) {
ActionListener<SearchResponse> listener = wrapRestActionListener(actionListener, "Fail to search");
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.search(request, listener);
} catch (Exception e) {
log.error(e);
listener.onFailure(e);
}
}

/**
* Wrap action listener to avoid return verbose error message and wrong 500 error to user.
* Suggestion for exception handling in ML common:
* 1. If the error is caused by wrong input, throw IllegalArgumentException exception.
* 2. For other errors, please use MLException or its subclass, or use
* OpenSearchStatusException.
*
* TODO: tune this function for wrapped exception, return root exception error message
*
* @param actionListener action listener
* @param generalErrorMessage general error message
* @param <T> action listener response type
* @return wrapped action listener
*/
public static <T> ActionListener<T> wrapRestActionListener(ActionListener<T> actionListener, String generalErrorMessage) {
return ActionListener.<T>wrap(r -> { actionListener.onResponse(r); }, e -> {
log.error("Wrap exception before sending back to user", e);
Throwable cause = Throwables.getRootCause(e);
if (isProperExceptionToReturn(e)) {
actionListener.onFailure(e);
} else if (isProperExceptionToReturn(cause)) {
actionListener.onFailure((Exception) cause);
} else {
RestStatus status = isBadRequest(e) ? BAD_REQUEST : INTERNAL_SERVER_ERROR;
String errorMessage = generalErrorMessage;
if (isBadRequest(e) || e instanceof MLException) {
errorMessage = e.getMessage();
} else if (cause != null && (isBadRequest(cause) || cause instanceof MLException)) {
errorMessage = cause.getMessage();
}
actionListener.onFailure(new OpenSearchStatusException(errorMessage, status));
}
});
}

public static boolean isProperExceptionToReturn(Throwable e) {
if (e == null) {
return false;
}
return e instanceof OpenSearchStatusException || e instanceof IndexNotFoundException || e instanceof InvalidIndexNameException;
}

public static boolean isBadRequest(Throwable e) {
if (e == null) {
return false;
}
return e instanceof IllegalArgumentException || e instanceof MLResourceNotFoundException;
mlSearchHandler.search(request, actionListener);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.action.models;

import org.opensearch.action.ActionListener;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.common.inject.Inject;
import org.opensearch.ml.action.handler.MLSearchHandler;
import org.opensearch.ml.common.transport.task.MLTaskSearchAction;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

public class SearchTaskTransportAction extends HandledTransportAction<SearchRequest, SearchResponse> {
private MLSearchHandler mlSearchHandler;

@Inject
public SearchTaskTransportAction(TransportService transportService, ActionFilters actionFilters, MLSearchHandler mlSearchHandler) {
super(MLTaskSearchAction.NAME, transportService, actionFilters, SearchRequest::new);
this.mlSearchHandler = mlSearchHandler;
}

@Override
protected void doExecute(Task task, SearchRequest request, ActionListener<SearchResponse> actionListener) {
mlSearchHandler.search(request, actionListener);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.opensearch.ml.action.models.DeleteModelTransportAction;
import org.opensearch.ml.action.models.GetModelTransportAction;
import org.opensearch.ml.action.models.SearchModelTransportAction;
import org.opensearch.ml.action.models.SearchTaskTransportAction;
import org.opensearch.ml.action.prediction.TransportPredictionTaskAction;
import org.opensearch.ml.action.stats.MLStatsNodesAction;
import org.opensearch.ml.action.stats.MLStatsNodesTransportAction;
Expand All @@ -55,6 +56,7 @@
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.task.MLTaskDeleteAction;
import org.opensearch.ml.common.transport.task.MLTaskGetAction;
import org.opensearch.ml.common.transport.task.MLTaskSearchAction;
import org.opensearch.ml.common.transport.training.MLTrainingTaskAction;
import org.opensearch.ml.common.transport.trainpredict.MLTrainAndPredictionTaskAction;
import org.opensearch.ml.engine.MLEngineClassLoader;
Expand Down Expand Up @@ -128,7 +130,8 @@ public Setting<Boolean> legacySetting() {
new ActionHandler<>(MLModelDeleteAction.INSTANCE, DeleteModelTransportAction.class),
new ActionHandler<>(MLModelSearchAction.INSTANCE, SearchModelTransportAction.class),
new ActionHandler<>(MLTaskGetAction.INSTANCE, GetTaskTransportAction.class),
new ActionHandler<>(MLTaskDeleteAction.INSTANCE, DeleteTaskTransportAction.class)
new ActionHandler<>(MLTaskDeleteAction.INSTANCE, DeleteTaskTransportAction.class),
new ActionHandler<>(MLTaskSearchAction.INSTANCE, SearchTaskTransportAction.class)
);
}

Expand Down Expand Up @@ -250,6 +253,7 @@ public List<RestHandler> getRestHandlers(
RestMLSearchModelAction restMLSearchModelAction = new RestMLSearchModelAction();
RestMLGetTaskAction restMLGetTaskAction = new RestMLGetTaskAction();
RestMLDeleteTaskAction restMLDeleteTaskAction = new RestMLDeleteTaskAction();
RestMLSearchTaskAction restMLSearchTaskAction = new RestMLSearchTaskAction();

return ImmutableList
.of(
Expand All @@ -262,7 +266,8 @@ public List<RestHandler> getRestHandlers(
restMLDeleteModelAction,
restMLSearchModelAction,
restMLGetTaskAction,
restMLDeleteTaskAction
restMLDeleteTaskAction,
restMLSearchTaskAction
);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package org.opensearch.ml.rest;

import static org.opensearch.common.xcontent.ToXContent.EMPTY_PARAMS;
import static org.opensearch.ml.utils.RestActionUtils.getSourceContext;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

import org.opensearch.action.ActionType;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.node.NodeClient;
import org.opensearch.common.xcontent.ToXContentObject;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestChannel;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.RestResponse;
import org.opensearch.rest.RestStatus;
import org.opensearch.rest.action.RestResponseListener;
import org.opensearch.search.builder.SearchSourceBuilder;

public abstract class AbstractMLSearchAction<T extends ToXContentObject> extends BaseRestHandler {

protected final List<String> urlPaths;
protected final String index;
protected final Class<T> clazz;
protected final ActionType<SearchResponse> actionType;

public AbstractMLSearchAction(List<String> urlPaths, String index, Class<T> clazz, ActionType<SearchResponse> actionType) {
this.urlPaths = urlPaths;
this.index = index;
this.clazz = clazz;
this.actionType = actionType;
}

@Override
protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
searchSourceBuilder.parseXContent(request.contentOrSourceParamParser());
searchSourceBuilder.fetchSource(getSourceContext(request));
searchSourceBuilder.seqNoAndPrimaryTerm(true).version(true);
SearchRequest searchRequest = new SearchRequest().source(searchSourceBuilder).indices(index);
return channel -> client.execute(actionType, searchRequest, search(channel));
}

protected RestResponseListener<SearchResponse> search(RestChannel channel) {
return new RestResponseListener<SearchResponse>(channel) {
@Override
public RestResponse buildResponse(SearchResponse response) throws Exception {
if (response.isTimedOut()) {
return new BytesRestResponse(RestStatus.REQUEST_TIMEOUT, response.toString());
}
return new BytesRestResponse(RestStatus.OK, response.toXContent(channel.newBuilder(), EMPTY_PARAMS));
}
};
}

@Override
public List<Route> routes() {
List<Route> routes = new ArrayList<>();
for (String path : urlPaths) {
routes.add(new Route(RestRequest.Method.POST, path));
routes.add(new Route(RestRequest.Method.GET, path));
}
return routes;
}
}
Loading

0 comments on commit afa4788

Please sign in to comment.