Skip to content

Commit

Permalink
Add option to deserialize plan without requiring functions (#1652)
Browse files Browse the repository at this point in the history
This commit adds a new parameter to the Plan.FromJson method that allows
deserializing a plan without requiring the functions to be registered in
the skill collection. This is useful for scenarios where the plan is
only used for inspection or analysis, and not for execution. The default
behavior is still to require the functions, and throw an exception if
they are not found. The commit also adds unit tests for both cases, and
updates the JSON serialization options to ignore default values.

Resolves #1631

### Contribution Checklist
<!-- Before submitting this PR, please make sure: -->
- [x] The code builds clean without any errors or warnings
- [x] The PR follows SK Contribution Guidelines
(https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
- [x] The code follows the .NET coding conventions
(https://learn.microsoft.com/dotnet/csharp/fundamentals/coding-style/coding-conventions)
verified with `dotnet format`
- [x] All unit tests pass, and I have added new tests where possible
- [ ] I didn't break anyone 😄
  • Loading branch information
lemillermicrosoft committed Jun 22, 2023
1 parent f17a0a1 commit a0976af
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -489,8 +489,10 @@ public async Task CanStepAndSerializeAndDeserializePlanWithStepsAndContextAsync(
Assert.Contains("\"next_step_index\":2", serializedPlan2, StringComparison.OrdinalIgnoreCase);
}

[Fact]
public void CanDeserializePlan()
[Theory]
[InlineData(false)]
[InlineData(true)]
public void CanDeserializePlan(bool requireFunctions)
{
// Arrange
var goal = "Write a poem or joke and send it in an e-mail to Kai.";
Expand All @@ -516,11 +518,20 @@ public void CanDeserializePlan()
returnContext.Variables.Update(returnContext.Variables.Input + c.Variables.Input))
.Returns(() => Task.FromResult(returnContext));

if (requireFunctions)
{
mockFunction.Setup(x => x.Name).Returns(string.Empty);
ISKFunction? outFunc = mockFunction.Object;
skills.Setup(x => x.TryGetFunction(It.IsAny<string>(), out outFunc)).Returns(true);
skills.Setup(x => x.TryGetFunction(It.IsAny<string>(), It.IsAny<string>(), out outFunc)).Returns(true);
skills.Setup(x => x.GetFunction(It.IsAny<string>(), It.IsAny<string>())).Returns(mockFunction.Object);
}

plan.AddSteps(new Plan("Step1", mockFunction.Object), mockFunction.Object);

// Act
var serializedPlan = plan.ToJson();
var deserializedPlan = Plan.FromJson(serializedPlan, returnContext);
var deserializedPlan = Plan.FromJson(serializedPlan, returnContext, requireFunctions);

// Assert
Assert.NotNull(deserializedPlan);
Expand All @@ -536,4 +547,63 @@ public void CanDeserializePlan()
Assert.Equal(plan.Steps[0].Name, deserializedPlan.Steps[0].Name);
Assert.Equal(plan.Steps[1].Name, deserializedPlan.Steps[1].Name);
}

[Theory]
[InlineData(false)]
[InlineData(true)]
public void DeserializeWithMissingFunctions(bool requireFunctions)
{
// Arrange
var goal = "Write a poem or joke and send it in an e-mail to Kai.";
var stepOutput = "Output: The input was: ";
var plan = new Plan(goal);

// Arrange
var kernel = new Mock<IKernel>();
var log = new Mock<ILogger>();
var memory = new Mock<ISemanticTextMemory>();
var skills = new Mock<ISkillCollection>();

var returnContext = new SKContext(
new ContextVariables(stepOutput),
memory.Object,
skills.Object,
log.Object
);

var mockFunction = new Mock<ISKFunction>();
mockFunction.Setup(x => x.InvokeAsync(It.IsAny<SKContext>(), null))
.Callback<SKContext, CompleteRequestSettings>((c, s) =>
returnContext.Variables.Update(returnContext.Variables.Input + c.Variables.Input))
.Returns(() => Task.FromResult(returnContext));

plan.AddSteps(new Plan("Step1", mockFunction.Object), mockFunction.Object);

var serializedPlan = plan.ToJson();

if (requireFunctions)
{
// Act + Assert
Assert.Throws<KernelException>(() => Plan.FromJson(serializedPlan, returnContext));
}
else
{
// Act
var deserializedPlan = Plan.FromJson(serializedPlan, returnContext, requireFunctions);

// Assert
Assert.NotNull(deserializedPlan);
Assert.Equal(goal, deserializedPlan.Description);

Assert.Equal(string.Join(",", plan.Outputs),
string.Join(",", deserializedPlan.Outputs));
Assert.Equal(string.Join(",", plan.Parameters.Select(kv => $"{kv.Key}:{kv.Value}")),
string.Join(",", deserializedPlan.Parameters.Select(kv => $"{kv.Key}:{kv.Value}")));
Assert.Equal(string.Join(",", plan.State.Select(kv => $"{kv.Key}:{kv.Value}")),
string.Join(",", deserializedPlan.State.Select(kv => $"{kv.Key}:{kv.Value}")));

Assert.Equal(plan.Steps[0].Name, deserializedPlan.Steps[0].Name);
Assert.Equal(plan.Steps[1].Name, deserializedPlan.Steps[1].Name);
}
}
}
16 changes: 12 additions & 4 deletions dotnet/src/SemanticKernel/Planning/Plan.cs
Original file line number Diff line number Diff line change
Expand Up @@ -173,15 +173,16 @@ public Plan(
/// </summary>
/// <param name="json">JSON string representation of a Plan</param>
/// <param name="context">The context to use for function registrations.</param>
/// <param name="requireFunctions">Whether to require functions to be registered. Only used when context is not null.</param>
/// <returns>An instance of a Plan object.</returns>
/// <remarks>If Context is not supplied, plan will not be able to execute.</remarks>
public static Plan FromJson(string json, SKContext? context = null)
public static Plan FromJson(string json, SKContext? context = null, bool requireFunctions = true)
{
var plan = JsonSerializer.Deserialize<Plan>(json, new JsonSerializerOptions { IncludeFields = true }) ?? new Plan(string.Empty);

if (context != null)
{
plan = SetAvailableFunctions(plan, context);
plan = SetAvailableFunctions(plan, context, requireFunctions);
}

return plan;
Expand Down Expand Up @@ -420,8 +421,9 @@ internal string ExpandFromVariables(ContextVariables variables, string input)
/// </summary>
/// <param name="plan">Plan to set functions for.</param>
/// <param name="context">Context to use.</param>
/// <param name="requireFunctions">Whether to throw an exception if a function is not found.</param>
/// <returns>The plan with functions set.</returns>
private static Plan SetAvailableFunctions(Plan plan, SKContext context)
private static Plan SetAvailableFunctions(Plan plan, SKContext context, bool requireFunctions = true)
{
if (plan.Steps.Count == 0)
{
Expand All @@ -436,12 +438,18 @@ private static Plan SetAvailableFunctions(Plan plan, SKContext context)
{
plan.SetFunction(skillFunction);
}
else if (requireFunctions)
{
throw new KernelException(
KernelException.ErrorCodes.FunctionNotAvailable,
$"Function '{plan.SkillName}.{plan.Name}' not found in skill collection");
}
}
else
{
foreach (var step in plan.Steps)
{
SetAvailableFunctions(step, context);
SetAvailableFunctions(step, context, requireFunctions);
}
}

Expand Down

0 comments on commit a0976af

Please sign in to comment.