Skip to content

Commit

Permalink
add more AD params; add UT; fix license header
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <ylwu@amazon.com>
  • Loading branch information
ylwu-amzn committed Feb 4, 2022
1 parent 3740101 commit 0030e0a
Show file tree
Hide file tree
Showing 8 changed files with 199 additions and 78 deletions.
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.parameter;
Expand All @@ -22,6 +16,7 @@
import org.opensearch.ml.common.annotation.MLAlgoParameter;

import java.io.IOException;
import java.util.Locale;

import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken;

Expand All @@ -35,44 +30,87 @@ public class AnomalyDetectionParams implements MLAlgoParams {
it -> parse(it)
);

public static final String KERNEL_FIELD = "kernel";
public static final String GAMMA_FIELD = "gamma";
public static final String NU_FIELD = "nu";
public static final String COST_FIELD = "cost";
public static final String COEFF_FIELD = "coeff";
public static final String EPSILON_FIELD = "epsilon";
public static final String DEGREE_FIELD = "degree";
private ADKernelType kernelType;
private Double gamma;
private Double nu;
private Double cost;
private Double coeff;
private Double epsilon;
private Integer degree;


@Builder
public AnomalyDetectionParams(Double gamma, Double nu) {
public AnomalyDetectionParams(ADKernelType kernelType, Double gamma, Double nu, Double cost, Double coeff, Double epsilon, Integer degree) {
this.kernelType = kernelType;
this.gamma = gamma;
this.nu = nu;
this.cost = cost;
this.coeff = coeff;
this.epsilon = epsilon;
this.degree = degree;
}

public AnomalyDetectionParams(StreamInput in) throws IOException {
if (in.readBoolean()) {
this.kernelType = in.readEnum(ADKernelType.class);
}
this.gamma = in.readOptionalDouble();
this.nu = in.readOptionalDouble();
this.cost = in.readOptionalDouble();
this.coeff = in.readOptionalDouble();
this.epsilon = in.readOptionalDouble();
this.degree = in.readOptionalInt();
}

private static MLAlgoParams parse(XContentParser parser) throws IOException {
public static MLAlgoParams parse(XContentParser parser) throws IOException {
ADKernelType kernelType = null;
Double gamma = null;
Double nu = null;
Double cost = null;
Double coeff = null;
Double epsilon = null;
Integer degree = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();

switch (fieldName) {
case KERNEL_FIELD:
kernelType = ADKernelType.valueOf(parser.text().toUpperCase(Locale.ROOT));
break;
case GAMMA_FIELD:
gamma = parser.doubleValue();
break;
case NU_FIELD:
nu = parser.doubleValue();
break;
case COST_FIELD:
cost = parser.doubleValue();
break;
case COEFF_FIELD:
coeff = parser.doubleValue();
break;
case EPSILON_FIELD:
epsilon = parser.doubleValue();
break;
case DEGREE_FIELD:
degree = parser.intValue();
break;
default:
parser.skipChildren();
break;
}
}
return new AnomalyDetectionParams(gamma, nu);
return new AnomalyDetectionParams(kernelType, gamma, nu, cost, coeff, epsilon, degree);
}

@Override
Expand All @@ -82,15 +120,44 @@ public String getWriteableName() {

@Override
public void writeTo(StreamOutput out) throws IOException {
if (kernelType == null) {
out.writeBoolean(false);
} else {
out.writeBoolean(true);
out.writeEnum(kernelType);
}
out.writeOptionalDouble(gamma);
out.writeOptionalDouble(nu);
out.writeOptionalDouble(cost);
out.writeOptionalDouble(coeff);
out.writeOptionalDouble(epsilon);
out.writeOptionalInt(degree);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(GAMMA_FIELD, gamma);
builder.field(NU_FIELD, nu);
if (kernelType != null) {
builder.field(KERNEL_FIELD, kernelType);
}
if (gamma != null) {
builder.field(GAMMA_FIELD, gamma);
}
if (nu != null) {
builder.field(NU_FIELD, nu);
}
if (cost != null) {
builder.field(COST_FIELD, cost);
}
if (coeff != null) {
builder.field(COEFF_FIELD, coeff);
}
if (epsilon != null) {
builder.field(EPSILON_FIELD, epsilon);
}
if (degree != null) {
builder.field(DEGREE_FIELD, degree);
}
builder.endObject();
return builder;
}
Expand All @@ -99,4 +166,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
public int getVersion() {
return 1;
}

public enum ADKernelType {
LINEAR,
POLY,
RBF,
SIGMOID
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.parameter;

import org.junit.Before;
import org.junit.Test;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.xcontent.XContentParser;
import org.opensearch.ml.common.TestHelper;

import java.io.IOException;
import java.util.function.Function;

import static org.junit.Assert.assertEquals;

public class AnomalyDetectionParamsTest {

AnomalyDetectionParams params;
private Function<XContentParser, AnomalyDetectionParams> function = parser -> {
try {
return (AnomalyDetectionParams)AnomalyDetectionParams.parse(parser);
} catch (IOException e) {
throw new RuntimeException("failed to parse AnomalyDetectionParams", e);
}
};

@Before
public void setUp() {
params = AnomalyDetectionParams.builder()
.kernelType(AnomalyDetectionParams.ADKernelType.POLY)
.degree(2)
.build();
}

@Test
public void parse_AnomalyDetectionParams() throws IOException {
TestHelper.testParse(params, function);
}

@Test
public void parse_Emptyparse_AnomalyDetectionParams() throws IOException {
TestHelper.testParse(AnomalyDetectionParams.builder().build(), function);
}

@Test
public void readInputStream_Success() throws IOException {
readInputStream(params);
}

@Test
public void readInputStream_Success_EmptyParams() throws IOException {
readInputStream(AnomalyDetectionParams.builder().build());
}

private void readInputStream(AnomalyDetectionParams params) throws IOException {
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
params.writeTo(bytesStreamOutput);

StreamInput streamInput = bytesStreamOutput.bytes().streamInput();
AnomalyDetectionParams parsedParams = new AnomalyDetectionParams(streamInput);
assertEquals(params, parsedParams);
}
}
9 changes: 1 addition & 8 deletions gradle.properties
Original file line number Diff line number Diff line change
@@ -1,14 +1,7 @@
#
# Copyright OpenSearch Contributors
# SPDX-License-Identifier: Apache-2.0
#
# The OpenSearch Contributors require contributions made to
# this file be licensed under the Apache-2.0 license or a
# compatible open source license.
#
# Modifications Copyright OpenSearch Contributors. See
# GitHub history for details.
#
#

opensearch_version = 1.3.0-SNAPSHOT
opensearchBaseVersion = 1.3.0
9 changes: 1 addition & 8 deletions gradle/wrapper/gradle-wrapper.properties
Original file line number Diff line number Diff line change
@@ -1,14 +1,7 @@
#
# Copyright OpenSearch Contributors
# SPDX-License-Identifier: Apache-2.0
#
# The OpenSearch Contributors require contributions made to
# this file be licensed under the Apache-2.0 license or a
# compatible open source license.
#
# Modifications Copyright OpenSearch Contributors. See
# GitHub history for details.
#
#

#Mon Jan 04 08:34:55 PST 2021
distributionUrl=https\://services.gradle.org/distributions/gradle-6.6.1-all.zip
Expand Down
9 changes: 1 addition & 8 deletions gradlew
Original file line number Diff line number Diff line change
@@ -1,16 +1,9 @@
#!/usr/bin/env sh

#
# Copyright OpenSearch Contributors
# SPDX-License-Identifier: Apache-2.0
#
# The OpenSearch Contributors require contributions made to
# this file be licensed under the Apache-2.0 license or a
# compatible open source license.
#
# Modifications Copyright OpenSearch Contributors. See
# GitHub history for details.
#
#

##############################################################################
##
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,6 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*
*/

package org.opensearch.ml.engine.algorithms.ad;
Expand Down Expand Up @@ -52,6 +45,7 @@ public class AnomalyDetection implements Trainable, Predictable {
public static final int VERSION = 1;
private static double DEFAULT_GAMMA = 1.0;
private static double DEFAULT_NU = 0.1;
private static KernelType DEFAULT_KERNEL_TYPE = KernelType.RBF;

private AnomalyDetectionParams parameters;

Expand Down Expand Up @@ -89,6 +83,7 @@ public MLOutput predict(DataFrame dataFrame, Model model) {
List<Map<String, Object>> adResults = new ArrayList<>();
predictions.forEach(e -> {
Map<String, Object> result = new HashMap<>();
result.put("score", e.getOutput().getScore());
result.put("anomaly_type", e.getOutput().getType().name());
adResults.add(result);
});
Expand All @@ -98,11 +93,24 @@ public MLOutput predict(DataFrame dataFrame, Model model) {

@Override
public Model train(DataFrame dataFrame) {
SVMParameters params = new SVMParameters<>(new SVMAnomalyType(SVMAnomalyType.SVMMode.ONE_CLASS), KernelType.RBF);
KernelType kernelType = parseKernelType();
SVMParameters params = new SVMParameters<>(new SVMAnomalyType(SVMAnomalyType.SVMMode.ONE_CLASS), kernelType);
Double gamma = Optional.ofNullable(parameters.getGamma()).orElse(DEFAULT_GAMMA);
Double nu = Optional.ofNullable(parameters.getNu()).orElse(DEFAULT_NU);
params.setGamma(gamma);
params.setNu(nu);
if (parameters.getCost() != null) {
params.setCost(parameters.getCost());
}
if (parameters.getCoeff() != null) {
params.setCoeff(parameters.getCoeff());
}
if (parameters.getEpsilon() != null) {
params.setEpsilon(parameters.getEpsilon());
}
if (parameters.getDegree() != null) {
params.setDegree(parameters.getDegree());
}
MutableDataset<Event> data = TribuoUtil.generateDataset(dataFrame, new AnomalyFactory(),
"Anomaly detection LibSVM training data from OpenSearch", TribuoOutputType.ANOMALY_DETECTION_LIBSVM);

Expand All @@ -117,4 +125,27 @@ public Model train(DataFrame dataFrame) {
return model;
}

private KernelType parseKernelType() {
KernelType kernelType = DEFAULT_KERNEL_TYPE;
if (parameters.getKernelType() == null) {
return kernelType;
}
switch (parameters.getKernelType()){
case LINEAR:
kernelType = KernelType.LINEAR;
break;
case POLY:
kernelType = KernelType.POLY;
break;
case RBF:
kernelType = KernelType.RBF;
break;
case SIGMOID:
kernelType = KernelType.SIGMOID;
break;
default:
break;
}
return kernelType;
}
}
Loading

0 comments on commit 0030e0a

Please sign in to comment.