diff --git a/docs/docs/interactive/assistants.md b/docs/docs/interactive/assistants.md index c926f9ee4..c9ac9d3a9 100644 --- a/docs/docs/interactive/assistants.md +++ b/docs/docs/interactive/assistants.md @@ -171,6 +171,9 @@ Each assistant can be given `instructions` that describe its purpose, personalit !!! success "Result" ![](/assets/images/docs/assistants/instructions.png) +Instructions are rendered as a Jinja template, which means you can use variables and conditionals to customize the assistant's behavior. A special variable, `self_` is provided to the template, which represents the assistant object itself. This allows you to template the assistant's name, tools, or other attributes into the instructions. + + ### Tools Each assistant can be given a list of `tools` that it can use when responding to a message. Tools are a way to extend the assistant's capabilities beyond its default behavior, including giving it access to external systems like the internet, a database, your computer, or any API. diff --git a/src/marvin/beta/assistants/assistants.py b/src/marvin/beta/assistants/assistants.py index c9dfa1192..137b182fe 100644 --- a/src/marvin/beta/assistants/assistants.py +++ b/src/marvin/beta/assistants/assistants.py @@ -20,6 +20,7 @@ run_async, run_sync, ) +from marvin.utilities.jinja import Environment as JinjaEnvironment from marvin.utilities.logging import get_logger from .threads import Thread @@ -92,7 +93,9 @@ def get_tools(self) -> list[AssistantTool]: ] def get_instructions(self, thread: Thread = None) -> str: - return self.instructions or "" + if self.instructions: + return JinjaEnvironment.render(self.instructions, self_=self) + return "" @expose_sync_method("say") async def say_async( diff --git a/tests/ai/beta/vision/test_extract.py b/tests/ai/beta/vision/test_extract.py index d148a67da..b428af77c 100644 --- a/tests/ai/beta/vision/test_extract.py +++ b/tests/ai/beta/vision/test_extract.py @@ -8,7 +8,7 @@ class Location(BaseModel): city: str - state: str = Field(description="The two letter abbreviation") + state: str = Field(description="The two letter abbreviation for the state") @pytest.mark.flaky(max_runs=2) @@ -17,45 +17,45 @@ def test_ny(self): img = marvin.beta.Image( "https://images.unsplash.com/photo-1568515387631-8b650bbcdb90" ) - result = marvin.beta.extract(img, target=Location) - assert result in ( - [Location(city="New York", state="NY")], - [Location(city="New York City", state="NY")], - ) + locations = marvin.beta.extract(img, target=Location) + assert len(locations) == 1 + location = locations[0] + assert location.city.startswith("New York") or location.city == "Manhattan" + assert location.state == "NY" def test_ny_images_input(self): img = marvin.beta.Image( "https://images.unsplash.com/photo-1568515387631-8b650bbcdb90" ) - result = marvin.beta.extract(data=None, images=[img], target=Location) - assert result in ( - [Location(city="New York", state="NY")], - [Location(city="New York City", state="NY")], - ) + locations = marvin.beta.extract(data=None, images=[img], target=Location) + assert len(locations) == 1 + location = locations[0] + assert location.city.startswith("New York") or location.city == "Manhattan" + assert location.state == "NY" def test_ny_image_input(self): img = marvin.beta.Image( "https://images.unsplash.com/photo-1568515387631-8b650bbcdb90" ) - result = marvin.beta.extract(data=img, target=Location) - assert result in ( - [Location(city="New York", state="NY")], - [Location(city="New York City", state="NY")], - ) + locations = marvin.beta.extract(data=img, target=Location) + assert len(locations) == 1 + location = locations[0] + assert location.city.startswith("New York") or location.city == "Manhattan" + assert location.state == "NY" def test_ny_image_and_text(self): img = marvin.beta.Image( "https://images.unsplash.com/photo-1568515387631-8b650bbcdb90" ) - result = marvin.beta.extract( + locations = marvin.beta.extract( data="I see the empire state building", images=[img], target=Location, ) - assert result in ( - [Location(city="New York", state="NY")], - [Location(city="New York City", state="NY")], - ) + assert len(locations) == 1 + location = locations[0] + assert location.city.startswith("New York") or location.city == "Manhattan" + assert location.state == "NY" @pytest.mark.flaky(max_runs=3) def test_dog(self): @@ -90,11 +90,11 @@ async def test_ny(self): img = marvin.beta.Image( "https://images.unsplash.com/photo-1568515387631-8b650bbcdb90" ) - result = await marvin.beta.extract_async(img, target=Location) - assert result in ( - [Location(city="New York", state="NY")], - [Location(city="New York City", state="NY")], - ) + locations = await marvin.beta.extract_async(img, target=Location) + assert len(locations) == 1 + location = locations[0] + assert location.city.startswith("New York") or location.city == "Manhattan" + assert location.state == "NY" class TestMapping: @@ -105,16 +105,15 @@ def test_map(self): dc = marvin.beta.Image( "https://images.unsplash.com/photo-1617581629397-a72507c3de9e" ) - result = marvin.beta.extract.map([ny, dc], target=Location) - assert isinstance(result, list) - assert result[0][0] in ( - Location(city="New York", state="NY"), - Location(city="New York City", state="NY"), - ) - assert result[1][0] in ( - Location(city="Washington", state="DC"), - Location(city="Washington", state="D.C."), - ) + locations = marvin.beta.extract.map([ny, dc], target=Location) + assert len(locations) == 2 + ny_location, dc_location = locations + + assert ny_location[0].city.startswith("New York") + assert ny_location[0].state == "NY" + + assert dc_location[0].city == "Washington" + assert dc_location[0].state.index("D") < dc_location[0].state.index("C") async def test_async_map(self): ny = marvin.beta.Image( @@ -123,13 +122,12 @@ async def test_async_map(self): dc = marvin.beta.Image( "https://images.unsplash.com/photo-1617581629397-a72507c3de9e" ) - result = await marvin.beta.extract_async.map([ny, dc], target=Location) - assert isinstance(result, list) - assert result[0][0] in ( - Location(city="New York", state="NY"), - Location(city="New York City", state="NY"), - ) - assert result[1][0] in ( - Location(city="Washington", state="DC"), - Location(city="Washington", state="D.C."), - ) + locations = await marvin.beta.extract_async.map([ny, dc], target=Location) + assert len(locations) == 2 + ny_location, dc_location = locations + + assert ny_location[0].city.startswith("New York") + assert ny_location[0].state == "NY" + + assert dc_location[0].city == "Washington" + assert dc_location[0].state.index("D") < dc_location[0].state.index("C")