55
66package org .opensearch .ml .engine .algorithms .remote ;
77
8+ import static org .junit .Assert .assertEquals ;
89import static org .mockito .Mockito .mock ;
910import static org .mockito .Mockito .times ;
1011import static org .mockito .Mockito .verify ;
3132import org .opensearch .ml .common .connector .ConnectorAction ;
3233import org .opensearch .ml .common .connector .HttpConnector ;
3334import org .opensearch .ml .common .connector .MLPostProcessFunction ;
34- import org .opensearch .ml .common .exception .MLException ;
3535import org .opensearch .ml .common .output .model .ModelTensors ;
3636import org .opensearch .script .ScriptService ;
3737import org .reactivestreams .Publisher ;
@@ -191,7 +191,7 @@ public void test_onError() {
191191 ArgumentCaptor <Exception > captor = ArgumentCaptor .forClass (Exception .class );
192192 verify (actionListener ).onFailure (captor .capture ());
193193 assert captor .getValue () instanceof OpenSearchStatusException ;
194- assert captor . getValue (). getMessage (). equals ( "Error communicating with remote model: runtime exception" );
194+ assertEquals ( "Error communicating with remote model: runtime exception" , captor . getValue (). getMessage () );
195195 }
196196
197197 @ Test
@@ -209,7 +209,7 @@ public void test_onSubscribe() {
209209 public void test_onNext () {
210210 test_onSubscribe ();// set the subscription to non-null.
211211 responseSubscriber .onNext (ByteBuffer .wrap ("hello world" .getBytes ()));
212- assert mlSdkAsyncHttpResponseHandler .getResponseBody ().toString (). equals ( "hello world" );
212+ assertEquals ( "hello world" , mlSdkAsyncHttpResponseHandler .getResponseBody ().toString ());
213213 }
214214
215215 @ Test
@@ -221,7 +221,7 @@ public void test_MLResponseSubscriber_onError() {
221221 ArgumentCaptor <Exception > captor = ArgumentCaptor .forClass (Exception .class );
222222 verify (actionListener , times (1 )).onFailure (captor .capture ());
223223 assert captor .getValue () instanceof OpenSearchStatusException ;
224- assert captor . getValue (). getMessage (). equals ( "Remote service returned error status 500 with empty body" );
224+ assertEquals ( "Remote service returned error status 500 with empty body" , captor . getValue (). getMessage () );
225225 }
226226
227227 @ Test
@@ -283,7 +283,7 @@ public void test_onComplete_failed() {
283283 mlSdkAsyncHttpResponseHandler .onStream (stream );
284284 ArgumentCaptor <OpenSearchStatusException > captor = ArgumentCaptor .forClass (OpenSearchStatusException .class );
285285 verify (actionListener , times (1 )).onFailure (captor .capture ());
286- assert captor . getValue (). getMessage (). equals ( "Error from remote service: Model current status is: FAILED" );
286+ assertEquals ( "Error from remote service: Model current status is: FAILED" , captor . getValue (). getMessage () );
287287 assert captor .getValue ().status ().getStatus () == 500 ;
288288 }
289289
@@ -302,7 +302,7 @@ public void test_onComplete_empty_response_body() {
302302 mlSdkAsyncHttpResponseHandler .onStream (stream );
303303 ArgumentCaptor <OpenSearchStatusException > captor = ArgumentCaptor .forClass (OpenSearchStatusException .class );
304304 verify (actionListener , times (1 )).onFailure (captor .capture ());
305- assert captor . getValue (). getMessage (). equals ( "Remote service returned empty response body" );
305+ assertEquals ( "Remote service returned empty response body" , captor . getValue (). getMessage () );
306306 }
307307
308308 @ Test
@@ -380,14 +380,12 @@ public void test_onComplete_throttle_exception_onFailure() {
380380
381381 ArgumentCaptor <OpenSearchStatusException > captor = ArgumentCaptor .forClass (RemoteConnectorThrottlingException .class );
382382 verify (actionListener , times (1 )).onFailure (captor .capture ());
383- assert captor
384- .getValue ()
385- .getMessage ()
386- .equals (
387- "Error from remote service: The request was denied due to remote server throttling. "
388- + "To change the retry policy and behavior, please update the connector client_config."
389- );
390383 assert captor .getValue ().status ().getStatus () == HttpStatusCode .BAD_REQUEST ;
384+ assertEquals (
385+ "Error from remote service: The request was denied due to remote server throttling. "
386+ + "To change the retry policy and behavior, please update the connector client_config." ,
387+ captor .getValue ().getMessage ()
388+ );
391389 }
392390
393391 @ Test
@@ -416,8 +414,39 @@ public void test_onComplete_processOutputFail_onFailure() {
416414 };
417415 mlSdkAsyncHttpResponseHandler .onStream (stream );
418416
419- ArgumentCaptor <MLException > captor = ArgumentCaptor .forClass (MLException .class );
417+ ArgumentCaptor <IllegalArgumentException > captor = ArgumentCaptor .forClass (IllegalArgumentException .class );
420418 verify (actionListener , times (1 )).onFailure (captor .capture ());
421- assert captor .getValue ().getMessage ().equals ("Fail to execute PREDICT in aws connector" );
419+ assertEquals ("no PREDICT action found" , captor .getValue ().getMessage ());
420+ }
421+
422+ /**
423+ * Asserts that IllegalArgumentException is propagated where post-processing function fails
424+ * on response
425+ */
426+ @ Test
427+ public void onComplete_InvalidEmbeddingBedRockPostProcessingOccurs_IllegalArgumentExceptionThrown () {
428+ String invalidEmbeddingResponse = "{ \" embedding\" : [[1]] }" ;
429+
430+ mlSdkAsyncHttpResponseHandler .onHeaders (sdkHttpResponse );
431+ Publisher <ByteBuffer > stream = s -> {
432+ try {
433+ s .onSubscribe (mock (Subscription .class ));
434+ s .onNext (ByteBuffer .wrap (invalidEmbeddingResponse .getBytes ()));
435+ s .onComplete ();
436+ } catch (Throwable e ) {
437+ s .onError (e );
438+ }
439+ };
440+ mlSdkAsyncHttpResponseHandler .onStream (stream );
441+
442+ ArgumentCaptor <IllegalArgumentException > exceptionCaptor = ArgumentCaptor .forClass (IllegalArgumentException .class );
443+ verify (actionListener , times (1 )).onFailure (exceptionCaptor .capture ());
444+
445+ // Error message
446+ assertEquals (
447+ "BedrockEmbeddingPostProcessFunction exception message should match" ,
448+ "The embedding should be a non-empty List containing Float values." ,
449+ exceptionCaptor .getValue ().getMessage ()
450+ );
422451 }
423452}
0 commit comments