Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Fix flaky test of visualization tool (#2416) #2435

Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,32 @@
import java.io.IOException;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import java.util.function.Predicate;

import org.junit.After;
import org.junit.Before;
import org.opensearch.client.Response;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.rest.RestBaseAgentToolsIT;
import org.opensearch.ml.utils.TestHelper;

import com.sun.net.httpserver.HttpServer;

import lombok.SneakyThrows;
import lombok.extern.log4j.Log4j2;

@Log4j2
public abstract class ToolIntegrationWithLLMTest extends RestBaseAgentToolsIT {

private static final int MAX_TASK_RESULT_QUERY_TIME_IN_SECOND = 30;
private static final int DEFAULT_TASK_RESULT_QUERY_INTERVAL_IN_MILLISECOND = 1000;

protected HttpServer server;
protected String modelId;
protected String agentId;
Expand Down Expand Up @@ -62,9 +71,39 @@ public void stopMockLLM() {
@After
public void deleteModel() throws IOException {
undeployModel(modelId);
waitModelUndeployed(modelId);
deleteModel(client(), modelId, null);
}

@SneakyThrows
private void waitModelUndeployed(String modelId) {
Predicate<Response> condition = response -> {
try {
Map<String, Object> responseInMap = parseResponseToMap(response);
MLModelState state = MLModelState.from(responseInMap.get(MLModel.MODEL_STATE_FIELD).toString());
return Set.of(MLModelState.UNDEPLOYED, MLModelState.DEPLOY_FAILED).contains(state);
} catch (Exception e) {
return false;
}
};
waitResponseMeetingCondition("GET", "/_plugins/_ml/models/" + modelId, null, condition);
}

@SneakyThrows
protected Response waitResponseMeetingCondition(String method, String endpoint, String jsonEntity, Predicate<Response> condition) {
for (int i = 0; i < MAX_TASK_RESULT_QUERY_TIME_IN_SECOND; i++) {
Response response = TestHelper.makeRequest(client(), method, endpoint, null, jsonEntity, null);
assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
if (condition.test(response)) {
return response;
}
logger.info("The {}-th response: {}", i, response.toString());
Thread.sleep(DEFAULT_TASK_RESULT_QUERY_INTERVAL_IN_MILLISECOND);
}
fail("The response failed to meet condition after " + MAX_TASK_RESULT_QUERY_TIME_IN_SECOND + " seconds.");
return null;
}

private String setUpConnectorWithRetry(int maxRetryTimes) throws InterruptedException {
int retryTimes = 0;
String connectorId = null;
Expand Down
Loading