Skip to content

Commit

Permalink
hotfix(example): fix ag_news example evaluation typo (#2505)
Browse files Browse the repository at this point in the history
  • Loading branch information
tianweidut authored Jul 14, 2023
1 parent 11a2637 commit 801d4c6
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
7 changes: 5 additions & 2 deletions example/text_cls_AG_NEWS/tcan/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import gradio
from torchtext.data.utils import get_tokenizer, ngrams_iterator

from starwhale import PipelineHandler, multi_classification
from starwhale import Text, PipelineHandler, multi_classification
from starwhale.api.service import api

from .model import TextClassificationModel
Expand All @@ -24,7 +24,10 @@ def __init__(self) -> None:

@torch.no_grad()
def ppl(self, data: dict, **kw):
ngrams = list(ngrams_iterator(self.tokenizer(data["text"]), 2))
content = (
data["text"].content if isinstance(data["text"], Text) else data["text"]
)
ngrams = list(ngrams_iterator(self.tokenizer(content), 2))
tensor = torch.tensor(self.vocab(ngrams)).to(self.device)
output = self.model(tensor, torch.tensor([0]).to(self.device))
pred_value = output.argmax(1).item()
Expand Down
9 changes: 6 additions & 3 deletions example/transformer/ag_news/code/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification

from starwhale import PipelineHandler, multi_classification
from starwhale import Text, PipelineHandler, multi_classification

ROOTDIR = Path(__file__).parent.parent
_LABEL_NAMES = ["LABEL_0", "LABEL_1", "LABEL_2", "LABEL_3"]
Expand All @@ -15,12 +15,15 @@ def __init__(self) -> None:
model = AutoModelForSequenceClassification.from_pretrained(
str(ROOTDIR / "models")
)
self.mode = pipeline(
self.model = pipeline(
task="text-classification", model=model, tokenizer=tokenizer
)

def ppl(self, data):
_r = self.mode(data["text"])
content = (
data["text"].content if isinstance(data["text"], Text) else data["text"]
)
_r = self.model(content)
return _LABEL_NAMES.index(_r[0]["label"])

@multi_classification(
Expand Down

0 comments on commit 801d4c6

Please sign in to comment.