Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ under the License.
<maven.compiler.target>${target.java.version}</maven.compiler.target>
<spotless.version>2.27.1</spotless.version>
<spotless.skip>false</spotless.skip>
<flink.version>1.20.3</flink.version>
<flink.version>2.3-SNAPSHOT</flink.version>
<kafka.version>4.0.0</kafka.version>
<junit5.version>5.10.1</junit5.version>
<flink.shaded.version>17.0</flink.shaded.version>
Expand Down
14 changes: 12 additions & 2 deletions python/flink_agents/api/agents/react_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,14 +269,24 @@ def stop_action(event: ChatResponseEvent, ctx: RunnerContext) -> None:
"""Stop action to output result."""
output = event.response.content
# parse llm response to target schema.
output_schema = ctx.get_action_config_value(key="output_schema")
output_schema_config = ctx.get_action_config_value(key="output_schema")

error_handling_strategy = ctx.config.get(
ReActAgentOptions.ERROR_HANDLING_STRATEGY
)
try:
output_schema = None
if output_schema_config:
# Handle both OutputSchema object and deserialized list/tuple
if hasattr(output_schema_config, "output_schema"):
output_schema = output_schema_config.output_schema
elif isinstance(output_schema_config, (list, tuple)) and len(output_schema_config) == 3:
# Deserialize from [module, class, dict] format
module = importlib.import_module(output_schema_config[0])
clazz = getattr(module, output_schema_config[1])
output_schema_obj = clazz.model_validate(output_schema_config[2])
output_schema = output_schema_obj.output_schema
if output_schema:
output_schema = output_schema.output_schema
output = json.loads(output.strip())
if isinstance(output_schema, type) and issubclass(
output_schema, BaseModel
Expand Down
Binary file not shown.
9 changes: 6 additions & 3 deletions python/flink_agents/plan/actions/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,12 @@ def __custom_deserialize(self) -> "Action":
self["config"].pop(_CONFIG_TYPE)
for name, value in config.items():
try:
module = importlib.import_module(value[0])
clazz = getattr(module, value[1])
self["config"][name] = clazz.model_validate(value[2])
if isinstance(value, (list, tuple)) and len(value) == 3:
module = importlib.import_module(value[0])
clazz = getattr(module, value[1])
self["config"][name] = clazz.model_validate(value[2])
else:
self["config"][name] = value
except Exception: # noqa : PERF203
self["config"][name] = value
return self
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import org.apache.flink.agents.api.InputEvent;
import org.apache.flink.agents.plan.actions.Action;
import org.apache.flink.agents.runtime.python.event.PythonEvent;
import org.apache.flink.shaded.guava31.com.google.common.base.Preconditions;
import org.apache.flink.shaded.guava33.com.google.common.base.Preconditions;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper;

import javax.annotation.Nonnull;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
*/
package org.apache.flink.agents.runtime.message;

import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.serialization.SerializerConfig;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.java.typeutils.runtime.kryo.KryoSerializer;
Expand Down Expand Up @@ -60,8 +60,8 @@ public boolean isKeyType() {
}

@Override
public TypeSerializer<Message> createSerializer(ExecutionConfig executionConfig) {
return new KryoSerializer<>(Message.class, executionConfig);
public TypeSerializer<Message> createSerializer(SerializerConfig serializerConfig) {
return new KryoSerializer<>(Message.class, serializerConfig);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,16 @@
import org.apache.flink.runtime.state.VoidNamespace;
import org.apache.flink.runtime.state.VoidNamespaceSerializer;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.flink.streaming.api.graph.StreamConfig;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.BoundedOneInput;
import org.apache.flink.streaming.api.operators.ChainingStrategy;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.api.operators.Output;
import org.apache.flink.streaming.api.operators.StreamingRuntimeContext;
import org.apache.flink.streaming.api.watermark.Watermark;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
import org.apache.flink.streaming.runtime.tasks.StreamTask;
import org.apache.flink.streaming.runtime.tasks.mailbox.MailboxExecutorImpl;
import org.apache.flink.streaming.runtime.tasks.mailbox.MailboxProcessor;
import org.apache.flink.types.Row;
Expand Down Expand Up @@ -171,7 +173,7 @@ public ActionExecutionOperator(
this.agentPlan = agentPlan;
this.inputIsJava = inputIsJava;
this.processingTimeService = processingTimeService;
this.chainingStrategy = ChainingStrategy.ALWAYS;
// chainingStrategy is now managed by StreamOperatorFactory in Flink 2.x
this.mailboxExecutor = mailboxExecutor;
this.eventLogger = EventLoggerFactory.createLogger(EventLoggerConfig.builder().build());
this.eventListeners = new ArrayList<>();
Expand All @@ -180,6 +182,17 @@ public ActionExecutionOperator(
this.actionTaskRunnerContexts = new HashMap<>();
}

/**
* Public setup method that delegates to the protected setup in AbstractStreamOperator. This
* allows the factory to properly initialize the operator.
*/
public void setupOperator(
StreamTask<?, ?> containingTask,
StreamConfig config,
Output<StreamRecord<OUT>> output) {
setup(containingTask, config, output);
}

@Override
public void open() throws Exception {
super.open();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@
import org.apache.flink.agents.plan.AgentPlan;
import org.apache.flink.agents.runtime.actionstate.ActionStateStore;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.streaming.api.operators.AbstractStreamOperatorFactory;
import org.apache.flink.streaming.api.operators.ChainingStrategy;
import org.apache.flink.streaming.api.operators.OneInputStreamOperatorFactory;
import org.apache.flink.streaming.api.operators.StreamOperator;
import org.apache.flink.streaming.api.operators.StreamOperatorParameters;

/** Operator factory for {@link ActionExecutionOperator}. */
public class ActionExecutionOperatorFactory<IN, OUT>
public class ActionExecutionOperatorFactory<IN, OUT> extends AbstractStreamOperatorFactory<OUT>
implements OneInputStreamOperatorFactory<IN, OUT> {

private final AgentPlan agentPlan;
Expand All @@ -45,9 +46,11 @@ protected ActionExecutionOperatorFactory(
this.agentPlan = agentPlan;
this.inputIsJava = inputIsJava;
this.actionStateStore = actionStateStore;
this.chainingStrategy = ChainingStrategy.ALWAYS;
}

@Override
@SuppressWarnings("unchecked")
public <T extends StreamOperator<OUT>> T createStreamOperator(
StreamOperatorParameters<OUT> parameters) {
ActionExecutionOperator<IN, OUT> op =
Expand All @@ -57,21 +60,13 @@ public <T extends StreamOperator<OUT>> T createStreamOperator(
parameters.getProcessingTimeService(),
parameters.getMailboxExecutor(),
actionStateStore);
op.setup(
op.setupOperator(
parameters.getContainingTask(),
parameters.getStreamConfig(),
parameters.getOutput());
return (T) op;
}

@Override
public void setChainingStrategy(ChainingStrategy chainingStrategy) {}

@Override
public ChainingStrategy getChainingStrategy() {
return ChainingStrategy.ALWAYS;
}

@Override
public Class<? extends StreamOperator> getStreamOperatorClass(ClassLoader classLoader) {
return ActionExecutionOperator.class;
Expand Down
Loading