Skip to content

Commit

Permalink
Fix flaky test of visualization tool (#2416)
Browse files Browse the repository at this point in the history
* add wait for model to be undeployed

Signed-off-by: Hailong Cui <ihailong@amazon.com>

* spotless

Signed-off-by: Hailong Cui <ihailong@amazon.com>

* update model undeploy status

Signed-off-by: Hailong Cui <ihailong@amazon.com>

* reduce total wait time

Signed-off-by: Hailong Cui <ihailong@amazon.com>

---------

Signed-off-by: Hailong Cui <ihailong@amazon.com>
  • Loading branch information
Hailong-am authored May 9, 2024
1 parent 89f23d2 commit aa09014
Showing 1 changed file with 39 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,33 @@
import java.io.IOException;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import java.util.function.Predicate;

import org.apache.hc.core5.http.ParseException;
import org.junit.After;
import org.junit.Before;
import org.opensearch.client.Response;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.rest.RestBaseAgentToolsIT;
import org.opensearch.ml.utils.TestHelper;

import com.sun.net.httpserver.HttpServer;

import lombok.SneakyThrows;
import lombok.extern.log4j.Log4j2;

@Log4j2
public abstract class ToolIntegrationWithLLMTest extends RestBaseAgentToolsIT {

private static final int MAX_TASK_RESULT_QUERY_TIME_IN_SECOND = 30;
private static final int DEFAULT_TASK_RESULT_QUERY_INTERVAL_IN_MILLISECOND = 1000;

protected HttpServer server;
protected String modelId;
protected String agentId;
Expand Down Expand Up @@ -63,9 +72,39 @@ public void stopMockLLM() {
@After
public void deleteModel() throws IOException {
undeployModel(modelId);
waitModelUndeployed(modelId);
deleteModel(client(), modelId, null);
}

@SneakyThrows
private void waitModelUndeployed(String modelId) {
Predicate<Response> condition = response -> {
try {
Map<String, Object> responseInMap = parseResponseToMap(response);
MLModelState state = MLModelState.from(responseInMap.get(MLModel.MODEL_STATE_FIELD).toString());
return Set.of(MLModelState.UNDEPLOYED, MLModelState.DEPLOY_FAILED).contains(state);
} catch (Exception e) {
return false;
}
};
waitResponseMeetingCondition("GET", "/_plugins/_ml/models/" + modelId, null, condition);
}

@SneakyThrows
protected Response waitResponseMeetingCondition(String method, String endpoint, String jsonEntity, Predicate<Response> condition) {
for (int i = 0; i < MAX_TASK_RESULT_QUERY_TIME_IN_SECOND; i++) {
Response response = TestHelper.makeRequest(client(), method, endpoint, null, jsonEntity, null);
assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
if (condition.test(response)) {
return response;
}
logger.info("The {}-th response: {}", i, response.toString());
Thread.sleep(DEFAULT_TASK_RESULT_QUERY_INTERVAL_IN_MILLISECOND);
}
fail("The response failed to meet condition after " + MAX_TASK_RESULT_QUERY_TIME_IN_SECOND + " seconds.");
return null;
}

private String setUpConnectorWithRetry(int maxRetryTimes) throws InterruptedException {
int retryTimes = 0;
String connectorId = null;
Expand Down

0 comments on commit aa09014

Please sign in to comment.