Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(agents-api): Remove unnecessary 'type' property from tool type definition #574

Merged
merged 2 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 0 additions & 31 deletions agents-api/agents_api/autogen/Tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,6 @@ class CreateToolRequest(BaseModel):
model_config = ConfigDict(
populate_by_name=True,
)
type: Literal["function", "integration", "system", "api_call"] = "function"
"""
Whether this tool is a `function`, `api_call`, `system` etc. (Only `function` tool supported right now)
"""
name: Annotated[str, Field(max_length=40, pattern="^[^\\W0-9]\\w*$")]
"""
Name of the tool (must be unique for this agent and a valid python identifier string )
Expand Down Expand Up @@ -168,10 +164,6 @@ class NamedToolChoice(BaseModel):
model_config = ConfigDict(
populate_by_name=True,
)
type: Literal["function", "integration", "system", "api_call"]
"""
Whether this tool is a `function`, `api_call`, `system` etc. (Only `function` tool supported right now)
"""
function: FunctionCallOption | None = None


Expand All @@ -183,10 +175,6 @@ class PatchToolRequest(BaseModel):
model_config = ConfigDict(
populate_by_name=True,
)
type: Literal["function", "integration", "system", "api_call"] = "function"
"""
Whether this tool is a `function`, `api_call`, `system` etc. (Only `function` tool supported right now)
"""
name: Annotated[str | None, Field(None, max_length=40, pattern="^[^\\W0-9]\\w*$")]
"""
Name of the tool (must be unique for this agent and a valid python identifier string )
Expand Down Expand Up @@ -247,10 +235,6 @@ class Tool(BaseModel):
model_config = ConfigDict(
populate_by_name=True,
)
type: Literal["function", "integration", "system", "api_call"] = "function"
"""
Whether this tool is a `function`, `api_call`, `system` etc. (Only `function` tool supported right now)
"""
name: Annotated[str, Field(max_length=40, pattern="^[^\\W0-9]\\w*$")]
"""
Name of the tool (must be unique for this agent and a valid python identifier string )
Expand Down Expand Up @@ -291,10 +275,6 @@ class UpdateToolRequest(BaseModel):
model_config = ConfigDict(
populate_by_name=True,
)
type: Literal["function", "integration", "system", "api_call"] = "function"
"""
Whether this tool is a `function`, `api_call`, `system` etc. (Only `function` tool supported right now)
"""
name: Annotated[str, Field(max_length=40, pattern="^[^\\W0-9]\\w*$")]
"""
Name of the tool (must be unique for this agent and a valid python identifier string )
Expand All @@ -316,14 +296,3 @@ class ChosenFunctionCall(ChosenToolCall):
"""
The function to call
"""


class NamedFunctionChoice(NamedToolChoice):
model_config = ConfigDict(
populate_by_name=True,
)
type: Literal["function"] = "function"
function: FunctionCallOption
"""
The function to call
"""
97 changes: 22 additions & 75 deletions agents-api/agents_api/autogen/openapi_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,88 +77,35 @@ class InputChatMLMessage(Message):
pass


# Custom types (not generated correctly)
# --------------------------------------

ChatMLContent = (
list[ChatMLTextContentPart | ChatMLImageContentPart]
| Tool
| ChosenToolCall
| str
| ToolResponse
| list[
list[ChatMLTextContentPart | ChatMLImageContentPart]
| Tool
| ChosenToolCall
| str
| ToolResponse
]
)

# Extract ChatMLRole
ChatMLRole = BaseEntry.model_fields["role"].annotation

# Extract ChatMLSource
ChatMLSource = BaseEntry.model_fields["source"].annotation

# Extract ExecutionStatus
ExecutionStatus = Execution.model_fields["status"].annotation

# Extract TransitionType
TransitionType = Transition.model_fields["type"].annotation

# Assertions to ensure consistency (optional, but recommended for runtime checks)
assert ChatMLRole == BaseEntry.model_fields["role"].annotation
assert ChatMLSource == BaseEntry.model_fields["source"].annotation
assert ExecutionStatus == Execution.model_fields["status"].annotation
assert TransitionType == Transition.model_fields["type"].annotation


# Create models
# -------------
# Patches
# -------


class CreateTransitionRequest(Transition):
# The following fields are optional in this
def type_property(self: BaseModel) -> str:
return (
"function"
if self.function
else "integration"
if self.integration
else "system"
if self.system
else "api_call"
if self.api_call
else None
)

id: UUID | None = None
execution_id: UUID | None = None
created_at: AwareDatetime | None = None
updated_at: AwareDatetime | None = None
metadata: dict[str, Any] | None = None
task_token: str | None = None

# Patch original Tool class to add 'type' property
TaskTool.type = computed_field(property(type_property))

class CreateEntryRequest(BaseEntry):
timestamp: Annotated[
float, Field(ge=0.0, default_factory=lambda: utcnow().timestamp())
]
# Patch original Tool class to add 'type' property
Tool.type = computed_field(property(type_property))

@classmethod
def from_model_input(
cls: Type[Self],
model: str,
*,
role: ChatMLRole,
content: ChatMLContent,
name: str | None = None,
source: ChatMLSource,
**kwargs: dict,
) -> Self:
tokenizer: dict = select_tokenizer(model=model)
token_count = token_counter(
model=model, messages=[{"role": role, "content": content, "name": name}]
)
# Patch original UpdateToolRequest class to add 'type' property
UpdateToolRequest.type = computed_field(property(type_property))

return cls(
role=role,
content=content,
name=name,
source=source,
tokenizer=tokenizer["type"],
token_count=token_count,
**kwargs,
)
# Patch original PatchToolRequest class to add 'type' property
PatchToolRequest.type = computed_field(property(type_property))


# Patch Task Workflow Steps
Expand Down
27 changes: 18 additions & 9 deletions agents-api/agents_api/common/protocol/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,19 +218,28 @@ def task_to_spec(
task: Task | CreateTaskRequest | UpdateTaskRequest | PatchTaskRequest, **model_opts
) -> TaskSpecDef | PartialTaskSpecDef:
task_data = task.model_dump(**model_opts)
main = task_data.pop("main")

workflows = [Workflow(name="main", steps=main)]
if "tools" in task_data:
del task_data["tools"]

for k in list(task_data.keys()):
if k in TaskSpec.model_fields.keys():
continue
tools = []
for tool in task.tools:
tool_spec = getattr(tool, tool.type)

steps = task_data.pop(k)
workflows.append(Workflow(name=k, steps=steps))
tools.append(
TaskToolDef(
type=tool.type,
spec=tool_spec.model_dump(),
**tool.model_dump(exclude={"type"}),
)
)

tools = task_data.pop("tools", [])
tools = [TaskToolDef(spec=tool.pop(tool["type"]), **tool) for tool in tools]
workflows = [Workflow(name="main", steps=task_data.pop("main"))]

for key, steps in list(task_data.items()):
if key not in TaskSpec.model_fields:
workflows.append(Workflow(name=key, steps=steps))
del task_data[key]

cls = PartialTaskSpecDef if isinstance(task, PatchTaskRequest) else TaskSpecDef

Expand Down
14 changes: 0 additions & 14 deletions typespec/tools/models.tsp
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,6 @@ model SystemDef {

// TODO: We should use this model for all tools, not just functions and discriminate on the type
model Tool {
/** Whether this tool is a `function`, `api_call`, `system` etc. (Only `function` tool supported right now) */
type: ToolType = ToolType.function;

/** Name of the tool (must be unique for this agent and a valid python identifier string )*/
name: validPythonIdentifier;

Expand All @@ -95,24 +92,13 @@ model FunctionCallOption {
name: string;
}

@discriminator("type")
model NamedToolChoice {
/** Whether this tool is a `function`, `api_call`, `system` etc. (Only `function` tool supported right now) */
type: ToolType;

function?: FunctionCallOption;
integration?: never; // TODO: Implement
system?: never; // TODO: Implement
api_call?: never; // TODO: Implement
}

model NamedFunctionChoice extends NamedToolChoice {
type: ToolType.function;

/** The function to call */
function: FunctionCallOption;
}

model ToolResponse {
@key id: uuid;

Expand Down
Loading