diff --git a/trl/scripts/chat.py b/trl/scripts/chat.py index d1746441ea..2ede78e6cb 100644 --- a/trl/scripts/chat.py +++ b/trl/scripts/chat.py @@ -17,6 +17,7 @@ import copy import json import os +import platform import pwd import re import time @@ -34,6 +35,9 @@ from trl.trainer.utils import get_quantization_config +if platform.system() != "Windows": + import pwd + init_zero_verbose() HELP_STRING = """\ @@ -217,7 +221,10 @@ def print_help(self): def get_username(): - return pwd.getpwuid(os.getuid())[0] + if platform.system() == "Windows": + return os.getlogin() + else: + return pwd.getpwuid(os.getuid()).pw_name def create_default_filename(model_name):