Skip to content

[ML] add nlp config update serialization tests #85867

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,22 @@
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdateTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.EmptyConfigUpdateTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfigUpdateTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfigUpdateTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfigUpdateTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdateTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ResultsFieldUpdateTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfigUpdateTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfigUpdateTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfigUpdateTests;

import java.util.ArrayList;
import java.util.List;
Expand Down Expand Up @@ -44,6 +57,12 @@ private static InferenceConfigUpdate randomInferenceConfigUpdate() {
RegressionConfigUpdateTests.randomRegressionConfigUpdate(),
ClassificationConfigUpdateTests.randomClassificationConfigUpdate(),
ResultsFieldUpdateTests.randomUpdate(),
TextClassificationConfigUpdateTests.randomUpdate(),
TextEmbeddingConfigUpdateTests.randomUpdate(),
NerConfigUpdateTests.randomUpdate(),
FillMaskConfigUpdateTests.randomUpdate(),
ZeroShotClassificationConfigUpdateTests.randomUpdate(),
PassThroughConfigUpdateTests.randomUpdate(),
EmptyConfigUpdateTests.testInstance()
);
}
Expand All @@ -68,6 +87,27 @@ protected NamedWriteableRegistry getNamedWriteableRegistry() {

@Override
protected Request mutateInstanceForVersion(Request instance, Version version) {
return instance;
InferenceConfigUpdate adjustedUpdate;
InferenceConfigUpdate currentUpdate = instance.getUpdate();
if (currentUpdate instanceof NlpConfigUpdate nlpConfigUpdate) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you use a fancy instanceof switch statement here?

switch (currentUpdate) {
        case NlpConfigUpdate i -> NlpConfigUpdateTests.mutateForVersion(update, version);
        ....
};

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh snap, you are correct!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Welp, I suppose not. It needs to be added via a compilation flag and it isn't...

if (nlpConfigUpdate instanceof TextClassificationConfigUpdate update) {
adjustedUpdate = TextClassificationConfigUpdateTests.mutateForVersion(update, version);
} else if (nlpConfigUpdate instanceof TextEmbeddingConfigUpdate update) {
adjustedUpdate = TextEmbeddingConfigUpdateTests.mutateForVersion(update, version);
} else if (nlpConfigUpdate instanceof NerConfigUpdate update) {
adjustedUpdate = NerConfigUpdateTests.mutateForVersion(update, version);
} else if (nlpConfigUpdate instanceof FillMaskConfigUpdate update) {
adjustedUpdate = FillMaskConfigUpdateTests.mutateForVersion(update, version);
} else if (nlpConfigUpdate instanceof ZeroShotClassificationConfigUpdate update) {
adjustedUpdate = ZeroShotClassificationConfigUpdateTests.mutateForVersion(update, version);
} else if (nlpConfigUpdate instanceof PassThroughConfigUpdate update) {
adjustedUpdate = PassThroughConfigUpdateTests.mutateForVersion(update, version);
} else {
throw new IllegalArgumentException("Unknown update [" + currentUpdate.getName() + "]");
}
} else {
adjustedUpdate = currentUpdate;
}
return new Request(instance.getModelId(), instance.getObjectsToInfer(), adjustedUpdate, instance.isPreviouslyLicensed());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,27 @@

public class FillMaskConfigUpdateTests extends AbstractNlpConfigUpdateTestCase<FillMaskConfigUpdate> {

public static FillMaskConfigUpdate randomUpdate() {
FillMaskConfigUpdate.Builder builder = new FillMaskConfigUpdate.Builder();
if (randomBoolean()) {
builder.setNumTopClasses(randomIntBetween(1, 4));
}
if (randomBoolean()) {
builder.setResultsField(randomAlphaOfLength(8));
}
if (randomBoolean()) {
builder.setTokenizationUpdate(new BertTokenizationUpdate(randomFrom(Tokenization.Truncate.values()), null));
}
return builder.build();
}

public static FillMaskConfigUpdate mutateForVersion(FillMaskConfigUpdate instance, Version version) {
if (version.before(Version.V_8_1_0)) {
return new FillMaskConfigUpdate(instance.getNumTopClasses(), instance.getResultsField(), null);
}
return instance;
}

@Override
Tuple<Map<String, Object>, FillMaskConfigUpdate> fromMapTestInstances(TokenizationUpdate expectedTokenization) {
int topClasses = randomIntBetween(1, 10);
Expand Down Expand Up @@ -103,25 +124,12 @@ protected Writeable.Reader<FillMaskConfigUpdate> instanceReader() {

@Override
protected FillMaskConfigUpdate createTestInstance() {
FillMaskConfigUpdate.Builder builder = new FillMaskConfigUpdate.Builder();
if (randomBoolean()) {
builder.setNumTopClasses(randomIntBetween(1, 4));
}
if (randomBoolean()) {
builder.setResultsField(randomAlphaOfLength(8));
}
if (randomBoolean()) {
builder.setTokenizationUpdate(new BertTokenizationUpdate(randomFrom(Tokenization.Truncate.values()), null));
}
return builder.build();
return randomUpdate();
}

@Override
protected FillMaskConfigUpdate mutateInstanceForVersion(FillMaskConfigUpdate instance, Version version) {
if (version.before(Version.V_8_1_0)) {
return new FillMaskConfigUpdate(instance.getNumTopClasses(), instance.getResultsField(), null);
}
return instance;
return mutateForVersion(instance, version);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,24 @@

public class NerConfigUpdateTests extends AbstractNlpConfigUpdateTestCase<NerConfigUpdate> {

public static NerConfigUpdate randomUpdate() {
NerConfigUpdate.Builder builder = new NerConfigUpdate.Builder();
if (randomBoolean()) {
builder.setResultsField(randomAlphaOfLength(8));
}
if (randomBoolean()) {
builder.setTokenizationUpdate(new BertTokenizationUpdate(randomFrom(Tokenization.Truncate.values()), null));
}
return builder.build();
}

public static NerConfigUpdate mutateForVersion(NerConfigUpdate instance, Version version) {
if (version.before(Version.V_8_1_0)) {
return new NerConfigUpdate(instance.getResultsField(), null);
}
return instance;
}

@Override
Tuple<Map<String, Object>, NerConfigUpdate> fromMapTestInstances(TokenizationUpdate expectedTokenization) {
NerConfigUpdate expected = new NerConfigUpdate("ml-results", expectedTokenization);
Expand Down Expand Up @@ -86,22 +104,12 @@ protected Writeable.Reader<NerConfigUpdate> instanceReader() {

@Override
protected NerConfigUpdate createTestInstance() {
NerConfigUpdate.Builder builder = new NerConfigUpdate.Builder();
if (randomBoolean()) {
builder.setResultsField(randomAlphaOfLength(8));
}
if (randomBoolean()) {
builder.setTokenizationUpdate(new BertTokenizationUpdate(randomFrom(Tokenization.Truncate.values()), null));
}
return builder.build();
return randomUpdate();
}

@Override
protected NerConfigUpdate mutateInstanceForVersion(NerConfigUpdate instance, Version version) {
if (version.before(Version.V_8_1_0)) {
return new NerConfigUpdate(instance.getResultsField(), null);
}
return instance;
return mutateForVersion(instance, version);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,24 @@

public class PassThroughConfigUpdateTests extends AbstractNlpConfigUpdateTestCase<PassThroughConfigUpdate> {

public static PassThroughConfigUpdate randomUpdate() {
PassThroughConfigUpdate.Builder builder = new PassThroughConfigUpdate.Builder();
if (randomBoolean()) {
builder.setResultsField(randomAlphaOfLength(8));
}
if (randomBoolean()) {
builder.setTokenizationUpdate(new BertTokenizationUpdate(randomFrom(Tokenization.Truncate.values()), null));
}
return builder.build();
}

public static PassThroughConfigUpdate mutateForVersion(PassThroughConfigUpdate instance, Version version) {
if (version.before(Version.V_8_1_0)) {
return new PassThroughConfigUpdate(instance.getResultsField(), null);
}
return instance;
}

@Override
Tuple<Map<String, Object>, PassThroughConfigUpdate> fromMapTestInstances(TokenizationUpdate expectedTokenization) {
PassThroughConfigUpdate expected = new PassThroughConfigUpdate("ml-results", expectedTokenization);
Expand Down Expand Up @@ -76,22 +94,12 @@ protected Writeable.Reader<PassThroughConfigUpdate> instanceReader() {

@Override
protected PassThroughConfigUpdate createTestInstance() {
PassThroughConfigUpdate.Builder builder = new PassThroughConfigUpdate.Builder();
if (randomBoolean()) {
builder.setResultsField(randomAlphaOfLength(8));
}
if (randomBoolean()) {
builder.setTokenizationUpdate(new BertTokenizationUpdate(randomFrom(Tokenization.Truncate.values()), null));
}
return builder.build();
return randomUpdate();
}

@Override
protected PassThroughConfigUpdate mutateInstanceForVersion(PassThroughConfigUpdate instance, Version version) {
if (version.before(Version.V_8_1_0)) {
return new PassThroughConfigUpdate(instance.getResultsField(), null);
}
return instance;
return mutateForVersion(instance, version);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,35 @@

public class TextClassificationConfigUpdateTests extends AbstractNlpConfigUpdateTestCase<TextClassificationConfigUpdate> {

public static TextClassificationConfigUpdate randomUpdate() {
TextClassificationConfigUpdate.Builder builder = new TextClassificationConfigUpdate.Builder();
if (randomBoolean()) {
builder.setNumTopClasses(randomIntBetween(1, 4));
}
if (randomBoolean()) {
builder.setClassificationLabels(randomList(1, 3, () -> randomAlphaOfLength(4)));
}
if (randomBoolean()) {
builder.setResultsField(randomAlphaOfLength(8));
}
if (randomBoolean()) {
builder.setTokenizationUpdate(new BertTokenizationUpdate(randomFrom(Tokenization.Truncate.values()), null));
}
return builder.build();
}

public static TextClassificationConfigUpdate mutateForVersion(TextClassificationConfigUpdate instance, Version version) {
if (version.before(Version.V_8_1_0)) {
return new TextClassificationConfigUpdate(
instance.getClassificationLabels(),
instance.getNumTopClasses(),
instance.getResultsField(),
null
);
}
return instance;
}

@Override
Tuple<Map<String, Object>, TextClassificationConfigUpdate> fromMapTestInstances(TokenizationUpdate expectedTokenization) {
int numClasses = randomIntBetween(1, 10);
Expand Down Expand Up @@ -159,33 +188,12 @@ protected Writeable.Reader<TextClassificationConfigUpdate> instanceReader() {

@Override
protected TextClassificationConfigUpdate createTestInstance() {
TextClassificationConfigUpdate.Builder builder = new TextClassificationConfigUpdate.Builder();
if (randomBoolean()) {
builder.setNumTopClasses(randomIntBetween(1, 4));
}
if (randomBoolean()) {
builder.setClassificationLabels(randomList(1, 3, () -> randomAlphaOfLength(4)));
}
if (randomBoolean()) {
builder.setResultsField(randomAlphaOfLength(8));
}
if (randomBoolean()) {
builder.setTokenizationUpdate(new BertTokenizationUpdate(randomFrom(Tokenization.Truncate.values()), null));
}
return builder.build();
return randomUpdate();
}

@Override
protected TextClassificationConfigUpdate mutateInstanceForVersion(TextClassificationConfigUpdate instance, Version version) {
if (version.before(Version.V_8_1_0)) {
return new TextClassificationConfigUpdate(
instance.getClassificationLabels(),
instance.getNumTopClasses(),
instance.getResultsField(),
null
);
}
return instance;
return mutateForVersion(instance, version);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,24 @@

public class TextEmbeddingConfigUpdateTests extends AbstractNlpConfigUpdateTestCase<TextEmbeddingConfigUpdate> {

public static TextEmbeddingConfigUpdate randomUpdate() {
TextEmbeddingConfigUpdate.Builder builder = new TextEmbeddingConfigUpdate.Builder();
if (randomBoolean()) {
builder.setResultsField(randomAlphaOfLength(8));
}
if (randomBoolean()) {
builder.setTokenizationUpdate(new BertTokenizationUpdate(randomFrom(Tokenization.Truncate.values()), null));
}
return builder.build();
}

public static TextEmbeddingConfigUpdate mutateForVersion(TextEmbeddingConfigUpdate instance, Version version) {
if (version.before(Version.V_8_1_0)) {
return new TextEmbeddingConfigUpdate(instance.getResultsField(), null);
}
return instance;
}

@Override
Tuple<Map<String, Object>, TextEmbeddingConfigUpdate> fromMapTestInstances(TokenizationUpdate expectedTokenization) {
TextEmbeddingConfigUpdate expected = new TextEmbeddingConfigUpdate("ml-results", expectedTokenization);
Expand Down Expand Up @@ -76,22 +94,12 @@ protected Writeable.Reader<TextEmbeddingConfigUpdate> instanceReader() {

@Override
protected TextEmbeddingConfigUpdate createTestInstance() {
TextEmbeddingConfigUpdate.Builder builder = new TextEmbeddingConfigUpdate.Builder();
if (randomBoolean()) {
builder.setResultsField(randomAlphaOfLength(8));
}
if (randomBoolean()) {
builder.setTokenizationUpdate(new BertTokenizationUpdate(randomFrom(Tokenization.Truncate.values()), null));
}
return builder.build();
return randomUpdate();
}

@Override
protected TextEmbeddingConfigUpdate mutateInstanceForVersion(TextEmbeddingConfigUpdate instance, Version version) {
if (version.before(Version.V_8_1_0)) {
return new TextEmbeddingConfigUpdate(instance.getResultsField(), null);
}
return instance;
return mutateForVersion(instance, version);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,22 @@

public class ZeroShotClassificationConfigUpdateTests extends AbstractNlpConfigUpdateTestCase<ZeroShotClassificationConfigUpdate> {

public static ZeroShotClassificationConfigUpdate randomUpdate() {
return new ZeroShotClassificationConfigUpdate(
randomBoolean() ? null : randomList(1, 5, () -> randomAlphaOfLength(10)),
randomBoolean() ? null : randomBoolean(),
randomBoolean() ? null : randomAlphaOfLength(5),
randomBoolean() ? null : new BertTokenizationUpdate(randomFrom(Tokenization.Truncate.values()), null)
);
}

public static ZeroShotClassificationConfigUpdate mutateForVersion(ZeroShotClassificationConfigUpdate instance, Version version) {
if (version.before(Version.V_8_1_0)) {
return new ZeroShotClassificationConfigUpdate(instance.getLabels(), instance.getMultiLabel(), instance.getResultsField(), null);
}
return instance;
}

@Override
protected boolean supportsUnknownFields() {
return false;
Expand All @@ -49,10 +65,7 @@ protected ZeroShotClassificationConfigUpdate createTestInstance() {

@Override
protected ZeroShotClassificationConfigUpdate mutateInstanceForVersion(ZeroShotClassificationConfigUpdate instance, Version version) {
if (version.before(Version.V_8_1_0)) {
return new ZeroShotClassificationConfigUpdate(instance.getLabels(), instance.getMultiLabel(), instance.getResultsField(), null);
}
return instance;
return mutateForVersion(instance, version);
}

@Override
Expand Down Expand Up @@ -197,12 +210,7 @@ public void testIsNoop() {
}

public static ZeroShotClassificationConfigUpdate createRandom() {
return new ZeroShotClassificationConfigUpdate(
randomBoolean() ? null : randomList(1, 5, () -> randomAlphaOfLength(10)),
randomBoolean() ? null : randomBoolean(),
randomBoolean() ? null : randomAlphaOfLength(5),
randomBoolean() ? null : new BertTokenizationUpdate(randomFrom(Tokenization.Truncate.values()), null)
);
return randomUpdate();
}

@Override
Expand Down