@@ -36,22 +36,23 @@ def llm():
3636 cleanup_dist_env_and_memory ()
3737
3838
39- @pytest .mark .skip_global_cleanup
4039def test_pooling_params (llm : LLM ):
41- def get_outputs (softmax ):
40+ def get_outputs (activation ):
4241 outputs = llm .reward (
43- prompts , pooling_params = PoolingParams (softmax = softmax ), use_tqdm = False
42+ prompts , pooling_params = PoolingParams (activation = activation ), use_tqdm = False
4443 )
4544 return torch .cat ([x .outputs .data for x in outputs ])
4645
47- default = get_outputs (softmax = None )
48- w_softmax = get_outputs (softmax = True )
49- wo_softmax = get_outputs (softmax = False )
46+ default = get_outputs (activation = None )
47+ w_activation = get_outputs (activation = True )
48+ wo_activation = get_outputs (activation = False )
5049
51- assert torch .allclose (default , w_softmax , atol = 1e-2 ), "Default should use softmax."
52- assert not torch .allclose (w_softmax , wo_softmax , atol = 1e-2 ), (
53- "wo_softmax should not use softmax."
50+ assert torch .allclose (default , w_activation , atol = 1e-2 ), (
51+ "Default should use activation."
5452 )
55- assert torch .allclose (softmax (wo_softmax ), w_softmax , atol = 1e-2 ), (
56- "w_softmax should be close to softmax(wo_softmax)."
53+ assert not torch .allclose (w_activation , wo_activation , atol = 1e-2 ), (
54+ "wo_activation should not use activation."
55+ )
56+ assert torch .allclose (softmax (wo_activation ), w_activation , atol = 1e-2 ), (
57+ "w_activation should be close to activation(wo_activation)."
5758 )
0 commit comments