Skip to content

Commit

Permalink
[ML] Ensure immutability of MlMetadata (elastic#31957)
Browse files Browse the repository at this point in the history
The test failure in elastic#31916 revealed that updating
rules on a job was modifying the detectors list
in-place. That meant the old cluster state and the
updated cluster state had no difference and thus the
change was not propagated to non-master nodes.

This commit fixes that and also reviews all of ML
metadata in order to ensure immutability.

Closes elastic#31916
  • Loading branch information
dimitris-athanasiou authored Jul 12, 2018
1 parent e3707ef commit 2cfe703
Show file tree
Hide file tree
Showing 10 changed files with 238 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -156,14 +156,14 @@ private DatafeedConfig(String id, String jobId, TimeValue queryDelay, TimeValue
this.jobId = jobId;
this.queryDelay = queryDelay;
this.frequency = frequency;
this.indices = indices;
this.types = types;
this.indices = indices == null ? null : Collections.unmodifiableList(indices);
this.types = types == null ? null : Collections.unmodifiableList(types);
this.query = query;
this.aggregations = aggregations;
this.scriptFields = scriptFields;
this.scriptFields = scriptFields == null ? null : Collections.unmodifiableList(scriptFields);
this.scrollSize = scrollSize;
this.chunkingConfig = chunkingConfig;
this.headers = Objects.requireNonNull(headers);
this.headers = Collections.unmodifiableMap(headers);
}

public DatafeedConfig(StreamInput in) throws IOException {
Expand All @@ -172,19 +172,19 @@ public DatafeedConfig(StreamInput in) throws IOException {
this.queryDelay = in.readOptionalTimeValue();
this.frequency = in.readOptionalTimeValue();
if (in.readBoolean()) {
this.indices = in.readList(StreamInput::readString);
this.indices = Collections.unmodifiableList(in.readList(StreamInput::readString));
} else {
this.indices = null;
}
if (in.readBoolean()) {
this.types = in.readList(StreamInput::readString);
this.types = Collections.unmodifiableList(in.readList(StreamInput::readString));
} else {
this.types = null;
}
this.query = in.readNamedWriteable(QueryBuilder.class);
this.aggregations = in.readOptionalWriteable(AggregatorFactories.Builder::new);
if (in.readBoolean()) {
this.scriptFields = in.readList(SearchSourceBuilder.ScriptField::new);
this.scriptFields = Collections.unmodifiableList(in.readList(SearchSourceBuilder.ScriptField::new));
} else {
this.scriptFields = null;
}
Expand All @@ -195,7 +195,7 @@ public DatafeedConfig(StreamInput in) throws IOException {
}
this.chunkingConfig = in.readOptionalWriteable(ChunkingConfig::new);
if (in.getVersion().onOrAfter(Version.V_6_2_0)) {
this.headers = in.readMap(StreamInput::readString, StreamInput::readString);
this.headers = Collections.unmodifiableMap(in.readMap(StreamInput::readString, StreamInput::readString));
} else {
this.headers = Collections.emptyMap();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,18 @@ public String toString() {
return Strings.toString(this);
}

boolean isNoop(DatafeedConfig datafeed) {
return (frequency == null || Objects.equals(frequency, datafeed.getFrequency()))
&& (queryDelay == null || Objects.equals(queryDelay, datafeed.getQueryDelay()))
&& (indices == null || Objects.equals(indices, datafeed.getIndices()))
&& (types == null || Objects.equals(types, datafeed.getTypes()))
&& (query == null || Objects.equals(query, datafeed.getQuery()))
&& (scrollSize == null || Objects.equals(scrollSize, datafeed.getQueryDelay()))
&& (aggregations == null || Objects.equals(aggregations, datafeed.getAggregations()))
&& (scriptFields == null || Objects.equals(scriptFields, datafeed.getScriptFields()))
&& (chunkingConfig == null || Objects.equals(chunkingConfig, datafeed.getChunkingConfig()));
}

public static class Builder {

private String id;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,29 +144,29 @@ private AnalysisConfig(TimeValue bucketSpan, String categorizationFieldName, Lis
this.latency = latency;
this.categorizationFieldName = categorizationFieldName;
this.categorizationAnalyzerConfig = categorizationAnalyzerConfig;
this.categorizationFilters = categorizationFilters;
this.categorizationFilters = categorizationFilters == null ? null : Collections.unmodifiableList(categorizationFilters);
this.summaryCountFieldName = summaryCountFieldName;
this.influencers = influencers;
this.influencers = Collections.unmodifiableList(influencers);
this.overlappingBuckets = overlappingBuckets;
this.resultFinalizationWindow = resultFinalizationWindow;
this.multivariateByFields = multivariateByFields;
this.multipleBucketSpans = multipleBucketSpans;
this.multipleBucketSpans = multipleBucketSpans == null ? null : Collections.unmodifiableList(multipleBucketSpans);
this.usePerPartitionNormalization = usePerPartitionNormalization;
}

public AnalysisConfig(StreamInput in) throws IOException {
bucketSpan = in.readTimeValue();
categorizationFieldName = in.readOptionalString();
categorizationFilters = in.readBoolean() ? in.readList(StreamInput::readString) : null;
categorizationFilters = in.readBoolean() ? Collections.unmodifiableList(in.readList(StreamInput::readString)) : null;
if (in.getVersion().onOrAfter(Version.V_6_2_0)) {
categorizationAnalyzerConfig = in.readOptionalWriteable(CategorizationAnalyzerConfig::new);
} else {
categorizationAnalyzerConfig = null;
}
latency = in.readOptionalTimeValue();
summaryCountFieldName = in.readOptionalString();
detectors = in.readList(Detector::new);
influencers = in.readList(StreamInput::readString);
detectors = Collections.unmodifiableList(in.readList(Detector::new));
influencers = Collections.unmodifiableList(in.readList(StreamInput::readString));
overlappingBuckets = in.readOptionalBoolean();
resultFinalizationWindow = in.readOptionalLong();
multivariateByFields = in.readOptionalBoolean();
Expand All @@ -176,7 +176,7 @@ public AnalysisConfig(StreamInput in) throws IOException {
for (int i = 0; i < arraySize; i++) {
spans.add(in.readTimeValue());
}
multipleBucketSpans = spans;
multipleBucketSpans = Collections.unmodifiableList(spans);
} else {
multipleBucketSpans = null;
}
Expand Down Expand Up @@ -487,18 +487,20 @@ public Builder(List<Detector> detectors) {
}

public Builder(AnalysisConfig analysisConfig) {
this.detectors = analysisConfig.detectors;
this.detectors = new ArrayList<>(analysisConfig.detectors);
this.bucketSpan = analysisConfig.bucketSpan;
this.latency = analysisConfig.latency;
this.categorizationFieldName = analysisConfig.categorizationFieldName;
this.categorizationFilters = analysisConfig.categorizationFilters;
this.categorizationFilters = analysisConfig.categorizationFilters == null ? null
: new ArrayList<>(analysisConfig.categorizationFilters);
this.categorizationAnalyzerConfig = analysisConfig.categorizationAnalyzerConfig;
this.summaryCountFieldName = analysisConfig.summaryCountFieldName;
this.influencers = analysisConfig.influencers;
this.influencers = new ArrayList<>(analysisConfig.influencers);
this.overlappingBuckets = analysisConfig.overlappingBuckets;
this.resultFinalizationWindow = analysisConfig.resultFinalizationWindow;
this.multivariateByFields = analysisConfig.multivariateByFields;
this.multipleBucketSpans = analysisConfig.multipleBucketSpans;
this.multipleBucketSpans = analysisConfig.multipleBucketSpans == null ? null
: new ArrayList<>(analysisConfig.multipleBucketSpans);
this.usePerPartitionNormalization = analysisConfig.usePerPartitionNormalization;
}

Expand All @@ -518,6 +520,10 @@ public void setDetectors(List<Detector> detectors) {
this.detectors = sequentialIndexDetectors;
}

public void setDetector(int detectorIndex, Detector detector) {
detectors.set(detectorIndex, detector);
}

public void setBucketSpan(TimeValue bucketSpan) {
this.bucketSpan = bucketSpan;
}
Expand All @@ -543,7 +549,7 @@ public void setSummaryCountFieldName(String summaryCountFieldName) {
}

public void setInfluencers(List<String> influencers) {
this.influencers = influencers;
this.influencers = ExceptionsHelper.requireNonNull(influencers, INFLUENCERS.getPreferredName());
}

public void setOverlappingBuckets(Boolean overlappingBuckets) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ public Detector(StreamInput in) throws IOException {
partitionFieldName = in.readOptionalString();
useNull = in.readBoolean();
excludeFrequent = in.readBoolean() ? ExcludeFrequent.readFromStream(in) : null;
rules = in.readList(DetectionRule::new);
rules = Collections.unmodifiableList(in.readList(DetectionRule::new));
if (in.getVersion().onOrAfter(Version.V_5_5_0)) {
detectorIndex = in.readInt();
} else {
Expand Down Expand Up @@ -508,7 +508,7 @@ public Builder(Detector detector) {
partitionFieldName = detector.partitionFieldName;
useNull = detector.useNull;
excludeFrequent = detector.excludeFrequent;
rules = new ArrayList<>(detector.getRules());
rules = new ArrayList<>(detector.rules);
detectorIndex = detector.detectorIndex;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ private Job(String jobId, String jobType, Version jobVersion, List<String> group
this.jobId = jobId;
this.jobType = jobType;
this.jobVersion = jobVersion;
this.groups = groups;
this.groups = Collections.unmodifiableList(groups);
this.description = description;
this.createTime = createTime;
this.finishedTime = finishedTime;
Expand All @@ -207,7 +207,7 @@ private Job(String jobId, String jobType, Version jobVersion, List<String> group
this.backgroundPersistInterval = backgroundPersistInterval;
this.modelSnapshotRetentionDays = modelSnapshotRetentionDays;
this.resultsRetentionDays = resultsRetentionDays;
this.customSettings = customSettings;
this.customSettings = customSettings == null ? null : Collections.unmodifiableMap(customSettings);
this.modelSnapshotId = modelSnapshotId;
this.modelSnapshotMinVersion = modelSnapshotMinVersion;
this.resultsIndexName = resultsIndexName;
Expand All @@ -223,7 +223,7 @@ public Job(StreamInput in) throws IOException {
jobVersion = null;
}
if (in.getVersion().onOrAfter(Version.V_6_1_0)) {
groups = in.readList(StreamInput::readString);
groups = Collections.unmodifiableList(in.readList(StreamInput::readString));
} else {
groups = Collections.emptyList();
}
Expand All @@ -244,7 +244,8 @@ public Job(StreamInput in) throws IOException {
backgroundPersistInterval = in.readOptionalTimeValue();
modelSnapshotRetentionDays = in.readOptionalLong();
resultsRetentionDays = in.readOptionalLong();
customSettings = in.readMap();
Map<String, Object> readCustomSettings = in.readMap();
customSettings = readCustomSettings == null ? null : Collections.unmodifiableMap(readCustomSettings);
modelSnapshotId = in.readOptionalString();
if (in.getVersion().onOrAfter(Version.V_7_0_0_alpha1) && in.readBoolean()) {
modelSnapshotMinVersion = Version.readVersion(in);
Expand Down Expand Up @@ -627,7 +628,8 @@ public boolean equals(Object other) {
&& Objects.equals(this.lastDataTime, that.lastDataTime)
&& Objects.equals(this.establishedModelMemory, that.establishedModelMemory)
&& Objects.equals(this.analysisConfig, that.analysisConfig)
&& Objects.equals(this.analysisLimits, that.analysisLimits) && Objects.equals(this.dataDescription, that.dataDescription)
&& Objects.equals(this.analysisLimits, that.analysisLimits)
&& Objects.equals(this.dataDescription, that.dataDescription)
&& Objects.equals(this.modelPlotConfig, that.modelPlotConfig)
&& Objects.equals(this.renormalizationWindowDays, that.renormalizationWindowDays)
&& Objects.equals(this.backgroundPersistInterval, that.backgroundPersistInterval)
Expand Down Expand Up @@ -1055,6 +1057,7 @@ public boolean equals(Object o) {
return Objects.equals(this.id, that.id)
&& Objects.equals(this.jobType, that.jobType)
&& Objects.equals(this.jobVersion, that.jobVersion)
&& Objects.equals(this.groups, that.groups)
&& Objects.equals(this.description, that.description)
&& Objects.equals(this.analysisConfig, that.analysisConfig)
&& Objects.equals(this.analysisLimits, that.analysisLimits)
Expand All @@ -1077,7 +1080,7 @@ public boolean equals(Object o) {

@Override
public int hashCode() {
return Objects.hash(id, jobType, jobVersion, description, analysisConfig, analysisLimits, dataDescription, createTime,
return Objects.hash(id, jobType, jobVersion, groups, description, analysisConfig, analysisLimits, dataDescription, createTime,
finishedTime, lastDataTime, establishedModelMemory, modelPlotConfig, renormalizationWindowDays,
backgroundPersistInterval, modelSnapshotRetentionDays, resultsRetentionDays, customSettings, modelSnapshotId,
modelSnapshotMinVersion, resultsIndexName, deleted);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -373,33 +373,33 @@ public Set<String> getUpdateFields() {
*/
public Job mergeWithJob(Job source, ByteSizeValue maxModelMemoryLimit) {
Job.Builder builder = new Job.Builder(source);
AnalysisConfig currentAnalysisConfig = source.getAnalysisConfig();
AnalysisConfig.Builder newAnalysisConfig = new AnalysisConfig.Builder(currentAnalysisConfig);

if (groups != null) {
builder.setGroups(groups);
}
if (description != null) {
builder.setDescription(description);
}
if (detectorUpdates != null && detectorUpdates.isEmpty() == false) {
AnalysisConfig ac = source.getAnalysisConfig();
int numDetectors = ac.getDetectors().size();
int numDetectors = currentAnalysisConfig.getDetectors().size();
for (DetectorUpdate dd : detectorUpdates) {
if (dd.getDetectorIndex() >= numDetectors) {
throw ExceptionsHelper.badRequestException("Supplied detector_index [{}] is >= the number of detectors [{}]",
dd.getDetectorIndex(), numDetectors);
}

Detector.Builder detectorbuilder = new Detector.Builder(ac.getDetectors().get(dd.getDetectorIndex()));
Detector.Builder detectorBuilder = new Detector.Builder(currentAnalysisConfig.getDetectors().get(dd.getDetectorIndex()));
if (dd.getDescription() != null) {
detectorbuilder.setDetectorDescription(dd.getDescription());
detectorBuilder.setDetectorDescription(dd.getDescription());
}
if (dd.getRules() != null) {
detectorbuilder.setRules(dd.getRules());
detectorBuilder.setRules(dd.getRules());
}
ac.getDetectors().set(dd.getDetectorIndex(), detectorbuilder.build());
}

AnalysisConfig.Builder acBuilder = new AnalysisConfig.Builder(ac);
builder.setAnalysisConfig(acBuilder);
newAnalysisConfig.setDetector(dd.getDetectorIndex(), detectorBuilder.build());
}
}
if (modelPlotConfig != null) {
builder.setModelPlotConfig(modelPlotConfig);
Expand All @@ -422,9 +422,7 @@ public Job mergeWithJob(Job source, ByteSizeValue maxModelMemoryLimit) {
builder.setResultsRetentionDays(resultsRetentionDays);
}
if (categorizationFilters != null) {
AnalysisConfig.Builder analysisConfigBuilder = new AnalysisConfig.Builder(source.getAnalysisConfig());
analysisConfigBuilder.setCategorizationFilters(categorizationFilters);
builder.setAnalysisConfig(analysisConfigBuilder);
newAnalysisConfig.setCategorizationFilters(categorizationFilters);
}
if (customSettings != null) {
builder.setCustomSettings(customSettings);
Expand All @@ -446,9 +444,48 @@ public Job mergeWithJob(Job source, ByteSizeValue maxModelMemoryLimit) {
if (jobVersion != null) {
builder.setJobVersion(jobVersion);
}

builder.setAnalysisConfig(newAnalysisConfig);
return builder.build();
}

boolean isNoop(Job job) {
return (groups == null || Objects.equals(groups, job.getGroups()))
&& (description == null || Objects.equals(description, job.getDescription()))
&& (modelPlotConfig == null || Objects.equals(modelPlotConfig, job.getModelPlotConfig()))
&& (analysisLimits == null || Objects.equals(analysisLimits, job.getAnalysisLimits()))
&& updatesDetectors(job) == false
&& (renormalizationWindowDays == null || Objects.equals(renormalizationWindowDays, job.getRenormalizationWindowDays()))
&& (backgroundPersistInterval == null || Objects.equals(backgroundPersistInterval, job.getBackgroundPersistInterval()))
&& (modelSnapshotRetentionDays == null || Objects.equals(modelSnapshotRetentionDays, job.getModelSnapshotRetentionDays()))
&& (resultsRetentionDays == null || Objects.equals(resultsRetentionDays, job.getResultsRetentionDays()))
&& (categorizationFilters == null
|| Objects.equals(categorizationFilters, job.getAnalysisConfig().getCategorizationFilters()))
&& (customSettings == null || Objects.equals(customSettings, job.getCustomSettings()))
&& (modelSnapshotId == null || Objects.equals(modelSnapshotId, job.getModelSnapshotId()))
&& (modelSnapshotMinVersion == null || Objects.equals(modelSnapshotMinVersion, job.getModelSnapshotMinVersion()))
&& (establishedModelMemory == null || Objects.equals(establishedModelMemory, job.getEstablishedModelMemory()))
&& (jobVersion == null || Objects.equals(jobVersion, job.getJobVersion()));
}

boolean updatesDetectors(Job job) {
AnalysisConfig analysisConfig = job.getAnalysisConfig();
if (detectorUpdates == null) {
return false;
}
for (DetectorUpdate detectorUpdate : detectorUpdates) {
if (detectorUpdate.description == null && detectorUpdate.rules == null) {
continue;
}
Detector detector = analysisConfig.getDetectors().get(detectorUpdate.detectorIndex);
if (Objects.equals(detectorUpdate.description, detector.getDetectorDescription()) == false
|| Objects.equals(detectorUpdate.rules, detector.getRules()) == false) {
return true;
}
}
return false;
}

@Override
public boolean equals(Object other) {
if (this == other) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ public RuleScope() {
}

public RuleScope(Map<String, FilterRef> scope) {
this.scope = Objects.requireNonNull(scope);
this.scope = Collections.unmodifiableMap(scope);
}

public RuleScope(StreamInput in) throws IOException {
Expand Down
Loading

0 comments on commit 2cfe703

Please sign in to comment.