Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions PolyPilot.Tests/AgentModeTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
using System.Text.Json;
using PolyPilot.Models;

namespace PolyPilot.Tests;

public class AgentModeTests
{
[Fact]
public void SendMessagePayload_AgentMode_DefaultsToNull()
{
var payload = new SendMessagePayload { SessionName = "s1", Message = "hello" };
Assert.Null(payload.AgentMode);
}

[Theory]
[InlineData("autopilot")]
[InlineData("plan")]
[InlineData("interactive")]
[InlineData("shell")]
public void SendMessagePayload_AgentMode_RoundTrips(string mode)
{
var payload = new SendMessagePayload
{
SessionName = "test",
Message = "do something",
AgentMode = mode
};

var msg = BridgeMessage.Create(BridgeMessageTypes.SendMessage, payload);
var json = JsonSerializer.Serialize(msg);
var deserialized = JsonSerializer.Deserialize<BridgeMessage>(json);

Assert.NotNull(deserialized);
var restored = deserialized!.GetPayload<SendMessagePayload>();
Assert.NotNull(restored);
Assert.Equal(mode, restored!.AgentMode);
Assert.Equal("test", restored.SessionName);
Assert.Equal("do something", restored.Message);
}

[Fact]
public void SendMessagePayload_NullAgentMode_OmittedInJson()
{
var payload = new SendMessagePayload
{
SessionName = "s1",
Message = "hello"
};

var json = JsonSerializer.Serialize(payload);
// Null properties should still deserialize cleanly
var restored = JsonSerializer.Deserialize<SendMessagePayload>(json);
Assert.NotNull(restored);
Assert.Null(restored!.AgentMode);
}

[Fact]
public void SendMessagePayload_AgentMode_BackwardCompatible()
{
// Old clients send JSON without AgentMode field - should deserialize as null
var json = """{"SessionName":"s1","Message":"hello"}""";
var payload = JsonSerializer.Deserialize<SendMessagePayload>(json);
Assert.NotNull(payload);
Assert.Null(payload!.AgentMode);
Assert.Equal("s1", payload.SessionName);
Assert.Equal("hello", payload.Message);
}
}
4 changes: 2 additions & 2 deletions PolyPilot.Tests/TestStubs.cs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ public Task RequestSessionsAsync(CancellationToken ct = default)
return Task.CompletedTask;
}
public Task RequestHistoryAsync(string sessionName, int? limit = null, CancellationToken ct = default) => Task.CompletedTask;
public Task SendMessageAsync(string sessionName, string message, CancellationToken ct = default) => Task.CompletedTask;
public Task SendMessageAsync(string sessionName, string message, string? agentMode = null, CancellationToken ct = default) => Task.CompletedTask;
public Task CreateSessionAsync(string name, string? model = null, string? workingDirectory = null, CancellationToken ct = default) => Task.CompletedTask;
public string? LastSwitchedSession { get; private set; }
public int SwitchSessionCallCount { get; private set; }
Expand All @@ -101,7 +101,7 @@ public Task SwitchSessionAsync(string name, CancellationToken ct = default)
}

public void FireOnStateChanged() => OnStateChanged?.Invoke();
public Task QueueMessageAsync(string sessionName, string message, CancellationToken ct = default) => Task.CompletedTask;
public Task QueueMessageAsync(string sessionName, string message, string? agentMode = null, CancellationToken ct = default) => Task.CompletedTask;
public Task ResumeSessionAsync(string sessionId, string? displayName = null, CancellationToken ct = default) => Task.CompletedTask;
public Task CloseSessionAsync(string name, CancellationToken ct = default) => Task.CompletedTask;
public Task AbortSessionAsync(string sessionName, CancellationToken ct = default) => Task.CompletedTask;
Expand Down
6 changes: 3 additions & 3 deletions PolyPilot.Tests/WsBridgeIntegrationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ public async Task SendMessage_AddsUserMessageToServerHistory()
using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(10));
var client = await ConnectClientAsync(cts.Token);

await client.SendMessageAsync("msg-test", "Hello from mobile", cts.Token);
await client.SendMessageAsync("msg-test", "Hello from mobile", ct: cts.Token);

var session = _copilot.GetSession("msg-test");
Assert.NotNull(session);
Expand All @@ -317,7 +317,7 @@ public async Task SendMessage_TriggersContentDelta_OnClient()
};
await client.ConnectAsync($"ws://localhost:{_port}/", null, cts.Token);

await client.SendMessageAsync("delta-test", "Tell me a joke", cts.Token);
await client.SendMessageAsync("delta-test", "Tell me a joke", ct: cts.Token);

// Demo mode sends a simulated response with content deltas
var content = await contentReceived.Task.WaitAsync(TimeSpan.FromSeconds(5));
Expand All @@ -334,7 +334,7 @@ public async Task QueueMessage_EnqueuesOnServer()
using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(10));
var client = await ConnectClientAsync(cts.Token);

await client.QueueMessageAsync("queue-test", "queued msg", cts.Token);
await client.QueueMessageAsync("queue-test", "queued msg", ct: cts.Token);

var session = _copilot.GetSession("queue-test");
Assert.NotNull(session);
Expand Down
15 changes: 9 additions & 6 deletions PolyPilot/Components/Pages/Dashboard.razor
Original file line number Diff line number Diff line change
Expand Up @@ -1152,10 +1152,13 @@
}

var inputMode = GetInputMode(sessionName);
if (inputMode == "plan")
finalPrompt = $"[[PLAN]] {finalPrompt}";
else if (inputMode == "autopilot")
finalPrompt = $"[[AUTOPILOT]] {finalPrompt}";
// Map UI input mode to SDK agent mode
string? agentMode = inputMode switch
{
"plan" => "plan",
"autopilot" => "autopilot",
_ => null
};

var dispatch = await FiestaService.DispatchMentionedWorkAsync(sessionName, finalPrompt);
if (dispatch.MentionsFound)
Expand Down Expand Up @@ -1186,7 +1189,7 @@
queueImagePaths = pendingSendImages!.Select(i => i.TempPath).ToList();
pendingImagesBySession.Remove(sessionName);
}
CopilotService.EnqueueMessage(sessionName, finalPrompt, queueImagePaths);
CopilotService.EnqueueMessage(sessionName, finalPrompt, queueImagePaths, agentMode);
return;
}

Expand Down Expand Up @@ -1215,7 +1218,7 @@

try
{
_ = CopilotService.SendPromptAsync(sessionName, finalPrompt, imagePaths).ContinueWith(t =>
_ = CopilotService.SendPromptAsync(sessionName, finalPrompt, imagePaths, agentMode: agentMode).ContinueWith(t =>
{
if (t.IsFaulted)
{
Expand Down
3 changes: 3 additions & 0 deletions PolyPilot/Models/BridgeMessages.cs
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,8 @@ public class SendMessagePayload
{
public string SessionName { get; set; } = "";
public string Message { get; set; } = "";
/// <summary>SDK agent mode: "interactive", "plan", "autopilot", "shell". Null = default (interactive).</summary>
public string? AgentMode { get; set; }
}

public class CreateSessionPayload
Expand All @@ -262,6 +264,7 @@ public class QueueMessagePayload
{
public string SessionName { get; set; } = "";
public string Message { get; set; } = "";
public string? AgentMode { get; set; }
}

public class PersistedSessionsPayload
Expand Down
44 changes: 41 additions & 3 deletions PolyPilot/Services/CopilotService.Events.cs
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,18 @@ private void CompleteResponse(SessionState state, long? expectedGeneration = nul
_queuedImagePaths.TryRemove(state.Info.Name, out _);
}
}
// Retrieve any queued agent mode for this message
string? nextAgentMode = null;
lock (_imageQueueLock)
{
if (_queuedAgentModes.TryGetValue(state.Info.Name, out var modeQueue) && modeQueue.Count > 0)
{
nextAgentMode = modeQueue[0];
modeQueue.RemoveAt(0);
if (modeQueue.Count == 0)
_queuedAgentModes.TryRemove(state.Info.Name, out _);
}
}

var skipHistory = state.Info.ReflectionCycle is { IsActive: true } &&
ReflectionCycle.IsReflectionFollowUpPrompt(nextPrompt);
Expand All @@ -817,7 +829,7 @@ private void CompleteResponse(SessionState state, long? expectedGeneration = nul
{
try
{
await SendPromptAsync(state.Info.Name, nextPrompt, imagePaths: nextImagePaths, skipHistoryMessage: skipHistory);
await SendPromptAsync(state.Info.Name, nextPrompt, imagePaths: nextImagePaths, skipHistoryMessage: skipHistory, agentMode: nextAgentMode);
tcs.TrySetResult();
}
catch (Exception ex)
Expand All @@ -829,7 +841,7 @@ private void CompleteResponse(SessionState state, long? expectedGeneration = nul
}
else
{
await SendPromptAsync(state.Info.Name, nextPrompt, imagePaths: nextImagePaths, skipHistoryMessage: skipHistory);
await SendPromptAsync(state.Info.Name, nextPrompt, imagePaths: nextImagePaths, skipHistoryMessage: skipHistory, agentMode: nextAgentMode);
}
}
catch (Exception ex)
Expand All @@ -846,6 +858,19 @@ private void CompleteResponse(SessionState state, long? expectedGeneration = nul
images.Insert(0, nextImagePaths);
}
}
// Re-queue the agent mode too (always re-insert to maintain alignment)
lock (_imageQueueLock)
{
if (_queuedAgentModes.TryGetValue(state.Info.Name, out var existingModes))
{
existingModes.Insert(0, nextAgentMode);
}
else if (nextAgentMode != null)
{
var modes = _queuedAgentModes.GetOrAdd(state.Info.Name, _ => new List<string?>());
modes.Insert(0, nextAgentMode);
}
}
});
}
});
Expand Down Expand Up @@ -1056,6 +1081,19 @@ private void HandleReflectionAdvanceResult(SessionState state, string response,
var nextPrompt = state.Info.MessageQueue[0];
state.Info.MessageQueue.RemoveAt(0);

// Consume any queued agent mode to keep alignment
string? nextAgentMode2 = null;
lock (_imageQueueLock)
{
if (_queuedAgentModes.TryGetValue(state.Info.Name, out var modeQueue2) && modeQueue2.Count > 0)
{
nextAgentMode2 = modeQueue2[0];
modeQueue2.RemoveAt(0);
if (modeQueue2.Count == 0)
_queuedAgentModes.TryRemove(state.Info.Name, out _);
}
}

var skipHistory = state.Info.ReflectionCycle is { IsActive: true } &&
ReflectionCycle.IsReflectionFollowUpPrompt(nextPrompt);

Expand All @@ -1071,7 +1109,7 @@ private void HandleReflectionAdvanceResult(SessionState state, string response,
{
try
{
await SendPromptAsync(state.Info.Name, nextPrompt, skipHistoryMessage: skipHistory);
await SendPromptAsync(state.Info.Name, nextPrompt, skipHistoryMessage: skipHistory, agentMode: nextAgentMode2);
tcs.TrySetResult();
}
catch (Exception ex)
Expand Down
Loading