Skip to content

Commit

Permalink
test_ad2
Browse files Browse the repository at this point in the history
  • Loading branch information
ylwu-amzn committed Feb 3, 2022
1 parent 826d151 commit f60d273
Show file tree
Hide file tree
Showing 7 changed files with 429 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public class DefaultDataFrame extends AbstractDataFrame{
List<Row> rows;
ColumnMeta[] columnMetas;

DefaultDataFrame(final ColumnMeta[] columnMetas){
public DefaultDataFrame(final ColumnMeta[] columnMetas){
super(DataFrameType.DEFAULT);
this.columnMetas = columnMetas;
this.rows = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.ml.common.parameter;

import lombok.Builder;
import lombok.Data;
import org.opensearch.common.ParseField;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.common.xcontent.NamedXContentRegistry;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentParser;
import org.opensearch.ml.common.annotation.MLAlgoParameter;

import java.io.IOException;

import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken;

@Data
@MLAlgoParameter(algorithms={FunctionName.ANOMALY_DETECTION})
public class AnomalyDetectionParams implements MLAlgoParams {
public static final String PARSE_FIELD_NAME = FunctionName.ANOMALY_DETECTION.name();
public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry(
MLAlgoParams.class,
new ParseField(PARSE_FIELD_NAME),
it -> parse(it)
);

public static final String GAMMA_FIELD = "gamma";
public static final String NU_FIELD = "nu";
private Double gamma;
private Double nu;

@Builder
public AnomalyDetectionParams(Double gamma, Double nu) {
this.gamma = gamma;
this.nu = nu;
}

public AnomalyDetectionParams(StreamInput in) throws IOException {
this.gamma = in.readOptionalDouble();
this.nu = in.readOptionalDouble();
}

private static MLAlgoParams parse(XContentParser parser) throws IOException {
Double gamma = null;
Double nu = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();

switch (fieldName) {
case GAMMA_FIELD:
gamma = parser.doubleValue();
break;
case NU_FIELD:
nu = parser.doubleValue();
break;
default:
parser.skipChildren();
break;
}
}
return new AnomalyDetectionParams(gamma, nu);
}

@Override
public String getWriteableName() {
return PARSE_FIELD_NAME;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalDouble(gamma);
out.writeOptionalDouble(nu);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(GAMMA_FIELD, gamma);
builder.field(NU_FIELD, nu);
builder.endObject();
return builder;
}

@Override
public int getVersion() {
return 1;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
public enum FunctionName {
LINEAR_REGRESSION,
KMEANS,
ANOMALY_DETECTION_LIBSVM,
ANOMALY_DETECTION,
SAMPLE_ALGO,
LOCAL_SAMPLE_CALCULATOR,
ANOMALY_LOCALIZATION
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*
*/

package org.opensearch.ml.engine.algorithms.ad;

import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataframe.DataFrameBuilder;
import org.opensearch.ml.common.parameter.FunctionName;
import org.opensearch.ml.common.parameter.AnomalyDetectionParams;
import org.opensearch.ml.common.parameter.MLAlgoParams;
import org.opensearch.ml.common.parameter.MLOutput;
import org.opensearch.ml.common.parameter.MLPredictionOutput;
import org.opensearch.ml.engine.Model;
import org.opensearch.ml.engine.Predictable;
import org.opensearch.ml.engine.Trainable;
import org.opensearch.ml.engine.annotation.Function;
import org.opensearch.ml.engine.contants.TribuoOutputType;
import org.opensearch.ml.engine.utils.ModelSerDeSer;
import org.opensearch.ml.engine.utils.TribuoUtil;
import org.tribuo.Example;
import org.tribuo.Feature;
import org.tribuo.MutableDataset;
import org.tribuo.Prediction;
import org.tribuo.anomaly.AnomalyFactory;
import org.tribuo.anomaly.Event;
import org.tribuo.anomaly.libsvm.LibSVMAnomalyModel;
import org.tribuo.anomaly.libsvm.LibSVMAnomalyTrainer;
import org.tribuo.anomaly.libsvm.SVMAnomalyType;
import org.tribuo.common.libsvm.KernelType;
import org.tribuo.common.libsvm.LibSVMModel;
import org.tribuo.common.libsvm.SVMParameters;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;

/**
* Wrap Tribuo's anomaly detection based on one-class SVM (libSVM).
*
*/
@Function(FunctionName.ANOMALY_DETECTION)
public class AnomalyDetection implements Trainable, Predictable {
public static final int VERSION = 1;
private static double DEFAULT_GAMMA = 1.0;
private static double DEFAULT_NU = 0.1;

private AnomalyDetectionParams parameters;

public AnomalyDetection() {}

public AnomalyDetection(MLAlgoParams parameters) {
this.parameters = parameters == null ? AnomalyDetectionParams.builder().build() : (AnomalyDetectionParams)parameters;
validateParameters();
}

private void validateParameters() {

if (parameters.getGamma() != null && parameters.getGamma() <= 0) {
throw new IllegalArgumentException("gamma should be positive.");
}

if (parameters.getNu() != null && parameters.getNu() <= 0) {
throw new IllegalArgumentException("nu should be positive.");
}

}

@Override
public MLOutput predict(DataFrame dataFrame, Model model) {
if (model == null) {
throw new IllegalArgumentException("No model found for KMeans prediction.");
}

List<Prediction<Event>> predictions;
MutableDataset<Event> predictionDataset = TribuoUtil.generateDataset(dataFrame, new AnomalyFactory(),
"Anomaly detection LibSVM prediction data from OpenSearch", TribuoOutputType.ANOMALY_DETECTION_LIBSVM);
LibSVMModel libSVMAnomalyModel = (LibSVMModel) ModelSerDeSer.deserialize(model.getContent());
predictions = libSVMAnomalyModel.predict(predictionDataset);

List<Map<String, Object>> adResults = new ArrayList<>();
predictions.forEach(e -> {
Map<String, Object> result = new HashMap<>();
result.put("anomaly_type", e.getOutput().getType().name());
// result.put("score", e.getOutput().getScore());
// Example<Event> example = e.getExample();
// for (Feature feature : example) {
// result.put(feature.getName(), feature.getValue());
// }
adResults.add(result);
});

return MLPredictionOutput.builder().predictionResult(DataFrameBuilder.load(adResults)).build();
}

@Override
public Model train(DataFrame dataFrame) {
SVMParameters params = new SVMParameters<>(new SVMAnomalyType(SVMAnomalyType.SVMMode.ONE_CLASS), KernelType.RBF);
Double gamma = Optional.ofNullable(parameters.getGamma()).orElse(DEFAULT_GAMMA);
Double nu = Optional.ofNullable(parameters.getNu()).orElse(DEFAULT_NU);
params.setGamma(gamma);
params.setNu(nu);
MutableDataset<Event> data = TribuoUtil.generateDataset(dataFrame, new AnomalyFactory(),
"Anomaly detection LibSVM training data from OpenSearch", TribuoOutputType.ANOMALY_DETECTION_LIBSVM);

LibSVMAnomalyTrainer trainer = new LibSVMAnomalyTrainer(params);

LibSVMModel libSVMModel = trainer.train(data);
((LibSVMAnomalyModel)libSVMModel).getNumberOfSupportVectors();
Model model = new Model();
model.setName(FunctionName.ANOMALY_DETECTION.name());
model.setVersion(VERSION);
model.setContent(ModelSerDeSer.serialize(libSVMModel));
return model;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
@UtilityClass
public class TribuoUtil {
public static Tuple transformDataFrame(DataFrame dataFrame) {
//TODO: remove this line, don't need to do this for everytime
String[] featureNames = Arrays.stream(dataFrame.columnMetas()).map(e -> e.getName()).toArray(String[]::new);
double[][] featureValues = new double[dataFrame.size()][];
Iterator<Row> itr = dataFrame.iterator();
Expand Down Expand Up @@ -68,24 +67,12 @@ public static <T extends Output<T>> MutableDataset<T> generateDataset(DataFrame
example = new ArrayExample<T>((T) new Regressor("DIM-0", Double.NaN), featureNamesValues.v1(), featureNamesValues.v2()[i]);
break;
case ANOMALY_DETECTION_LIBSVM:
Event.EventType eventType = Event.EventType.UNKNOWN;
String[] columns = featureNamesValues.v1();
if ("anomaly_type".equals(columns[columns.length - 1])) {
double v = featureNamesValues.v2()[i][columns.length - 1];
int anomalyType = Double.valueOf(v).intValue();
switch (anomalyType) {
case 1:
eventType = Event.EventType.ANOMALOUS;
break;
case 0:
eventType = Event.EventType.EXPECTED;
break;
default:
break;
}
}

example = new ArrayExample<T>((T) new Event(eventType), featureNamesValues.v1(), featureNamesValues.v2()[i]);
// Why we set default event type as EXPECTED(non-anomalous)
// 1. For training data, Tribuo LibSVMAnomalyTrainer only supports EXPECTED events at training time.
// 2. For prediction data, we treat the data as non-anomalous by default as Tribuo lib don't accept UNKNOWN type.
Event.EventType defaultEventType = Event.EventType.EXPECTED;
// TODO: support anomaly labels to evaluate prediction result
example = new ArrayExample<T>((T) new Event(defaultEventType), featureNamesValues.v1(), featureNamesValues.v2()[i]);
break;
default:
throw new IllegalArgumentException("unknown type:" + outputType);
Expand Down
Loading

0 comments on commit f60d273

Please sign in to comment.