@@ -389,7 +389,7 @@ def ref_program_twiddling(A, qB, Scale, Bias=None):
389389    """ 
390390    dtypeC  =  "bfloat16" 
391391    B  =  torch_convert_bit_twiddling (qB )
392-     B  *=  2 ** (Scale [:, (torch .arange (B .shape [1 ], device = B .device ) //  scale_size )])
392+     B  *=  2 ** (Scale [:, (torch .arange (B .shape [1 ], device = B .device ) //  32 )])
393393    C  =  torch .matmul (A .to (torch .float ), B .T .to (torch .float ))
394394    C  =  C .to (torch .__getattribute__ (dtypeC ))
395395    return  C 
@@ -412,7 +412,7 @@ def ref_program_twiddling_with_bias(A, qB, Scale, Bias):
412412    """ 
413413    dtypeC  =  "bfloat16" 
414414    B  =  torch_convert_bit_twiddling (qB )
415-     B  *=  2 ** (Scale [:, (torch .arange (B .shape [1 ], device = B .device ) //  scale_size )])
415+     B  *=  2 ** (Scale [:, (torch .arange (B .shape [1 ], device = B .device ) //  32 )])
416416    C  =  torch .matmul (A .to (torch .float ), B .T .to (torch .float )) +  Bias 
417417    C  =  C .to (torch .__getattribute__ (dtypeC ))
418418    return  C 
@@ -436,7 +436,7 @@ def ref_program_simple(A, qB, Scale, Bias=None):
436436    """ 
437437    dtypeC  =  "bfloat16" 
438438    B  =  torch_convert (qB )
439-     B  *=  2 ** (Scale [:, (torch .arange (B .shape [1 ], device = B .device ) //  scale_size )])
439+     B  *=  2 ** (Scale [:, (torch .arange (B .shape [1 ], device = B .device ) //  32 )])
440440    C  =  torch .matmul (A .to (torch .float ), B .T .to (torch .float ))
441441    C  =  C .to (torch .__getattribute__ (dtypeC ))
442442    return  C 
@@ -464,7 +464,7 @@ def ref_program_simple_with_bias(A, qB, Scale, Bias):
464464    """ 
465465    dtypeC  =  "bfloat16" 
466466    B  =  torch_convert (qB )
467-     B  *=  2 ** (Scale [:, (torch .arange (B .shape [1 ], device = B .device ) //  scale_size )])
467+     B  *=  2 ** (Scale [:, (torch .arange (B .shape [1 ], device = B .device ) //  32 )])
468468    C  =  torch .matmul (A .to (torch .float ), B .T .to (torch .float )) +  Bias 
469469    C  =  C .to (torch .__getattribute__ (dtypeC ))
470470    return  C 
0 commit comments