Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MSTR-376 : [AI] FastAPI로 마이그레이션 #13

Merged
merged 16 commits into from
Dec 17, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
♻️ Change bentoml test code to FastAPI test code
change bentoml test code to fastapi test code
  • Loading branch information
ekzm8523 committed Nov 18, 2022
commit 53eb4a2a0b77c617c19308c59370d3e89b2dd06e
48 changes: 31 additions & 17 deletions app/api/dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from sentence_transformers import SentenceTransformer

from app.config import (
CONTENT_LOCAL_MODEL_PATH,
CONTENT_MODEL_NAME,
CONTENT_MODEL_PATH,
CONTENT_MODEL_S3_PATH,
Expand All @@ -26,29 +27,36 @@
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
keyword_model: Optional[SentenceTransformer] = None
content_model: Optional[PromptForClassification] = None
keyword_controller: Optional[KeywordController] = None
content_controller: Optional[ContentController] = None


def init_model() -> None:
global keyword_model, content_model

if os.path.isfile(KEYWORD_LOCAL_MODEL_PATH):
keyword_model = SentenceTransformer()
keyword_model.load_state_dict(torch.load(KEYWORD_LOCAL_MODEL_PATH))
if os.path.exists(KEYWORD_LOCAL_MODEL_PATH):
keyword_model = torch.load(KEYWORD_LOCAL_MODEL_PATH)
else:
keyword_model = SentenceTransformer(KEYWORD_MODEL_PATH)
torch.save(keyword_model.state_dict(), KEYWORD_LOCAL_MODEL_PATH)
torch.save(keyword_model, KEYWORD_LOCAL_MODEL_PATH)
keyword_model.eval()

model_class = get_model_class(plm_type=CONTENT_MODEL_NAME)
plm = model_class.model.from_pretrained(CONTENT_MODEL_PATH)
tokenizer = model_class.tokenizer.from_pretrained(CONTENT_MODEL_PATH)
template_text = get_template_text()
template = ManualTemplate(tokenizer=tokenizer, text=template_text)
verbalizer = ManualVerbalizer(tokenizer=tokenizer, num_classes=2, label_words=[["yes"], ["no"]])
if os.path.exists(CONTENT_LOCAL_MODEL_PATH):
content_model = torch.load(CONTENT_LOCAL_MODEL_PATH)
else:
model_class = get_model_class(plm_type=CONTENT_MODEL_NAME)
plm = model_class.model.from_pretrained(CONTENT_MODEL_PATH)
tokenizer = model_class.tokenizer.from_pretrained(CONTENT_MODEL_PATH)
template_text = get_template_text()
template = ManualTemplate(tokenizer=tokenizer, text=template_text)
verbalizer = ManualVerbalizer(tokenizer=tokenizer, num_classes=2, label_words=[["yes"], ["no"]])

content_model = PromptForClassification(plm=plm, template=template, verbalizer=verbalizer)
content_model = PromptForClassification(plm=plm, template=template, verbalizer=verbalizer)

model_path = s3.download(url=CONTENT_MODEL_S3_PATH, local_dir=".cache")
content_model.load_state_dict(torch.load(model_path, map_location=device))
model_path = s3.download(url=CONTENT_MODEL_S3_PATH, local_dir=".cache")
content_model.load_state_dict(torch.load(model_path, map_location=device))
torch.save(content_model, CONTENT_LOCAL_MODEL_PATH)
content_model.eval()


def get_keyword_grading_model() -> SentenceTransformer:
Expand All @@ -59,12 +67,18 @@ def get_content_grading_model() -> PromptForClassification:
return content_model


def get_content_controller(model: PromptForClassification = Depends(get_content_grading_model)) -> ContentController:
return ContentController(model)
def get_keyword_controller(model: SentenceTransformer = Depends(get_keyword_grading_model)) -> KeywordController:
global keyword_controller
if keyword_controller is None:
keyword_controller = KeywordController(model)
return keyword_controller


def get_keyword_controller(model: SentenceTransformer = Depends(get_keyword_grading_model)) -> KeywordController:
return KeywordController(model)
def get_content_controller(model: PromptForClassification = Depends(get_content_grading_model)) -> ContentController:
global content_controller
if content_controller is None:
content_controller = ContentController(model)
return content_controller


init_model()
8 changes: 3 additions & 5 deletions app/api/v1/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ async def keyword_predict(
keyword_grading_req: KeywordGradingRequest = Body(...),
keyword_controller: KeywordController = Depends(get_keyword_controller),
) -> KeywordGradingResponse:
return await keyword_controller.is_correct_keyword(keyword_grading_req)
return await keyword_controller.grading(keyword_grading_req)


@router.post("/integrate")
Expand All @@ -37,11 +37,9 @@ async def integrate_predict(
)

keyword_grading_result, content_grading_result = await asyncio.gather(
keyword_controller.is_correct_keyword(keyword_predict_input),
content_controller.is_correct_content(content_predict_input),
keyword_controller.grading(keyword_predict_input),
content_controller.grading(content_predict_input),
)
# keyword_grading_result = keyword_controller.is_correct_keyword(keyword_predict_input)
# content_grading_result = content_controller.is_correct_content(content_predict_input)
return IntegratedGradingResponse(
problem_id=keyword_grading_result.problem_id,
correct_keywords=keyword_grading_result.correct_keywords,
Expand Down
4 changes: 2 additions & 2 deletions app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def get_secret():
CONTENT_MODEL_NAME = "t5"
CONTENT_MODEL_PATH = "google/mt5-base"
KEYWORD_MODEL_PATH = "Huffon/sentence-klue-roberta-base"
KEYWORD_LOCAL_MODEL_PATH = os.path.join(root, f"app/static/{STAGE}_keyword_model.pth")
CONTENT_LOCAL_MODEL_PATH = os.path.join(root, f"app/static/{STAGE}_content_model.pth")
KEYWORD_LOCAL_MODEL_PATH = os.path.join(root, f"app/static/{STAGE}_keyword_model")
CONTENT_LOCAL_MODEL_PATH = os.path.join(root, f"app/static/{STAGE}_content_model")
KEYWORD_MODEL_S3_PATH = os.getenv("KEYWORD_MODEL_S3_PATH")
CONTENT_MODEL_S3_PATH = os.getenv("CONTENT_MODEL_S3_PATH")
STOPWORD_FILE_PATH = os.path.join(root, "app/static/stopwords.txt")
Expand Down
9 changes: 9 additions & 0 deletions app/controller/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from abc import ABC, abstractmethod

from app.schemas import UserAnswer


class BaseController(ABC):
@abstractmethod
async def grading(self, input_data: UserAnswer):
pass
5 changes: 3 additions & 2 deletions app/controller/content.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
from openprompt.data_utils import InputExample
from openprompt.plms import T5TokenizerWrapper

from app.controller.base import BaseController
from app.schemas import ContentGradingRequest, ContentGradingResponse, ContentResponse

log = logging.getLogger("__main__")


class ContentController:
class ContentController(BaseController):
def __init__(self, model: PromptForClassification):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
log.info(f"content predict model is running on : {self.device}")
Expand All @@ -28,7 +29,7 @@ def __init__(self, model: PromptForClassification):
def is_correct(predict) -> bool:
return predict == 1

async def is_correct_content(self, input_data: ContentGradingRequest) -> ContentGradingResponse:
async def grading(self, input_data: ContentGradingRequest) -> ContentGradingResponse:
log.info(pformat(input_data.__dict__))
user_answer = input_data.user_answer.strip()
input_data_list = [
Expand Down
5 changes: 3 additions & 2 deletions app/controller/keyeword.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from sklearn.metrics.pairwise import cosine_similarity

from app.config import MECAB_DIC_PATH, OS
from app.controller.base import BaseController
from app.schemas import (
KeywordGradingRequest,
KeywordGradingResponse,
Expand All @@ -23,7 +24,7 @@
log = logging.getLogger("__main__")


class KeywordController:
class KeywordController(BaseController):
def __init__(self, model: SentenceTransformer, problem_dict: Optional[dict] = None):
self.problem_dict = problem_dict if problem_dict else {}
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand Down Expand Up @@ -104,7 +105,7 @@ def get_predicted_keyword_position(
end_idx = user_answer.find(last_word) + len(last_word)
return start_idx, end_idx

async def is_correct_keyword(self, input_data: KeywordGradingRequest) -> KeywordGradingResponse:
async def grading(self, input_data: KeywordGradingRequest) -> KeywordGradingResponse:
log.info(pformat(input_data.__dict__))
self.synchronize_keywords(input_data)

Expand Down
135 changes: 72 additions & 63 deletions app/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,72 @@
# import pandas as pd
# import pytest
# from openprompt import PromptForClassification
# from sentence_transformers import SentenceTransformer
#
# from app.model import get_content_grading_model, get_keyword_grading_model
# from app.runnable import KeywordPredictRunnable
# from app.schemas import ContentGradingRequest, KeywordGradingRequest
# from app.tests.factory import ContentDataFactory, KeywordDataFactory
#
#
# @pytest.fixture(scope="session")
# def keyword_model() -> SentenceTransformer:
# return get_keyword_grading_model()
#
#
# @pytest.fixture(scope="session")
# def content_model() -> PromptForClassification:
# return get_content_grading_model()
#
#
# @pytest.fixture(scope="session")
# def user_answer_df(
# path: str = "app/static/changed_user_answer.csv",
# ) -> pd.DataFrame:
# return pd.read_csv(path)
#
#
# @pytest.fixture(scope="session")
# def keyword_runnable() -> KeywordPredictRunnable:
# return KeywordPredictRunnable()
#
#
# @pytest.fixture(scope="session")
# def keyword_data_factory(keyword_model: SentenceTransformer) -> KeywordDataFactory:
# keyword_data_factory = KeywordDataFactory()
# keyword_data_factory.set_problem_dict(keyword_model)
# return keyword_data_factory
#
#
# @pytest.fixture(scope="session")
# def content_data_factory() -> ContentDataFactory:
# return ContentDataFactory()
#
#
# @pytest.fixture
# def random_multi_candidate_keyword_data(keyword_data_factory: KeywordDataFactory) -> KeywordGradingRequest:
# return keyword_data_factory.get_multi_candidate_keyword_request_data()
#
#
# @pytest.fixture
# def random_content_data(content_data_factory: ContentDataFactory) -> ContentGradingRequest:
# return content_data_factory.get_request_data()
#
#
# @pytest.fixture
# def random_keyword_data(keyword_data_factory: KeywordDataFactory) -> KeywordGradingRequest:
# return keyword_data_factory.get_request_data()
#
#
# @pytest.fixture
# def problem_dict(keyword_data_factory: KeywordDataFactory) -> dict:
# return keyword_data_factory.get_problem_dict()
import pandas as pd
import pytest
from openprompt import PromptForClassification
from sentence_transformers import SentenceTransformer

from app.api.dependency import (
get_content_controller,
get_content_grading_model,
get_keyword_controller,
get_keyword_grading_model,
)
from app.controller.content import ContentController
from app.controller.keyeword import KeywordController
from app.schemas import ContentGradingRequest, KeywordGradingRequest
from app.tests.factory import ContentDataFactory, KeywordDataFactory


@pytest.fixture(scope="session")
def user_answer_df(path: str = "app/static/changed_user_answer.csv") -> pd.DataFrame:
return pd.read_csv(path)


@pytest.fixture(scope="session")
def keyword_model() -> SentenceTransformer:
return get_keyword_grading_model()


@pytest.fixture(scope="session")
def content_model() -> PromptForClassification:
return get_content_grading_model()


@pytest.fixture(scope="session")
def keyword_controller(keyword_model: SentenceTransformer) -> KeywordController:
return get_keyword_controller(keyword_model)


@pytest.fixture(scope="session")
def content_controller(content_model: PromptForClassification) -> ContentController:
return get_content_controller(content_model)


@pytest.fixture(scope="session")
def keyword_data_factory(keyword_model: SentenceTransformer) -> KeywordDataFactory:
keyword_data_factory = KeywordDataFactory()
keyword_data_factory.set_problem_dict(keyword_model)
return keyword_data_factory


@pytest.fixture(scope="session")
def content_data_factory() -> ContentDataFactory:
return ContentDataFactory()


@pytest.fixture
def random_multi_candidate_keyword_data(keyword_data_factory: KeywordDataFactory) -> KeywordGradingRequest:
return keyword_data_factory.get_multi_candidate_keyword_request_data()


@pytest.fixture
def random_content_data(content_data_factory: ContentDataFactory) -> ContentGradingRequest:
return content_data_factory.get_request_data()


@pytest.fixture
def random_keyword_data(keyword_data_factory: KeywordDataFactory) -> KeywordGradingRequest:
return keyword_data_factory.get_request_data()


@pytest.fixture
def problem_dict(keyword_data_factory: KeywordDataFactory) -> dict:
return keyword_data_factory.get_problem_dict()
Loading