Skip to content

Commit

Permalink
oops
Browse files Browse the repository at this point in the history
  • Loading branch information
zzstoatzz committed Mar 25, 2024
1 parent ae95a6d commit d52c07b
Showing 1 changed file with 44 additions and 45 deletions.
89 changes: 44 additions & 45 deletions tests/ai/beta/vision/test_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


class Location(BaseModel):
city: str = Field(description="Official city name, no boroughs or neighborhoods")
city: str
state: str = Field(description="The two letter abbreviation for the state")


Expand All @@ -17,44 +17,45 @@ def test_ny(self):
img = marvin.beta.Image(
"https://images.unsplash.com/photo-1568515387631-8b650bbcdb90"
)
result = marvin.beta.extract(img, target=Location)

assert result.city.startswith("New York")
assert result.state == "NY"
locations = marvin.beta.extract(img, target=Location)
assert len(locations) == 1
location = locations[0]
assert location.city.startswith("New York") or location.city == "Manhattan"
assert location.state == "NY"

def test_ny_images_input(self):
img = marvin.beta.Image(
"https://images.unsplash.com/photo-1568515387631-8b650bbcdb90"
)
result = marvin.beta.extract(data=None, images=[img], target=Location)
assert result in (
[Location(city="New York", state="NY")],
[Location(city="New York City", state="NY")],
)
locations = marvin.beta.extract(data=None, images=[img], target=Location)
assert len(locations) == 1
location = locations[0]
assert location.city.startswith("New York") or location.city == "Manhattan"
assert location.state == "NY"

def test_ny_image_input(self):
img = marvin.beta.Image(
"https://images.unsplash.com/photo-1568515387631-8b650bbcdb90"
)
result = marvin.beta.extract(data=img, target=Location)
assert result in (
[Location(city="New York", state="NY")],
[Location(city="New York City", state="NY")],
)
locations = marvin.beta.extract(data=img, target=Location)
assert len(locations) == 1
location = locations[0]
assert location.city.startswith("New York") or location.city == "Manhattan"
assert location.state == "NY"

def test_ny_image_and_text(self):
img = marvin.beta.Image(
"https://images.unsplash.com/photo-1568515387631-8b650bbcdb90"
)
result = marvin.beta.extract(
locations = marvin.beta.extract(
data="I see the empire state building",
images=[img],
target=Location,
)
assert result in (
[Location(city="New York", state="NY")],
[Location(city="New York City", state="NY")],
)
assert len(locations) == 1
location = locations[0]
assert location.city.startswith("New York") or location.city == "Manhattan"
assert location.state == "NY"

@pytest.mark.flaky(max_runs=3)
def test_dog(self):
Expand Down Expand Up @@ -89,11 +90,11 @@ async def test_ny(self):
img = marvin.beta.Image(
"https://images.unsplash.com/photo-1568515387631-8b650bbcdb90"
)
result = await marvin.beta.extract_async(img, target=Location)
assert result in (
[Location(city="New York", state="NY")],
[Location(city="New York City", state="NY")],
)
locations = await marvin.beta.extract_async(img, target=Location)
assert len(locations) == 1
location = locations[0]
assert location.city.startswith("New York") or location.city == "Manhattan"
assert location.state == "NY"


class TestMapping:
Expand All @@ -104,16 +105,15 @@ def test_map(self):
dc = marvin.beta.Image(
"https://images.unsplash.com/photo-1617581629397-a72507c3de9e"
)
result = marvin.beta.extract.map([ny, dc], target=Location)
assert isinstance(result, list)
assert result[0][0] in (
Location(city="New York", state="NY"),
Location(city="New York City", state="NY"),
)
assert result[1][0] in (
Location(city="Washington", state="DC"),
Location(city="Washington", state="D.C."),
)
locations = marvin.beta.extract.map([ny, dc], target=Location)
assert len(locations) == 2
ny_location, dc_location = locations

assert ny_location[0].city.startswith("New York")
assert ny_location[0].state == "NY"

assert dc_location[0].city == "Washington"
assert dc_location[0].state.index("D") < dc_location[0].state.index("C")

async def test_async_map(self):
ny = marvin.beta.Image(
Expand All @@ -122,13 +122,12 @@ async def test_async_map(self):
dc = marvin.beta.Image(
"https://images.unsplash.com/photo-1617581629397-a72507c3de9e"
)
result = await marvin.beta.extract_async.map([ny, dc], target=Location)
assert isinstance(result, list)
assert result[0][0] in (
Location(city="New York", state="NY"),
Location(city="New York City", state="NY"),
)
assert result[1][0] in (
Location(city="Washington", state="DC"),
Location(city="Washington", state="D.C."),
)
locations = await marvin.beta.extract_async.map([ny, dc], target=Location)
assert len(locations) == 2
ny_location, dc_location = locations

assert ny_location[0].city.startswith("New York")
assert ny_location[0].state == "NY"

assert dc_location[0].city == "Washington"
assert dc_location[0].state.index("D") < dc_location[0].state.index("C")

0 comments on commit d52c07b

Please sign in to comment.