8
8
9
9
class QRazor (Function ):
10
10
@staticmethod
11
- def forward (ctx , x , q_bit , r_bit , group ):
11
+ def forward (ctx , x , sign , q_bit , r_bit , group ):
12
12
raw_x = torch .reshape (x , (- 1 ,))
13
13
org_len = len (raw_x )
14
14
if org_len % group :
15
15
vacant_num = group - org_len % group
16
16
raw_x = F .pad (raw_x , (0 , vacant_num ), 'constant' , 0 )
17
17
raw_x = raw_x .view (- 1 , group )
18
18
max_dim1 , _ = raw_x .max (dim = 1 )
19
-
19
+
20
20
for b in range (r_bit , q_bit + 1 ):
21
21
mul_xth = 2 ** (b - 1 )
22
22
round_value = 2 ** (b + 1 - r_bit )
23
23
outlier_id = (max_dim1 >= mul_xth ) & (max_dim1 < mul_xth * 2 )
24
- cond2 = max_dim1 >= (2 * mul_xth - 2 ** (b - 4 ))
25
-
24
+
26
25
if outlier_id .any ():
27
- threshold = math .floor (2 * mul_xth - 2 ** (b - 4 ))
28
- indices_both = outlier_id & cond2
29
- indices_round = outlier_id & (~ cond2 )
30
-
31
- if indices_both .any ():
32
- selected_raw_x = raw_x [indices_both ]
33
- element_floor = selected_raw_x >= threshold
34
- element_round = selected_raw_x < threshold
35
- selected_raw_x [element_floor ] = torch .floor (selected_raw_x [element_floor ] / round_value ) * round_value
36
- selected_raw_x [element_round ] = torch .round (selected_raw_x [element_round ] / round_value ) * round_value
37
- raw_x [indices_both ] = selected_raw_x
38
-
39
- if indices_round .any ():
40
- raw_x [indices_round ] = torch .round (raw_x [indices_round ] / round_value ) * round_value
26
+ rounded_values = torch .round (raw_x [outlier_id ] / round_value )
27
+ outlier_signs = sign .view (- 1 )[:len (raw_x .view (- 1 ))].view (- 1 , group )[outlier_id ]
28
+ condition = (rounded_values == (2 ** (r_bit - 1 ))) & (outlier_signs > 0 )
29
+ rounded_values = torch .where (condition , torch .tensor (2 ** (r_bit - 1 )- 1 , device = rounded_values .device ), rounded_values )
30
+ raw_x [outlier_id ] = rounded_values * round_value
31
+
41
32
raw_x = raw_x .view (- 1 )
42
33
x = raw_x [:org_len ].view_as (x )
43
-
44
- return x
34
+ return x
0 commit comments