Skip to content

Commit 7dce484

Browse files
more tests
Signed-off-by: rithin-pullela-aws <rithinp@amazon.com>
1 parent 47e1741 commit 7dce484

File tree

12 files changed

+423
-345
lines changed

12 files changed

+423
-345
lines changed

common/src/main/java/org/opensearch/ml/common/CommonValue.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,9 @@ public class CommonValue {
107107
public static final String MCP_CONNECTORS_FIELD = "mcp_connectors";
108108
public static final String MCP_CONNECTOR_ID_FIELD = "mcp_connector_id";
109109
public static final String MCP_DEFAULT_SSE_ENDPOINT = "/sse";
110-
public static final String SSE_ENDPOINT_FILED = "sse_endpoint";
111-
public static final String MCP_DEFAULT_ENDPOINT = "/mcp/";
112-
public static final String ENDPOINT_FILED = "endpoint";
110+
public static final String SSE_ENDPOINT_FIELD = "sse_endpoint";
111+
public static final String MCP_DEFAULT_STREAMABLE_HTTP_ENDPOINT = "/mcp/";
112+
public static final String ENDPOINT_FIELD = "endpoint";
113113

114114
// TOOL Constants
115115
public static final String TOOL_INPUT_SCHEMA_FIELD = "input_schema";

common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;
1010
import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0;
1111
import static org.opensearch.ml.common.CommonValue.VERSION_3_0_0;
12+
import static org.opensearch.ml.common.connector.ConnectorProtocols.MCP_SSE;
13+
import static org.opensearch.ml.common.connector.ConnectorProtocols.MCP_STREAMABLE_HTTP;
1214
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;
1315

1416
import java.io.IOException;
@@ -106,7 +108,7 @@ public MLCreateConnectorInput(
106108
if (protocol == null) {
107109
throw new IllegalArgumentException("Connector protocol is null");
108110
}
109-
if (credential == null || credential.isEmpty()) {
111+
if ((!protocol.equals(MCP_SSE) && !protocol.equals(MCP_STREAMABLE_HTTP)) && (credential == null || credential.isEmpty())) {
110112
throw new IllegalArgumentException("Connector credential is null or empty list");
111113
}
112114
if (actions != null) {

common/src/test/java/org/opensearch/ml/common/connector/ConnectorProtocolsTest.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@ public class ConnectorProtocolsTest {
1717
@Test
1818
public void validateProtocol_Null() {
1919
exceptionRule.expect(IllegalArgumentException.class);
20-
exceptionRule.expectMessage("Connector protocol is null. Please use one of [aws_sigv4, http, mcp_sse]");
20+
exceptionRule.expectMessage("Connector protocol is null. Please use one of [aws_sigv4, http, mcp_sse, mcp_streamable_http]");
2121
ConnectorProtocols.validateProtocol(null);
2222
}
2323

2424
@Test
2525
public void validateProtocol_WrongValue() {
2626
exceptionRule.expect(IllegalArgumentException.class);
27-
exceptionRule.expectMessage("Unsupported connector protocol. Please use one of [aws_sigv4, http, mcp_sse]");
27+
exceptionRule.expectMessage("Unsupported connector protocol. Please use one of [aws_sigv4, http, mcp_sse, mcp_streamable_http]");
2828
ConnectorProtocols.validateProtocol("abc");
2929
}
3030
}

common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ public void setUp() {
6262
@Test
6363
public void constructor_InvalidProtocol() {
6464
exceptionRule.expect(IllegalArgumentException.class);
65-
exceptionRule.expectMessage("Unsupported connector protocol. Please use one of [aws_sigv4, http, mcp_sse]");
65+
exceptionRule.expectMessage("Unsupported connector protocol. Please use one of [aws_sigv4, http, mcp_sse, mcp_streamable_http]");
6666

6767
HttpConnector.builder().protocol("wrong protocol").build();
6868
}

common/src/test/java/org/opensearch/ml/common/connector/McpConnectorTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ public void setUp() {
5555
@Test
5656
public void constructor_InvalidProtocol() {
5757
exceptionRule.expect(IllegalArgumentException.class);
58-
exceptionRule.expectMessage("Unsupported connector protocol. Please use one of [aws_sigv4, http, mcp_sse]");
58+
exceptionRule.expectMessage("Unsupported connector protocol. Please use one of [aws_sigv4, http, mcp_sse, mcp_streamable_http]");
5959

6060
McpConnector.builder().protocol("wrong protocol").build();
6161
}

common/src/test/java/org/opensearch/ml/common/connector/McpStreamableHttpConnectorTest.java

Lines changed: 137 additions & 123 deletions
Large diffs are not rendered by default.

common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import static org.junit.Assert.assertNull;
1111
import static org.junit.Assert.assertThrows;
1212
import static org.junit.Assert.assertTrue;
13+
import static org.opensearch.ml.common.connector.ConnectorProtocols.MCP_SSE;
14+
import static org.opensearch.ml.common.connector.ConnectorProtocols.MCP_STREAMABLE_HTTP;
1315

1416
import java.io.IOException;
1517
import java.util.Arrays;
@@ -207,6 +209,94 @@ public void constructorMLCreateConnectorInput_EmptyCredential() {
207209
assertEquals("Connector credential is null or empty list", exception.getMessage());
208210
}
209211

212+
@Test
213+
public void constructorMLCreateConnectorInput_McpSseWithNullCredential_ShouldNotThrow() {
214+
// MCP SSE connectors should be allowed to have null credentials
215+
MLCreateConnectorInput connector = MLCreateConnectorInput
216+
.builder()
217+
.name(TEST_CONNECTOR_NAME)
218+
.description(TEST_CONNECTOR_DESCRIPTION)
219+
.version(TEST_CONNECTOR_VERSION)
220+
.protocol(MCP_SSE)
221+
.parameters(Map.of(TEST_PARAM_KEY, TEST_PARAM_VALUE))
222+
.credential(null)
223+
.actions(List.of())
224+
.access(AccessMode.PUBLIC)
225+
.backendRoles(Arrays.asList(TEST_ROLE1, TEST_ROLE2))
226+
.addAllBackendRoles(false)
227+
.build();
228+
229+
assertNotNull(connector);
230+
assertEquals(MCP_SSE, connector.getProtocol());
231+
assertNull(connector.getCredential());
232+
}
233+
234+
@Test
235+
public void constructorMLCreateConnectorInput_McpSseWithEmptyCredential_ShouldNotThrow() {
236+
// MCP SSE connectors should be allowed to have empty credentials
237+
MLCreateConnectorInput connector = MLCreateConnectorInput
238+
.builder()
239+
.name(TEST_CONNECTOR_NAME)
240+
.description(TEST_CONNECTOR_DESCRIPTION)
241+
.version(TEST_CONNECTOR_VERSION)
242+
.protocol(MCP_SSE)
243+
.parameters(Map.of(TEST_PARAM_KEY, TEST_PARAM_VALUE))
244+
.credential(Map.of())
245+
.actions(List.of())
246+
.access(AccessMode.PUBLIC)
247+
.backendRoles(Arrays.asList(TEST_ROLE1, TEST_ROLE2))
248+
.addAllBackendRoles(false)
249+
.build();
250+
251+
assertNotNull(connector);
252+
assertEquals(MCP_SSE, connector.getProtocol());
253+
assertTrue(connector.getCredential().isEmpty());
254+
}
255+
256+
@Test
257+
public void constructorMLCreateConnectorInput_McpStreamableHttpWithNullCredential_ShouldNotThrow() {
258+
// MCP Streamable HTTP connectors should be allowed to have null credentials
259+
MLCreateConnectorInput connector = MLCreateConnectorInput
260+
.builder()
261+
.name(TEST_CONNECTOR_NAME)
262+
.description(TEST_CONNECTOR_DESCRIPTION)
263+
.version(TEST_CONNECTOR_VERSION)
264+
.protocol(MCP_STREAMABLE_HTTP)
265+
.parameters(Map.of(TEST_PARAM_KEY, TEST_PARAM_VALUE))
266+
.credential(null)
267+
.actions(List.of())
268+
.access(AccessMode.PUBLIC)
269+
.backendRoles(Arrays.asList(TEST_ROLE1, TEST_ROLE2))
270+
.addAllBackendRoles(false)
271+
.build();
272+
273+
assertNotNull(connector);
274+
assertEquals(MCP_STREAMABLE_HTTP, connector.getProtocol());
275+
assertNull(connector.getCredential());
276+
}
277+
278+
@Test
279+
public void constructorMLCreateConnectorInput_McpStreamableHttpWithEmptyCredential_ShouldNotThrow() {
280+
// MCP Streamable HTTP connectors should be allowed to have empty credentials
281+
MLCreateConnectorInput connector = MLCreateConnectorInput
282+
.builder()
283+
.name(TEST_CONNECTOR_NAME)
284+
.description(TEST_CONNECTOR_DESCRIPTION)
285+
.version(TEST_CONNECTOR_VERSION)
286+
.protocol(MCP_STREAMABLE_HTTP)
287+
.parameters(Map.of(TEST_PARAM_KEY, TEST_PARAM_VALUE))
288+
.credential(Map.of())
289+
.actions(List.of())
290+
.access(AccessMode.PUBLIC)
291+
.backendRoles(Arrays.asList(TEST_ROLE1, TEST_ROLE2))
292+
.addAllBackendRoles(false)
293+
.build();
294+
295+
assertNotNull(connector);
296+
assertEquals(MCP_STREAMABLE_HTTP, connector.getProtocol());
297+
assertTrue(connector.getCredential().isEmpty());
298+
}
299+
210300
@Test
211301
public void testToXContent_FullFields() throws Exception {
212302
XContentBuilder builder = XContentFactory.jsonBuilder();

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/McpConnectorExecutor.java

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import static org.opensearch.ml.common.CommonValue.MCP_TOOL_DESCRIPTION_FIELD;
1212
import static org.opensearch.ml.common.CommonValue.MCP_TOOL_INPUT_SCHEMA_FIELD;
1313
import static org.opensearch.ml.common.CommonValue.MCP_TOOL_NAME_FIELD;
14-
import static org.opensearch.ml.common.CommonValue.SSE_ENDPOINT_FILED;
14+
import static org.opensearch.ml.common.CommonValue.SSE_ENDPOINT_FIELD;
1515
import static org.opensearch.ml.common.CommonValue.TOOL_INPUT_SCHEMA_FIELD;
1616
import static org.opensearch.ml.common.connector.ConnectorProtocols.MCP_SSE;
1717

@@ -75,8 +75,8 @@ public McpConnectorExecutor(Connector connector) {
7575

7676
public List<MLToolSpec> getMcpToolSpecs() {
7777
String mcpServerUrl = connector.getUrl();
78-
String sseEndpoint = connector.getParameters() != null && connector.getParameters().containsKey(SSE_ENDPOINT_FILED)
79-
? connector.getParameters().get(SSE_ENDPOINT_FILED)
78+
String sseEndpoint = connector.getParameters() != null && connector.getParameters().containsKey(SSE_ENDPOINT_FIELD)
79+
? connector.getParameters().get(SSE_ENDPOINT_FIELD)
8080
: MCP_DEFAULT_SSE_ENDPOINT;
8181
if (mcpServerUrl == null) {
8282
return Collections.emptyList();
@@ -87,8 +87,10 @@ public List<MLToolSpec> getMcpToolSpecs() {
8787
Duration readTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getReadTimeout());
8888

8989
Consumer<HttpRequest.Builder> headerConfig = builder -> {
90-
for (Map.Entry<String, String> entry : connector.getDecryptedHeaders().entrySet()) {
91-
builder.header(entry.getKey(), entry.getValue());
90+
if (connector.getDecryptedHeaders() != null) {
91+
for (Map.Entry<String, String> entry : connector.getDecryptedHeaders().entrySet()) {
92+
builder.header(entry.getKey(), entry.getValue());
93+
}
9294
}
9395
};
9496

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/McpStreamableHttpConnectorExecutor.java

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55

66
package org.opensearch.ml.engine.algorithms.remote;
77

8-
import static org.opensearch.ml.common.CommonValue.ENDPOINT_FILED;
9-
import static org.opensearch.ml.common.CommonValue.MCP_DEFAULT_ENDPOINT;
8+
import static org.opensearch.ml.common.CommonValue.ENDPOINT_FIELD;
9+
import static org.opensearch.ml.common.CommonValue.MCP_DEFAULT_STREAMABLE_HTTP_ENDPOINT;
1010
import static org.opensearch.ml.common.CommonValue.MCP_SYNC_CLIENT;
1111
import static org.opensearch.ml.common.CommonValue.MCP_TOOLS_FIELD;
1212
import static org.opensearch.ml.common.CommonValue.MCP_TOOL_DESCRIPTION_FIELD;
@@ -76,9 +76,9 @@ public McpStreamableHttpConnectorExecutor(Connector connector) {
7676

7777
public List<MLToolSpec> getMcpToolSpecs() {
7878
String mcpServerUrl = connector.getUrl();
79-
String endpoint = connector.getParameters() != null && connector.getParameters().containsKey(ENDPOINT_FILED)
80-
? connector.getParameters().get(ENDPOINT_FILED)
81-
: MCP_DEFAULT_ENDPOINT;
79+
String endpoint = connector.getParameters() != null && connector.getParameters().containsKey(ENDPOINT_FIELD)
80+
? connector.getParameters().get(ENDPOINT_FIELD)
81+
: MCP_DEFAULT_STREAMABLE_HTTP_ENDPOINT;
8282
if (mcpServerUrl == null) {
8383
return Collections.emptyList();
8484
}
@@ -88,8 +88,10 @@ public List<MLToolSpec> getMcpToolSpecs() {
8888
Duration readTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getReadTimeout());
8989

9090
Consumer<HttpRequest.Builder> headerConfig = builder -> {
91-
for (Map.Entry<String, String> entry : connector.getDecryptedHeaders().entrySet()) {
92-
builder.header(entry.getKey(), entry.getValue());
91+
if (connector.getDecryptedHeaders() != null) {
92+
for (Map.Entry<String, String> entry : connector.getDecryptedHeaders().entrySet()) {
93+
builder.header(entry.getKey(), entry.getValue());
94+
}
9395
}
9496
};
9597

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,9 @@
7373
import org.opensearch.ml.common.agent.MLToolSpec;
7474
import org.opensearch.ml.common.connector.AwsConnector;
7575
import org.opensearch.ml.common.connector.Connector;
76+
import org.opensearch.ml.common.connector.HttpConnector;
7677
import org.opensearch.ml.common.connector.McpConnector;
78+
import org.opensearch.ml.common.connector.McpStreamableHttpConnector;
7779
import org.opensearch.ml.common.output.model.ModelTensor;
7880
import org.opensearch.ml.common.output.model.ModelTensorOutput;
7981
import org.opensearch.ml.common.output.model.ModelTensors;
@@ -83,10 +85,12 @@
8385
import org.opensearch.ml.engine.MLEngineClassLoader;
8486
import org.opensearch.ml.engine.MLStaticMockBase;
8587
import org.opensearch.ml.engine.algorithms.remote.McpConnectorExecutor;
88+
import org.opensearch.ml.engine.algorithms.remote.McpStreamableHttpConnectorExecutor;
8689
import org.opensearch.ml.engine.encryptor.Encryptor;
8790
import org.opensearch.ml.engine.function_calling.FunctionCalling;
8891
import org.opensearch.ml.engine.function_calling.FunctionCallingFactory;
8992
import org.opensearch.ml.engine.tools.McpSseTool;
93+
import org.opensearch.ml.engine.tools.McpStreamableHttpTool;
9094
import org.opensearch.remote.metadata.client.GetDataObjectRequest;
9195
import org.opensearch.remote.metadata.client.GetDataObjectResponse;
9296
import org.opensearch.remote.metadata.client.SdkClient;
@@ -1898,4 +1902,102 @@ public void testParseLLMOutput_PathNotFoundExceptionWithEmptyToolCalls() {
18981902
Assert.assertTrue(output.containsKey(FINAL_ANSWER));
18991903
Assert.assertTrue(output.get(FINAL_ANSWER).contains("[]"));
19001904
}
1905+
1906+
@Test
1907+
public void testGetMcpToolSpecs_McpStreamableHttpConnectorSuccess() throws Exception {
1908+
stubGetConnector();
1909+
List<MLToolSpec> expected = List
1910+
.of(MLToolSpec.builder().type(McpStreamableHttpTool.TYPE).name("StreamableHttpTool").description("mock").build());
1911+
1912+
try (
1913+
MockedStatic<Connector> connStatic = mockStatic(Connector.class);
1914+
MockedStatic<MLEngineClassLoader> loadStatic = mockStatic(MLEngineClassLoader.class)
1915+
) {
1916+
// mock McpStreamableHttpConnector, McpStreamableHttpConnectorExecutor, agent, and listener
1917+
mockMcpStreamableHttpConnector(connStatic);
1918+
McpStreamableHttpConnectorExecutor exec = mock(McpStreamableHttpConnectorExecutor.class);
1919+
when(exec.getMcpToolSpecs()).thenReturn(expected);
1920+
loadStatic.when(() -> MLEngineClassLoader.initInstance(anyString(), any(), any())).thenReturn(exec);
1921+
1922+
MLAgent mlAgent = mockAgent("[{\"" + MCP_CONNECTOR_ID_FIELD + "\":\"c1\"}]", "tenant");
1923+
ActionListener<List<MLToolSpec>> listener = mock(ActionListener.class);
1924+
1925+
// run and verify
1926+
AgentUtils.getMcpToolSpecs(mlAgent, client, sdkClient, null, listener);
1927+
verify(listener).onResponse(expected);
1928+
}
1929+
}
1930+
1931+
@Test
1932+
public void testGetMcpToolSpecs_UnsupportedConnectorType() throws Exception {
1933+
stubGetConnector();
1934+
try (MockedStatic<Connector> connStatic = mockStatic(Connector.class)) {
1935+
// Mock an unsupported connector type (neither McpConnector nor McpStreamableHttpConnector)
1936+
HttpConnector mockConnector = mock(HttpConnector.class);
1937+
when(mockConnector.getProtocol()).thenReturn("http");
1938+
doNothing().when(mockConnector).decrypt(anyString(), any(), anyString());
1939+
connStatic.when(() -> Connector.createConnector(any(XContentParser.class))).thenReturn(mockConnector);
1940+
1941+
MLAgent agent = mockAgent("[{\"" + MCP_CONNECTOR_ID_FIELD + "\":\"c1\"}]", "tenant");
1942+
ActionListener<List<MLToolSpec>> listener = mock(ActionListener.class);
1943+
1944+
AgentUtils.getMcpToolSpecs(agent, client, sdkClient, null, listener);
1945+
verify(listener).onResponse(Collections.emptyList());
1946+
}
1947+
}
1948+
1949+
@Test
1950+
public void testGetMcpToolSpecs_ExceptionInGetMcpToolSpecs() throws Exception {
1951+
stubGetConnector();
1952+
1953+
try (
1954+
MockedStatic<Connector> connStatic = mockStatic(Connector.class);
1955+
MockedStatic<MLEngineClassLoader> loadStatic = mockStatic(MLEngineClassLoader.class)
1956+
) {
1957+
// mock McpConnector, McpConnectorExecutor, agent, and listener
1958+
mockMcpConnector(connStatic);
1959+
McpConnectorExecutor exec = mock(McpConnectorExecutor.class);
1960+
when(exec.getMcpToolSpecs()).thenThrow(new RuntimeException("Test exception"));
1961+
loadStatic.when(() -> MLEngineClassLoader.initInstance(anyString(), any(), any())).thenReturn(exec);
1962+
1963+
MLAgent mlAgent = mockAgent("[{\"" + MCP_CONNECTOR_ID_FIELD + "\":\"c1\"}]", "tenant");
1964+
ActionListener<List<MLToolSpec>> listener = mock(ActionListener.class);
1965+
1966+
// run and verify
1967+
AgentUtils.getMcpToolSpecs(mlAgent, client, sdkClient, null, listener);
1968+
verify(listener).onResponse(Collections.emptyList());
1969+
}
1970+
}
1971+
1972+
@Test
1973+
public void testGetMcpToolSpecs_ExceptionInGetConnector() throws Exception {
1974+
// Mock getConnector to throw an exception
1975+
threadContext = new ThreadContext(Settings.builder().build());
1976+
when(client.threadPool()).thenReturn(threadPool);
1977+
when(threadPool.getThreadContext()).thenReturn(threadContext);
1978+
1979+
when(sdkClient.getDataObjectAsync(any(GetDataObjectRequest.class))).thenAnswer(inv -> {
1980+
CompletionStage<GetDataObjectResponse> stage = mock(CompletionStage.class);
1981+
when(stage.whenComplete(any())).thenAnswer(cbInv -> {
1982+
BiConsumer<GetDataObjectResponse, Throwable> cb = cbInv.getArgument(0);
1983+
cb.accept(null, new RuntimeException("Failed to get connector"));
1984+
return stage;
1985+
});
1986+
return stage;
1987+
});
1988+
1989+
MLAgent mlAgent = mockAgent("[{\"" + MCP_CONNECTOR_ID_FIELD + "\":\"c1\"}]", "tenant");
1990+
ActionListener<List<MLToolSpec>> listener = mock(ActionListener.class);
1991+
1992+
AgentUtils.getMcpToolSpecs(mlAgent, client, sdkClient, null, listener);
1993+
verify(listener).onResponse(Collections.emptyList());
1994+
}
1995+
1996+
// Helper method to mock McpStreamableHttpConnector
1997+
private void mockMcpStreamableHttpConnector(MockedStatic<Connector> connectorStatic) {
1998+
McpStreamableHttpConnector mockConnector = mock(McpStreamableHttpConnector.class);
1999+
when(mockConnector.getProtocol()).thenReturn("mcp_streamable_http");
2000+
doNothing().when(mockConnector).decrypt(anyString(), any(), anyString());
2001+
connectorStatic.when(() -> Connector.createConnector(any(XContentParser.class))).thenReturn(mockConnector);
2002+
}
19012003
}

0 commit comments

Comments
 (0)