Skip to content

Commit

Permalink
ifx
Browse files Browse the repository at this point in the history
  • Loading branch information
jonafeucht committed Jun 12, 2024
1 parent 40c8ec4 commit 4579167
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
10 changes: 6 additions & 4 deletions src/routes/api/image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ async def image_classification(
):

try:
classifier = check_model(model_name)

# Read the file as bytes
contents = await file.read()
# Check if the image is in fact an image
Expand All @@ -38,7 +40,6 @@ async def image_classification(
# Check if the image is a GIF and if it's animated
if img.format.lower() == "gif":
try:
classifier = check_model(model_name)

results = []
with ThreadPoolExecutor() as executor:
Expand Down Expand Up @@ -77,7 +78,7 @@ async def image_classification(
)
finally:
img.close()
del res2
del classifier
torch.cuda.empty_cache()
else:
return HTTPException(
Expand All @@ -103,6 +104,8 @@ async def multi_image_classification(

for index, file in enumerate(files):
try:
classifier = check_model(model_name)

# Read the file as bytes
contents = await file.read()

Expand All @@ -115,7 +118,6 @@ async def multi_image_classification(
# Check if the image is a GIF and if it's animated
if img.format.lower() == "gif":
try:
classifier = check_model(model_name)

results = []
with ThreadPoolExecutor() as executor:
Expand Down Expand Up @@ -151,7 +153,7 @@ async def multi_image_classification(
)
finally:
img.close()
del res2
del classifier
torch.cuda.empty_cache()

else:
Expand Down
11 changes: 6 additions & 5 deletions src/routes/api/image_query_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ async def image_query_classification(
# Read the file as bytes
contents = await file.read()

classifier = check_model(model_name)

for model_name in model_names:
_score = score or default_score

Expand All @@ -101,7 +103,6 @@ async def image_query_classification(
if img.format.lower() == "gif":

try:
classifier = check_model(model_name)

res = await asyncio.get_event_loop().run_in_executor(
executor,
Expand Down Expand Up @@ -156,7 +157,7 @@ async def image_query_classification(

finally:
img.close()
del res2
del classifier
torch.cuda.empty_cache()

except Exception as e:
Expand Down Expand Up @@ -198,6 +199,8 @@ async def multi_image_query_classification(
# Read the file as bytes
image_list = []

classifier = check_model(model_name)

for model_name in model_names:
try:
contents = await file.read()
Expand All @@ -218,8 +221,6 @@ async def multi_image_query_classification(
# Check if the image is a GIF and if it's animated
if img.format.lower() == "gif":
try:
classifier = check_model(model_name)

res = await asyncio.get_event_loop().run_in_executor(
executor,
process_image,
Expand Down Expand Up @@ -273,7 +274,7 @@ async def multi_image_query_classification(
)
finally:
img.close()
del res2
del classifier
torch.cuda.empty_cache()

except Exception as e:
Expand Down

0 comments on commit 4579167

Please sign in to comment.