Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

integrate tribuo anomaly detection based on libSVM #96

Merged
merged 3 commits into from
Feb 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public class DefaultDataFrame extends AbstractDataFrame{
List<Row> rows;
ColumnMeta[] columnMetas;

DefaultDataFrame(final ColumnMeta[] columnMetas){
public DefaultDataFrame(final ColumnMeta[] columnMetas){
super(DataFrameType.DEFAULT);
this.columnMetas = columnMetas;
this.rows = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.parameter;

import lombok.Builder;
import lombok.Data;
import org.opensearch.common.ParseField;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.common.xcontent.NamedXContentRegistry;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentParser;
import org.opensearch.ml.common.annotation.MLAlgoParameter;

import java.io.IOException;
import java.util.Locale;

import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken;

@Data
@MLAlgoParameter(algorithms={FunctionName.ANOMALY_DETECTION})
public class AnomalyDetectionParams implements MLAlgoParams {
public static final String PARSE_FIELD_NAME = FunctionName.ANOMALY_DETECTION.name();
public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry(
MLAlgoParams.class,
new ParseField(PARSE_FIELD_NAME),
it -> parse(it)
);

public static final String KERNEL_FIELD = "kernel";
public static final String GAMMA_FIELD = "gamma";
public static final String NU_FIELD = "nu";
public static final String COST_FIELD = "cost";
public static final String COEFF_FIELD = "coeff";
public static final String EPSILON_FIELD = "epsilon";
public static final String DEGREE_FIELD = "degree";
private ADKernelType kernelType;
private Double gamma;
private Double nu;
private Double cost;
private Double coeff;
private Double epsilon;
private Integer degree;


@Builder
public AnomalyDetectionParams(ADKernelType kernelType, Double gamma, Double nu, Double cost, Double coeff, Double epsilon, Integer degree) {
this.kernelType = kernelType;
this.gamma = gamma;
this.nu = nu;
this.cost = cost;
this.coeff = coeff;
this.epsilon = epsilon;
this.degree = degree;
}

public AnomalyDetectionParams(StreamInput in) throws IOException {
if (in.readBoolean()) {
this.kernelType = in.readEnum(ADKernelType.class);
}
this.gamma = in.readOptionalDouble();
this.nu = in.readOptionalDouble();
this.cost = in.readOptionalDouble();
this.coeff = in.readOptionalDouble();
this.epsilon = in.readOptionalDouble();
this.degree = in.readOptionalInt();
}

public static MLAlgoParams parse(XContentParser parser) throws IOException {
ADKernelType kernelType = null;
Double gamma = null;
Double nu = null;
Double cost = null;
Double coeff = null;
Double epsilon = null;
Integer degree = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();

switch (fieldName) {
case KERNEL_FIELD:
kernelType = ADKernelType.valueOf(parser.text().toUpperCase(Locale.ROOT));
break;
case GAMMA_FIELD:
gamma = parser.doubleValue();
break;
case NU_FIELD:
nu = parser.doubleValue();
break;
case COST_FIELD:
cost = parser.doubleValue();
break;
case COEFF_FIELD:
coeff = parser.doubleValue();
break;
case EPSILON_FIELD:
epsilon = parser.doubleValue();
break;
case DEGREE_FIELD:
degree = parser.intValue();
break;
default:
parser.skipChildren();
break;
}
}
return new AnomalyDetectionParams(kernelType, gamma, nu, cost, coeff, epsilon, degree);
}

@Override
public String getWriteableName() {
return PARSE_FIELD_NAME;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
if (kernelType == null) {
out.writeBoolean(false);
} else {
out.writeBoolean(true);
out.writeEnum(kernelType);
}
out.writeOptionalDouble(gamma);
out.writeOptionalDouble(nu);
out.writeOptionalDouble(cost);
out.writeOptionalDouble(coeff);
out.writeOptionalDouble(epsilon);
out.writeOptionalInt(degree);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (kernelType != null) {
builder.field(KERNEL_FIELD, kernelType);
}
if (gamma != null) {
builder.field(GAMMA_FIELD, gamma);
}
if (nu != null) {
builder.field(NU_FIELD, nu);
}
if (cost != null) {
builder.field(COST_FIELD, cost);
}
if (coeff != null) {
builder.field(COEFF_FIELD, coeff);
}
if (epsilon != null) {
builder.field(EPSILON_FIELD, epsilon);
}
if (degree != null) {
builder.field(DEGREE_FIELD, degree);
}
builder.endObject();
return builder;
}

@Override
public int getVersion() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious: this version is always 1?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We set version as 1 for first release. May change to higher version in future

return 1;
}

public enum ADKernelType {
LINEAR,
POLY,
RBF,
SIGMOID
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
public enum FunctionName {
LINEAR_REGRESSION,
KMEANS,
ANOMALY_DETECTION,
SAMPLE_ALGO,
LOCAL_SAMPLE_CALCULATOR,
ANOMALY_LOCALIZATION
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
public enum MLOutputType {
TRAINING,
PREDICTION,
ANOMALY_DETECTION_LIBSVM,
SAMPLE_ALGO,
SAMPLE_CALCULATOR
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.parameter;

import org.junit.Before;
import org.junit.Test;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.xcontent.XContentParser;
import org.opensearch.ml.common.TestHelper;

import java.io.IOException;
import java.util.function.Function;

import static org.junit.Assert.assertEquals;

public class AnomalyDetectionParamsTest {

AnomalyDetectionParams params;
private Function<XContentParser, AnomalyDetectionParams> function = parser -> {
try {
return (AnomalyDetectionParams)AnomalyDetectionParams.parse(parser);
} catch (IOException e) {
throw new RuntimeException("failed to parse AnomalyDetectionParams", e);
}
};

@Before
public void setUp() {
params = AnomalyDetectionParams.builder()
.kernelType(AnomalyDetectionParams.ADKernelType.POLY)
.gamma(1.0)
.nu(0.5)
.cost(1.0)
.coeff(0.1)
.epsilon(0.2)
.degree(2)
.build();
}

@Test
public void parse_AnomalyDetectionParams() throws IOException {
TestHelper.testParse(params, function);
}

@Test
public void parse_Emptyparse_AnomalyDetectionParams() throws IOException {
TestHelper.testParse(AnomalyDetectionParams.builder().build(), function);
}

@Test
public void readInputStream_Success() throws IOException {
readInputStream(params);
}

@Test
public void readInputStream_Success_EmptyParams() throws IOException {
readInputStream(AnomalyDetectionParams.builder().build());
}

private void readInputStream(AnomalyDetectionParams params) throws IOException {
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
params.writeTo(bytesStreamOutput);

StreamInput streamInput = bytesStreamOutput.bytes().streamInput();
AnomalyDetectionParams parsedParams = new AnomalyDetectionParams(streamInput);
assertEquals(params, parsedParams);
}
}
9 changes: 1 addition & 8 deletions gradle.properties
Original file line number Diff line number Diff line change
@@ -1,14 +1,7 @@
#
# Copyright OpenSearch Contributors
# SPDX-License-Identifier: Apache-2.0
#
# The OpenSearch Contributors require contributions made to
# this file be licensed under the Apache-2.0 license or a
# compatible open source license.
#
# Modifications Copyright OpenSearch Contributors. See
# GitHub history for details.
#
#

opensearch_version = 1.3.0-SNAPSHOT
opensearchBaseVersion = 1.3.0
9 changes: 1 addition & 8 deletions gradle/wrapper/gradle-wrapper.properties
Original file line number Diff line number Diff line change
@@ -1,14 +1,7 @@
#
# Copyright OpenSearch Contributors
# SPDX-License-Identifier: Apache-2.0
#
# The OpenSearch Contributors require contributions made to
# this file be licensed under the Apache-2.0 license or a
# compatible open source license.
#
# Modifications Copyright OpenSearch Contributors. See
# GitHub history for details.
#
#

#Mon Jan 04 08:34:55 PST 2021
distributionUrl=https\://services.gradle.org/distributions/gradle-6.6.1-all.zip
Expand Down
9 changes: 1 addition & 8 deletions gradlew
Original file line number Diff line number Diff line change
@@ -1,16 +1,9 @@
#!/usr/bin/env sh

#
# Copyright OpenSearch Contributors
# SPDX-License-Identifier: Apache-2.0
#
# The OpenSearch Contributors require contributions made to
# this file be licensed under the Apache-2.0 license or a
# compatible open source license.
#
# Modifications Copyright OpenSearch Contributors. See
# GitHub history for details.
#
#

##############################################################################
##
Expand Down
1 change: 1 addition & 0 deletions ml-algorithms/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ dependencies {
compile group: 'org.reflections', name: 'reflections', version: '0.9.12'
compile group: 'org.tribuo', name: 'tribuo-clustering-kmeans', version: '4.2.0'
compile group: 'org.tribuo', name: 'tribuo-regression-sgd', version: '4.2.0'
compile group: 'org.tribuo', name: 'tribuo-anomaly-libsvm', version: '4.2.0'
compile group: 'commons-io', name: 'commons-io', version: '2.11.0'
testCompile group: 'junit', name: 'junit', version: '4.12'
testImplementation group: 'org.mockito', name: 'mockito-core', version: '3.9.0'
Expand Down
Loading