Skip to content

Commit

Permalink
init master key automatically (#1075)
Browse files Browse the repository at this point in the history
* init master key automatically

Signed-off-by: Yaliang Wu <ylwu@amazon.com>

* remove unnecessary escape

Signed-off-by: Yaliang Wu <ylwu@amazon.com>

* fix failed ut

Signed-off-by: Yaliang Wu <ylwu@amazon.com>

* tune syncup jot interval

Signed-off-by: Yaliang Wu <ylwu@amazon.com>

* tune syncup jot interval

Signed-off-by: Yaliang Wu <ylwu@amazon.com>

* remove local config file code

Signed-off-by: Yaliang Wu <ylwu@amazon.com>

* set master key when init remote model

Signed-off-by: Yaliang Wu <ylwu@amazon.com>

* move init master key to encryptor

Signed-off-by: Yaliang Wu <ylwu@amazon.com>

* fine tune code

Signed-off-by: Yaliang Wu <ylwu@amazon.com>

* fine tune code

Signed-off-by: Yaliang Wu <ylwu@amazon.com>

---------

Signed-off-by: Yaliang Wu <ylwu@amazon.com>
  • Loading branch information
ylwu-amzn authored and zane-neo committed Sep 1, 2023
1 parent 275cb05 commit 1b4a9f6
Show file tree
Hide file tree
Showing 23 changed files with 415 additions and 65 deletions.
20 changes: 20 additions & 0 deletions common/src/main/java/org/opensearch/ml/common/CommonValue.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ public class CommonValue {
public static final String UNDEPLOYED = "undeployed";
public static final String NOT_FOUND = "not_found";

public static final String MASTER_KEY = "master_key";
public static final String CREATE_TIME_FIELD = "create_time";

public static final String BOX_TYPE_KEY = "box_type";
//hot node
public static String HOT_BOX_TYPE = "hot";
Expand All @@ -37,6 +40,8 @@ public class CommonValue {
public static final String ML_CONNECTOR_INDEX = ".plugins-ml-connector";
public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 1;
public static final Integer ML_CONNECTOR_SCHEMA_VERSION = 1;
public static final String ML_CONFIG_INDEX = ".plugins-ml-config";
public static final Integer ML_CONFIG_INDEX_SCHEMA_VERSION = 1;
public static final String USER_FIELD_MAPPING = " \""
+ CommonValue.USER
+ "\": {\n"
Expand Down Expand Up @@ -301,4 +306,19 @@ public class CommonValue {
+ "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n"
+ " }\n"
+ "}";


public static final String ML_CONFIG_INDEX_MAPPING = "{\n"
+ " \"_meta\": {\"schema_version\": "
+ ML_CONFIG_INDEX_SCHEMA_VERSION
+ "},\n"
+ " \"properties\": {\n"
+ " \""
+ MASTER_KEY
+ "\": {\"type\": \"keyword\"},\n"
+ " \""
+ CREATE_TIME_FIELD
+ "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n"
+ " }\n"
+ "}";
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.engine;

import lombok.Getter;
import lombok.extern.log4j.Log4j2;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.dataframe.DataFrame;
Expand All @@ -18,29 +19,33 @@
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.Output;
import org.opensearch.ml.engine.encryptor.Encryptor;

import java.nio.file.Path;
import java.util.Locale;
import java.util.Map;

/**
* This is the interface to all ml algorithms.
*/
@Log4j2
public class MLEngine {

public static final String REGISTER_MODEL_FOLDER = "register";
public static final String DEPLOY_MODEL_FOLDER = "deploy";
private final String MODEL_REPO = "https://artifacts.opensearch.org/models/ml-models";

@Getter
private final Path mlConfigPath;

@Getter
private final Path mlCachePath;
private final Path mlModelsCachePath;

private final Encryptor encryptor;
private Encryptor encryptor;

public MLEngine(Path opensearchDataFolder, Encryptor encryptor) {
mlCachePath = opensearchDataFolder.resolve("ml_cache");
mlModelsCachePath = mlCachePath.resolve("models_cache");
this.mlCachePath = opensearchDataFolder.resolve("ml_cache");
this.mlModelsCachePath = mlCachePath.resolve("models_cache");
this.mlConfigPath = mlCachePath.resolve("config");
this.encryptor = encryptor;
}

Expand Down Expand Up @@ -195,7 +200,4 @@ public String encrypt(String credential) {
return encryptor.encrypt(credential);
}

public void setMasterKey(String masterKey) {
encryptor.setMasterKey(masterKey);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,6 @@ public static RemoteInferenceInputDataSet processInput(MLInput mlInput, Connecto
} else {
throw new IllegalArgumentException("Wrong input type");
}
Map<String, String> escapedParameters = new HashMap<>();
inputData.getParameters().entrySet().forEach(entry -> {
escapedParameters.put(entry.getKey(), escapeJava(entry.getValue()));
});
inputData.setParameters(escapedParameters);
return inputData;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

package org.opensearch.ml.engine.encryptor;

import java.security.SecureRandom;
import java.util.Base64;

public interface Encryptor {

/**
Expand All @@ -29,4 +32,8 @@ public interface Encryptor {
* @param masterKey masterKey to be set.
*/
void setMasterKey(String masterKey);
String getMasterKey();

String generateMasterKey();

}
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,39 @@
import com.amazonaws.encryptionsdk.CommitmentPolicy;
import com.amazonaws.encryptionsdk.CryptoResult;
import com.amazonaws.encryptionsdk.jce.JceMasterKey;
import org.opensearch.ml.engine.exceptions.MetaDataException;
import lombok.extern.log4j.Log4j2;
import org.opensearch.ResourceNotFoundException;
import org.opensearch.action.ActionListener;
import org.opensearch.action.LatchedActionListener;
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.get.GetResponse;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.ml.common.exception.MLException;

import javax.crypto.spec.SecretKeySpec;
import java.nio.charset.StandardCharsets;
import java.security.SecureRandom;
import java.util.Base64;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicReference;

import static org.opensearch.ml.common.CommonValue.MASTER_KEY;
import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX;

@Log4j2
public class EncryptorImpl implements Encryptor {

private ClusterService clusterService;
private Client client;
private volatile String masterKey;

public EncryptorImpl(ClusterService clusterService, Client client) {
this.masterKey = null;
this.clusterService = clusterService;
this.client = client;
}

public EncryptorImpl(String masterKey) {
this.masterKey = masterKey;
}
Expand All @@ -28,9 +51,14 @@ public void setMasterKey(String masterKey) {
this.masterKey = masterKey;
}

@Override
public String getMasterKey() {
return masterKey;
}

@Override
public String encrypt(String plainText) {
checkMasterKey();
initMasterKey();
final AwsCrypto crypto = AwsCrypto.builder()
.withCommitmentPolicy(CommitmentPolicy.RequireEncryptRequireDecrypt)
.build();
Expand All @@ -46,7 +74,7 @@ public String encrypt(String plainText) {

@Override
public String decrypt(String encryptedText) {
checkMasterKey();
initMasterKey();
final AwsCrypto crypto = AwsCrypto.builder()
.withCommitmentPolicy(CommitmentPolicy.RequireEncryptRequireDecrypt)
.build();
Expand All @@ -60,14 +88,45 @@ public String decrypt(String encryptedText) {
return new String(decryptedResult.getResult());
}

private void checkMasterKey() {
if (masterKey == "0000000000000000" || masterKey == null) {
throw new MetaDataException("Please provide a masterKey for credential encryption! Example: PUT /_cluster/settings\n" +
"{\n" +
" \"persistent\" : {\n" +
" \"plugins.ml_commons.encryption.master_key\" : \"1234567x\" \n" +
" }\n" +
"}");
@Override
public String generateMasterKey() {
byte[] keyBytes = new byte[16];
new SecureRandom().nextBytes(keyBytes);
String base64Key = Base64.getEncoder().encodeToString(keyBytes);
return base64Key;
}

private void initMasterKey() {
if (masterKey != null) {
return;
}
AtomicReference<Exception> exceptionRef = new AtomicReference<>();

CountDownLatch latch = new CountDownLatch(1);
if (clusterService.state().metadata().hasIndex(ML_CONFIG_INDEX)) {
GetRequest getRequest = new GetRequest(ML_CONFIG_INDEX).id(MASTER_KEY);
client.get(getRequest, new LatchedActionListener(ActionListener.<GetResponse>wrap(r -> {
if (r.isExists()) {
String masterKey = (String) r.getSourceAsMap().get(MASTER_KEY);
setMasterKey(masterKey);
} else {
exceptionRef.set(new ResourceNotFoundException("ML encryption master key not initialized yet"));
}
}, e -> {
log.error("Failed to get ML encryption master key", e);
exceptionRef.set(e);
}), latch));
} else {
exceptionRef.set(new ResourceNotFoundException("ML encryption master key not initialized yet"));
}

if (exceptionRef.get() != null) {
log.debug("Failed to init master key", exceptionRef.get());
if (exceptionRef.get() instanceof RuntimeException) {
throw (RuntimeException) exceptionRef.get();
} else {
throw new MLException(exceptionRef.get());
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import org.opensearch.ml.engine.MLEngine;
import org.opensearch.ml.engine.ModelHelper;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.search.aggregations.InternalAggregations;
Expand Down Expand Up @@ -128,7 +129,8 @@ public class MetricsCorrelationTest {
ActionListener<MLDeployModelResponse> mlDeployModelResponseActionListener;
private MetricsCorrelation metricsCorrelation;
private MetricsCorrelationInput input, extendedInput;
private Path djlCachePath;
private Path mlCachePath;
private Path mlConfigPath;
private MLModel model;

private MetricsCorrelationModelConfig modelConfig;
Expand All @@ -144,7 +146,6 @@ public class MetricsCorrelationTest {

Map<String, Object> params = new HashMap<>();

@Mock
private Encryptor encryptor;

public MetricsCorrelationTest() {
Expand All @@ -155,8 +156,9 @@ public void setUp() throws IOException, URISyntaxException {

System.setProperty("testMode", "true");

djlCachePath = Path.of("/tmp/djl_cache_" + UUID.randomUUID());
mlEngine = new MLEngine(djlCachePath, encryptor);
mlCachePath = Path.of("/tmp/djl_cache_" + UUID.randomUUID());
encryptor = new EncryptorImpl("0000000000000001");
mlEngine = new MLEngine(mlCachePath, encryptor);
modelConfig = MetricsCorrelationModelConfig.builder()
.modelType(MetricsCorrelation.MODEL_TYPE)
.allConfig(null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public void processInput_TextDocsInputDataSet_PreprocessFunction_MultiTextDoc()
processInput_TextDocsInputDataSet_PreprocessFunction(
"{\"input\": ${parameters.input}}",
"{\"parameters\": { \"input\": [\"test_value1\", \"test_value2\"] } }",
"[\\\"test_value1\\\",\\\"test_value2\\\"]");
"[\"test_value1\",\"test_value2\"]");
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.engine.MLEngine;
import org.opensearch.ml.engine.ModelHelper;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.ml.engine.encryptor.EncryptorImpl;

import java.io.IOException;
import java.net.URISyntaxException;
Expand Down Expand Up @@ -50,12 +52,15 @@ public class ModelHelperTest {
@Mock
ActionListener<MLRegisterModelInput> registerModelListener;

Encryptor encryptor;

@Before
public void setup() throws URISyntaxException {
MockitoAnnotations.openMocks(this);
modelFormat = MLModelFormat.TORCH_SCRIPT;
modelId = "model_id";
mlEngine = new MLEngine(Path.of("/tmp/test" + modelId), null);
encryptor = new EncryptorImpl("0000000000000001");
mlEngine = new MLEngine(Path.of("/tmp/test" + modelId), encryptor);
modelHelper = new ModelHelper(mlEngine);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.opensearch.ml.engine.MLEngine;
import org.opensearch.ml.engine.ModelHelper;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
import org.opensearch.ml.engine.utils.FileUtils;

import java.io.File;
Expand Down Expand Up @@ -62,16 +63,18 @@ public class TextEmbeddingModelTest {
private ModelHelper modelHelper;
private Map<String, Object> params;
private TextEmbeddingModel textEmbeddingModel;
private Path djlCachePath;
private Path mlCachePath;
private Path mlConfigPath;
private TextDocsInputDataSet inputDataSet;
private int dimension = 384;
private MLEngine mlEngine;
private Encryptor encryptor;

@Before
public void setUp() throws URISyntaxException {
djlCachePath = Path.of("/tmp/djl_cache_" + UUID.randomUUID());
mlEngine = new MLEngine(djlCachePath, encryptor);
mlCachePath = Path.of("/tmp/ml_cache" + UUID.randomUUID());
encryptor = new EncryptorImpl("0000000000000001");
mlEngine = new MLEngine(mlCachePath, encryptor);
modelId = "test_model_id";
modelName = "test_model_name";
functionName = FunctionName.TEXT_EMBEDDING;
Expand Down Expand Up @@ -329,7 +332,7 @@ public void predict_BeforeInitingModel() {

@After
public void tearDown() {
FileUtils.deleteFileQuietly(djlCachePath);
FileUtils.deleteFileQuietly(mlCachePath);
}

private int findSentenceEmbeddingPosition(ModelTensors modelTensors) {
Expand Down
Loading

0 comments on commit 1b4a9f6

Please sign in to comment.