Skip to content

Commit

Permalink
Merge pull request #881 from PrefectHQ/instructions
Browse files Browse the repository at this point in the history
Allow assistant instructions to be jinja and self-referential
  • Loading branch information
zzstoatzz authored Mar 25, 2024
2 parents 0da9680 + d52c07b commit d3e8e16
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 47 deletions.
3 changes: 3 additions & 0 deletions docs/docs/interactive/assistants.md
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,9 @@ Each assistant can be given `instructions` that describe its purpose, personalit
!!! success "Result"
![](/assets/images/docs/assistants/instructions.png)

Instructions are rendered as a Jinja template, which means you can use variables and conditionals to customize the assistant's behavior. A special variable, `self_` is provided to the template, which represents the assistant object itself. This allows you to template the assistant's name, tools, or other attributes into the instructions.


### Tools

Each assistant can be given a list of `tools` that it can use when responding to a message. Tools are a way to extend the assistant's capabilities beyond its default behavior, including giving it access to external systems like the internet, a database, your computer, or any API.
Expand Down
5 changes: 4 additions & 1 deletion src/marvin/beta/assistants/assistants.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
run_async,
run_sync,
)
from marvin.utilities.jinja import Environment as JinjaEnvironment
from marvin.utilities.logging import get_logger

from .threads import Thread
Expand Down Expand Up @@ -92,7 +93,9 @@ def get_tools(self) -> list[AssistantTool]:
]

def get_instructions(self, thread: Thread = None) -> str:
return self.instructions or ""
if self.instructions:
return JinjaEnvironment.render(self.instructions, self_=self)
return ""

@expose_sync_method("say")
async def say_async(
Expand Down
90 changes: 44 additions & 46 deletions tests/ai/beta/vision/test_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

class Location(BaseModel):
city: str
state: str = Field(description="The two letter abbreviation")
state: str = Field(description="The two letter abbreviation for the state")


@pytest.mark.flaky(max_runs=2)
Expand All @@ -17,45 +17,45 @@ def test_ny(self):
img = marvin.beta.Image(
"https://images.unsplash.com/photo-1568515387631-8b650bbcdb90"
)
result = marvin.beta.extract(img, target=Location)
assert result in (
[Location(city="New York", state="NY")],
[Location(city="New York City", state="NY")],
)
locations = marvin.beta.extract(img, target=Location)
assert len(locations) == 1
location = locations[0]
assert location.city.startswith("New York") or location.city == "Manhattan"
assert location.state == "NY"

def test_ny_images_input(self):
img = marvin.beta.Image(
"https://images.unsplash.com/photo-1568515387631-8b650bbcdb90"
)
result = marvin.beta.extract(data=None, images=[img], target=Location)
assert result in (
[Location(city="New York", state="NY")],
[Location(city="New York City", state="NY")],
)
locations = marvin.beta.extract(data=None, images=[img], target=Location)
assert len(locations) == 1
location = locations[0]
assert location.city.startswith("New York") or location.city == "Manhattan"
assert location.state == "NY"

def test_ny_image_input(self):
img = marvin.beta.Image(
"https://images.unsplash.com/photo-1568515387631-8b650bbcdb90"
)
result = marvin.beta.extract(data=img, target=Location)
assert result in (
[Location(city="New York", state="NY")],
[Location(city="New York City", state="NY")],
)
locations = marvin.beta.extract(data=img, target=Location)
assert len(locations) == 1
location = locations[0]
assert location.city.startswith("New York") or location.city == "Manhattan"
assert location.state == "NY"

def test_ny_image_and_text(self):
img = marvin.beta.Image(
"https://images.unsplash.com/photo-1568515387631-8b650bbcdb90"
)
result = marvin.beta.extract(
locations = marvin.beta.extract(
data="I see the empire state building",
images=[img],
target=Location,
)
assert result in (
[Location(city="New York", state="NY")],
[Location(city="New York City", state="NY")],
)
assert len(locations) == 1
location = locations[0]
assert location.city.startswith("New York") or location.city == "Manhattan"
assert location.state == "NY"

@pytest.mark.flaky(max_runs=3)
def test_dog(self):
Expand Down Expand Up @@ -90,11 +90,11 @@ async def test_ny(self):
img = marvin.beta.Image(
"https://images.unsplash.com/photo-1568515387631-8b650bbcdb90"
)
result = await marvin.beta.extract_async(img, target=Location)
assert result in (
[Location(city="New York", state="NY")],
[Location(city="New York City", state="NY")],
)
locations = await marvin.beta.extract_async(img, target=Location)
assert len(locations) == 1
location = locations[0]
assert location.city.startswith("New York") or location.city == "Manhattan"
assert location.state == "NY"


class TestMapping:
Expand All @@ -105,16 +105,15 @@ def test_map(self):
dc = marvin.beta.Image(
"https://images.unsplash.com/photo-1617581629397-a72507c3de9e"
)
result = marvin.beta.extract.map([ny, dc], target=Location)
assert isinstance(result, list)
assert result[0][0] in (
Location(city="New York", state="NY"),
Location(city="New York City", state="NY"),
)
assert result[1][0] in (
Location(city="Washington", state="DC"),
Location(city="Washington", state="D.C."),
)
locations = marvin.beta.extract.map([ny, dc], target=Location)
assert len(locations) == 2
ny_location, dc_location = locations

assert ny_location[0].city.startswith("New York")
assert ny_location[0].state == "NY"

assert dc_location[0].city == "Washington"
assert dc_location[0].state.index("D") < dc_location[0].state.index("C")

async def test_async_map(self):
ny = marvin.beta.Image(
Expand All @@ -123,13 +122,12 @@ async def test_async_map(self):
dc = marvin.beta.Image(
"https://images.unsplash.com/photo-1617581629397-a72507c3de9e"
)
result = await marvin.beta.extract_async.map([ny, dc], target=Location)
assert isinstance(result, list)
assert result[0][0] in (
Location(city="New York", state="NY"),
Location(city="New York City", state="NY"),
)
assert result[1][0] in (
Location(city="Washington", state="DC"),
Location(city="Washington", state="D.C."),
)
locations = await marvin.beta.extract_async.map([ny, dc], target=Location)
assert len(locations) == 2
ny_location, dc_location = locations

assert ny_location[0].city.startswith("New York")
assert ny_location[0].state == "NY"

assert dc_location[0].city == "Washington"
assert dc_location[0].state.index("D") < dc_location[0].state.index("C")

0 comments on commit d3e8e16

Please sign in to comment.