Skip to content

Commit 3ac026a

Browse files
authored
Update qrazor.py
1 parent cdddde9 commit 3ac026a

File tree

1 file changed

+10
-20
lines changed

1 file changed

+10
-20
lines changed

lib/qrazor/qrazor.py

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,37 +8,27 @@
88

99
class QRazor(Function):
1010
@staticmethod
11-
def forward(ctx, x, q_bit, r_bit, group):
11+
def forward(ctx, x, sign, q_bit, r_bit, group):
1212
raw_x = torch.reshape(x, (-1,))
1313
org_len = len(raw_x)
1414
if org_len % group:
1515
vacant_num = group - org_len % group
1616
raw_x = F.pad(raw_x, (0, vacant_num), 'constant', 0)
1717
raw_x = raw_x.view(-1, group)
1818
max_dim1, _ = raw_x.max(dim=1)
19-
19+
2020
for b in range(r_bit, q_bit+1):
2121
mul_xth = 2 ** (b - 1)
2222
round_value = 2 ** (b + 1 - r_bit)
2323
outlier_id = (max_dim1 >= mul_xth) & (max_dim1 < mul_xth * 2)
24-
cond2 = max_dim1 >= (2 * mul_xth - 2 ** (b - 4))
25-
24+
2625
if outlier_id.any():
27-
threshold = math.floor(2 * mul_xth - 2 ** (b - 4))
28-
indices_both = outlier_id & cond2
29-
indices_round = outlier_id & (~cond2)
30-
31-
if indices_both.any():
32-
selected_raw_x = raw_x[indices_both]
33-
element_floor = selected_raw_x >= threshold
34-
element_round = selected_raw_x < threshold
35-
selected_raw_x[element_floor] = torch.floor(selected_raw_x[element_floor] / round_value) * round_value
36-
selected_raw_x[element_round] = torch.round(selected_raw_x[element_round] / round_value) * round_value
37-
raw_x[indices_both] = selected_raw_x
38-
39-
if indices_round.any():
40-
raw_x[indices_round] = torch.round(raw_x[indices_round] / round_value) * round_value
26+
rounded_values = torch.round(raw_x[outlier_id] / round_value)
27+
outlier_signs = sign.view(-1)[:len(raw_x.view(-1))].view(-1, group)[outlier_id]
28+
condition = (rounded_values == (2**(r_bit-1))) & (outlier_signs > 0)
29+
rounded_values = torch.where(condition, torch.tensor(2**(r_bit-1)-1, device=rounded_values.device), rounded_values)
30+
raw_x[outlier_id] = rounded_values * round_value
31+
4132
raw_x = raw_x.view(-1)
4233
x = raw_x[:org_len].view_as(x)
43-
44-
return x
34+
return x

0 commit comments

Comments
 (0)