Skip to content

Commit

Permalink
Enhance profile API to add model centric result controled by view par…
Browse files Browse the repository at this point in the history
…ameter

Signed-off-by: Zan Niu <zaniu@amazon.com>
  • Loading branch information
zane-neo committed Jan 28, 2023
1 parent 856498e commit f3a2888
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,14 @@

package org.opensearch.ml.action.profile;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;

import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.common.io.stream.Writeable;
Expand All @@ -18,10 +23,6 @@
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.profile.MLModelProfile;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

@Getter
@NoArgsConstructor
public class MLProfileModelResponse implements ToXContentFragment, Writeable {
Expand Down Expand Up @@ -55,27 +56,27 @@ public MLProfileModelResponse(StreamInput in) throws IOException {
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (targetWorkerNodes != null) {
builder.field("target_worker_nodes", targetWorkerNodes);
builder.field("target_worker_nodes", targetWorkerNodes);
}
if (workerNodes != null) {
builder.field("worker_nodes", workerNodes);
}
if (mlModelProfileMap.size() > 0) {
builder.startObject("nodes");
for (Map.Entry<String, MLModelProfile> entry : mlModelProfileMap.entrySet()) {
builder.field(entry.getKey(), entry.getValue());
}
builder.endObject();
builder.startObject("nodes");
for (Map.Entry<String, MLModelProfile> entry : mlModelProfileMap.entrySet()) {
builder.field(entry.getKey(), entry.getValue());
}
builder.endObject();
}
if (mlTaskMap.size() > 0) {
builder.startObject("tasks");
for (Map.Entry<String, MLTask> entry : mlTaskMap.entrySet()) {
builder.field(entry.getKey(), entry.getValue());
}
builder.endObject();
builder.startObject("tasks");
for (Map.Entry<String, MLTask> entry : mlTaskMap.entrySet()) {
builder.field(entry.getKey(), entry.getValue());
}
builder.endObject();
}
builder.endObject();
return builder;
builder.endObject();
return builder;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import java.util.Optional;
import java.util.stream.Collectors;

import com.google.common.collect.ImmutableMap;
import lombok.extern.log4j.Log4j2;

import org.opensearch.action.ActionListener;
Expand All @@ -44,6 +43,7 @@
import org.opensearch.rest.RestStatus;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;

@Log4j2
public class RestMLProfileAction extends BaseRestHandler {
Expand Down Expand Up @@ -144,18 +144,25 @@ private Map<String, MLProfileModelResponse> buildModelCentricResult(List<MLProfi
for (Map.Entry<String, MLModelProfile> entry : modelProfileMap.entrySet()) {
MLProfileModelResponse mlProfileModelResponse = modelCentricMap.get(entry.getKey());
if (mlProfileModelResponse == null) {
mlProfileModelResponse = new MLProfileModelResponse(entry.getValue().getTargetWorkerNodes(),
entry.getValue().getWorkerNodes());
mlProfileModelResponse = new MLProfileModelResponse(
entry.getValue().getTargetWorkerNodes(),
entry.getValue().getWorkerNodes()
);
modelCentricMap.put(entry.getKey(), mlProfileModelResponse);
}
if (mlProfileModelResponse.getTargetWorkerNodes() == null || mlProfileModelResponse.getWorkerNodes() == null) {
mlProfileModelResponse.setTargetWorkerNodes(entry.getValue().getTargetWorkerNodes());
mlProfileModelResponse.setWorkerNodes(entry.getValue().getWorkerNodes());
}
// Create a new object and remove targetWorkerNodes and workerNodes.
MLModelProfile modelProfile = new MLModelProfile(entry.getValue().getModelState(),
entry.getValue().getPredictor(), null, null, entry.getValue().getModelInferenceStats(),
entry.getValue().getPredictRequestStats());
MLModelProfile modelProfile = new MLModelProfile(
entry.getValue().getModelState(),
entry.getValue().getPredictor(),
null,
null,
entry.getValue().getModelInferenceStats(),
entry.getValue().getPredictRequestStats()
);
mlProfileModelResponse.getMlModelProfileMap().putAll(ImmutableMap.of(nodeId, modelProfile));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@

package org.opensearch.ml.action.profile;

import java.io.IOException;
import java.time.Instant;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;

import org.junit.Before;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.xcontent.ToXContent;
Expand All @@ -24,12 +30,6 @@
import org.opensearch.ml.utils.TestHelper;
import org.opensearch.test.OpenSearchTestCase;

import java.io.IOException;
import java.time.Instant;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;

public class MLProfileModelResponseTests extends OpenSearchTestCase {

MLTask mlTask;
Expand Down Expand Up @@ -64,8 +64,8 @@ public void setup() {
}

public void test_create_MLProfileModelResponse_withArgs() throws IOException {
String[] targetWorkerNodes = new String[]{"node1", "node2"};
String[] workerNodes = new String[]{"node1"};
String[] targetWorkerNodes = new String[] { "node1", "node2" };
String[] workerNodes = new String[] { "node1" };
Map<String, MLModelProfile> profileMap = new HashMap<>();
Map<String, MLTask> taskMap = new HashMap<>();
profileMap.put("node1", mlModelProfile);
Expand Down Expand Up @@ -93,8 +93,8 @@ public void test_create_MLProfileModelResponse_NoArgs() throws IOException {
}

public void test_toXContent() throws IOException {
String[] targetWorkerNodes = new String[]{"node1", "node2"};
String[] workerNodes = new String[]{"node1"};
String[] targetWorkerNodes = new String[] { "node1", "node2" };
String[] workerNodes = new String[] { "node1" };
Map<String, MLModelProfile> profileMap = new HashMap<>();
Map<String, MLTask> taskMap = new HashMap<>();
profileMap.put("node1", mlModelProfile);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;

import com.google.common.collect.ImmutableMap;
import org.junit.Before;
import org.junit.Rule;
import org.junit.rules.ExpectedException;
Expand All @@ -41,7 +40,6 @@
import org.opensearch.cluster.node.DiscoveryNodeRole;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.Strings;
import org.opensearch.common.collect.MapBuilder;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.transport.TransportAddress;
import org.opensearch.common.xcontent.NamedXContentRegistry;
Expand Down Expand Up @@ -69,6 +67,7 @@
import org.opensearch.threadpool.ThreadPool;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;

public class RestMLProfileActionTests extends OpenSearchTestCase {
@Rule
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ public static RestRequest getStatsRestRequest(MLStatsInput input) throws IOExcep
}

public static RestRequest getProfileRestRequest(MLProfileInput input) throws IOException {
return new FakeRestRequest.Builder(getXContentRegistry())
return new FakeRestRequest.Builder(getXContentRegistry())
.withContent(new BytesArray(buildRequestContent(input)), XContentType.JSON)
.build();
}
Expand Down

0 comments on commit f3a2888

Please sign in to comment.