-
Notifications
You must be signed in to change notification settings - Fork 138
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
429 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
102 changes: 102 additions & 0 deletions
102
common/src/main/java/org/opensearch/ml/common/parameter/AnomalyDetectionParams.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* The OpenSearch Contributors require contributions made to | ||
* this file be licensed under the Apache-2.0 license or a | ||
* compatible open source license. | ||
* | ||
* Modifications Copyright OpenSearch Contributors. See | ||
* GitHub history for details. | ||
*/ | ||
|
||
package org.opensearch.ml.common.parameter; | ||
|
||
import lombok.Builder; | ||
import lombok.Data; | ||
import org.opensearch.common.ParseField; | ||
import org.opensearch.common.io.stream.StreamInput; | ||
import org.opensearch.common.io.stream.StreamOutput; | ||
import org.opensearch.common.xcontent.NamedXContentRegistry; | ||
import org.opensearch.common.xcontent.XContentBuilder; | ||
import org.opensearch.common.xcontent.XContentParser; | ||
import org.opensearch.ml.common.annotation.MLAlgoParameter; | ||
|
||
import java.io.IOException; | ||
|
||
import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; | ||
|
||
@Data | ||
@MLAlgoParameter(algorithms={FunctionName.ANOMALY_DETECTION}) | ||
public class AnomalyDetectionParams implements MLAlgoParams { | ||
public static final String PARSE_FIELD_NAME = FunctionName.ANOMALY_DETECTION.name(); | ||
public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( | ||
MLAlgoParams.class, | ||
new ParseField(PARSE_FIELD_NAME), | ||
it -> parse(it) | ||
); | ||
|
||
public static final String GAMMA_FIELD = "gamma"; | ||
public static final String NU_FIELD = "nu"; | ||
private Double gamma; | ||
private Double nu; | ||
|
||
@Builder | ||
public AnomalyDetectionParams(Double gamma, Double nu) { | ||
this.gamma = gamma; | ||
this.nu = nu; | ||
} | ||
|
||
public AnomalyDetectionParams(StreamInput in) throws IOException { | ||
this.gamma = in.readOptionalDouble(); | ||
this.nu = in.readOptionalDouble(); | ||
} | ||
|
||
private static MLAlgoParams parse(XContentParser parser) throws IOException { | ||
Double gamma = null; | ||
Double nu = null; | ||
|
||
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); | ||
while (parser.nextToken() != XContentParser.Token.END_OBJECT) { | ||
String fieldName = parser.currentName(); | ||
parser.nextToken(); | ||
|
||
switch (fieldName) { | ||
case GAMMA_FIELD: | ||
gamma = parser.doubleValue(); | ||
break; | ||
case NU_FIELD: | ||
nu = parser.doubleValue(); | ||
break; | ||
default: | ||
parser.skipChildren(); | ||
break; | ||
} | ||
} | ||
return new AnomalyDetectionParams(gamma, nu); | ||
} | ||
|
||
@Override | ||
public String getWriteableName() { | ||
return PARSE_FIELD_NAME; | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
out.writeOptionalDouble(gamma); | ||
out.writeOptionalDouble(nu); | ||
} | ||
|
||
@Override | ||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { | ||
builder.startObject(); | ||
builder.field(GAMMA_FIELD, gamma); | ||
builder.field(NU_FIELD, nu); | ||
builder.endObject(); | ||
return builder; | ||
} | ||
|
||
@Override | ||
public int getVersion() { | ||
return 1; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
127 changes: 127 additions & 0 deletions
127
ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/ad/AnomalyDetection.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* The OpenSearch Contributors require contributions made to | ||
* this file be licensed under the Apache-2.0 license or a | ||
* compatible open source license. | ||
* | ||
* Modifications Copyright OpenSearch Contributors. See | ||
* GitHub history for details. | ||
* | ||
*/ | ||
|
||
package org.opensearch.ml.engine.algorithms.ad; | ||
|
||
import org.opensearch.ml.common.dataframe.DataFrame; | ||
import org.opensearch.ml.common.dataframe.DataFrameBuilder; | ||
import org.opensearch.ml.common.parameter.FunctionName; | ||
import org.opensearch.ml.common.parameter.AnomalyDetectionParams; | ||
import org.opensearch.ml.common.parameter.MLAlgoParams; | ||
import org.opensearch.ml.common.parameter.MLOutput; | ||
import org.opensearch.ml.common.parameter.MLPredictionOutput; | ||
import org.opensearch.ml.engine.Model; | ||
import org.opensearch.ml.engine.Predictable; | ||
import org.opensearch.ml.engine.Trainable; | ||
import org.opensearch.ml.engine.annotation.Function; | ||
import org.opensearch.ml.engine.contants.TribuoOutputType; | ||
import org.opensearch.ml.engine.utils.ModelSerDeSer; | ||
import org.opensearch.ml.engine.utils.TribuoUtil; | ||
import org.tribuo.Example; | ||
import org.tribuo.Feature; | ||
import org.tribuo.MutableDataset; | ||
import org.tribuo.Prediction; | ||
import org.tribuo.anomaly.AnomalyFactory; | ||
import org.tribuo.anomaly.Event; | ||
import org.tribuo.anomaly.libsvm.LibSVMAnomalyModel; | ||
import org.tribuo.anomaly.libsvm.LibSVMAnomalyTrainer; | ||
import org.tribuo.anomaly.libsvm.SVMAnomalyType; | ||
import org.tribuo.common.libsvm.KernelType; | ||
import org.tribuo.common.libsvm.LibSVMModel; | ||
import org.tribuo.common.libsvm.SVMParameters; | ||
|
||
import java.util.ArrayList; | ||
import java.util.HashMap; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.Optional; | ||
|
||
/** | ||
* Wrap Tribuo's anomaly detection based on one-class SVM (libSVM). | ||
* | ||
*/ | ||
@Function(FunctionName.ANOMALY_DETECTION) | ||
public class AnomalyDetection implements Trainable, Predictable { | ||
public static final int VERSION = 1; | ||
private static double DEFAULT_GAMMA = 1.0; | ||
private static double DEFAULT_NU = 0.1; | ||
|
||
private AnomalyDetectionParams parameters; | ||
|
||
public AnomalyDetection() {} | ||
|
||
public AnomalyDetection(MLAlgoParams parameters) { | ||
this.parameters = parameters == null ? AnomalyDetectionParams.builder().build() : (AnomalyDetectionParams)parameters; | ||
validateParameters(); | ||
} | ||
|
||
private void validateParameters() { | ||
|
||
if (parameters.getGamma() != null && parameters.getGamma() <= 0) { | ||
throw new IllegalArgumentException("gamma should be positive."); | ||
} | ||
|
||
if (parameters.getNu() != null && parameters.getNu() <= 0) { | ||
throw new IllegalArgumentException("nu should be positive."); | ||
} | ||
|
||
} | ||
|
||
@Override | ||
public MLOutput predict(DataFrame dataFrame, Model model) { | ||
if (model == null) { | ||
throw new IllegalArgumentException("No model found for KMeans prediction."); | ||
} | ||
|
||
List<Prediction<Event>> predictions; | ||
MutableDataset<Event> predictionDataset = TribuoUtil.generateDataset(dataFrame, new AnomalyFactory(), | ||
"Anomaly detection LibSVM prediction data from OpenSearch", TribuoOutputType.ANOMALY_DETECTION_LIBSVM); | ||
LibSVMModel libSVMAnomalyModel = (LibSVMModel) ModelSerDeSer.deserialize(model.getContent()); | ||
predictions = libSVMAnomalyModel.predict(predictionDataset); | ||
|
||
List<Map<String, Object>> adResults = new ArrayList<>(); | ||
predictions.forEach(e -> { | ||
Map<String, Object> result = new HashMap<>(); | ||
result.put("anomaly_type", e.getOutput().getType().name()); | ||
// result.put("score", e.getOutput().getScore()); | ||
// Example<Event> example = e.getExample(); | ||
// for (Feature feature : example) { | ||
// result.put(feature.getName(), feature.getValue()); | ||
// } | ||
adResults.add(result); | ||
}); | ||
|
||
return MLPredictionOutput.builder().predictionResult(DataFrameBuilder.load(adResults)).build(); | ||
} | ||
|
||
@Override | ||
public Model train(DataFrame dataFrame) { | ||
SVMParameters params = new SVMParameters<>(new SVMAnomalyType(SVMAnomalyType.SVMMode.ONE_CLASS), KernelType.RBF); | ||
Double gamma = Optional.ofNullable(parameters.getGamma()).orElse(DEFAULT_GAMMA); | ||
Double nu = Optional.ofNullable(parameters.getNu()).orElse(DEFAULT_NU); | ||
params.setGamma(gamma); | ||
params.setNu(nu); | ||
MutableDataset<Event> data = TribuoUtil.generateDataset(dataFrame, new AnomalyFactory(), | ||
"Anomaly detection LibSVM training data from OpenSearch", TribuoOutputType.ANOMALY_DETECTION_LIBSVM); | ||
|
||
LibSVMAnomalyTrainer trainer = new LibSVMAnomalyTrainer(params); | ||
|
||
LibSVMModel libSVMModel = trainer.train(data); | ||
((LibSVMAnomalyModel)libSVMModel).getNumberOfSupportVectors(); | ||
Model model = new Model(); | ||
model.setName(FunctionName.ANOMALY_DETECTION.name()); | ||
model.setVersion(VERSION); | ||
model.setContent(ModelSerDeSer.serialize(libSVMModel)); | ||
return model; | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.