Skip to content

Commit

Permalink
feat: text-classification
Browse files Browse the repository at this point in the history
  • Loading branch information
jonafeucht committed Jun 11, 2024
1 parent 7e92d2c commit b5e32d7
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 0 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ services:
environment:
- DEFAULT_SUMMARIZATION_MODEL_NAME
- DEFAULT_TRANSLATION_MODEL_NAME
- DEFAULT_TEXT_CLASSIFICATION_MODEL_NAME
- ACCESS_TOKEN
- USE_API_KEYS
- API_KEYS
Expand All @@ -42,6 +43,7 @@ services:
environment:
- DEFAULT_SUMMARIZATION_MODEL_NAME
- DEFAULT_TRANSLATION_MODEL_NAME
- DEFAULT_TEXT_CLASSIFICATION_MODEL_NAME
- ACCESS_TOKEN
- DEFAULT_SCORE
- USE_API_KEYS
Expand All @@ -67,6 +69,7 @@ volumes:
```sh
DEFAULT_SUMMARIZATION_MODEL_NAME=Falconsai/text_summarization
DEFAULT_TRANSLATION_MODEL_NAME=google-t5/t5-base
DEFAULT_TEXT_CLASSIFICATION_MODEL_NAME=s-nlp/roberta_toxicity_classifier
ACCESS_TOKEN=
# False == Public Access
Expand All @@ -80,6 +83,7 @@ API_KEYS=abc,123,xyz
## Supported NLP tasks
- [x] [Summarization](https://huggingface.co/tasks/summarization)
- [x] [Translation](https://huggingface.co/tasks/translation)
- [x] [Text Classification](https://huggingface.co/tasks/text-classification)

## Models
Any model designed for above tasks and compatible with huggingface transformers should work.
Expand Down
2 changes: 2 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ services:
environment:
- DEFAULT_SUMMARIZATION_MODEL_NAME
- DEFAULT_TRANSLATION_MODEL_NAME
- DEFAULT_TEXT_CLASSIFICATION_MODEL_NAME
- ACCESS_TOKEN
- DEFAULT_SCORE
- USE_API_KEYS
Expand All @@ -23,6 +24,7 @@ services:
environment:
- DEFAULT_SUMMARIZATION_MODEL_NAME
- DEFAULT_TRANSLATION_MODEL_NAME
- DEFAULT_TEXT_CLASSIFICATION_MODEL_NAME
- ACCESS_TOKEN
- DEFAULT_SCORE
- USE_API_KEYS
Expand Down
2 changes: 2 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from fastapi import FastAPI
from src.routes.api import summarization
from src.routes.api import translation
from src.routes.api import text_classification

app = FastAPI()
app.include_router(summarization.router)
app.include_router(translation.router)
app.include_router(text_classification.router)


@app.get("/")
Expand Down
25 changes: 25 additions & 0 deletions src/routes/api/text_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from fastapi import APIRouter, Query, Depends
from src.middleware.auth.auth import get_api_key
from src.shared.shared import text_classification_model
import time

router = APIRouter()


@router.post("/api/text-classification", dependencies=[Depends(get_api_key)])
async def text_classification(
text: str,
model_name: str = Query(None),
):
start_time = time.time()
text_classifier = text_classification_model(model_name)
try:
response = text_classifier(text)
return {
"execution_time": time.time() - start_time,
"res": response,
}

except Exception as e:
print("Something went wrong: ", e)
return {"error": str(e)}
20 changes: 20 additions & 0 deletions src/shared/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
"DEFAULT_TRANSLATION_MODEL_NAME", "google-t5/t5-base"
)

default_text_classification_model_name = os.getenv(
"DEFAULT_TEXT_CLASSIFICATION_MODEL_NAME", "s-nlp/roberta_toxicity_classifier"
)

device = 0 if torch.cuda.is_available() else -1

# API KEY
Expand Down Expand Up @@ -54,3 +58,19 @@ def translation_model(model_name, input_language, output_language):
except Exception as e:
print(e)
return {"error": str(e)}


def text_classification_model(model_name):
try:
_model_name = model_name or default_text_classification_model_name

translator = pipeline(
"text-classification",
model=_model_name,
device=device,
)

return translator
except Exception as e:
print(e)
return {"error": str(e)}

0 comments on commit b5e32d7

Please sign in to comment.