diff --git a/src/asm/bint-x64-amd64.S b/src/asm/bint-x64-amd64.S index 53929501..b6136d32 100644 --- a/src/asm/bint-x64-amd64.S +++ b/src/asm/bint-x64-amd64.S @@ -127,6 +127,100 @@ vpcmpgtq %zmm2, %zmm1, %k1 vzeroupper ret SIZE(mcl_c5_vsubPre) +.global PRE(mcl_c5_vadd) +PRE(mcl_c5_vadd): +TYPE(mcl_c5_vadd) +mov $4503599627370495, %rax +vpbroadcastq %rax, %zmm16 +vmovdqa64 (%rsi), %zmm0 +vpaddq (%rdx), %zmm0, %zmm0 +vpsrlq $52, %zmm0, %zmm17 +vpandq %zmm16, %zmm0, %zmm0 +vmovdqa64 64(%rsi), %zmm1 +vpaddq 64(%rdx), %zmm1, %zmm1 +vpaddq %zmm17, %zmm1, %zmm1 +vpsrlq $52, %zmm1, %zmm17 +vpandq %zmm16, %zmm1, %zmm1 +vmovdqa64 128(%rsi), %zmm2 +vpaddq 128(%rdx), %zmm2, %zmm2 +vpaddq %zmm17, %zmm2, %zmm2 +vpsrlq $52, %zmm2, %zmm17 +vpandq %zmm16, %zmm2, %zmm2 +vmovdqa64 192(%rsi), %zmm3 +vpaddq 192(%rdx), %zmm3, %zmm3 +vpaddq %zmm17, %zmm3, %zmm3 +vpsrlq $52, %zmm3, %zmm17 +vpandq %zmm16, %zmm3, %zmm3 +vmovdqa64 256(%rsi), %zmm4 +vpaddq 256(%rdx), %zmm4, %zmm4 +vpaddq %zmm17, %zmm4, %zmm4 +vpsrlq $52, %zmm4, %zmm17 +vpandq %zmm16, %zmm4, %zmm4 +vmovdqa64 320(%rsi), %zmm5 +vpaddq 320(%rdx), %zmm5, %zmm5 +vpaddq %zmm17, %zmm5, %zmm5 +vpsrlq $52, %zmm5, %zmm17 +vpandq %zmm16, %zmm5, %zmm5 +vmovdqa64 384(%rsi), %zmm6 +vpaddq 384(%rdx), %zmm6, %zmm6 +vpaddq %zmm17, %zmm6, %zmm6 +vpsrlq $52, %zmm6, %zmm17 +vpandq %zmm16, %zmm6, %zmm6 +vmovdqa64 448(%rsi), %zmm7 +vpaddq 448(%rdx), %zmm7, %zmm7 +vpaddq %zmm17, %zmm7, %zmm7 +vpsubq PRE(p)(%rip){1to8}, %zmm0, %zmm8 +vpsrlq $63, %zmm8, %zmm17 +vpandq %zmm16, %zmm8, %zmm8 +vpsubq PRE(p)+8(%rip){1to8}, %zmm1, %zmm9 +vpsubq %zmm17, %zmm9, %zmm9 +vpsrlq $63, %zmm9, %zmm17 +vpandq %zmm16, %zmm9, %zmm9 +vpsubq PRE(p)+16(%rip){1to8}, %zmm2, %zmm10 +vpsubq %zmm17, %zmm10, %zmm10 +vpsrlq $63, %zmm10, %zmm17 +vpandq %zmm16, %zmm10, %zmm10 +vpsubq PRE(p)+24(%rip){1to8}, %zmm3, %zmm11 +vpsubq %zmm17, %zmm11, %zmm11 +vpsrlq $63, %zmm11, %zmm17 +vpandq %zmm16, %zmm11, %zmm11 +vpsubq PRE(p)+32(%rip){1to8}, %zmm4, %zmm12 +vpsubq %zmm17, %zmm12, %zmm12 +vpsrlq $63, %zmm12, %zmm17 +vpandq %zmm16, %zmm12, %zmm12 +vpsubq PRE(p)+40(%rip){1to8}, %zmm5, %zmm13 +vpsubq %zmm17, %zmm13, %zmm13 +vpsrlq $63, %zmm13, %zmm17 +vpandq %zmm16, %zmm13, %zmm13 +vpsubq PRE(p)+48(%rip){1to8}, %zmm6, %zmm14 +vpsubq %zmm17, %zmm14, %zmm14 +vpsrlq $63, %zmm14, %zmm17 +vpandq %zmm16, %zmm14, %zmm14 +vpsubq PRE(p)+56(%rip){1to8}, %zmm7, %zmm15 +vpsubq %zmm17, %zmm15, %zmm15 +vpsrlq $63, %zmm15, %zmm17 +vpandq %zmm16, %zmm15, %zmm15 +vpxorq %zmm16, %zmm16, %zmm16 +vpcmpgtq %zmm16, %zmm17, %k1 +vmovdqa64 %zmm0, %zmm8{%k1} +vmovdqa64 %zmm1, %zmm9{%k1} +vmovdqa64 %zmm2, %zmm10{%k1} +vmovdqa64 %zmm3, %zmm11{%k1} +vmovdqa64 %zmm4, %zmm12{%k1} +vmovdqa64 %zmm5, %zmm13{%k1} +vmovdqa64 %zmm6, %zmm14{%k1} +vmovdqa64 %zmm7, %zmm15{%k1} +vmovdqa64 %zmm8, (%rdi) +vmovdqa64 %zmm9, 64(%rdi) +vmovdqa64 %zmm10, 128(%rdi) +vmovdqa64 %zmm11, 192(%rdi) +vmovdqa64 %zmm12, 256(%rdi) +vmovdqa64 %zmm13, 320(%rdi) +vmovdqa64 %zmm14, 384(%rdi) +vmovdqa64 %zmm15, 448(%rdi) +vzeroupper +ret +SIZE(mcl_c5_vadd) .global PRE(mcl_c5_vaddPreA) PRE(mcl_c5_vaddPreA): TYPE(mcl_c5_vaddPreA) @@ -330,79 +424,80 @@ vpcmpgtq %zmm3, %zmm2, %k2 vzeroupper ret SIZE(mcl_c5_vsubPreA) -.global PRE(mcl_c5_vadd) -PRE(mcl_c5_vadd): -TYPE(mcl_c5_vadd) +.global PRE(mcl_c5_vaddA) +PRE(mcl_c5_vaddA): +TYPE(mcl_c5_vaddA) mov $4503599627370495, %rax vpbroadcastq %rax, %zmm16 +mov $2, %ecx +.L1: vmovdqa64 (%rsi), %zmm0 -vmovdqa64 64(%rsi), %zmm1 -vmovdqa64 128(%rsi), %zmm2 -vmovdqa64 192(%rsi), %zmm3 -vmovdqa64 256(%rsi), %zmm4 -vmovdqa64 320(%rsi), %zmm5 -vmovdqa64 384(%rsi), %zmm6 -vmovdqa64 448(%rsi), %zmm7 vpaddq (%rdx), %zmm0, %zmm0 -vpaddq 64(%rdx), %zmm1, %zmm1 -vpaddq 128(%rdx), %zmm2, %zmm2 -vpaddq 192(%rdx), %zmm3, %zmm3 -vpaddq 256(%rdx), %zmm4, %zmm4 -vpaddq 320(%rdx), %zmm5, %zmm5 -vpaddq 384(%rdx), %zmm6, %zmm6 -vpaddq 448(%rdx), %zmm7, %zmm7 vpsrlq $52, %zmm0, %zmm17 +vpandq %zmm16, %zmm0, %zmm0 +vmovdqa64 64(%rsi), %zmm1 +vpaddq 64(%rdx), %zmm1, %zmm1 vpaddq %zmm17, %zmm1, %zmm1 vpsrlq $52, %zmm1, %zmm17 +vpandq %zmm16, %zmm1, %zmm1 +vmovdqa64 128(%rsi), %zmm2 +vpaddq 128(%rdx), %zmm2, %zmm2 vpaddq %zmm17, %zmm2, %zmm2 vpsrlq $52, %zmm2, %zmm17 +vpandq %zmm16, %zmm2, %zmm2 +vmovdqa64 192(%rsi), %zmm3 +vpaddq 192(%rdx), %zmm3, %zmm3 vpaddq %zmm17, %zmm3, %zmm3 vpsrlq $52, %zmm3, %zmm17 +vpandq %zmm16, %zmm3, %zmm3 +vmovdqa64 256(%rsi), %zmm4 +vpaddq 256(%rdx), %zmm4, %zmm4 vpaddq %zmm17, %zmm4, %zmm4 vpsrlq $52, %zmm4, %zmm17 +vpandq %zmm16, %zmm4, %zmm4 +vmovdqa64 320(%rsi), %zmm5 +vpaddq 320(%rdx), %zmm5, %zmm5 vpaddq %zmm17, %zmm5, %zmm5 vpsrlq $52, %zmm5, %zmm17 +vpandq %zmm16, %zmm5, %zmm5 +vmovdqa64 384(%rsi), %zmm6 +vpaddq 384(%rdx), %zmm6, %zmm6 vpaddq %zmm17, %zmm6, %zmm6 vpsrlq $52, %zmm6, %zmm17 -vpaddq %zmm17, %zmm7, %zmm7 -vpandq %zmm16, %zmm0, %zmm0 -vpandq %zmm16, %zmm1, %zmm1 -vpandq %zmm16, %zmm2, %zmm2 -vpandq %zmm16, %zmm3, %zmm3 -vpandq %zmm16, %zmm4, %zmm4 -vpandq %zmm16, %zmm5, %zmm5 vpandq %zmm16, %zmm6, %zmm6 -vpandq %zmm16, %zmm7, %zmm7 +vmovdqa64 448(%rsi), %zmm7 +vpaddq 448(%rdx), %zmm7, %zmm7 +vpaddq %zmm17, %zmm7, %zmm7 vpsubq PRE(p)(%rip){1to8}, %zmm0, %zmm8 -vpsubq PRE(p)+8(%rip){1to8}, %zmm1, %zmm9 -vpsubq PRE(p)+16(%rip){1to8}, %zmm2, %zmm10 -vpsubq PRE(p)+24(%rip){1to8}, %zmm3, %zmm11 -vpsubq PRE(p)+32(%rip){1to8}, %zmm4, %zmm12 -vpsubq PRE(p)+40(%rip){1to8}, %zmm5, %zmm13 -vpsubq PRE(p)+48(%rip){1to8}, %zmm6, %zmm14 -vpsubq PRE(p)+56(%rip){1to8}, %zmm7, %zmm15 vpsrlq $63, %zmm8, %zmm17 +vpandq %zmm16, %zmm8, %zmm8 +vpsubq PRE(p)+8(%rip){1to8}, %zmm1, %zmm9 vpsubq %zmm17, %zmm9, %zmm9 vpsrlq $63, %zmm9, %zmm17 +vpandq %zmm16, %zmm9, %zmm9 +vpsubq PRE(p)+16(%rip){1to8}, %zmm2, %zmm10 vpsubq %zmm17, %zmm10, %zmm10 vpsrlq $63, %zmm10, %zmm17 +vpandq %zmm16, %zmm10, %zmm10 +vpsubq PRE(p)+24(%rip){1to8}, %zmm3, %zmm11 vpsubq %zmm17, %zmm11, %zmm11 vpsrlq $63, %zmm11, %zmm17 +vpandq %zmm16, %zmm11, %zmm11 +vpsubq PRE(p)+32(%rip){1to8}, %zmm4, %zmm12 vpsubq %zmm17, %zmm12, %zmm12 vpsrlq $63, %zmm12, %zmm17 +vpandq %zmm16, %zmm12, %zmm12 +vpsubq PRE(p)+40(%rip){1to8}, %zmm5, %zmm13 vpsubq %zmm17, %zmm13, %zmm13 vpsrlq $63, %zmm13, %zmm17 +vpandq %zmm16, %zmm13, %zmm13 +vpsubq PRE(p)+48(%rip){1to8}, %zmm6, %zmm14 vpsubq %zmm17, %zmm14, %zmm14 vpsrlq $63, %zmm14, %zmm17 +vpandq %zmm16, %zmm14, %zmm14 +vpsubq PRE(p)+56(%rip){1to8}, %zmm7, %zmm15 vpsubq %zmm17, %zmm15, %zmm15 vpsrlq $63, %zmm15, %zmm17 -vpandq %zmm16, %zmm8, %zmm8 -vpandq %zmm16, %zmm9, %zmm9 -vpandq %zmm16, %zmm10, %zmm10 -vpandq %zmm16, %zmm11, %zmm11 -vpandq %zmm16, %zmm12, %zmm12 -vpandq %zmm16, %zmm13, %zmm13 -vpandq %zmm16, %zmm14, %zmm14 vpandq %zmm16, %zmm15, %zmm15 vpxorq %zmm16, %zmm16, %zmm16 vpcmpgtq %zmm16, %zmm17, %k1 @@ -422,9 +517,14 @@ vmovdqa64 %zmm12, 256(%rdi) vmovdqa64 %zmm13, 320(%rdi) vmovdqa64 %zmm14, 384(%rdi) vmovdqa64 %zmm15, 448(%rdi) +add $64, %rsi +add $64, %rdx +add $64, %rdi +sub $1, %ecx +jnz .L1 vzeroupper ret -SIZE(mcl_c5_vadd) +SIZE(mcl_c5_vaddA) .align 16 .global PRE(mclb_add1) PRE(mclb_add1): diff --git a/src/gen_bint_x64.py b/src/gen_bint_x64.py index f456ba86..9b20f8f0 100644 --- a/src/gen_bint_x64.py +++ b/src/gen_bint_x64.py @@ -106,9 +106,10 @@ def gen_vsubPre(mont, vN=1): vpxorq(t[0], t[0], t[0]) un(vpcmpgtq)([k1, k2], c, t[0]) -def gen_vadd(mont): - with FuncProc(MSM_PRE+'vadd'): - with StackFrame(3, 0, vNum=mont.N*2+2, vType=T_ZMM) as sf: +def gen_vadd(mont, vN=1): + SUF = 'A' if vN == 2 else '' + with FuncProc(MSM_PRE+'vadd'+SUF): + with StackFrame(3, 0, useRCX=True, vNum=mont.N*2+2, vType=T_ZMM) as sf: regs = list(reversed(sf.v)) W = mont.W N = mont.N @@ -130,6 +131,11 @@ def gen_vadd(mont): un = genUnrollFunc() + if vN == 2: + mov(ecx, 2) + lpL = Label() + L(lpL) + if False: unb = genUnrollFunc(addrOffset=8) un(vmovdqa64)(s, ptr(x)) @@ -166,12 +172,18 @@ def gen_vadd(mont): vpxorq(vmask, vmask, vmask) vpcmpgtq(k1, c, vmask) # k1 = t<0 - # z = select(k1, s, t) for i in range(N): vmovdqa64(t[i]|k1, s[i]) un(vmovdqa64)(ptr(z), t) + if vN == 2: + add(x, 64) + add(y, 64) + add(z, 64) + sub(ecx, 1) + jnz(lpL) + def msm_data(mont): makeLabel(C_p) dq_(', '.join(map(hex, mont.toArray(mont.p)))) @@ -180,7 +192,7 @@ def msm_code(mont): for vN in [1, 2]: gen_vaddPre(mont, vN) gen_vsubPre(mont, vN) - gen_vadd(mont) + gen_vadd(mont, vN) SUF='_fast' param=None diff --git a/src/msm_avx.cpp b/src/msm_avx.cpp index b3d686b4..1013dccf 100644 --- a/src/msm_avx.cpp +++ b/src/msm_avx.cpp @@ -30,6 +30,7 @@ void mcl_c5_vsubPre(Vec *, const Vec *, const Vec *); void mcl_c5_vsubPreA(VecA *, const VecA *, const VecA *); void mcl_c5_vadd(Vec *, const Vec *, const Vec *); +void mcl_c5_vaddA(VecA *, const VecA *, const VecA *); } @@ -166,6 +167,13 @@ inline void vadd(Vec *z, const Vec *x, const Vec *y) { mcl_c5_vadd(z, x, y); } +#if 0 +template<> +inline void vadd(VecA *z, const VecA *x, const VecA *y) +{ + mcl_c5_vaddA(z, x, y); +} +#endif #endif template @@ -1693,6 +1701,7 @@ CYBOZU_TEST_AUTO(vaddPre) CYBOZU_BENCH_C("asm vaddPreA", C, mcl_c5_vaddPreA, za.v, za.v, xa.v); CYBOZU_BENCH_C("asm vsubPreA", C, mcl_c5_vsubPreA, za.v, za.v, xa.v); CYBOZU_BENCH_C("asm vadd", C, mcl_c5_vadd, z[0].v, z[0].v, x[0].v); + CYBOZU_BENCH_C("asm vaddA", C, mcl_c5_vaddA, za.v, za.v, xa.v); #endif CYBOZU_BENCH_C("vadd::Vec", C, vadd, z[0].v, z[0].v, x[0].v); CYBOZU_BENCH_C("vsub::Vec", C, vsub, z[0].v, z[0].v, x[0].v);