1
- # paddle.pdist设计文档
1
+ # paddle.bitwise_left_shift/paddle.bitwise_right_shift设计文档
2
2
3
3
| API 名称 | paddle.bitwise_right_shift<br />paddle.bitwise_left_shift |
4
4
| ------------ | --------------------------------------------------------- |
@@ -53,14 +53,104 @@ shift_right_arithmetic = _make_elementwise_binary_prim(
53
53
shift_right_logical = _not_impl # 可见pytorch中仅支持算数位移
54
54
```
55
55
56
- 具体元素尺度的实现,[ 代码位置 ] ( https://github.com/pytorch/pytorch/blob/main/torch/_inductor/codegen/common.py#L401-L405 ) :
56
+ 具体元素尺度的实现,
57
57
58
- ``` python
59
- @ staticmethod
60
- def bitwise_right_shift (a , b ):
61
- return f " decltype( { a} )( { a} >> { b} ) "
58
+ [ 左移 cpu kernel] ( https://github.com/pytorch/pytorch/blob/3747aca49a39479c2c5e223b91369db5bd339cdf/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp#L423-L437 ) :
59
+
60
+ ``` cpp
61
+ void lshift_kernel (TensorIteratorBase& iter) {
62
+ AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "lshift_cpu", [ &] ( ) {
63
+ cpu_kernel_vec(
64
+ iter,
65
+ [ ] (scalar_t a, scalar_t b) -> scalar_t {
66
+ constexpr scalar_t max_shift = sizeof(scalar_t) * CHAR_BIT;
67
+ if ((static_cast< std::make_signed_t<scalar_t > >(b) < 0) ||
68
+ (b >= max_shift)) {
69
+ return 0;
70
+ }
71
+ return static_cast< std::make_unsigned_t<scalar_t > >(a) << b;
72
+ },
73
+ [ ] (Vectorized<scalar_t> a, Vectorized<scalar_t> b) { return a << b; });
74
+ });
75
+ }
76
+ ```
77
+
78
+ [左移 cuda kernel](https://github.com/pytorch/pytorch/blob/6e1ba79b7fdf3d66db8fb69462fb502e5006e5e7/aten/src/ATen/native/cuda/BinaryShiftOpsKernels.cu#L14-L25)
79
+
80
+ ```cpp
81
+ void lshift_kernel_cuda(TensorIteratorBase& iter) {
82
+ AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "lshift_cuda", [&]() {
83
+ gpu_kernel_with_scalars(iter,
84
+ []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
85
+ constexpr scalar_t max_shift = sizeof(scalar_t) * CHAR_BIT;
86
+ if ((static_cast<std::make_signed_t<scalar_t>>(b) < 0) || (b >= max_shift)) {
87
+ return 0;
88
+ }
89
+ return static_cast<std::make_unsigned_t<scalar_t>>(a) << b;
90
+ });
91
+ });
92
+ }
62
93
```
63
94
95
+ + 可以发现,在算术左移时,kernel需要针对[ 两类情况进行处理] ( https://wiki.sei.cmu.edu/confluence/display/c/INT34-C.+Do+not+shift+an+expression+by+a+negative+number+of+bits+or+by+greater+than+or+equal+to+the+number+of+bits+that+exist+in+the+operand ) :
96
+
97
+ + ` b ` 移动的距离大于等于当前类型的位数时(例如对int16左移16位),则直接返回0(若进行移动,编译器会在此时发生取模优化,例如左移1000位时,实际上会移动1000%16=8位,但实际上需要返回0,表示溢出)
98
+ + ` b ` 为负数时,在C语言标准中为"未定义行为",认为等效于左移了无穷位,直接返回0;
99
+
100
+ 另外,kernel中用` std::make_signed_t<scalar_t>>(b) ` 把` b ` 强转为有符号数,若` b ` 原本就是有符号数,无影响;若` b ` 原本是无符号数,且最高位为0,无影响;若` b ` 原本是无符号数,而且较大,最高位为` 1 ` ,强转后为负数,小于0。(不过感觉即使不强转,最高位为1的无符号数应该也会令` (b >= max_shift) ` 为true)
101
+
102
+
103
+
104
+ [ 右移 cpu kernel] ( https://github.com/pytorch/pytorch/blob/3747aca49a39479c2c5e223b91369db5bd339cdf/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp#L494-L511 )
105
+
106
+ ``` cpp
107
+ void rshift_kernel (TensorIteratorBase& iter) {
108
+ AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "rshift_cpu", [ &] ( ) {
109
+ cpu_kernel_vec(
110
+ iter,
111
+ [ ] (scalar_t a, scalar_t b) -> scalar_t {
112
+ // right shift value to retain sign bit for signed and no bits for
113
+ // unsigned
114
+ constexpr scalar_t max_shift =
115
+ sizeof(scalar_t) * CHAR_BIT - std::is_signed_v<scalar_t>;
116
+ if ((static_cast< std::make_signed_t<scalar_t > >(b) < 0) ||
117
+ (b >= max_shift)) {
118
+ return a >> max_shift;
119
+ }
120
+ return a >> b;
121
+ },
122
+ [ ] (Vectorized<scalar_t> a, Vectorized<scalar_t> b) { return a >> b; });
123
+ });
124
+ }
125
+ ```
126
+
127
+ [右移 cuda kernel](https://github.com/pytorch/pytorch/blob/6e1ba79b7fdf3d66db8fb69462fb502e5006e5e7/aten/src/ATen/native/cuda/BinaryShiftOpsKernels.cu#L27-L39C2)
128
+
129
+ ```cpp
130
+ void rshift_kernel_cuda(TensorIteratorBase& iter) {
131
+ AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "rshift_cuda", [&]() {
132
+ gpu_kernel_with_scalars(iter,
133
+ []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
134
+ // right shift value to retain sign bit for signed and no bits for unsigned
135
+ constexpr scalar_t max_shift = sizeof(scalar_t) * CHAR_BIT - std::is_signed_v<scalar_t>;
136
+ if ((static_cast<std::make_signed_t<scalar_t>>(b) < 0) || (b >= max_shift)) {
137
+ return a >> max_shift;
138
+ }
139
+ return a >> b;
140
+ });
141
+ });
142
+ }
143
+ ```
144
+
145
+ + 算术右移时,` max_shift ` 需要考虑最大的移动距离,有符号数最高位为符号位,故表示数值的位数实际上会少一位。
146
+ + 有符号数时,例如` int8 x=-100 ` ,补码为` 1001,1100 ` ,最高位为符号位,仅需要右移7位,所有的` int8 ` 就都会变成` 1111,1111 ` ,即` -1 ` ;
147
+ + 无符号数时候,例如` uint8 x=200 ` ,存储为` 1100,1000 ` ,八位均表示数值大小,需要右移8位才可以将所有的` uint8 ` 变为` 0000,0000 ` ,即` 0 ` ;
148
+ + 当` b ` 为负数这一未定义行为时,同样等效于右移无穷位,与移动` max_shift ` 等效(从代码中可以看到,` b<0 ` 和` b>=max_shift ` 是在同一个` if ` 判断中),只要满足两个条件中任意一个,则使得有符号数变为` -1 ` ,无符号数变为` 0 ` 。
149
+
150
+ ** 在paddle API的设计过程中,也按照这样的方式来实现,当` b ` 为负数或者移动超过最大值,则使得有符号数变为` -1 ` ,无符号数变为` 0 ` **
151
+
152
+
153
+
64
154
65
155
66
156
## Numpy
@@ -125,6 +215,27 @@ NPY_NO_EXPORT void NPY_CPU_DISPATCH_CURFX(@TYPE@_right_shift)
125
215
}
126
216
```
127
217
218
+ `npy_lshift`[相关调用](https://github.com/numpy/numpy/blob/0032ede015c9b06f88cc7f9b07138ce35f4357ae/numpy/_core/src/npymath/npy_math_internal.h.src#L653-L662):
219
+
220
+ ```cpp
221
+ NPY_INPLACE npy_@u@@type@
222
+ npy_lshift@u@@c@(npy_@u@@type@ a, npy_@u@@type@ b)
223
+ {
224
+ if (NPY_LIKELY((size_t)b < sizeof(a) * CHAR_BIT)) {
225
+ return a << b;
226
+ }
227
+ else {
228
+ return 0;
229
+ }
230
+ }
231
+ ```
232
+
233
+ + 在左移时,为了防止编译器对位移的自动取模优化(例如int16类型左移100位,实际上被自动优化成左移` 100%16=4 ` 位),导致结果不为0(溢出);
234
+
235
+ 而且这里将` b ` 转为` size_t ` ,而` size_t ` 是unsigned类型,所以当` b ` 为有符号负数时,由于补码最高位的符号位为1,所以会被转换成一个很大的正数,必然超过` sizeof(a) * CHAR_BIT ` 的大小,所以直接走else返回0,这里应该与` b < 0 ` 实现了同样的效果。
236
+
237
+
238
+
128
239
` npy_rshift ` 相关调用
129
240
130
241
``` cpp
@@ -145,6 +256,14 @@ npy_rshift@u@@c@(npy_@u@@type@ a, npy_@u@@type@ b)
145
256
}
146
257
```
147
258
259
+ + 在右移时,右移的最大位数限制需要区分有符号数和无符号数:
260
+
261
+ ** 此处实现与pytorch中的实现略有不同,不过结果还是等效的:pytorch中认为,有符号数最大右移位数为` n_bit-1 ` ,而无符号数最大右移位数为` n_bit ` ,例如(int16最多右移15位,uint16最多右移16位,否则触发溢出,全置为符号位);numpy中没有刻意限定符号数和无符号数的最大位移位数(例如int16和uint16的最大位移位数都是16位,都是16位才出发溢出),由于对于有符号数例如int16来说,“(pytorch)右移15位触发溢出,全部置为符号位”与“(numpy)右移15位”,两者结果是一样的,只是前者直接走溢出的else,后者真正去做了位运算而已,所以还是等效**
262
+
263
+
264
+
265
+ 这里的` NPY_LIKELY((size_t)b ` 与左移一样,隐含了` b ` 需要大于0。若` b ` 小于0,则转unsigned之后大小必然大于` sizeof(a) * CHAR_BIT ` 溢出,而后又根据` a ` 的符号位作为返回(负数溢出补码为` 1111,1111,...1111 ` ,也就是-1,正数和无符号数溢出为0)。
266
+
148
267
149
268
150
269
## Jax
@@ -157,8 +276,8 @@ npy_rshift@u@@c@(npy_@u@@type@ a, npy_@u@@type@ b)
157
276
158
277
Parameters
159
278
160
- - ** x** ([ ` Union ` ] ( https://docs.python.org/3/library/typing.html#typing.Union ) [[ ` Array ` ] ( https://jax.readthedocs.io/en/latest/_autosummary/jax.Array.html#jax.Array ) , [ ` ndarray ` ] ( https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray ) , [ ` bool_ ` ] ( https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.bool_ ) , [ ` number ` ] ( https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.number.html#jax.numpy.number ) , [ ` bool ` ] ( https://docs.python.org/3/library/functions.html#bool ) , [ ` int ` ] ( https://docs.python.org/3/library/functions.html#int ) , [ ` float ` ] ( https://docs.python.org/3/library/functions.html#float ) , [ ` complex ` ] ( https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.complex.html#jax.lax.complex )] )
161
- - ** y** ([ ` Union ` ] ( https://docs.python.org/3/library/typing.html#typing.Union ) [[ ` Array ` ] ( https://jax.readthedocs.io/en/latest/_autosummary/jax.Array.html#jax.Array ) , [ ` ndarray ` ] ( https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray ) , [ ` bool_ ` ] ( https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.bool_ ) , [ ` number ` ] ( https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.number.html#jax.numpy.number ) , [ ` bool ` ] ( https://docs.python.org/3/library/functions.html#bool ) , [ ` int ` ] ( https://docs.python.org/3/library/functions.html#int ) , [ ` float ` ] ( https://docs.python.org/3/library/functions.html#float ) , [ ` complex ` ] ( https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.complex.html#jax.lax.complex )] )
279
+ - ** x** ([ ` Union ` ] ( https://docs.python.org/3/library/typing.html#typing.Union ) [[ ` Array ` ] ( https://jax.readthedocs.io/en/latest/_autosummary/jax.Array.html#jax.Array ) , [ ` ndarray ` ] ( https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray ) , [ ` bool_ ` ] ( https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.bool_ ) , [ ` number ` ] ( https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.number.html#jax.numpy.number ) , [ ` bool ` ] ( https://docs.python.org/3/library/functions.html#bool ) , [ ` int ` ] ( https://docs.python.org/3/library/functions.html#int ) , [ ` float ` ] ( https://docs.python.org/3/library/functions.html#float ) , [ ` complex ` ] ( https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.complex.html#jax.lax.complex )] )
280
+ - ** y** ([ ` Union ` ] ( https://docs.python.org/3/library/typing.html#typing.Union ) [[ ` Array ` ] ( https://jax.readthedocs.io/en/latest/_autosummary/jax.Array.html#jax.Array ) , [ ` ndarray ` ] ( https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray ) , [ ` bool_ ` ] ( https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.bool_ ) , [ ` number ` ] ( https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.number.html#jax.numpy.number ) , [ ` bool ` ] ( https://docs.python.org/3/library/functions.html#bool ) , [ ` int ` ] ( https://docs.python.org/3/library/functions.html#int ) , [ ` float ` ] ( https://docs.python.org/3/library/functions.html#float ) , [ ` complex ` ] ( https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.complex.html#jax.lax.complex )] )
162
281
163
282
Return type[ ` Array ` ] ( https://jax.readthedocs.io/en/latest/_autosummary/jax.Array.html#jax.Array )
164
283
@@ -190,8 +309,8 @@ def _shift_right_arithmetic_raw(x, y):
190
309
191
310
Elementwise logical right shift: x ≫ y.Parameters
192
311
193
- - ** x** ([ ` Union ` ] ( https://docs.python.org/3/library/typing.html#typing.Union ) [[ ` Array ` ] ( https://jax.readthedocs.io/en/latest/_autosummary/jax.Array.html#jax.Array ) , [ ` ndarray ` ] ( https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray ) , [ ` bool_ ` ] ( https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.bool_ ) , [ ` number ` ] ( https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.number.html#jax.numpy.number ) , [ ` bool ` ] ( https://docs.python.org/3/library/functions.html#bool ) , [ ` int ` ] ( https://docs.python.org/3/library/functions.html#int ) , [ ` float ` ] ( https://docs.python.org/3/library/functions.html#float ) , [ ` complex ` ] ( https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.complex.html#jax.lax.complex )] )
194
- - ** y** ([ ` Union ` ] ( https://docs.python.org/3/library/typing.html#typing.Union ) [[ ` Array ` ] ( https://jax.readthedocs.io/en/latest/_autosummary/jax.Array.html#jax.Array ) , [ ` ndarray ` ] ( https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray ) , [ ` bool_ ` ] ( https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.bool_ ) , [ ` number ` ] ( https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.number.html#jax.numpy.number ) , [ ` bool ` ] ( https://docs.python.org/3/library/functions.html#bool ) , [ ` int ` ] ( https://docs.python.org/3/library/functions.html#int ) , [ ` float ` ] ( https://docs.python.org/3/library/functions.html#float ) , [ ` complex ` ] ( https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.complex.html#jax.lax.complex )] )
312
+ - ** x** ([ ` Union ` ] ( https://docs.python.org/3/library/typing.html#typing.Union ) [[ ` Array ` ] ( https://jax.readthedocs.io/en/latest/_autosummary/jax.Array.html#jax.Array ) , [ ` ndarray ` ] ( https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray ) , [ ` bool_ ` ] ( https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.bool_ ) , [ ` number ` ] ( https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.number.html#jax.numpy.number ) , [ ` bool ` ] ( https://docs.python.org/3/library/functions.html#bool ) , [ ` int ` ] ( https://docs.python.org/3/library/functions.html#int ) , [ ` float ` ] ( https://docs.python.org/3/library/functions.html#float ) , [ ` complex ` ] ( https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.complex.html#jax.lax.complex )] )
313
+ - ** y** ([ ` Union ` ] ( https://docs.python.org/3/library/typing.html#typing.Union ) [[ ` Array ` ] ( https://jax.readthedocs.io/en/latest/_autosummary/jax.Array.html#jax.Array ) , [ ` ndarray ` ] ( https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray ) , [ ` bool_ ` ] ( https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.bool_ ) , [ ` number ` ] ( https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.number.html#jax.numpy.number ) , [ ` bool ` ] ( https://docs.python.org/3/library/functions.html#bool ) , [ ` int ` ] ( https://docs.python.org/3/library/functions.html#int ) , [ ` float ` ] ( https://docs.python.org/3/library/functions.html#float ) , [ ` complex ` ] ( https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.complex.html#jax.lax.complex )] )
195
314
196
315
Return type[ ` Array ` ] ( https://jax.readthedocs.io/en/latest/_autosummary/jax.Array.html#jax.Array )
197
316
@@ -225,7 +344,9 @@ PyTorch是将算子注册到element wise系列中,Numpy也类似地`BINARY_LOO
225
344
226
345
同时,PyTorch与Numpy中都仅支持算术位移,不支持逻辑位移,而JAX中实现了算术位移和逻辑位移。
227
346
347
+ PyTorch和numpy的处理基本一致,在前面Numpy的调研部分详细说明了两者略微的差异,这个差异不影响最终结果,两者处理的思路都是一致的。
228
348
349
+ 面对第二个参数` b ` 为负数的时候,都是将其等效为位移无穷大的距离(这两个判断条件在同一个` if ` 中,用“或”逻辑连接),处理方式都是使有符号数时变为` -1 ` ,无符号数时变为` 0 ` 。
229
350
230
351
# 五、设计思路与实现方案
231
352
@@ -235,7 +356,84 @@ API的设计为`paddle.bitwise_right_shift(x, y, is_arithmetic=True)`,其余
235
356
236
357
## API实现方案
237
358
238
- 参考` PyTorch ` 、` Numpy ` 、` JAX ` 中的设计,组合已有API实现功能
359
+
360
+ 由于python层相关API的类型支持需求不合理(例如jax中的设计,unsigned转signed会溢出),考虑下沉到cpp层实现。以右移为例,python层接口为` paddle.bitwise_right_shift ` ,通过参数` is_arithmetic ` 的设置来调用算术位移或逻辑位移的kernel,若为算术位移,则调用` _C_ops.bitwise_left_shift_arithmetic_(x, y) ` ,若为逻辑位移,则调用` _C_ops.bitwise_left_shift_logic_(x, y) `
361
+
362
+ cpp的kernel实现主要通过elementwise的方法,与` bitwise_and ` 等bitwise op设计类似,复用elementwise相关代码以支持broadcast、具体Functor的调用等。
363
+
364
+
365
+
366
+ 具体行为定义:(` n_bits ` 表示数据类型存储位数,例如int8的` n_bits ` 为8,uint16的` n_bits ` 为16;当` y ` 小于0时为“未定义行为”,等效于位移超过最大位数溢出)
367
+
368
+ + 算术位移
369
+
370
+ + 算术左移:当` y ` 小于0,或者` y ` 大于等于` n_bits ` 时候溢出,返回0;否则正常位移;
371
+ + 算术右移:
372
+ + 有符号数时:当` y ` 小于0,或者` y ` 大于等于` n_bits ` 时候溢出,返回符号位(` a>>(n_bits-1)&1 ` );否则正常位移;
373
+ + 无符号数时:当` y ` 小于0,或者` y ` 大于等于` n_bits ` 时候溢出,返回0;否则正常位移;
374
+
375
+ + 逻辑位移
376
+
377
+ + 逻辑左移:当` y ` 小于0,或者` y ` 大于等于` n_bits ` 时候溢出,返回0;否则正常位移;
378
+
379
+ + 逻辑右移:
380
+
381
+ + 有符号数时:当` y ` 小于0,或者` y ` 大于等于` n_bits ` 时候溢出,返回0;否则特殊位移:
382
+
383
+ ``` cpp
384
+ template <typename T>
385
+ HOSTDEVICE T logic_shift_func (const T a, const T b) {
386
+ if (b < static_cast<T >(0) || b >= static_cast<T >(sizeof(T) * 8))
387
+ return static_cast<T >(0);
388
+ T t = static_cast<T >(sizeof(T) * 8 - 1);
389
+ T mask = (((a >> t) << t) >> b) << 1;
390
+ return (a >> b) ^ mask;
391
+ }
392
+ ```
393
+
394
+ 在`T mask = (((a >> t) << t) >> b) << 1;`中,先`(a >> t)`取符号位,然后`<< t`回到原位,再右移`b`后左移一位,最后与`a>>b`的结果做亦或,下面举两个例子:
395
+
396
+ ```
397
+ example1:
398
+ a = 1001,1010 b = 3, 有t=7
399
+ ((a>>t)<<t) = 1000,0000
400
+ mask=(((a>>t)<<t)>>b)<<1 = 1110,0000
401
+ a>>b = 1111,0011
402
+ 所以 (a>>b) ^ mask = 0001,0011
403
+
404
+ example2:
405
+ a = 0001,1010 b = 3, 有t=7
406
+ ((a>>t)<<t) = 0000,0000
407
+ mask=(((a>>t)<<t)>>b)<<1 = 0000,0000
408
+ a>>b = 0000,0011
409
+ 所以 (a>>b) ^ mask = 0000,0011
410
+ ```
411
+
412
+ + 无符号数时:当`y`小于0,或者`y`大于等于`n_bits`时候溢出,返回0;否则正常位移;
413
+
414
+ 以上行为中,算术位移与numpy、pytorch的实现对齐;由于numpy和pytorch不支持逻辑位移,所以逻辑位移参考jax的实现思路,用numpy来进行间接实现和验证。
415
+
416
+
417
+ + 关于** 有符号数的符号位** 在不同情景下的行为:
418
+ 1 . 算术左移时,符号位同其他位一样,一起左移,右边补0;
419
+ 2 . 逻辑左移时,符号位同其他位一样,一起左移,右边补0;
420
+ 3 . 算术右移时,符号位同其他位一样,一起右移,左边补符号位;
421
+ 4 . 逻辑右移时,符号位同其他位一样,一起右移,左边补0;
422
+
423
+ 注意:当有符号数左移发生溢出时,其值不可控,可能会在左移时突然变号,这是因为在左移时,有符号数的符号位同样进行左移,会导致符号位右侧的值不断成为符号位,例如
424
+ ```
425
+ example1:
426
+ int8_t x = -45; // 补码为 1101,0011 表示-45
427
+ int8_t y = x << 2; //补码为 0100,1100 表示76
428
+ int8_t z = x << 3; //补码为 1001,1000 表示-104
429
+
430
+ example2:
431
+ int8_t x = -86; // 补码为 1010,1010 表示-86
432
+ int8_t y = x << 1; //补码为 0101,0100 表示84
433
+ int8_t z = x << 2; //补码为 1010,1000 表示-88
434
+ ```
435
+ 以上为溢出导致的符号突变。
436
+
239
437
240
438
# 六、测试和验收的考量
241
439
@@ -261,4 +459,4 @@ API的设计为`paddle.bitwise_right_shift(x, y, is_arithmetic=True)`,其余
261
459
262
460
[ PyTorch文档] ( https://pytorch.org/docs/stable/generated/torch.bitwise_right_shift.html?highlight=bitwise_right_shift#torch.bitwise_right_shift )
263
461
264
- [ Numpy文档] ( https://numpy.org/doc/stable/reference/generated/numpy.right_shift.html#numpy.right_shift )
462
+ [ Numpy文档] ( https://numpy.org/doc/stable/reference/generated/numpy.right_shift.html#numpy.right_shift )
0 commit comments