@@ -123,7 +123,7 @@ def main():
123
123
parser .add_argument ('--history_cnt' , default = 0 , type = int )
124
124
parser .add_argument ('--stream' , default = True , type = bool )
125
125
parser .add_argument ('--load' , default = 0 , type = int , help = "0: 原生torch权重,1: transformers加载" )
126
- parser .add_argument ('--model_mode' , default = 0 , type = int ,
126
+ parser .add_argument ('--model_mode' , default = 1 , type = int ,
127
127
help = "0: 预训练模型,1: SFT-Chat模型,2: RLHF-Chat模型,3: Reason模型" )
128
128
args = parser .parse_args ()
129
129
@@ -133,6 +133,8 @@ def main():
133
133
test_mode = int (input ('[0] 自动测试\n [1] 手动输入\n ' ))
134
134
messages = []
135
135
for idx , prompt in enumerate (prompts if test_mode == 0 else iter (lambda : input ('👶: ' ), '' )):
136
+ setup_seed (random .randint (0 , 2048 ))
137
+ # setup_seed(2025) # 如需固定每次输出则换成【固定】的随机种子
136
138
if test_mode == 0 : print (f'👶: { prompt } ' )
137
139
138
140
messages = messages [- args .history_cnt :] if args .history_cnt else []
@@ -177,6 +179,4 @@ def main():
177
179
178
180
179
181
if __name__ == "__main__" :
180
- torch .backends .cudnn .deterministic = True
181
- random .seed (random .randint (0 , 2048 ))
182
182
main ()
0 commit comments