Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Gpt2 reranker #4326

Merged
merged 3 commits into from
Jan 27, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions parlai/agents/reranker/classifier_gpt2_reranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""
Classifier Gpt2 Re-Ranker Object.

Provided with a classifier model file, the gpt2 re-ranker provides an API for re-ranking
candidate outputs based on maximizing the probability of a given provided class.
"""
from typing import Optional, List
from parlai.core.agents import create_agent_from_model_file
from parlai.core.message import Message
from parlai.core.opt import Opt
from parlai.core.params import ParlaiParser

from parlai.agents.reranker.reranker import AbstractGpt2RerankAgent
from parlai.agents.reranker.classifier_reranker import ClassifierReranker


class ClassifierGpt2Reranker(ClassifierReranker):
pass


class ClassifierGpt2RerankerAgent(AbstractGpt2RerankAgent):
"""
Generative GPT2 Re-rank agent for adding a ClassifierReranker.
"""

@classmethod
def get_reranker_class(cls):
return ClassifierGpt2Reranker
17 changes: 16 additions & 1 deletion parlai/agents/reranker/reranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from parlai.core.torch_agent import TorchAgent
from parlai.utils.strings import normalize_reply
from parlai.utils.torch import argsort

from parlai.agents.hugging_face.gpt2 import Gpt2Agent
from projects.msc.agents.long_tga import TransformerVariantAgent

RERANKER_STRATEGIES = ['sum_scores', 'hard_choice', 'reranker_score', 'none']
Expand Down Expand Up @@ -574,3 +574,18 @@ def add_cmdline_args(
reranker_class = cls.get_reranker_class() or AbstractReranker
reranker_class.add_cmdline_args(parser, partial_opt=partial_opt)
return parser


class AbstractGpt2RerankAgent(AbstractGeneratorRerankAgentMixin, Gpt2Agent, ABC):
@classmethod
def add_cmdline_args(
cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None
) -> ParlaiParser:
"""
Add command-line arguments specifically for this agent.
"""
Gpt2Agent.add_cmdline_args(parser, partial_opt=partial_opt)
AbstractGeneratorRerankAgentMixin.add_cmdline_args(parser, partial_opt)
reranker_class = cls.get_reranker_class() or AbstractReranker
reranker_class.add_cmdline_args(parser, partial_opt=partial_opt)
return parser