Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import org.apache.commons.lang.StringUtils;
import org.apache.pinot.common.proto.Worker;
import org.apache.pinot.query.planner.physical.DispatchablePlanFragment;
import org.apache.pinot.query.planner.physical.DispatchableSubPlan;
import org.apache.pinot.query.planner.plannode.AbstractPlanNode;
import org.apache.pinot.query.planner.plannode.StageNodeSerDeUtils;
import org.apache.pinot.query.routing.MailboxMetadata;
Expand All @@ -42,8 +41,8 @@
* This utility class serialize/deserialize between {@link Worker.StagePlan} elements to Planner elements.
*/
public class QueryPlanSerDeUtils {
private static final Pattern VIRTUAL_SERVER_PATTERN = Pattern.compile(
"(?<virtualid>[0-9]+)@(?<host>[^:]+):(?<port>[0-9]+)");
private static final Pattern VIRTUAL_SERVER_PATTERN =
Pattern.compile("(?<virtualid>[0-9]+)@(?<host>[^:]+):(?<port>[0-9]+)");

private QueryPlanSerDeUtils() {
// do not instantiate.
Expand All @@ -57,18 +56,6 @@ public static List<DistributedStagePlan> deserializeStagePlan(Worker.QueryReques
return distributedStagePlans;
}

public static Worker.StagePlan serialize(DispatchableSubPlan dispatchableSubPlan, int stageId,
QueryServerInstance queryServerInstance, List<Integer> workerIds) {
return Worker.StagePlan.newBuilder()
.setStageId(stageId)
.setStageRoot(StageNodeSerDeUtils.serializeStageNode(
(AbstractPlanNode) dispatchableSubPlan.getQueryStageList().get(stageId).getPlanFragment()
.getFragmentRoot()))
.setStageMetadata(
toProtoStageMetadata(dispatchableSubPlan.getQueryStageList().get(stageId), queryServerInstance, workerIds))
.build();
}

public static VirtualServerAddress protoToAddress(String virtualAddressStr) {
Matcher matcher = VIRTUAL_SERVER_PATTERN.matcher(virtualAddressStr);
if (!matcher.matches()) {
Expand All @@ -78,8 +65,8 @@ public static VirtualServerAddress protoToAddress(String virtualAddressStr) {
}

// Skipped netty and grpc port as they are not used in worker instance.
return new VirtualServerAddress(matcher.group("host"),
Integer.parseInt(matcher.group("port")), Integer.parseInt(matcher.group("virtualid")));
return new VirtualServerAddress(matcher.group("host"), Integer.parseInt(matcher.group("port")),
Integer.parseInt(matcher.group("virtualid")));
}

public static String addressToProto(VirtualServerAddress serverAddress) {
Expand Down Expand Up @@ -145,17 +132,21 @@ private static MailboxMetadata fromProtoMailbox(Worker.MailboxMetadata protoMail
return mailboxMetadata;
}

private static Worker.StageMetadata toProtoStageMetadata(DispatchablePlanFragment planFragment,
QueryServerInstance queryServerInstance, List<Integer> workerIds) {
Worker.StageMetadata.Builder builder = Worker.StageMetadata.newBuilder();
for (WorkerMetadata workerMetadata : planFragment.getWorkerMetadataList()) {
builder.addWorkerMetadata(toProtoWorkerMetadata(workerMetadata));
public static Worker.StageMetadata toProtoStageMetadata(List<Worker.WorkerMetadata> workerMetadataList,
Map<String, String> customProperties, QueryServerInstance serverInstance, List<Integer> workerIds) {
return Worker.StageMetadata.newBuilder().addAllWorkerMetadata(workerMetadataList)
.putAllCustomProperty(customProperties)
.setServerAddress(String.format("%s:%d", serverInstance.getHostname(), serverInstance.getQueryMailboxPort()))
.addAllWorkerIds(workerIds).build();
}

public static List<Worker.WorkerMetadata> toProtoWorkerMetadataList(DispatchablePlanFragment planFragment) {
List<WorkerMetadata> workerMetadataList = planFragment.getWorkerMetadataList();
List<Worker.WorkerMetadata> protoWorkerMetadataList = new ArrayList<>(workerMetadataList.size());
for (WorkerMetadata workerMetadata : workerMetadataList) {
protoWorkerMetadataList.add(toProtoWorkerMetadata(workerMetadata));
}
builder.putAllCustomProperty(planFragment.getCustomProperties());
builder.setServerAddress(String.format("%s:%d", queryServerInstance.getHostname(),
queryServerInstance.getQueryMailboxPort()));
builder.addAllWorkerIds(workerIds);
return builder.build();
return protoWorkerMetadataList;
}

private static Worker.WorkerMetadata toProtoWorkerMetadata(WorkerMetadata workerMetadata) {
Expand All @@ -166,8 +157,7 @@ private static Worker.WorkerMetadata toProtoWorkerMetadata(WorkerMetadata worker
return builder.build();
}

private static Map<Integer, Worker.MailboxMetadata> toProtoMailboxMap(
Map<Integer, MailboxMetadata> mailBoxInfosMap) {
private static Map<Integer, Worker.MailboxMetadata> toProtoMailboxMap(Map<Integer, MailboxMetadata> mailBoxInfosMap) {
Map<Integer, Worker.MailboxMetadata> mailboxMetadataMap = new HashMap<>();
for (Map.Entry<Integer, MailboxMetadata> entry : mailBoxInfosMap.entrySet()) {
mailboxMetadataMap.put(entry.getKey(), toProtoMailbox(entry.getValue()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,27 +30,22 @@
* {@link #getThrowable()} to check if it is null.
*/
class AsyncQueryDispatchResponse {
private final QueryServerInstance _virtualServer;
private final int _stageId;
private final QueryServerInstance _serverInstance;
private final Worker.QueryResponse _queryResponse;
private final Throwable _throwable;

public AsyncQueryDispatchResponse(QueryServerInstance virtualServer, int stageId, Worker.QueryResponse queryResponse,
public AsyncQueryDispatchResponse(QueryServerInstance serverInstance, @Nullable Worker.QueryResponse queryResponse,
@Nullable Throwable throwable) {
_virtualServer = virtualServer;
_stageId = stageId;
_serverInstance = serverInstance;
_queryResponse = queryResponse;
_throwable = throwable;
}

public QueryServerInstance getVirtualServer() {
return _virtualServer;
}

public int getStageId() {
return _stageId;
public QueryServerInstance getServerInstance() {
return _serverInstance;
}

@Nullable
public Worker.QueryResponse getQueryResponse() {
return _queryResponse;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@
import org.apache.pinot.common.proto.PinotQueryWorkerGrpc;
import org.apache.pinot.common.proto.Worker;
import org.apache.pinot.query.routing.QueryServerInstance;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;


/**
Expand All @@ -37,8 +35,8 @@
* let that take care of pooling. (2) Create a DispatchClient interface and implement pooled/non-pooled versions.
*/
class DispatchClient {
private static final Logger LOGGER = LoggerFactory.getLogger(DispatchClient.class);
private static final StreamObserver<Worker.CancelResponse> NO_OP_CANCEL_STREAM_OBSERVER = new CancelObserver();

private final ManagedChannel _channel;
private final PinotQueryWorkerGrpc.PinotQueryWorkerStub _dispatchStub;

Expand All @@ -51,23 +49,13 @@ public ManagedChannel getChannel() {
return _channel;
}

public void submit(Worker.QueryRequest request, int stageId, QueryServerInstance virtualServer, Deadline deadline,
public void submit(Worker.QueryRequest request, QueryServerInstance virtualServer, Deadline deadline,
Consumer<AsyncQueryDispatchResponse> callback) {
try {
_dispatchStub.withDeadline(deadline).submit(request, new DispatchObserver(stageId, virtualServer, callback));
} catch (Exception e) {
LOGGER.error("Query Dispatch failed at client-side", e);
callback.accept(new AsyncQueryDispatchResponse(
virtualServer, stageId, Worker.QueryResponse.getDefaultInstance(), e));
}
_dispatchStub.withDeadline(deadline).submit(request, new DispatchObserver(virtualServer, callback));
}

public void cancel(long requestId) {
try {
Worker.CancelRequest cancelRequest = Worker.CancelRequest.newBuilder().setRequestId(requestId).build();
_dispatchStub.cancel(cancelRequest, NO_OP_CANCEL_STREAM_OBSERVER);
} catch (Exception e) {
LOGGER.error("Query Cancellation failed at client-side", e);
}
Worker.CancelRequest cancelRequest = Worker.CancelRequest.newBuilder().setRequestId(requestId).build();
_dispatchStub.cancel(cancelRequest, NO_OP_CANCEL_STREAM_OBSERVER);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,13 @@
* A {@link StreamObserver} used by {@link DispatchClient} to subscribe to the response of a async Query Dispatch call.
*/
class DispatchObserver implements StreamObserver<Worker.QueryResponse> {
private int _stageId;
private QueryServerInstance _virtualServer;
private Consumer<AsyncQueryDispatchResponse> _callback;
private final QueryServerInstance _serverInstance;
private final Consumer<AsyncQueryDispatchResponse> _callback;

private Worker.QueryResponse _queryResponse;

public DispatchObserver(int stageId, QueryServerInstance virtualServer,
Consumer<AsyncQueryDispatchResponse> callback) {
_stageId = stageId;
_virtualServer = virtualServer;
public DispatchObserver(QueryServerInstance serverInstance, Consumer<AsyncQueryDispatchResponse> callback) {
_serverInstance = serverInstance;
_callback = callback;
}

Expand All @@ -48,12 +46,11 @@ public void onNext(Worker.QueryResponse queryResponse) {
@Override
public void onError(Throwable throwable) {
_callback.accept(
new AsyncQueryDispatchResponse(_virtualServer, _stageId, Worker.QueryResponse.getDefaultInstance(),
throwable));
new AsyncQueryDispatchResponse(_serverInstance, Worker.QueryResponse.getDefaultInstance(), throwable));
}

@Override
public void onCompleted() {
_callback.accept(new AsyncQueryDispatchResponse(_virtualServer, _stageId, _queryResponse, null));
_callback.accept(new AsyncQueryDispatchResponse(_serverInstance, _queryResponse, null));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,17 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import javax.annotation.Nullable;
import org.apache.calcite.util.Pair;
import org.apache.pinot.common.datablock.DataBlock;
import org.apache.pinot.common.proto.Plan;
import org.apache.pinot.common.proto.Worker;
import org.apache.pinot.common.response.broker.ResultTable;
import org.apache.pinot.common.utils.DataSchema;
Expand All @@ -48,8 +49,10 @@
import org.apache.pinot.query.planner.PlanFragment;
import org.apache.pinot.query.planner.physical.DispatchablePlanFragment;
import org.apache.pinot.query.planner.physical.DispatchableSubPlan;
import org.apache.pinot.query.planner.plannode.AbstractPlanNode;
import org.apache.pinot.query.planner.plannode.MailboxReceiveNode;
import org.apache.pinot.query.planner.plannode.PlanNode;
import org.apache.pinot.query.planner.plannode.StageNodeSerDeUtils;
import org.apache.pinot.query.routing.QueryServerInstance;
import org.apache.pinot.query.routing.WorkerMetadata;
import org.apache.pinot.query.runtime.blocks.TransferableBlock;
Expand Down Expand Up @@ -107,50 +110,76 @@ public ResultTable submitAndReduce(RequestContext context, DispatchableSubPlan d
void submit(long requestId, DispatchableSubPlan dispatchableSubPlan, long timeoutMs, Map<String, String> queryOptions)
throws Exception {
Deadline deadline = Deadline.after(timeoutMs, TimeUnit.MILLISECONDS);
BlockingQueue<AsyncQueryDispatchResponse> dispatchCallbacks = new LinkedBlockingQueue<>();
List<DispatchablePlanFragment> stagePlans = dispatchableSubPlan.getQueryStageList();
int numStages = stagePlans.size();
int numDispatchCalls = 0;
// Do not submit the reduce stage (stage 0)
Set<QueryServerInstance> serverInstances = new HashSet<>();
// TODO: If serialization is slow, consider serializing each stage in parallel
StageInfo[] stageInfoMap = new StageInfo[numStages];
// Ignore the reduce stage (stage 0)
for (int stageId = 1; stageId < numStages; stageId++) {
for (Map.Entry<QueryServerInstance, List<Integer>> entry : stagePlans.get(stageId)
.getServerInstanceToWorkerIdMap().entrySet()) {
QueryServerInstance queryServerInstance = entry.getKey();
Worker.QueryRequest.Builder queryRequestBuilder = Worker.QueryRequest.newBuilder();
queryRequestBuilder.addStagePlan(
QueryPlanSerDeUtils.serialize(dispatchableSubPlan, stageId, queryServerInstance, entry.getValue()));
Worker.QueryRequest queryRequest =
queryRequestBuilder.putMetadata(CommonConstants.Query.Request.MetadataKeys.REQUEST_ID,
String.valueOf(requestId))
.putMetadata(CommonConstants.Broker.Request.QueryOptionKey.TIMEOUT_MS, String.valueOf(timeoutMs))
.putAllMetadata(queryOptions).build();
DispatchClient client = getOrCreateDispatchClient(queryServerInstance);
int finalStageId = stageId;
_executorService.submit(
() -> client.submit(queryRequest, finalStageId, queryServerInstance, deadline, dispatchCallbacks::offer));
numDispatchCalls++;
}
DispatchablePlanFragment stagePlan = stagePlans.get(stageId);
serverInstances.addAll(stagePlan.getServerInstanceToWorkerIdMap().keySet());
Plan.StageNode rootNode =
StageNodeSerDeUtils.serializeStageNode((AbstractPlanNode) stagePlan.getPlanFragment().getFragmentRoot());
List<Worker.WorkerMetadata> workerMetadataList = QueryPlanSerDeUtils.toProtoWorkerMetadataList(stagePlan);
stageInfoMap[stageId] = new StageInfo(rootNode, workerMetadataList, stagePlan.getCustomProperties());
}
Map<String, String> requestMetadata = new HashMap<>();
requestMetadata.put(CommonConstants.Query.Request.MetadataKeys.REQUEST_ID, Long.toString(requestId));
requestMetadata.put(CommonConstants.Broker.Request.QueryOptionKey.TIMEOUT_MS, Long.toString(timeoutMs));
requestMetadata.putAll(queryOptions);

// Submit the query plan to all servers in parallel
int numServers = serverInstances.size();
BlockingQueue<AsyncQueryDispatchResponse> dispatchCallbacks = new ArrayBlockingQueue<>(numServers);
for (QueryServerInstance serverInstance : serverInstances) {
_executorService.submit(() -> {
Comment on lines +135 to +136
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note that QueryServer could also apply similar technique

try {
distributedStagePlans = QueryPlanSerDeUtils.deserializeStagePlan(request);
} catch (Exception e) {
LOGGER.error("Caught exception while deserializing the request: {}", requestId, e);
responseObserver.onError(Status.INVALID_ARGUMENT.withDescription("Bad request").withCause(e).asException());
return;
}
// 2. Submit distributed stage plans, await response successful or any failure which cancels all other tasks.
int numSubmission = distributedStagePlans.size();
CompletableFuture<?>[] submissionStubs = new CompletableFuture[numSubmission];
for (int i = 0; i < numSubmission; i++) {
DistributedStagePlan distributedStagePlan = distributedStagePlans.get(i);
submissionStubs[i] =
CompletableFuture.runAsync(() -> _queryRunner.processQuery(distributedStagePlan, requestMetadata),
_querySubmissionExecutorService);
}

deserialization (line 109) can be move into the runAsync (line 121)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! Will do it as a separate PR

try {
Worker.QueryRequest.Builder requestBuilder = Worker.QueryRequest.newBuilder();
for (int stageId = 1; stageId < numStages; stageId++) {
List<Integer> workerIds = stagePlans.get(stageId).getServerInstanceToWorkerIdMap().get(serverInstance);
if (workerIds != null) {
StageInfo stageInfo = stageInfoMap[stageId];
Worker.StageMetadata stageMetadata =
QueryPlanSerDeUtils.toProtoStageMetadata(stageInfo._workerMetadataList, stageInfo._customProperties,
serverInstance, workerIds);
Worker.StagePlan stagePlan =
Worker.StagePlan.newBuilder().setStageId(stageId).setStageRoot(stageInfo._rootNode)
.setStageMetadata(stageMetadata).build();
requestBuilder.addStagePlan(stagePlan);
}
}
requestBuilder.putAllMetadata(requestMetadata);
getOrCreateDispatchClient(serverInstance).submit(requestBuilder.build(), serverInstance, deadline,
dispatchCallbacks::offer);
} catch (Throwable t) {
LOGGER.warn("Caught exception while dispatching query: {} to server: {}", requestId, serverInstance, t);
dispatchCallbacks.offer(new AsyncQueryDispatchResponse(serverInstance, null, t));
}
});
}
int successfulDispatchCalls = 0;

int numSuccessCalls = 0;
// TODO: Cancel all dispatched requests if one of the dispatch errors out or deadline is breached.
while (!deadline.isExpired() && successfulDispatchCalls < numDispatchCalls) {
while (!deadline.isExpired() && numSuccessCalls < numServers) {
AsyncQueryDispatchResponse resp =
dispatchCallbacks.poll(deadline.timeRemaining(TimeUnit.MILLISECONDS), TimeUnit.MILLISECONDS);
if (resp != null) {
if (resp.getThrowable() != null) {
throw new RuntimeException(
String.format("Error dispatching query to server=%s stage=%s", resp.getVirtualServer(),
resp.getStageId()), resp.getThrowable());
String.format("Error dispatching query: %d to server: %s", requestId, resp.getServerInstance()),
resp.getThrowable());
} else {
Worker.QueryResponse response = resp.getQueryResponse();
assert response != null;
if (response.containsMetadata(CommonConstants.Query.Response.ServerResponseStatus.STATUS_ERROR)) {
throw new RuntimeException(
String.format("Unable to execute query plan at stage %s on server %s: ERROR: %s", resp.getStageId(),
resp.getVirtualServer(),
String.format("Unable to execute query plan for request: %d on server: %s, ERROR: %s", requestId,
resp.getServerInstance(),
response.getMetadataOrDefault(CommonConstants.Query.Response.ServerResponseStatus.STATUS_ERROR,
"null")));
}
successfulDispatchCalls++;
numSuccessCalls++;
}
}
}
Expand All @@ -159,6 +188,19 @@ void submit(long requestId, DispatchableSubPlan dispatchableSubPlan, long timeou
}
}

private static class StageInfo {
final Plan.StageNode _rootNode;
final List<Worker.WorkerMetadata> _workerMetadataList;
final Map<String, String> _customProperties;

StageInfo(Plan.StageNode rootNode, List<Worker.WorkerMetadata> workerMetadataList,
Map<String, String> customProperties) {
_rootNode = rootNode;
_workerMetadataList = workerMetadataList;
_customProperties = customProperties;
}
}

private void cancel(long requestId, DispatchableSubPlan dispatchableSubPlan) {
List<DispatchablePlanFragment> stagePlans = dispatchableSubPlan.getQueryStageList();
int numStages = stagePlans.size();
Expand All @@ -168,7 +210,11 @@ private void cancel(long requestId, DispatchableSubPlan dispatchableSubPlan) {
serversToCancel.addAll(stagePlans.get(stageId).getServerInstanceToWorkerIdMap().keySet());
}
for (QueryServerInstance queryServerInstance : serversToCancel) {
getOrCreateDispatchClient(queryServerInstance).cancel(requestId);
try {
getOrCreateDispatchClient(queryServerInstance).cancel(requestId);
} catch (Throwable t) {
LOGGER.warn("Caught exception while cancelling query: {} on server: {}", requestId, queryServerInstance, t);
}
}
}

Expand Down
Loading