Skip to content

Commit

Permalink
fix post_process_function on rerank_pipeline_with_bge-rerank-m3-v2_mo…
Browse files Browse the repository at this point in the history
…del_deployed_on_Sagemaker.md (opensearch-project#3296)

* fix post_process_function bug on sort results for rerank_pipeline_with_bge-rerank-m3-v2_model_deployed_on_Sagemaker.md (opensearch-project#3247)

Signed-off-by: tkykenmt <tkykenmto+github.com@gmail.com>

* fix typo

Signed-off-by: Yaliang Wu <ylwu@amazon.com>

---------

Signed-off-by: tkykenmt <tkykenmto+github.com@gmail.com>
Signed-off-by: Yaliang Wu <ylwu@amazon.com>
Co-authored-by: Yaliang Wu <ylwu@amazon.com>
  • Loading branch information
tkykenmt and ylwu-amzn authored Jan 3, 2025
1 parent bf48f99 commit d5f47b4
Showing 1 changed file with 127 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,39 @@ result = predictor.predict(data={
]
})

print(json.dumps(sorted(result, key=lambda x: x['index']), indent=2))
print(json.dumps(result, indent=2))
```

The reranking results are as follows:
The reranking result is ordering by the highest score first:
```
[
{
"index": 2,
"score": 0.92879725
},
{
"index": 0,
"score": 0.013636836
},
{
"index": 1,
"score": 0.000593021
},
{
"index": 3,
"score": 0.00012148176
}
]
```

You can sort the result by index number.

```python
print(json.dumps(sorted(result, key=lambda x: x['index']),indent=2))

```

The results are as follows:

```
[
Expand Down Expand Up @@ -121,9 +150,51 @@ POST /_plugins/_ml/connectors/_create
"headers": {
"content-type": "application/json"
},
"request_body": "{ \"query\": \"${parameters.query}\", \"texts\": ${parameters.texts} }",
"pre_process_function": "\n def query_text = params.query_text;\n def text_docs = params.text_docs;\n def textDocsBuilder = new StringBuilder('[');\n for (int i=0; i<text_docs.length; i++) {\n textDocsBuilder.append('\"');\n textDocsBuilder.append(text_docs[i]);\n textDocsBuilder.append('\"');\n if (i<text_docs.length - 1) {\n textDocsBuilder.append(',');\n }\n }\n textDocsBuilder.append(']');\n def parameters = '{ \"query\": \"' + query_text + '\", \"texts\": ' + textDocsBuilder.toString() + ' }';\n return '{\"parameters\": ' + parameters + '}';\n",
"post_process_function": "\n \n def dataType = \"FLOAT32\";\n \n \n if (params.result == null)\n {\n return 'no result generated';\n //return params.response;\n }\n def outputs = params.result;\n \n def sorted_outputs = outputs;\n for (int i=0; i<outputs.length; i++) {\n def idx = new BigDecimal(outputs[i].index.toString()).intValue();\n sorted_outputs[idx] = outputs[i];\n }\n def resultBuilder = new StringBuilder('[');\n for (int i=0; i<outputs.length; i++) {\n resultBuilder.append(' {\"name\": \"similarity\", \"data_type\": \"FLOAT32\", \"shape\": [1],');\n //resultBuilder.append('{\"name\": \"similarity\"}');\n \n resultBuilder.append('\"data\": [');\n resultBuilder.append(outputs[i].score);\n resultBuilder.append(']}');\n if (i<outputs.length - 1) {\n resultBuilder.append(',');\n }\n }\n resultBuilder.append(']');\n \n return resultBuilder.toString();\n "
"pre_process_function": """
def query_text = params.query_text;
def text_docs = params.text_docs;
def textDocsBuilder = new StringBuilder('[');
for (int i=0; i<text_docs.length; i++) {
textDocsBuilder.append('"');
textDocsBuilder.append(text_docs[i]);
textDocsBuilder.append('"');
if (i<text_docs.length - 1) {
textDocsBuilder.append(',');
}
}
textDocsBuilder.append(']');
def parameters = '{ "query": "' + query_text + '", "texts": ' + textDocsBuilder.toString() + ' }';
return '{"parameters": ' + parameters + '}';
""",
"request_body": """
{
"query": "${parameters.query}",
"texts": ${parameters.texts}
}
""",
"post_process_function": """
if (params.result == null || params.result.length == 0) {
throw new IllegalArgumentException("Post process function input is empty.");
}
def outputs = params.result;
def scores = new Double[outputs.length];
for (int i=0; i<outputs.length; i++) {
def index = new BigDecimal(outputs[i].index.toString()).intValue();
scores[index] = outputs[i].score;
}
def resultBuilder = new StringBuilder('[');
for (int i=0; i<scores.length; i++) {
resultBuilder.append(' {"name": "similarity", "data_type": "FLOAT32", "shape": [1],');
resultBuilder.append('"data": [');
resultBuilder.append(scores[i]);
resultBuilder.append(']}');
if (i<outputs.length - 1) {
resultBuilder.append(',');
}
}
resultBuilder.append(']');
return resultBuilder.toString();
"""
}
]
}
Expand Down Expand Up @@ -152,9 +223,51 @@ POST /_plugins/_ml/connectors/_create
"headers": {
"content-type": "application/json"
},
"request_body": "{ \"query\": \"${parameters.query}\", \"texts\": ${parameters.texts} }",
"pre_process_function": "\n def query_text = params.query_text;\n def text_docs = params.text_docs;\n def textDocsBuilder = new StringBuilder('[');\n for (int i=0; i<text_docs.length; i++) {\n textDocsBuilder.append('\"');\n textDocsBuilder.append(text_docs[i]);\n textDocsBuilder.append('\"');\n if (i<text_docs.length - 1) {\n textDocsBuilder.append(',');\n }\n }\n textDocsBuilder.append(']');\n def parameters = '{ \"query\": \"' + query_text + '\", \"texts\": ' + textDocsBuilder.toString() + ' }';\n return '{\"parameters\": ' + parameters + '}';\n",
"post_process_function": "\n \n def dataType = \"FLOAT32\";\n \n \n if (params.result == null)\n {\n return 'no result generated';\n //return params.response;\n }\n def outputs = params.result;\n \n def sorted_outputs = outputs;\n for (int i=0; i<outputs.length; i++) {\n def idx = new BigDecimal(outputs[i].index.toString()).intValue();\n sorted_outputs[idx] = outputs[i];\n }\n def resultBuilder = new StringBuilder('[');\n for (int i=0; i<outputs.length; i++) {\n resultBuilder.append(' {\"name\": \"similarity\", \"data_type\": \"FLOAT32\", \"shape\": [1],');\n //resultBuilder.append('{\"name\": \"similarity\"}');\n \n resultBuilder.append('\"data\": [');\n resultBuilder.append(outputs[i].score);\n resultBuilder.append(']}');\n if (i<outputs.length - 1) {\n resultBuilder.append(',');\n }\n }\n resultBuilder.append(']');\n \n return resultBuilder.toString();\n "
"pre_process_function": """
def query_text = params.query_text;
def text_docs = params.text_docs;
def textDocsBuilder = new StringBuilder('[');
for (int i=0; i<text_docs.length; i++) {
textDocsBuilder.append('"');
textDocsBuilder.append(text_docs[i]);
textDocsBuilder.append('"');
if (i<text_docs.length - 1) {
textDocsBuilder.append(',');
}
}
textDocsBuilder.append(']');
def parameters = '{ "query": "' + query_text + '", "texts": ' + textDocsBuilder.toString() + ' }';
return '{"parameters": ' + parameters + '}';
""",
"request_body": """
{
"query": "${parameters.query}",
"texts": ${parameters.texts}
}
""",
"post_process_function": """
if (params.result == null || params.result.length == 0) {
throw new IllegalArgumentException("Post process function input is empty.");
}
def outputs = params.result;
def scores = new Double[outputs.length];
for (int i=0; i<outputs.length; i++) {
def index = new BigDecimal(outputs[i].index.toString()).intValue();
scores[index] = outputs[i].score;
}
def resultBuilder = new StringBuilder('[');
for (int i=0; i<scores.length; i++) {
resultBuilder.append(' {"name": "similarity", "data_type": "FLOAT32", "shape": [1],');
resultBuilder.append('"data": [');
resultBuilder.append(scores[i]);
resultBuilder.append(']}');
if (i<outputs.length - 1) {
resultBuilder.append(',');
}
}
resultBuilder.append(']');
return resultBuilder.toString();
"""
}
]
}
Expand Down Expand Up @@ -188,7 +301,7 @@ POST _plugins/_ml/models/your_model_id/_predict
}
```

Each item in the `inputs` array comprises a `query_text` and a `text_docs` string, separated by a ` . `
Each item in the array comprises a `query_text` and a `text_docs` string, separated by a ` . `

Alternatively, you can test the model as follows:
```json
Expand All @@ -209,6 +322,10 @@ The connector `pre_process_function` transforms the input into the format requir
By default, the SageMaker model output has the following format:
```json
[
{
"index": 2,
"score": 0.92879725
},
{
"index": 0,
"score": 0.013636836
Expand All @@ -217,18 +334,14 @@ By default, the SageMaker model output has the following format:
"index": 1,
"score": 0.000593021
},
{
"index": 2,
"score": 0.92879725
},
{
"index": 3,
"score": 0.00012148176
}
]
```

The connector `post_process_function` transforms the model's output into a format that the [Reranker processor](https://opensearch.org/docs/latest/search-plugins/search-pipelines/rerank-processor/) can interpret. This adapted format is as follows:
The connector `post_process_function` transforms the model's output into a format that the [Reranker processor](https://opensearch.org/docs/latest/search-plugins/search-pipelines/rerank-processor/) can interpret, and order result by index. This adapted format is as follows:
```json
{
"inference_results": [
Expand Down

0 comments on commit d5f47b4

Please sign in to comment.