Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support returning values from PromptFunction fn's #718

Merged
merged 5 commits into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/marvin/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .settings import settings

# legacy
from .components import ai_fn, ai_model, ai_classifier
from .components.prompt.fn import prompt_fn

Expand Down
4 changes: 2 additions & 2 deletions src/marvin/beta/applications/state/state.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import inspect
import json
import textwrap
from typing import Optional, Union

from jsonpatch import JsonPatch
Expand Down Expand Up @@ -71,7 +71,7 @@ def as_tool(self, name: str = None) -> "Tool":
name = "state"
schema = self.get_schema()
if schema:
description = textwrap.dedent(
description = inspect.cleandoc(
f"Update the {name} object using JSON Patch documents. Updates will"
" fail if they do not comply with the following"
" schema:\n\n```json\n{schema}\n```"
Expand Down
2 changes: 2 additions & 0 deletions src/marvin/components/ai_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ class AIFunction(BaseModel, Generic[P, T], ExposeSyncMethodsMixin):
- {{ arg }}: {{ value }}
{% endfor %}



What is its output?
"""
)
Expand Down
27 changes: 15 additions & 12 deletions src/marvin/components/ai_image.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import inspect
from functools import partial, wraps
from typing import (
Any,
Expand All @@ -24,9 +25,15 @@
)

T = TypeVar("T")

P = ParamSpec("P")

DEFAULT_PROMPT = inspect.cleandoc(
"""
{{_doc}}
{{_return_value}}
"""
)


class AIImageKwargs(TypedDict):
environment: NotRequired[BaseEnvironment]
Expand All @@ -38,7 +45,7 @@ class AIImageKwargs(TypedDict):
class AIImageKwargsDefaults(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
environment: Optional[BaseEnvironment] = None
prompt: Optional[str] = None
prompt: Optional[str] = DEFAULT_PROMPT
client: Optional[Client] = None
aclient: Optional[AsyncClient] = None

Expand All @@ -47,7 +54,7 @@ class AIImage(BaseModel, Generic[P]):
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
fn: Optional[Callable[P, Any]] = None
environment: Optional[BaseEnvironment] = None
prompt: Optional[str] = Field(default=None)
prompt: Optional[str] = Field(default=DEFAULT_PROMPT)
client: Client = Field(default_factory=lambda: MarvinClient().client)
aclient: AsyncClient = Field(default_factory=lambda: AsyncMarvinClient().client)

Expand All @@ -73,16 +80,12 @@ def as_prompt(
*args: P.args,
**kwargs: P.kwargs,
) -> str:
return (
PromptFunction[BaseModel]
.as_tool_call(
fn=self.fn,
environment=self.environment,
prompt=self.prompt,
)(*args, **kwargs)
.messages[0]
.content
tool_call = PromptFunction[BaseModel].as_tool_call(
fn=self.fn,
environment=self.environment,
prompt=self.prompt,
)
return tool_call(*args, **kwargs).messages[0].content

@overload
@classmethod
Expand Down
78 changes: 46 additions & 32 deletions src/marvin/components/prompt/fn.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's a pattern I like of keeping type mappings in a mapping layer

I can move this to our mapping layer in the future, but colocating all these mappings is 🔥

Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,34 @@
U = TypeVar("U", bound=BaseModel)


def fn_to_messages(
fn: Callable,
fn_args,
fn_kwargs,
prompt=None,
render_kwargs=None,
call_fn: bool = True,
) -> list[Message]:
prompt = prompt or fn.__doc__ or ""

signature = inspect.signature(fn)
params = signature.bind(*fn_args, **fn_kwargs)
params.apply_defaults()
return_annotation = inspect.signature(fn).return_annotation
return_value = fn(*fn_args, **fn_kwargs) if call_fn else None

messages = Transcript(content=prompt).render_to_messages(
**fn_kwargs | params.arguments,
_arguments=params.arguments,
_doc=inspect.getdoc(fn),
_return_value=return_value,
_return_annotation=return_annotation,
_source_code=("\ndef" + "def".join(re.split("def", inspect.getsource(fn))[1:])),
**(render_kwargs or {}),
)
return messages


class PromptFunction(Prompt[U]):
model_config = pydantic.ConfigDict(
extra="allow",
Expand Down Expand Up @@ -99,35 +127,24 @@ def as_grammar(
Callable[[Callable[P, Any]], Callable[P, Self]],
Callable[P, Self],
]:
def wrapper(func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs) -> Self:
# Get the signature of the function
signature = inspect.signature(func)
params = signature.bind(*args, **kwargs)
params.apply_defaults()

def wrapper(func: Callable[P, Any], *args: P.args, **kwargs_: P.kwargs) -> Self:
vocabulary = create_vocabulary_from_type(
inspect.signature(func).return_annotation
)

messages = fn_to_messages(
fn=fn,
fn_args=args,
fn_kwargs=kwargs_,
prompt=prompt,
render_kwargs=dict(_options=vocabulary),
)
grammar = create_grammar_from_vocabulary(
vocabulary=vocabulary,
encoder=encoder,
_enumerate=enumerate,
max_tokens=max_tokens,
)

messages = Transcript(
content=prompt or func.__doc__ or ""
).render_to_messages(
**kwargs | params.arguments,
_arguments=params.arguments,
_options=vocabulary,
_doc=func.__doc__,
_source_code=(
"\ndef" + "def".join(re.split("def", inspect.getsource(func))[1:])
),
)

return cls(
messages=messages,
temperature=temperature,
Expand All @@ -154,6 +171,7 @@ def as_tool_call(
model_description: str = "Formats the response.",
field_name: str = "data",
field_description: str = "The data to format.",
render_kwargs: Optional[dict[str, Any]] = None,
) -> Callable[[Callable[P, Any]], Callable[P, Self]]:
pass

Expand All @@ -169,6 +187,7 @@ def as_tool_call(
model_description: str = "Formats the response.",
field_name: str = "data",
field_description: str = "The data to format.",
render_kwargs: Optional[dict[str, Any]] = None,
) -> Callable[P, Self]:
pass

Expand All @@ -183,15 +202,13 @@ def as_tool_call(
model_description: str = "Formats the response.",
field_name: str = "data",
field_description: str = "The data to format.",
render_kwargs: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> Union[
Callable[[Callable[P, Any]], Callable[P, Self]],
Callable[P, Self],
]:
def wrapper(func: Callable[P, Any], *args: P.args, **kwargs_: P.kwargs) -> Self:
signature = inspect.signature(func)
params = signature.bind(*args, **kwargs_)
params.apply_defaults()
_type = inspect.signature(func).return_annotation
if _type is inspect._empty:
_type = str
Expand All @@ -204,16 +221,13 @@ def wrapper(func: Callable[P, Any], *args: P.args, **kwargs_: P.kwargs) -> Self:
field_description=field_description,
)

messages = Transcript(
content=prompt or func.__doc__ or ""
).render_to_messages(
**kwargs_ | params.arguments,
_doc=func.__doc__,
_arguments=params.arguments,
_response_model=toolset.tools[0], # type: ignore
_source_code=(
"\ndef" + "def".join(re.split("def", inspect.getsource(func))[1:])
),
messages = fn_to_messages(
fn=fn,
fn_args=args,
fn_kwargs=kwargs_,
prompt=prompt,
render_kwargs=(render_kwargs or {})
| dict(_response_model=toolset.tools[0]),
)

return cls(
Expand Down