From 801d4c62067ea1284fc5c0f6477750bc2e28130c Mon Sep 17 00:00:00 2001 From: tianwei Date: Fri, 14 Jul 2023 14:24:21 +0800 Subject: [PATCH] hotfix(example): fix ag_news example evaluation typo (#2505) --- example/text_cls_AG_NEWS/tcan/evaluator.py | 7 +++++-- example/transformer/ag_news/code/evaluator.py | 9 ++++++--- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/example/text_cls_AG_NEWS/tcan/evaluator.py b/example/text_cls_AG_NEWS/tcan/evaluator.py index 6470e1dcb5..2fe7a1dd9b 100644 --- a/example/text_cls_AG_NEWS/tcan/evaluator.py +++ b/example/text_cls_AG_NEWS/tcan/evaluator.py @@ -4,7 +4,7 @@ import gradio from torchtext.data.utils import get_tokenizer, ngrams_iterator -from starwhale import PipelineHandler, multi_classification +from starwhale import Text, PipelineHandler, multi_classification from starwhale.api.service import api from .model import TextClassificationModel @@ -24,7 +24,10 @@ def __init__(self) -> None: @torch.no_grad() def ppl(self, data: dict, **kw): - ngrams = list(ngrams_iterator(self.tokenizer(data["text"]), 2)) + content = ( + data["text"].content if isinstance(data["text"], Text) else data["text"] + ) + ngrams = list(ngrams_iterator(self.tokenizer(content), 2)) tensor = torch.tensor(self.vocab(ngrams)).to(self.device) output = self.model(tensor, torch.tensor([0]).to(self.device)) pred_value = output.argmax(1).item() diff --git a/example/transformer/ag_news/code/evaluator.py b/example/transformer/ag_news/code/evaluator.py index f408c45d07..5be074f5d8 100644 --- a/example/transformer/ag_news/code/evaluator.py +++ b/example/transformer/ag_news/code/evaluator.py @@ -2,7 +2,7 @@ from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification -from starwhale import PipelineHandler, multi_classification +from starwhale import Text, PipelineHandler, multi_classification ROOTDIR = Path(__file__).parent.parent _LABEL_NAMES = ["LABEL_0", "LABEL_1", "LABEL_2", "LABEL_3"] @@ -15,12 +15,15 @@ def __init__(self) -> None: model = AutoModelForSequenceClassification.from_pretrained( str(ROOTDIR / "models") ) - self.mode = pipeline( + self.model = pipeline( task="text-classification", model=model, tokenizer=tokenizer ) def ppl(self, data): - _r = self.mode(data["text"]) + content = ( + data["text"].content if isinstance(data["text"], Text) else data["text"] + ) + _r = self.model(content) return _LABEL_NAMES.index(_r[0]["label"]) @multi_classification(