From 6c01110f4fcfc73744d22026eb365c6244de4fef Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Tue, 2 May 2023 20:32:16 -0700 Subject: [PATCH] check hash value (#878) Signed-off-by: Jing Zhang --- .../register/MLRegisterModelInput.java | 21 ++++++++++++++-- .../org/opensearch/ml/engine/ModelHelper.java | 11 +++++++- .../text_embedding/ModelHelperTest.java | 25 +++++++++++++++++-- .../opensearch/ml/model/MLModelManager.java | 1 + .../ml/model/MLModelManagerTests.java | 24 +++++++++--------- 5 files changed, 65 insertions(+), 17 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java index 1700807ab2..94c51f24e6 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java @@ -37,6 +37,7 @@ public class MLRegisterModelInput implements ToXContentObject, Writeable { public static final String DESCRIPTION_FIELD = "description"; public static final String VERSION_FIELD = "version"; public static final String URL_FIELD = "url"; + public static final String HASH_VALUE_FIELD = "model_content_hash_value"; public static final String MODEL_FORMAT_FIELD = "model_format"; public static final String MODEL_CONFIG_FIELD = "model_config"; public static final String DEPLOY_MODEL_FIELD = "deploy_model"; @@ -47,6 +48,7 @@ public class MLRegisterModelInput implements ToXContentObject, Writeable { private String version; private String description; private String url; + private String hashValue; private MLModelFormat modelFormat; private MLModelConfig modelConfig; @@ -59,6 +61,7 @@ public MLRegisterModelInput(FunctionName functionName, String version, String description, String url, + String hashValue, MLModelFormat modelFormat, MLModelConfig modelConfig, boolean deployModel, @@ -84,6 +87,7 @@ public MLRegisterModelInput(FunctionName functionName, this.version = version; this.description = description; this.url = url; + this.hashValue = hashValue; this.modelFormat = modelFormat; this.modelConfig = modelConfig; this.deployModel = deployModel; @@ -97,6 +101,7 @@ public MLRegisterModelInput(StreamInput in) throws IOException { this.version = in.readString(); this.description = in.readOptionalString(); this.url = in.readOptionalString(); + this.hashValue = in.readOptionalString(); if (in.readBoolean()) { this.modelFormat = in.readEnum(MLModelFormat.class); } @@ -114,6 +119,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(version); out.writeOptionalString(description); out.writeOptionalString(url); + out.writeOptionalString(hashValue); if (modelFormat != null) { out.writeBoolean(true); out.writeEnum(modelFormat); @@ -142,6 +148,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (url != null) { builder.field(URL_FIELD, url); } + if (hashValue != null) { + builder.field(HASH_VALUE_FIELD, hashValue); + } if (modelFormat != null) { builder.field(MODEL_FORMAT_FIELD, modelFormat); } @@ -159,6 +168,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws public static MLRegisterModelInput parse(XContentParser parser, String modelName, String version, boolean deployModel) throws IOException { FunctionName functionName = null; String url = null; + String hashValue = null; String description = null; MLModelFormat modelFormat = null; MLModelConfig modelConfig = null; @@ -175,6 +185,9 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName case URL_FIELD: url = parser.text(); break; + case HASH_VALUE_FIELD: + hashValue = parser.text(); + break; case DESCRIPTION_FIELD: description = parser.text(); break; @@ -195,7 +208,7 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName break; } } - return new MLRegisterModelInput(functionName, modelName, version, description, url, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0])); + return new MLRegisterModelInput(functionName, modelName, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0])); } public static MLRegisterModelInput parse(XContentParser parser, boolean deployModel) throws IOException { @@ -203,6 +216,7 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo String name = null; String version = null; String url = null; + String hashValue = null; String description = null; MLModelFormat modelFormat = null; MLModelConfig modelConfig = null; @@ -229,6 +243,9 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo case URL_FIELD: url = parser.text(); break; + case HASH_VALUE_FIELD: + hashValue = parser.text(); + break; case MODEL_FORMAT_FIELD: modelFormat = MLModelFormat.from(parser.text().toUpperCase(Locale.ROOT)); break; @@ -246,6 +263,6 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo break; } } - return new MLRegisterModelInput(functionName, name, version, description, url, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0])); + return new MLRegisterModelInput(functionName, name, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0])); } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java index a19db4923a..446c91fc04 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java @@ -120,6 +120,9 @@ public void downloadPrebuiltModelConfig(String taskId, MLRegisterModelInput regi } builder.modelConfig(configBuilder.build()); break; + case MLRegisterModelInput.HASH_VALUE_FIELD: + builder.hashValue(entry.getValue().toString()); + break; default: break; } @@ -141,9 +144,10 @@ public void downloadPrebuiltModelConfig(String taskId, MLRegisterModelInput regi * @param modelName model name * @param version model version * @param url model file URL + * @param modelContentHash model content hash value * @param listener action listener */ - public void downloadAndSplit(MLModelFormat modelFormat, String taskId, String modelName, String version, String url, ActionListener> listener) { + public void downloadAndSplit(MLModelFormat modelFormat, String taskId, String modelName, String version, String url, String modelContentHash, ActionListener> listener) { try { AccessController.doPrivileged((PrivilegedExceptionAction) () -> { Path registerModelPath = mlEngine.getRegisterModelPath(taskId, modelName, version); @@ -153,6 +157,11 @@ public void downloadAndSplit(MLModelFormat modelFormat, String taskId, String mo log.debug("download model to file {}", modelZipFile.getAbsolutePath()); DownloadUtils.download(url, modelPath, new ProgressBar()); verifyModelZipFile(modelFormat, modelPath, modelName); + String hash = calculateFileHash(modelZipFile); + if (modelContentHash != null && !modelContentHash.equals(hash)) { + log.error("Model content hash can't match original hash value when registering"); + throw (new IllegalArgumentException("model content changed")); + } List chunkFiles = splitFileIntoChunks(modelZipFile, modelPartsPath, CHUNK_SIZE); Map result = new HashMap<>(); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/ModelHelperTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/ModelHelperTest.java index 2d14c51878..8ac420fc2c 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/ModelHelperTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/ModelHelperTest.java @@ -39,6 +39,7 @@ public class ModelHelperTest { private MLModelFormat modelFormat; private String modelId; private MLEngine mlEngine; + private String hashValue = "e13b74006290a9d0f58c1376f9629d4ebc05a0f9385f40db837452b167ae9021"; @Mock ActionListener> actionListener; @@ -57,7 +58,8 @@ public void setup() throws URISyntaxException { @Test public void testDownloadAndSplit_UrlFailure() { - modelHelper.downloadAndSplit(modelFormat, modelId, "model_name", "1", "http://testurl", actionListener); + modelId = "url_failure_model_id"; + modelHelper.downloadAndSplit(modelFormat, modelId, "model_name", "1", "http://testurl", null, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals(PrivilegedActionException.class, argumentCaptor.getValue().getClass()); @@ -66,7 +68,26 @@ public void testDownloadAndSplit_UrlFailure() { @Test public void testDownloadAndSplit() throws URISyntaxException { String modelUrl = getClass().getResource("traced_small_model.zip").toURI().toString(); - modelHelper.downloadAndSplit(modelFormat, modelId, "model_name", "1", modelUrl, actionListener); + modelHelper.downloadAndSplit(modelFormat, modelId, "model_name", "1", modelUrl, null, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class); + verify(actionListener).onResponse(argumentCaptor.capture()); + assertNotNull(argumentCaptor.getValue()); + assertNotEquals(0, argumentCaptor.getValue().size()); + } + + @Test + public void testDownloadAndSplit_HashFailure() throws URISyntaxException { + String modelUrl = getClass().getResource("traced_small_model.zip").toURI().toString(); + modelHelper.downloadAndSplit(modelFormat, modelId, "model_name", "1", modelUrl, "wrong_hash_value", actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals(IllegalArgumentException.class, argumentCaptor.getValue().getClass()); + } + + @Test + public void testDownloadAndSplit_Hash() throws URISyntaxException { + String modelUrl = getClass().getResource("traced_small_model.zip").toURI().toString(); + modelHelper.downloadAndSplit(modelFormat, modelId, "model_name", "1", modelUrl, hashValue, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class); verify(actionListener).onResponse(argumentCaptor.capture()); assertNotNull(argumentCaptor.getValue()); diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index e5465c06fb..4ff7ff5e0b 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -330,6 +330,7 @@ private void registerModel( modelName, version, registerModelInput.getUrl(), + registerModelInput.getHashValue(), ActionListener.wrap(result -> { Long modelSizeInBytes = (Long) result.get(MODEL_SIZE_IN_BYTES); if (modelSizeInBytes >= MODEL_FILE_SIZE_LIMIT) { diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java index 4d563bff42..d40803c859 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java @@ -293,7 +293,7 @@ public void testRegisterMLModel_InitModelIndexFailure() { modelManager.registerMLModel(registerModelInput, mlTask); verify(mlTaskManager).updateMLTask(anyString(), anyMap(), anyLong(), anyBoolean()); - verify(modelHelper, never()).downloadAndSplit(any(), any(), any(), any(), any(), any()); + verify(modelHelper, never()).downloadAndSplit(any(), any(), any(), any(), any(), any(), any()); verify(client, never()).index(any(), any()); } @@ -308,7 +308,7 @@ public void testRegisterMLModel_IndexModelMetaFailure() { modelManager.registerMLModel(registerModelInput, mlTask); verify(mlIndicesHandler).initModelIndexIfAbsent(any()); verify(client).index(any(), any()); - verify(modelHelper, never()).downloadAndSplit(any(), any(), any(), any(), any(), any()); + verify(modelHelper, never()).downloadAndSplit(any(), any(), any(), any(), any(), any(), any()); } public void testRegisterMLModel_IndexModelChunkFailure() throws IOException { @@ -323,7 +323,7 @@ public void testRegisterMLModel_IndexModelChunkFailure() throws IOException { modelManager.registerMLModel(registerModelInput, mlTask); verify(mlIndicesHandler).initModelIndexIfAbsent(any()); verify(client, times(2)).index(any(), any()); - verify(modelHelper).downloadAndSplit(any(), any(), any(), any(), any(), any()); + verify(modelHelper).downloadAndSplit(any(), any(), any(), any(), any(), any(), any()); } public void testRegisterMLModel_DownloadModelFileFailure() { @@ -337,7 +337,7 @@ public void testRegisterMLModel_DownloadModelFileFailure() { modelManager.registerMLModel(registerModelInput, mlTask); verify(mlIndicesHandler).initModelIndexIfAbsent(any()); verify(client).index(any(), any()); - verify(modelHelper).downloadAndSplit(eq(modelFormat), eq(modelId), eq(modelName), eq(version), eq(url), any()); + verify(modelHelper).downloadAndSplit(eq(modelFormat), eq(modelId), eq(modelName), eq(version), eq(url), any(), any()); } public void testRegisterMLModel_DownloadModelFile() throws IOException { @@ -352,7 +352,7 @@ public void testRegisterMLModel_DownloadModelFile() throws IOException { modelManager.registerMLModel(registerModelInput, mlTask); verify(mlIndicesHandler).initModelIndexIfAbsent(any()); verify(client, times(3)).index(any(), any()); - verify(modelHelper).downloadAndSplit(eq(modelFormat), eq(modelId), eq(modelName), eq(version), eq(url), any()); + verify(modelHelper).downloadAndSplit(eq(modelFormat), eq(modelId), eq(modelName), eq(version), eq(url), any(), any()); } public void testRegisterMLModel_DeployModel() throws IOException { @@ -369,7 +369,7 @@ public void testRegisterMLModel_DeployModel() throws IOException { modelManager.registerMLModel(mlRegisterModelInput, mlTask); verify(mlIndicesHandler).initModelIndexIfAbsent(any()); verify(client, times(3)).index(any(), any()); - verify(modelHelper).downloadAndSplit(eq(modelFormat), eq(modelId), eq(modelName), eq(version), eq(url), any()); + verify(modelHelper).downloadAndSplit(eq(modelFormat), eq(modelId), eq(modelName), eq(version), eq(url), any(), any()); verify(client).execute(eq(MLDeployModelAction.INSTANCE), any(), any()); } @@ -387,7 +387,7 @@ public void testRegisterMLModel_DeployModel_failure() throws IOException { modelManager.registerMLModel(mlRegisterModelInput, mlTask); verify(mlIndicesHandler).initModelIndexIfAbsent(any()); verify(client, times(3)).index(any(), any()); - verify(modelHelper).downloadAndSplit(eq(modelFormat), eq(modelId), eq(modelName), eq(version), eq(url), any()); + verify(modelHelper).downloadAndSplit(eq(modelFormat), eq(modelId), eq(modelName), eq(version), eq(url), any(), any()); verify(client, never()).execute(eq(MLDeployModelAction.INSTANCE), any(), any()); } @@ -403,7 +403,7 @@ public void testRegisterMLModel_DownloadModelFile_ModelFileSizeExceedLimit() thr modelManager.registerMLModel(registerModelInput, mlTask); verify(mlIndicesHandler).initModelIndexIfAbsent(any()); verify(client, times(1)).index(any(), any()); - verify(modelHelper).downloadAndSplit(eq(modelFormat), eq(modelId), eq(modelName), eq(version), eq(url), any()); + verify(modelHelper).downloadAndSplit(eq(modelFormat), eq(modelId), eq(modelName), eq(version), eq(url), any(), any()); } public void testRegisterModel_ClientFailedToGetThreadPool() { @@ -759,22 +759,22 @@ private void setUpMock_GetModelMeta_FailedToGetLastChunk(MLModel model) { private void setUpMock_DownloadModelFileFailure() { doAnswer(invocation -> { - ActionListener> listener = invocation.getArgument(5); + ActionListener> listener = invocation.getArgument(6); listener.onFailure(new RuntimeException("downloadAndSplit failure")); return null; - }).when(modelHelper).downloadAndSplit(any(), any(), any(), any(), any(), any()); + }).when(modelHelper).downloadAndSplit(any(), any(), any(), any(), any(), any(), any()); } private void setUpMock_DownloadModelFile(String[] chunks, Long modelContentSize) { doAnswer(invocation -> { - ActionListener> listener = invocation.getArgument(5); + ActionListener> listener = invocation.getArgument(6); Map result = new HashMap<>(); result.put(MODEL_SIZE_IN_BYTES, modelContentSize); result.put(CHUNK_FILES, Arrays.asList(chunks[0], chunks[1])); result.put(MODEL_FILE_HASH, randomAlphaOfLength(10)); listener.onResponse(result); return null; - }).when(modelHelper).downloadAndSplit(any(), any(), any(), any(), any(), any()); + }).when(modelHelper).downloadAndSplit(any(), any(), any(), any(), any(), any(), any()); } @Mock