Skip to content

Commit

Permalink
check hash value (#878) (#880)
Browse files Browse the repository at this point in the history
Signed-off-by: Jing Zhang <jngz@amazon.com>
(cherry picked from commit 22ae4ca)

Co-authored-by: Jing Zhang <jngz@amazon.com>
  • Loading branch information
opensearch-trigger-bot[bot] and jngz-es authored May 4, 2023
1 parent e2210be commit e26a605
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ public class MLRegisterModelInput implements ToXContentObject, Writeable {
public static final String DESCRIPTION_FIELD = "description";
public static final String VERSION_FIELD = "version";
public static final String URL_FIELD = "url";
public static final String HASH_VALUE_FIELD = "model_content_hash_value";
public static final String MODEL_FORMAT_FIELD = "model_format";
public static final String MODEL_CONFIG_FIELD = "model_config";
public static final String DEPLOY_MODEL_FIELD = "deploy_model";
Expand All @@ -47,6 +48,7 @@ public class MLRegisterModelInput implements ToXContentObject, Writeable {
private String version;
private String description;
private String url;
private String hashValue;
private MLModelFormat modelFormat;
private MLModelConfig modelConfig;

Expand All @@ -59,6 +61,7 @@ public MLRegisterModelInput(FunctionName functionName,
String version,
String description,
String url,
String hashValue,
MLModelFormat modelFormat,
MLModelConfig modelConfig,
boolean deployModel,
Expand All @@ -84,6 +87,7 @@ public MLRegisterModelInput(FunctionName functionName,
this.version = version;
this.description = description;
this.url = url;
this.hashValue = hashValue;
this.modelFormat = modelFormat;
this.modelConfig = modelConfig;
this.deployModel = deployModel;
Expand All @@ -97,6 +101,7 @@ public MLRegisterModelInput(StreamInput in) throws IOException {
this.version = in.readString();
this.description = in.readOptionalString();
this.url = in.readOptionalString();
this.hashValue = in.readOptionalString();
if (in.readBoolean()) {
this.modelFormat = in.readEnum(MLModelFormat.class);
}
Expand All @@ -114,6 +119,7 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeString(version);
out.writeOptionalString(description);
out.writeOptionalString(url);
out.writeOptionalString(hashValue);
if (modelFormat != null) {
out.writeBoolean(true);
out.writeEnum(modelFormat);
Expand Down Expand Up @@ -142,6 +148,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (url != null) {
builder.field(URL_FIELD, url);
}
if (hashValue != null) {
builder.field(HASH_VALUE_FIELD, hashValue);
}
if (modelFormat != null) {
builder.field(MODEL_FORMAT_FIELD, modelFormat);
}
Expand All @@ -159,6 +168,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
public static MLRegisterModelInput parse(XContentParser parser, String modelName, String version, boolean deployModel) throws IOException {
FunctionName functionName = null;
String url = null;
String hashValue = null;
String description = null;
MLModelFormat modelFormat = null;
MLModelConfig modelConfig = null;
Expand All @@ -175,6 +185,9 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName
case URL_FIELD:
url = parser.text();
break;
case HASH_VALUE_FIELD:
hashValue = parser.text();
break;
case DESCRIPTION_FIELD:
description = parser.text();
break;
Expand All @@ -195,14 +208,15 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName
break;
}
}
return new MLRegisterModelInput(functionName, modelName, version, description, url, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]));
return new MLRegisterModelInput(functionName, modelName, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]));
}

public static MLRegisterModelInput parse(XContentParser parser, boolean deployModel) throws IOException {
FunctionName functionName = null;
String name = null;
String version = null;
String url = null;
String hashValue = null;
String description = null;
MLModelFormat modelFormat = null;
MLModelConfig modelConfig = null;
Expand All @@ -229,6 +243,9 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo
case URL_FIELD:
url = parser.text();
break;
case HASH_VALUE_FIELD:
hashValue = parser.text();
break;
case MODEL_FORMAT_FIELD:
modelFormat = MLModelFormat.from(parser.text().toUpperCase(Locale.ROOT));
break;
Expand All @@ -246,6 +263,6 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo
break;
}
}
return new MLRegisterModelInput(functionName, name, version, description, url, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]));
return new MLRegisterModelInput(functionName, name, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ public void downloadPrebuiltModelConfig(String taskId, MLRegisterModelInput regi
}
builder.modelConfig(configBuilder.build());
break;
case MLRegisterModelInput.HASH_VALUE_FIELD:
builder.hashValue(entry.getValue().toString());
break;
default:
break;
}
Expand All @@ -141,9 +144,10 @@ public void downloadPrebuiltModelConfig(String taskId, MLRegisterModelInput regi
* @param modelName model name
* @param version model version
* @param url model file URL
* @param modelContentHash model content hash value
* @param listener action listener
*/
public void downloadAndSplit(MLModelFormat modelFormat, String taskId, String modelName, String version, String url, ActionListener<Map<String, Object>> listener) {
public void downloadAndSplit(MLModelFormat modelFormat, String taskId, String modelName, String version, String url, String modelContentHash, ActionListener<Map<String, Object>> listener) {
try {
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
Path registerModelPath = mlEngine.getRegisterModelPath(taskId, modelName, version);
Expand All @@ -153,6 +157,11 @@ public void downloadAndSplit(MLModelFormat modelFormat, String taskId, String mo
log.debug("download model to file {}", modelZipFile.getAbsolutePath());
DownloadUtils.download(url, modelPath, new ProgressBar());
verifyModelZipFile(modelFormat, modelPath, modelName);
String hash = calculateFileHash(modelZipFile);
if (modelContentHash != null && !modelContentHash.equals(hash)) {
log.error("Model content hash can't match original hash value when registering");
throw (new IllegalArgumentException("model content changed"));
}

List<String> chunkFiles = splitFileIntoChunks(modelZipFile, modelPartsPath, CHUNK_SIZE);
Map<String, Object> result = new HashMap<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ public class ModelHelperTest {
private MLModelFormat modelFormat;
private String modelId;
private MLEngine mlEngine;
private String hashValue = "e13b74006290a9d0f58c1376f9629d4ebc05a0f9385f40db837452b167ae9021";

@Mock
ActionListener<Map<String, Object>> actionListener;
Expand All @@ -57,7 +58,8 @@ public void setup() throws URISyntaxException {

@Test
public void testDownloadAndSplit_UrlFailure() {
modelHelper.downloadAndSplit(modelFormat, modelId, "model_name", "1", "http://testurl", actionListener);
modelId = "url_failure_model_id";
modelHelper.downloadAndSplit(modelFormat, modelId, "model_name", "1", "http://testurl", null, actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals(PrivilegedActionException.class, argumentCaptor.getValue().getClass());
Expand All @@ -66,7 +68,26 @@ public void testDownloadAndSplit_UrlFailure() {
@Test
public void testDownloadAndSplit() throws URISyntaxException {
String modelUrl = getClass().getResource("traced_small_model.zip").toURI().toString();
modelHelper.downloadAndSplit(modelFormat, modelId, "model_name", "1", modelUrl, actionListener);
modelHelper.downloadAndSplit(modelFormat, modelId, "model_name", "1", modelUrl, null, actionListener);
ArgumentCaptor<Map> argumentCaptor = ArgumentCaptor.forClass(Map.class);
verify(actionListener).onResponse(argumentCaptor.capture());
assertNotNull(argumentCaptor.getValue());
assertNotEquals(0, argumentCaptor.getValue().size());
}

@Test
public void testDownloadAndSplit_HashFailure() throws URISyntaxException {
String modelUrl = getClass().getResource("traced_small_model.zip").toURI().toString();
modelHelper.downloadAndSplit(modelFormat, modelId, "model_name", "1", modelUrl, "wrong_hash_value", actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals(IllegalArgumentException.class, argumentCaptor.getValue().getClass());
}

@Test
public void testDownloadAndSplit_Hash() throws URISyntaxException {
String modelUrl = getClass().getResource("traced_small_model.zip").toURI().toString();
modelHelper.downloadAndSplit(modelFormat, modelId, "model_name", "1", modelUrl, hashValue, actionListener);
ArgumentCaptor<Map> argumentCaptor = ArgumentCaptor.forClass(Map.class);
verify(actionListener).onResponse(argumentCaptor.capture());
assertNotNull(argumentCaptor.getValue());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ private void registerModel(
modelName,
version,
registerModelInput.getUrl(),
registerModelInput.getHashValue(),
ActionListener.wrap(result -> {
Long modelSizeInBytes = (Long) result.get(MODEL_SIZE_IN_BYTES);
if (modelSizeInBytes >= MODEL_FILE_SIZE_LIMIT) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ public void testRegisterMLModel_InitModelIndexFailure() {

modelManager.registerMLModel(registerModelInput, mlTask);
verify(mlTaskManager).updateMLTask(anyString(), anyMap(), anyLong(), anyBoolean());
verify(modelHelper, never()).downloadAndSplit(any(), any(), any(), any(), any(), any());
verify(modelHelper, never()).downloadAndSplit(any(), any(), any(), any(), any(), any(), any());
verify(client, never()).index(any(), any());
}

Expand All @@ -308,7 +308,7 @@ public void testRegisterMLModel_IndexModelMetaFailure() {
modelManager.registerMLModel(registerModelInput, mlTask);
verify(mlIndicesHandler).initModelIndexIfAbsent(any());
verify(client).index(any(), any());
verify(modelHelper, never()).downloadAndSplit(any(), any(), any(), any(), any(), any());
verify(modelHelper, never()).downloadAndSplit(any(), any(), any(), any(), any(), any(), any());
}

public void testRegisterMLModel_IndexModelChunkFailure() throws IOException {
Expand All @@ -323,7 +323,7 @@ public void testRegisterMLModel_IndexModelChunkFailure() throws IOException {
modelManager.registerMLModel(registerModelInput, mlTask);
verify(mlIndicesHandler).initModelIndexIfAbsent(any());
verify(client, times(2)).index(any(), any());
verify(modelHelper).downloadAndSplit(any(), any(), any(), any(), any(), any());
verify(modelHelper).downloadAndSplit(any(), any(), any(), any(), any(), any(), any());
}

public void testRegisterMLModel_DownloadModelFileFailure() {
Expand All @@ -337,7 +337,7 @@ public void testRegisterMLModel_DownloadModelFileFailure() {
modelManager.registerMLModel(registerModelInput, mlTask);
verify(mlIndicesHandler).initModelIndexIfAbsent(any());
verify(client).index(any(), any());
verify(modelHelper).downloadAndSplit(eq(modelFormat), eq(modelId), eq(modelName), eq(version), eq(url), any());
verify(modelHelper).downloadAndSplit(eq(modelFormat), eq(modelId), eq(modelName), eq(version), eq(url), any(), any());
}

public void testRegisterMLModel_DownloadModelFile() throws IOException {
Expand All @@ -352,7 +352,7 @@ public void testRegisterMLModel_DownloadModelFile() throws IOException {
modelManager.registerMLModel(registerModelInput, mlTask);
verify(mlIndicesHandler).initModelIndexIfAbsent(any());
verify(client, times(3)).index(any(), any());
verify(modelHelper).downloadAndSplit(eq(modelFormat), eq(modelId), eq(modelName), eq(version), eq(url), any());
verify(modelHelper).downloadAndSplit(eq(modelFormat), eq(modelId), eq(modelName), eq(version), eq(url), any(), any());
}

public void testRegisterMLModel_DeployModel() throws IOException {
Expand All @@ -369,7 +369,7 @@ public void testRegisterMLModel_DeployModel() throws IOException {
modelManager.registerMLModel(mlRegisterModelInput, mlTask);
verify(mlIndicesHandler).initModelIndexIfAbsent(any());
verify(client, times(3)).index(any(), any());
verify(modelHelper).downloadAndSplit(eq(modelFormat), eq(modelId), eq(modelName), eq(version), eq(url), any());
verify(modelHelper).downloadAndSplit(eq(modelFormat), eq(modelId), eq(modelName), eq(version), eq(url), any(), any());
verify(client).execute(eq(MLDeployModelAction.INSTANCE), any(), any());
}

Expand All @@ -387,7 +387,7 @@ public void testRegisterMLModel_DeployModel_failure() throws IOException {
modelManager.registerMLModel(mlRegisterModelInput, mlTask);
verify(mlIndicesHandler).initModelIndexIfAbsent(any());
verify(client, times(3)).index(any(), any());
verify(modelHelper).downloadAndSplit(eq(modelFormat), eq(modelId), eq(modelName), eq(version), eq(url), any());
verify(modelHelper).downloadAndSplit(eq(modelFormat), eq(modelId), eq(modelName), eq(version), eq(url), any(), any());
verify(client, never()).execute(eq(MLDeployModelAction.INSTANCE), any(), any());
}

Expand All @@ -403,7 +403,7 @@ public void testRegisterMLModel_DownloadModelFile_ModelFileSizeExceedLimit() thr
modelManager.registerMLModel(registerModelInput, mlTask);
verify(mlIndicesHandler).initModelIndexIfAbsent(any());
verify(client, times(1)).index(any(), any());
verify(modelHelper).downloadAndSplit(eq(modelFormat), eq(modelId), eq(modelName), eq(version), eq(url), any());
verify(modelHelper).downloadAndSplit(eq(modelFormat), eq(modelId), eq(modelName), eq(version), eq(url), any(), any());
}

public void testRegisterModel_ClientFailedToGetThreadPool() {
Expand Down Expand Up @@ -759,22 +759,22 @@ private void setUpMock_GetModelMeta_FailedToGetLastChunk(MLModel model) {

private void setUpMock_DownloadModelFileFailure() {
doAnswer(invocation -> {
ActionListener<Map<String, Object>> listener = invocation.getArgument(5);
ActionListener<Map<String, Object>> listener = invocation.getArgument(6);
listener.onFailure(new RuntimeException("downloadAndSplit failure"));
return null;
}).when(modelHelper).downloadAndSplit(any(), any(), any(), any(), any(), any());
}).when(modelHelper).downloadAndSplit(any(), any(), any(), any(), any(), any(), any());
}

private void setUpMock_DownloadModelFile(String[] chunks, Long modelContentSize) {
doAnswer(invocation -> {
ActionListener<Map<String, Object>> listener = invocation.getArgument(5);
ActionListener<Map<String, Object>> listener = invocation.getArgument(6);
Map<String, Object> result = new HashMap<>();
result.put(MODEL_SIZE_IN_BYTES, modelContentSize);
result.put(CHUNK_FILES, Arrays.asList(chunks[0], chunks[1]));
result.put(MODEL_FILE_HASH, randomAlphaOfLength(10));
listener.onResponse(result);
return null;
}).when(modelHelper).downloadAndSplit(any(), any(), any(), any(), any(), any());
}).when(modelHelper).downloadAndSplit(any(), any(), any(), any(), any(), any(), any());
}

@Mock
Expand Down

0 comments on commit e26a605

Please sign in to comment.