Skip to content

Commit

Permalink
Fixing new default scoring
Browse files Browse the repository at this point in the history
  • Loading branch information
ClementRomac committed Nov 21, 2023
1 parent 744ffa3 commit 4b36c0c
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 5 deletions.
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .base_module_function import BaseModuleFunction
from .score_module_function import ScoreModuleFunction
from .score_module_function import LogScoringModuleFn
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
import torch

class LogScoringModuleFn(BaseModuleFunction):
def __init__(self, model_type, pre_encoded_input):
def __init__(self, pade_token, model_type, pre_encoded_input):
super().__init__()
self._pad_token = pade_token
self._model_type = model_type
self._pad_token = 0
self._pre_encoded_input = pre_encoded_input

def initialize(self):
Expand Down
5 changes: 3 additions & 2 deletions lamorel/src/lamorel/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from .llms import HF_LLM
from .llms.updaters import BaseUpdater
from .llms.module_functions import BaseModuleFunction, ScoreModuleFunction
from .llms.module_functions import BaseModuleFunction, LogScoringModuleFn
from .dispatcher import Dispatcher
from .utils import InstructionsEnum

Expand Down Expand Up @@ -41,7 +41,8 @@ def __init__(self, config, llm_index, llm_group, llm_master, rl_llm_group, rl_ll
self._dispatcher = Dispatcher(self._llm_group, self._rl_llm_group_size - 1, self._llm_group_size,
self._is_main_server, self._master_server_rank, self._index)

custom_module_functions["__score"] = ScoreModuleFunction(self._model.pad_token, config.llm_args.model_type)
custom_module_functions["__score"] = LogScoringModuleFn(self._model.pad_token, config.llm_args.model_type,
config.llm_args.pre_encode_inputs)
for k, _fn in custom_module_functions.items():
assert isinstance(_fn, BaseModuleFunction)
_fn.device = self._model.device
Expand Down

0 comments on commit 4b36c0c

Please sign in to comment.