77
88import static org .mockito .ArgumentMatchers .any ;
99import static org .mockito .Mockito .argThat ;
10+ import static org .mockito .Mockito .doThrow ;
1011import static org .mockito .Mockito .spy ;
1112import static org .mockito .Mockito .times ;
13+ import static org .mockito .Mockito .verify ;
1214import static org .mockito .Mockito .when ;
1315import static org .opensearch .ml .common .connector .AbstractConnector .ACCESS_KEY_FIELD ;
1416import static org .opensearch .ml .common .connector .AbstractConnector .SECRET_KEY_FIELD ;
1719import static org .opensearch .ml .common .connector .HttpConnector .SERVICE_NAME_FIELD ;
1820import static org .opensearch .ml .engine .algorithms .remote .ConnectorUtils .SKIP_VALIDATE_MISSING_PARAMETERS ;
1921
22+ import java .io .IOException ;
2023import java .util .Arrays ;
24+ import java .util .HashMap ;
2125import java .util .Map ;
2226
2327import org .junit .Assert ;
3034import org .opensearch .common .settings .Settings ;
3135import org .opensearch .common .util .concurrent .ThreadContext ;
3236import org .opensearch .core .action .ActionListener ;
37+ import org .opensearch .core .xcontent .XContentBuilder ;
3338import org .opensearch .ingest .TestTemplateService ;
3439import org .opensearch .ml .common .FunctionName ;
3540import org .opensearch .ml .common .connector .AwsConnector ;
3944import org .opensearch .ml .common .connector .RetryBackoffPolicy ;
4045import org .opensearch .ml .common .dataset .remote .RemoteInferenceInputDataSet ;
4146import org .opensearch .ml .common .input .MLInput ;
47+ import org .opensearch .ml .common .input .parameter .MLAlgoParams ;
48+ import org .opensearch .ml .common .input .parameter .clustering .KMeansParams ;
49+ import org .opensearch .ml .common .input .parameter .textembedding .AsymmetricTextEmbeddingParameters ;
50+ import org .opensearch .ml .common .input .parameter .textembedding .SparseEmbeddingFormat ;
4251import org .opensearch .ml .common .output .model .ModelTensors ;
4352import org .opensearch .ml .engine .encryptor .Encryptor ;
4453import org .opensearch .ml .engine .encryptor .EncryptorImpl ;
@@ -64,6 +73,9 @@ public class RemoteConnectorExecutorTest {
6473 @ Mock
6574 ActionListener <Tuple <Integer , ModelTensors >> actionListener ;
6675
76+ @ Mock
77+ private MLAlgoParams mlInputParams ;
78+
6779 @ Before
6880 public void setUp () {
6981 MockitoAnnotations .openMocks (this );
@@ -169,4 +181,165 @@ public void executePreparePayloadAndInvoke_SkipValidateMissingParameterDefault()
169181 );
170182 assert exception .getMessage ().contains ("Some parameter placeholder not filled in payload: role" );
171183 }
184+
185+ @ Test
186+ public void executePreparePayloadAndInvoke_PassingParameter () {
187+ Map <String , String > parameters = ImmutableMap .of (SERVICE_NAME_FIELD , "sagemaker" , REGION_FIELD , "us-west-2" );
188+ Connector connector = getConnector (parameters );
189+ AwsConnectorExecutor executor = getExecutor (connector );
190+
191+ RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet
192+ .builder ()
193+ .parameters (Map .of ("input" , "You are a ${parameters.role}" ))
194+ .actionType (PREDICT )
195+ .build ();
196+ String actionType = inputDataSet .getActionType ().toString ();
197+ AsymmetricTextEmbeddingParameters inputParams = AsymmetricTextEmbeddingParameters
198+ .builder ()
199+ .sparseEmbeddingFormat (SparseEmbeddingFormat .WORD )
200+ .embeddingContentType (null )
201+ .build ();
202+ MLInput mlInput = MLInput
203+ .builder ()
204+ .algorithm (FunctionName .TEXT_EMBEDDING )
205+ .parameters (inputParams )
206+ .inputDataset (inputDataSet )
207+ .build ();
208+
209+ Exception exception = Assert
210+ .assertThrows (
211+ IllegalArgumentException .class ,
212+ () -> executor .preparePayloadAndInvoke (actionType , mlInput , null , actionListener )
213+ );
214+ assert exception .getMessage ().contains ("Some parameter placeholder not filled in payload: role" );
215+ }
216+
217+ @ Test
218+ public void executePreparePayloadAndInvoke_GetParamsIOException () throws Exception {
219+ Map <String , String > parameters = ImmutableMap .of (SERVICE_NAME_FIELD , "sagemaker" , REGION_FIELD , "us-west-2" );
220+ Connector connector = getConnector (parameters );
221+ AwsConnectorExecutor executor = getExecutor (connector );
222+
223+ RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet
224+ .builder ()
225+ .parameters (Map .of ("input" , "test input" ))
226+ .actionType (PREDICT )
227+ .build ();
228+ String actionType = inputDataSet .getActionType ().toString ();
229+ doThrow (new IOException ("UT test IOException" )).when (mlInputParams ).toXContent (any (XContentBuilder .class ), any ());
230+ MLInput mlInput = MLInput
231+ .builder ()
232+ .algorithm (FunctionName .TEXT_EMBEDDING )
233+ .parameters (mlInputParams )
234+ .inputDataset (inputDataSet )
235+ .build ();
236+
237+ executor .preparePayloadAndInvoke (actionType , mlInput , null , actionListener );
238+ verify (actionListener ).onFailure (argThat (e -> e instanceof IOException && e .getMessage ().contains ("UT test IOException" )));
239+ }
240+
241+ @ Test
242+ public void executeGetParams_MissingParameter () {
243+ Map <String , String > parameters = ImmutableMap .of (SERVICE_NAME_FIELD , "sagemaker" , REGION_FIELD , "us-west-2" );
244+ Connector connector = getConnector (parameters );
245+ AwsConnectorExecutor executor = getExecutor (connector );
246+
247+ RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet
248+ .builder ()
249+ .parameters (Map .of ("input" , "${parameters.input}" ))
250+ .actionType (PREDICT )
251+ .build ();
252+ String actionType = inputDataSet .getActionType ().toString ();
253+ AsymmetricTextEmbeddingParameters inputParams = AsymmetricTextEmbeddingParameters
254+ .builder ()
255+ .sparseEmbeddingFormat (SparseEmbeddingFormat .WORD )
256+ .embeddingContentType (null )
257+ .build ();
258+ MLInput mlInput = MLInput
259+ .builder ()
260+ .algorithm (FunctionName .TEXT_EMBEDDING )
261+ .parameters (inputParams )
262+ .inputDataset (inputDataSet )
263+ .build ();
264+
265+ try {
266+ Map <String , String > paramsMap = RemoteConnectorExecutor .getParams (mlInput );
267+ Map <String , String > expectedMap = new HashMap <>();
268+ expectedMap .put ("sparse_embedding_format" , "WORD" );
269+ Assert .assertEquals (expectedMap , paramsMap );
270+ } catch (IOException e ) {
271+ e .printStackTrace ();
272+ }
273+ }
274+
275+ @ Test
276+ public void executeGetParams_PassingParameter () {
277+ Map <String , String > parameters = ImmutableMap .of (SERVICE_NAME_FIELD , "sagemaker" , REGION_FIELD , "us-west-2" );
278+ Connector connector = getConnector (parameters );
279+ AwsConnectorExecutor executor = getExecutor (connector );
280+
281+ RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet
282+ .builder ()
283+ .parameters (Map .of ("input" , "${parameters.input}" ))
284+ .actionType (PREDICT )
285+ .build ();
286+ String actionType = inputDataSet .getActionType ().toString ();
287+ AsymmetricTextEmbeddingParameters inputParams = AsymmetricTextEmbeddingParameters
288+ .builder ()
289+ .sparseEmbeddingFormat (SparseEmbeddingFormat .WORD )
290+ .embeddingContentType (AsymmetricTextEmbeddingParameters .EmbeddingContentType .PASSAGE )
291+ .build ();
292+ MLInput mlInput = MLInput
293+ .builder ()
294+ .algorithm (FunctionName .TEXT_EMBEDDING )
295+ .parameters (inputParams )
296+ .inputDataset (inputDataSet )
297+ .build ();
298+
299+ try {
300+ Map <String , String > paramsMap = RemoteConnectorExecutor .getParams (mlInput );
301+ Map <String , String > expectedMap = new HashMap <>();
302+ expectedMap .put ("sparse_embedding_format" , "WORD" );
303+ expectedMap .put ("content_type" , "PASSAGE" );
304+ Assert .assertEquals (expectedMap , paramsMap );
305+ } catch (IOException e ) {
306+ e .printStackTrace ();
307+ }
308+ }
309+
310+ @ Test
311+ public void executeGetParams_ConvertToString () {
312+ Map <String , String > parameters = ImmutableMap .of (SERVICE_NAME_FIELD , "sagemaker" , REGION_FIELD , "us-west-2" );
313+ Connector connector = getConnector (parameters );
314+ AwsConnectorExecutor executor = getExecutor (connector );
315+
316+ RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet
317+ .builder ()
318+ .parameters (Map .of ("input" , "${parameters.input}" ))
319+ .actionType (PREDICT )
320+ .build ();
321+ KMeansParams inputParams = KMeansParams
322+ .builder ()
323+ .centroids (5 )
324+ .iterations (100 )
325+ .distanceType (KMeansParams .DistanceType .EUCLIDEAN )
326+ .build ();
327+ MLInput mlInput = MLInput
328+ .builder ()
329+ .algorithm (FunctionName .TEXT_EMBEDDING )
330+ .parameters (inputParams )
331+ .inputDataset (inputDataSet )
332+ .build ();
333+
334+ try {
335+ Map <String , String > paramsMap = RemoteConnectorExecutor .getParams (mlInput );
336+ Map <String , String > expectedMap = new HashMap <>();
337+ expectedMap .put ("centroids" , "5" );
338+ expectedMap .put ("iterations" , "100" );
339+ expectedMap .put ("distance_type" , "EUCLIDEAN" );
340+ Assert .assertEquals (expectedMap , paramsMap );
341+ } catch (IOException e ) {
342+ e .printStackTrace ();
343+ }
344+ }
172345}
0 commit comments