From 516f5c4b928a69da85e1cc404c6142d90395f000 Mon Sep 17 00:00:00 2001 From: Jeremi Kurdek <59935235+jkurdek@users.noreply.github.com> Date: Fri, 15 Mar 2024 12:58:57 +0100 Subject: [PATCH] [Mono] [Arm64] Added SIMD support for vector 2/3/4 methods (#98761) * [Mono] [Arm64] Added multiple vector instrinsics * Added LLVM support * fix build errors on x64 * Added Quaternion.Conjugate * Changed frsqrts codegen * fixed whitespace changes * Removed reciprocal sqrt estimation from normalize * Extracted dot method into sepearate function * Refactored code to use exposed dot function * Fixed x64 build error * Fixed more build errors * Removed intrinsics for methods not intrinsified on coreclr side * Cleaned up code * Replaced break with return null * Removed trailing whitespaces --- src/mono/mono/mini/simd-intrinsics.c | 292 ++++++++++++++++++--------- 1 file changed, 199 insertions(+), 93 deletions(-) diff --git a/src/mono/mono/mini/simd-intrinsics.c b/src/mono/mono/mini/simd-intrinsics.c index 2a218c7d4a365..e003d1247892c 100644 --- a/src/mono/mono/mini/simd-intrinsics.c +++ b/src/mono/mono/mini/simd-intrinsics.c @@ -632,6 +632,7 @@ emit_xconst_v128 (MonoCompile *cfg, MonoClass *klass, guint8 value[16]) ins->type = STACK_VTYPE; ins->dreg = alloc_xreg (cfg); ins->inst_p0 = mono_mem_manager_alloc (cfg->mem_manager, size); + ins->klass = klass; MONO_ADD_INS (cfg->cbb, ins); memcpy (ins->inst_p0, &value[0], size); @@ -1390,6 +1391,76 @@ emit_msb_shift_vector_constant (MonoCompile *cfg, MonoClass *arg_class, MonoType } #endif +static MonoInst* +emit_dot (MonoCompile *cfg, MonoClass *klass, MonoType *vector_type, MonoTypeEnum arg0_type, int sreg1, int sreg2) { + if (!is_element_type_primitive (vector_type)) + return NULL; +#if defined(TARGET_WASM) + if (!COMPILE_LLVM (cfg) && (arg0_type == MONO_TYPE_I8 || arg0_type == MONO_TYPE_U8)) + return NULL; +#elif defined(TARGET_ARM64) + if (!COMPILE_LLVM (cfg) && (arg0_type == MONO_TYPE_I8 || arg0_type == MONO_TYPE_U8 || arg0_type == MONO_TYPE_I || arg0_type == MONO_TYPE_U)) + return NULL; +#endif + +#if defined(TARGET_ARM64) || defined(TARGET_WASM) + MonoInst *pairwise_multiply = emit_simd_ins (cfg, klass, OP_XBINOP, sreg1, sreg2); + pairwise_multiply->inst_c0 = type_enum_is_float (arg0_type) ? OP_FMUL : OP_IMUL; + pairwise_multiply->inst_c1 = arg0_type; + return emit_sum_vector (cfg, vector_type, arg0_type, pairwise_multiply); +#elif defined(TARGET_AMD64) + int instc =-1; + if (type_enum_is_float (arg0_type)) { + if (is_SIMD_feature_supported (cfg, MONO_CPU_X86_SSE41)) { + int mask_val = -1; + switch (arg0_type) { + case MONO_TYPE_R4: + instc = COMPILE_LLVM (cfg) ? OP_SSE41_DPPS : OP_SSE41_DPPS_IMM; + mask_val = 0xf1; // 0xf1 ... 0b11110001 + break; + case MONO_TYPE_R8: + instc = COMPILE_LLVM (cfg) ? OP_SSE41_DPPD : OP_SSE41_DPPD_IMM; + mask_val = 0x31; // 0x31 ... 0b00110001 + break; + default: + return NULL; + } + + MonoInst *dot; + if (COMPILE_LLVM (cfg)) { + int mask_reg = alloc_ireg (cfg); + MONO_EMIT_NEW_ICONST (cfg, mask_reg, mask_val); + + dot = emit_simd_ins (cfg, klass, instc, sreg1, sreg2); + dot->sreg3 = mask_reg; + } else { + dot = emit_simd_ins (cfg, klass, instc, sreg1, sreg2); + dot->inst_c0 = mask_val; + } + return extract_first_element (cfg, klass, arg0_type, dot->dreg); + } else { + instc = OP_FMUL; + } + } else { + if (arg0_type == MONO_TYPE_I1 || arg0_type == MONO_TYPE_U1) + return NULL; // We don't support sum vector for byte, sbyte types yet + + // FIXME: + if (!COMPILE_LLVM (cfg)) + return NULL; + + instc = OP_IMUL; + } + MonoInst *pairwise_multiply = emit_simd_ins (cfg, klass, OP_XBINOP, sreg1, sreg2); + pairwise_multiply->inst_c0 = type_enum_is_float (arg0_type) ? OP_FMUL : OP_IMUL; + pairwise_multiply->inst_c1 = arg0_type; + + return emit_sum_vector (cfg, vector_type, arg0_type, pairwise_multiply); +#else + return NULL; +#endif +} + /* * Emit intrinsics in System.Numerics.Vector and System.Runtime.Intrinsics.Vector64/128/256/512. * If the intrinsic is not supported for some reasons, return NULL, and fall back to the c# @@ -1768,70 +1839,7 @@ emit_sri_vector (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *fsi } } case SN_Dot: { - if (!is_element_type_primitive (fsig->params [0])) - return NULL; -#if defined(TARGET_WASM) - if (!COMPILE_LLVM (cfg) && (arg0_type == MONO_TYPE_I8 || arg0_type == MONO_TYPE_U8)) - return NULL; -#elif defined(TARGET_ARM64) - if (!COMPILE_LLVM (cfg) && (arg0_type == MONO_TYPE_I8 || arg0_type == MONO_TYPE_U8 || arg0_type == MONO_TYPE_I || arg0_type == MONO_TYPE_U)) - return NULL; -#endif - -#if defined(TARGET_ARM64) || defined(TARGET_WASM) - int instc0 = type_enum_is_float (arg0_type) ? OP_FMUL : OP_IMUL; - MonoInst *pairwise_multiply = emit_simd_ins_for_sig (cfg, klass, OP_XBINOP, instc0, arg0_type, fsig, args); - return emit_sum_vector (cfg, fsig->params [0], arg0_type, pairwise_multiply); -#elif defined(TARGET_AMD64) - int instc =-1; - if (type_enum_is_float (arg0_type)) { - if (is_SIMD_feature_supported (cfg, MONO_CPU_X86_SSE41)) { - int mask_val = -1; - switch (arg0_type) { - case MONO_TYPE_R4: - instc = COMPILE_LLVM (cfg) ? OP_SSE41_DPPS : OP_SSE41_DPPS_IMM; - mask_val = 0xf1; // 0xf1 ... 0b11110001 - break; - case MONO_TYPE_R8: - instc = COMPILE_LLVM (cfg) ? OP_SSE41_DPPD : OP_SSE41_DPPD_IMM; - mask_val = 0x31; // 0x31 ... 0b00110001 - break; - default: - return NULL; - } - - MonoInst *dot; - if (COMPILE_LLVM (cfg)) { - int mask_reg = alloc_ireg (cfg); - MONO_EMIT_NEW_ICONST (cfg, mask_reg, mask_val); - - dot = emit_simd_ins (cfg, klass, instc, args [0]->dreg, args [1]->dreg); - dot->sreg3 = mask_reg; - } else { - dot = emit_simd_ins (cfg, klass, instc, args [0]->dreg, args [1]->dreg); - dot->inst_c0 = mask_val; - } - - return extract_first_element (cfg, klass, arg0_type, dot->dreg); - } else { - instc = OP_FMUL; - } - } else { - if (arg0_type == MONO_TYPE_I1 || arg0_type == MONO_TYPE_U1) - return NULL; // We don't support sum vector for byte, sbyte types yet - - // FIXME: - if (!COMPILE_LLVM (cfg)) - return NULL; - - instc = OP_IMUL; - } - MonoInst *pairwise_multiply = emit_simd_ins_for_sig (cfg, klass, OP_XBINOP, instc, arg0_type, fsig, args); - - return emit_sum_vector (cfg, fsig->params [0], arg0_type, pairwise_multiply); -#else - return NULL; -#endif + return emit_dot (cfg, klass, fsig->params [0], arg0_type, args [0]->dreg, args [1]->dreg); } case SN_Equals: case SN_EqualsAll: @@ -2910,6 +2918,8 @@ emit_vector_2_3_4 (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *f value [1] = 1.0f; value [2] = 1.0f; value [3] = 1.0f; + if (len == 3) + value [3] = 0.0f; return emit_xconst_v128 (cfg, klass, (guint8*)value); } case SN_set_Item: { @@ -2988,28 +2998,7 @@ emit_vector_2_3_4 (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *f return emit_simd_ins_for_binary_op (cfg, klass, fsig, args, MONO_TYPE_R4, id); } case SN_Dot: { -#if defined(TARGET_ARM64) || defined(TARGET_WASM) - MonoInst *pairwise_multiply = emit_simd_ins_for_sig (cfg, klass, OP_XBINOP, OP_FMUL, MONO_TYPE_R4, fsig, args); - return emit_sum_vector (cfg, fsig->params [0], MONO_TYPE_R4, pairwise_multiply); -#elif defined(TARGET_AMD64) - if (!(mini_get_cpu_features (cfg) & MONO_CPU_X86_SSE41)) - return NULL; - - int mask_reg = alloc_ireg (cfg); - MONO_EMIT_NEW_ICONST (cfg, mask_reg, 0xf1); - MonoInst *dot = emit_simd_ins (cfg, klass, OP_SSE41_DPPS, args [0]->dreg, args [1]->dreg); - dot->sreg3 = mask_reg; - - MONO_INST_NEW (cfg, ins, OP_EXTRACT_R4); - ins->dreg = alloc_freg (cfg); - ins->sreg1 = dot->dreg; - ins->inst_c0 = 0; - ins->inst_c1 = MONO_TYPE_R4; - MONO_ADD_INS (cfg->cbb, ins); - return ins; -#else - return NULL; -#endif + return emit_dot (cfg, klass, fsig->params [0], MONO_TYPE_R4, args [0]->dreg, args [1]->dreg); } case SN_Negate: case SN_op_UnaryNegation: { @@ -3061,7 +3050,6 @@ emit_vector_2_3_4 (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *f #endif } case SN_CopyTo: - // FIXME: https://github.com/dotnet/runtime/issues/91394 return NULL; case SN_Clamp: { if (!(!fsig->hasthis && fsig->param_count == 3 && mono_metadata_type_equal (fsig->ret, type) && mono_metadata_type_equal (fsig->params [0], type) && mono_metadata_type_equal (fsig->params [1], type) && mono_metadata_type_equal (fsig->params [2], type))) @@ -3077,15 +3065,133 @@ emit_vector_2_3_4 (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *f return min; } - case SN_Conjugate: - case SN_Distance: - case SN_DistanceSquared: + case SN_Distance: + case SN_DistanceSquared: { +#if defined(TARGET_ARM64) + MonoInst *diffs = emit_simd_ins_for_sig (cfg, klass, OP_XBINOP, OP_FSUB, MONO_TYPE_R4, fsig, args); + MonoInst *dot = emit_dot(cfg, klass, fsig->params [0], MONO_TYPE_R4, diffs->dreg, diffs->dreg); + + switch (id) { + case SN_Distance: { + dot = emit_simd_ins (cfg, klass, OP_EXPAND_R4, dot->dreg, -1); + dot->inst_c1 = MONO_TYPE_R4; + + MonoInst *sqrt = emit_simd_ins (cfg, klass, OP_XOP_OVR_X_X, dot->dreg, -1); + sqrt->inst_c0 = INTRINS_AARCH64_ADV_SIMD_FSQRT; + sqrt->inst_c1 = MONO_TYPE_R4; + + MonoInst *distance = emit_simd_ins (cfg, klass, OP_EXTRACT_R4, sqrt->dreg, -1); + distance->inst_c0 = 0; + distance->inst_c1 = MONO_TYPE_R4; + return distance; + } + case SN_DistanceSquared: + return dot; + default: + g_assert_not_reached (); + } +#else + return NULL; +#endif + } case SN_Length: - case SN_LengthSquared: - case SN_Lerp: + case SN_LengthSquared: { +#if defined (TARGET_ARM64) + int src1 = load_simd_vreg (cfg, cmethod, args [0], NULL); + MonoInst *dot = emit_dot(cfg, klass, type, MONO_TYPE_R4, src1, src1); + + switch (id) { + case SN_Length: { + dot = emit_simd_ins (cfg, klass, OP_EXPAND_R4, dot->dreg, -1); + dot->inst_c1 = MONO_TYPE_R4; + + MonoInst *sqrt = emit_simd_ins (cfg, klass, OP_XOP_OVR_X_X, dot->dreg, -1); + sqrt->inst_c0 = INTRINS_AARCH64_ADV_SIMD_FSQRT; + sqrt->inst_c1 = MONO_TYPE_R4; + + MonoInst *length = emit_simd_ins (cfg, klass, OP_EXTRACT_R4, sqrt->dreg, -1); + length->inst_c0 = 0; + length->inst_c1 = MONO_TYPE_R4; + return length; + } + case SN_LengthSquared: + return dot; + default: + g_assert_not_reached (); + } +#else + return NULL; +#endif + } + case SN_Lerp: { +#if defined (TARGET_ARM64) + MonoInst* v1 = args [1]; + if (!strcmp ("Quaternion", m_class_get_name (klass))) + return NULL; + + + MonoInst *diffs = emit_simd_ins (cfg, klass, OP_XBINOP, v1->dreg, args [0]->dreg); + diffs->inst_c0 = OP_FSUB; + diffs->inst_c1 = MONO_TYPE_R4; + + MonoInst *scaled_diffs = handle_mul_div_by_scalar (cfg, klass, MONO_TYPE_R4, args [2]->dreg, diffs->dreg, OP_FMUL); + + MonoInst *result = emit_simd_ins (cfg, klass, OP_XBINOP, args [0]->dreg, scaled_diffs->dreg); + result->inst_c0 = OP_FADD; + result->inst_c1 = MONO_TYPE_R4; + + return result; +#else + return NULL; +#endif + } case SN_Normalize: { - // FIXME: https://github.com/dotnet/runtime/issues/91394 +#if defined (TARGET_ARM64) + MonoInst* vec = args[0]; + const char *class_name = m_class_get_name (klass); + if (!strcmp ("Plane", class_name)) { + static float r4_0 = 0; + MonoInst *zero; + int zero_dreg = alloc_freg (cfg); + MONO_INST_NEW (cfg, zero, OP_R4CONST); + zero->inst_p0 = (void*)&r4_0; + zero->dreg = zero_dreg; + MONO_ADD_INS (cfg->cbb, zero); + vec = emit_vector_insert_element (cfg, klass, vec, MONO_TYPE_R4, zero, 3, FALSE); + } + + MonoInst *dot = emit_dot(cfg, klass, type, MONO_TYPE_R4, vec->dreg, vec->dreg); + dot = emit_simd_ins (cfg, klass, OP_EXPAND_R4, dot->dreg, -1); + dot->inst_c1 = MONO_TYPE_R4; + + MonoInst *sqrt_vec = emit_simd_ins (cfg, klass, OP_XOP_OVR_X_X, dot->dreg, -1); + sqrt_vec->inst_c0 = INTRINS_AARCH64_ADV_SIMD_FSQRT; + sqrt_vec->inst_c1 = MONO_TYPE_R4; + + MonoInst *normalized_vec = emit_simd_ins (cfg, klass, OP_XBINOP, args [0]->dreg, sqrt_vec->dreg); + normalized_vec->inst_c0 = OP_FDIV; + normalized_vec->inst_c1 = MONO_TYPE_R4; + + return normalized_vec; +#else return NULL; +#endif + } + case SN_Conjugate: { +#if defined (TARGET_ARM64) + float value[4]; + value [0] = -1.0f; + value [1] = -1.0f; + value [2] = -1.0f; + value [3] = 1.0f; + MonoInst* r = emit_xconst_v128 (cfg, klass, (guint8*)value); + MonoInst* result = emit_simd_ins (cfg, klass, OP_XBINOP, args [0]->dreg, r->dreg); + result->inst_c0 = OP_FMUL; + result->inst_c1 = MONO_TYPE_R4; + return result; +#else + return NULL; +#endif } default: g_assert_not_reached ();