Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import io.temporal.internal.worker.WorkflowExecutorCache;
import io.temporal.workflow.CancellationScope;
import java.util.Optional;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;

Expand Down Expand Up @@ -90,4 +91,17 @@ static DeterministicRunner newRunner(
/** Creates a new instance of a workflow callback thread. */
@Nonnull
WorkflowThread newCallbackThread(Runnable runnable, @Nullable String name);

/**
* Retrieve data from runner locals. Returns 1. not found (an empty Optional) 2. found but null
* (an Optional of an empty Optional) 3. found and non-null (an Optional of an Optional of a
* value). The type nesting is because Java Optionals cannot understand "Some null" vs "None",
* which is exactly what we need here.
*
* @param key
* @return one of three cases
* @param <T>
*/
@SuppressWarnings("unchecked")
<T> Optional<Optional<T>> getRunnerLocal(RunnerLocalInternal<T> key);
}
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ private boolean areThreadsToBeExecuted() {
* @param <T>
*/
@SuppressWarnings("unchecked")
<T> Optional<Optional<T>> getRunnerLocal(RunnerLocalInternal<T> key) {
public <T> Optional<Optional<T>> getRunnerLocal(RunnerLocalInternal<T> key) {
if (!runnerLocalMap.containsKey(key)) {
return Optional.empty();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,33 @@ class QueryDispatcher {

private DynamicQueryHandler dynamicQueryHandler;
private WorkflowInboundCallsInterceptor inboundCallsInterceptor;
private static final ThreadLocal<SyncWorkflowContext> queryHandlerWorkflowContext =
new ThreadLocal<>();

public QueryDispatcher(DataConverter dataConverterWithWorkflowContext) {
this.dataConverterWithWorkflowContext = dataConverterWithWorkflowContext;
}

/**
* @return True if the current thread is executing a query handler.
*/
public static boolean isQueryHandler() {
SyncWorkflowContext value = queryHandlerWorkflowContext.get();
return value != null;
}

/**
* @return The current workflow context if the current thread is executing a query handler.
* @throws IllegalStateException if not in a query handler.
*/
public static SyncWorkflowContext getWorkflowContext() {
SyncWorkflowContext value = queryHandlerWorkflowContext.get();
if (value == null) {
throw new IllegalStateException("Not in a query handler");
}
return value;
}

public void setInboundCallsInterceptor(WorkflowInboundCallsInterceptor inboundCallsInterceptor) {
this.inboundCallsInterceptor = inboundCallsInterceptor;
}
Expand All @@ -51,7 +73,11 @@ public WorkflowInboundCallsInterceptor.QueryOutput handleInterceptedQuery(
return new WorkflowInboundCallsInterceptor.QueryOutput(result);
}

public Optional<Payloads> handleQuery(String queryName, Header header, Optional<Payloads> input) {
public Optional<Payloads> handleQuery(
SyncWorkflowContext replayContext,
String queryName,
Header header,
Optional<Payloads> input) {
WorkflowOutboundCallsInterceptor.RegisterQueryInput handler = queryCallbacks.get(queryName);
Object[] args;
if (queryName.startsWith(TEMPORAL_RESERVED_PREFIX)) {
Expand All @@ -69,11 +95,18 @@ public Optional<Payloads> handleQuery(String queryName, Header header, Optional<
dataConverterWithWorkflowContext.fromPayloads(
input, handler.getArgTypes(), handler.getGenericArgTypes());
}
Object result =
inboundCallsInterceptor
.handleQuery(new WorkflowInboundCallsInterceptor.QueryInput(queryName, header, args))
.getResult();
return dataConverterWithWorkflowContext.toPayloads(result);
try {
replayContext.setReadOnly(true);
queryHandlerWorkflowContext.set(replayContext);
Object result =
inboundCallsInterceptor
.handleQuery(new WorkflowInboundCallsInterceptor.QueryInput(queryName, header, args))
.getResult();
return dataConverterWithWorkflowContext.toPayloads(result);
} finally {
replayContext.setReadOnly(false);
queryHandlerWorkflowContext.set(null);
}
}

public void registerQueryHandlers(WorkflowOutboundCallsInterceptor.RegisterQueryInput request) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,17 @@ public RunnerLocalInternal(boolean useCaching) {
}

public T get(Supplier<? extends T> supplier) {
Optional<Optional<T>> result =
DeterministicRunnerImpl.currentThreadInternal().getRunner().getRunnerLocal(this);
Optional<Optional<T>> result;
// Query handlers are special in that they are executing in a different context
// than the main workflow execution threads. We need to fetch the runner local from the
// correct context based on whether we are in a query handler or not.
if (QueryDispatcher.isQueryHandler()) {
result = QueryDispatcher.getWorkflowContext().getRunner().getRunnerLocal(this);
} else {
result = DeterministicRunnerImpl.currentThreadInternal().getRunner().getRunnerLocal(this);
}
T out = result.orElseGet(() -> Optional.ofNullable(supplier.get())).orElse(null);
if (!result.isPresent() && useCaching) {
if (!result.isPresent() && useCaching && !QueryDispatcher.isQueryHandler()) {
// This is the first time we've tried fetching this, and caching is enabled. Store it.
set(out);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ public WorkflowInboundCallsInterceptor.QueryOutput handleInterceptedQuery(
}

public Optional<Payloads> handleQuery(String queryName, Header header, Optional<Payloads> input) {
return queryDispatcher.handleQuery(queryName, header, input);
return queryDispatcher.handleQuery(this, queryName, header, input);
}

public boolean isEveryHandlerFinished() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -843,6 +843,13 @@ static WorkflowOutboundCallsInterceptor getWorkflowOutboundInterceptor() {
}

static SyncWorkflowContext getRootWorkflowContext() {
// If we are in a query handler, we need to get the workflow context from the
// QueryDispatcher, otherwise we get it from the current thread's internal context.
// This is necessary because query handlers run in a different context than the main workflow
// threads.
if (QueryDispatcher.isQueryHandler()) {
return QueryDispatcher.getWorkflowContext();
}
return DeterministicRunnerImpl.currentThreadInternal().getWorkflowContext();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ public void testQuerySuccess() {

// Invoke functionality under test, expect no exceptions for an existing query.
Optional<Payloads> queryResult =
dispatcher.handleQuery("QueryB", Header.empty(), Optional.empty());
dispatcher.handleQuery(
mock(SyncWorkflowContext.class), "QueryB", Header.empty(), Optional.empty());
assertTrue(queryResult.isPresent());
}

Expand All @@ -61,7 +62,8 @@ public void testQueryDispatcherException() {
assertThrows(
IllegalArgumentException.class,
() -> {
dispatcher.handleQuery("QueryC", Header.empty(), null);
dispatcher.handleQuery(
mock(SyncWorkflowContext.class), "QueryC", Header.empty(), null);
});
assertEquals("Unknown query type: QueryC, knownTypes=[QueryA, QueryB]", exception.getMessage());
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package io.temporal.workflow.queryTests;

import static org.junit.Assert.assertEquals;

import io.temporal.client.WorkflowClient;
import io.temporal.client.WorkflowStub;
import io.temporal.testing.internal.SDKTestWorkflowRule;
import io.temporal.workflow.Workflow;
import io.temporal.workflow.WorkflowInfo;
import io.temporal.workflow.WorkflowLocal;
import io.temporal.workflow.shared.TestWorkflows;
import java.time.Duration;
import org.junit.Rule;
import org.junit.Test;

public class WorkflowInfoAndLocalInQueryTest {

@Rule
public SDKTestWorkflowRule testWorkflowRule =
SDKTestWorkflowRule.newBuilder().setWorkflowTypes(TestWorkflow.class).build();

@Test
public void queryReturnsInfoAndLocal() {
TestWorkflows.TestWorkflowWithQuery workflowStub =
testWorkflowRule.newWorkflowStub(TestWorkflows.TestWorkflowWithQuery.class);
WorkflowClient.start(workflowStub::execute);

assertEquals("attempt=1 local=42", workflowStub.query());
assertEquals("done", WorkflowStub.fromTyped(workflowStub).getResult(String.class));
}

public static class TestWorkflow implements TestWorkflows.TestWorkflowWithQuery {

private final WorkflowLocal<Integer> local = WorkflowLocal.withCachedInitial(() -> 0);

@Override
public String execute() {
local.set(42);
Workflow.sleep(Duration.ofSeconds(1));
return "done";
}

@Override
public String query() {
WorkflowInfo info = Workflow.getInfo();
return "attempt=" + info.getAttempt() + " local=" + local.get();
}
}
}
Loading