Skip to content

[ML] Allow NLP truncate option to be updated when span is set #91224

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/91224.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 91224
summary: Allow NLP truncate option to be updated when span is set
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.ml.inference.trainedmodel;

import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.Objects;

public abstract class AbstractTokenizationUpdate implements TokenizationUpdate {

private final Tokenization.Truncate truncate;
private final Integer span;

protected static void declareCommonParserFields(ConstructingObjectParser<? extends AbstractTokenizationUpdate, Void> parser) {
parser.declareString(ConstructingObjectParser.optionalConstructorArg(), Tokenization.TRUNCATE);
parser.declareInt(ConstructingObjectParser.optionalConstructorArg(), Tokenization.SPAN);
}

public AbstractTokenizationUpdate(@Nullable Tokenization.Truncate truncate, @Nullable Integer span) {
this.truncate = truncate;
this.span = span;
}

public AbstractTokenizationUpdate(StreamInput in) throws IOException {
this.truncate = in.readOptionalEnum(Tokenization.Truncate.class);
if (in.getVersion().onOrAfter(Version.V_8_2_0)) {
this.span = in.readOptionalInt();
} else {
this.span = null;
}
}

@Override
public boolean isNoop() {
return truncate == null && span == null;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (truncate != null) {
builder.field(Tokenization.TRUNCATE.getPreferredName(), truncate.toString());
}
if (span != null) {
builder.field(Tokenization.SPAN.getPreferredName(), span);
}
builder.endObject();
return builder;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalEnum(truncate);
if (out.getVersion().onOrAfter(Version.V_8_2_0)) {
out.writeOptionalInt(span);
}
}

public Integer getSpan() {
return span;
}

public Tokenization.Truncate getTruncate() {
return truncate;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o instanceof AbstractTokenizationUpdate == false) {
return false;
}
AbstractTokenizationUpdate that = (AbstractTokenizationUpdate) o;
return Objects.equals(truncate, that.truncate) && Objects.equals(span, that.span);
}

@Override
public int hashCode() {
return Objects.hash(truncate, span);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,17 @@

package org.elasticsearch.xpack.core.ml.inference.trainedmodel;

import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
import java.util.Objects;
import java.util.Optional;

public class BertTokenizationUpdate implements TokenizationUpdate {
public class BertTokenizationUpdate extends AbstractTokenizationUpdate {

public static final ParseField NAME = BertTokenization.NAME;

Expand All @@ -31,29 +27,19 @@ public class BertTokenizationUpdate implements TokenizationUpdate {
);

static {
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), Tokenization.TRUNCATE);
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), Tokenization.SPAN);
declareCommonParserFields(PARSER);
}

public static BertTokenizationUpdate fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}

private final Tokenization.Truncate truncate;
private final Integer span;

public BertTokenizationUpdate(@Nullable Tokenization.Truncate truncate, @Nullable Integer span) {
this.truncate = truncate;
this.span = span;
super(truncate, span);
}

public BertTokenizationUpdate(StreamInput in) throws IOException {
this.truncate = in.readOptionalEnum(Tokenization.Truncate.class);
if (in.getVersion().onOrAfter(Version.V_8_2_0)) {
this.span = in.readOptionalInt();
} else {
this.span = null;
}
super(in);
}

@Override
Expand All @@ -66,65 +52,41 @@ public Tokenization apply(Tokenization originalConfig) {
);
}

Tokenization.validateSpanAndTruncate(getTruncate(), getSpan());

if (isNoop()) {
return originalConfig;
}

if (getTruncate() != null && getTruncate().isInCompatibleWithSpan() == false) {
// When truncate value is incompatible with span wipe out
// the existing span setting to avoid an invalid combination of settings.
// This avoids the user have to set span to the special unset value
return new BertTokenization(
originalConfig.doLowerCase(),
originalConfig.withSpecialTokens(),
originalConfig.maxSequenceLength(),
getTruncate(),
null
);
}

return new BertTokenization(
originalConfig.doLowerCase(),
originalConfig.withSpecialTokens(),
originalConfig.maxSequenceLength(),
Optional.ofNullable(this.truncate).orElse(originalConfig.getTruncate()),
Optional.ofNullable(this.span).orElse(originalConfig.getSpan())
Optional.ofNullable(getTruncate()).orElse(originalConfig.getTruncate()),
Optional.ofNullable(getSpan()).orElse(originalConfig.getSpan())
);
}

@Override
public boolean isNoop() {
return truncate == null && span == null;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (truncate != null) {
builder.field(Tokenization.TRUNCATE.getPreferredName(), truncate.toString());
}
if (span != null) {
builder.field(Tokenization.SPAN.getPreferredName(), span);
}
builder.endObject();
return builder;
}

@Override
public String getWriteableName() {
return BertTokenization.NAME.getPreferredName();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalEnum(truncate);
if (out.getVersion().onOrAfter(Version.V_8_2_0)) {
out.writeOptionalInt(span);
}
}

@Override
public String getName() {
return BertTokenization.NAME.getPreferredName();
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
BertTokenizationUpdate that = (BertTokenizationUpdate) o;
return Objects.equals(truncate, that.truncate) && Objects.equals(span, that.span);
}

@Override
public int hashCode() {
return Objects.hash(truncate, span);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,17 @@

package org.elasticsearch.xpack.core.ml.inference.trainedmodel;

import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
import java.util.Objects;
import java.util.Optional;

public class MPNetTokenizationUpdate implements TokenizationUpdate {
public class MPNetTokenizationUpdate extends AbstractTokenizationUpdate {

public static final ParseField NAME = MPNetTokenization.NAME;

Expand All @@ -31,29 +27,19 @@ public class MPNetTokenizationUpdate implements TokenizationUpdate {
);

static {
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), Tokenization.TRUNCATE);
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), Tokenization.SPAN);
declareCommonParserFields(PARSER);
}

public static MPNetTokenizationUpdate fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}

private final Tokenization.Truncate truncate;
private final Integer span;

public MPNetTokenizationUpdate(@Nullable Tokenization.Truncate truncate, @Nullable Integer span) {
this.truncate = truncate;
this.span = span;
super(truncate, span);
}

public MPNetTokenizationUpdate(StreamInput in) throws IOException {
this.truncate = in.readOptionalEnum(Tokenization.Truncate.class);
if (in.getVersion().onOrAfter(Version.V_8_2_0)) {
this.span = in.readOptionalInt();
} else {
this.span = null;
}
super(in);
}

@Override
Expand All @@ -70,61 +56,35 @@ public Tokenization apply(Tokenization originalConfig) {
return originalConfig;
}

if (getTruncate() != null && getTruncate().isInCompatibleWithSpan() == false) {
// When truncate value is incompatible with span wipe out
// the existing span setting to avoid an invalid combination of settings.
// This avoids the user have to set span to the special unset value
return new MPNetTokenization(
originalConfig.doLowerCase(),
originalConfig.withSpecialTokens(),
originalConfig.maxSequenceLength(),
getTruncate(),
null
);
}

return new MPNetTokenization(
originalConfig.doLowerCase(),
originalConfig.withSpecialTokens(),
originalConfig.maxSequenceLength(),
Optional.ofNullable(this.truncate).orElse(originalConfig.getTruncate()),
Optional.ofNullable(this.span).orElse(originalConfig.getSpan())
Optional.ofNullable(this.getTruncate()).orElse(originalConfig.getTruncate()),
Optional.ofNullable(this.getSpan()).orElse(originalConfig.getSpan())
);
}

@Override
public boolean isNoop() {
return truncate == null && span == null;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (truncate != null) {
builder.field(Tokenization.TRUNCATE.getPreferredName(), truncate.toString());
}
if (span != null) {
builder.field(Tokenization.SPAN.getPreferredName(), span);
}
builder.endObject();
return builder;
}

@Override
public String getWriteableName() {
return MPNetTokenization.NAME.getPreferredName();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalEnum(truncate);
if (out.getVersion().onOrAfter(Version.V_8_2_0)) {
out.writeOptionalInt(span);
}
}

@Override
public String getName() {
return MPNetTokenization.NAME.getPreferredName();
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
MPNetTokenizationUpdate that = (MPNetTokenizationUpdate) o;
return Objects.equals(truncate, that.truncate) && Objects.equals(span, that.span);
}

@Override
public int hashCode() {
return Objects.hash(truncate, span);
}
}
Loading