Skip to content

Commit

Permalink
[Mono] [Arm64] Added SIMD support for vector 2/3/4 methods (#98761)
Browse files Browse the repository at this point in the history
* [Mono] [Arm64] Added multiple vector instrinsics

* Added LLVM support

* fix build errors on x64

* Added Quaternion.Conjugate

* Changed frsqrts codegen

* fixed whitespace changes

* Removed reciprocal sqrt estimation from normalize

* Extracted dot method into sepearate function

* Refactored code to use exposed dot function

* Fixed x64 build error

* Fixed more build errors

* Removed intrinsics for methods not intrinsified on coreclr side

* Cleaned up code

* Replaced break with return null

* Removed trailing whitespaces
  • Loading branch information
jkurdek authored Mar 15, 2024
1 parent 334020d commit 516f5c4
Showing 1 changed file with 199 additions and 93 deletions.
292 changes: 199 additions & 93 deletions src/mono/mono/mini/simd-intrinsics.c
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,7 @@ emit_xconst_v128 (MonoCompile *cfg, MonoClass *klass, guint8 value[16])
ins->type = STACK_VTYPE;
ins->dreg = alloc_xreg (cfg);
ins->inst_p0 = mono_mem_manager_alloc (cfg->mem_manager, size);
ins->klass = klass;
MONO_ADD_INS (cfg->cbb, ins);

memcpy (ins->inst_p0, &value[0], size);
Expand Down Expand Up @@ -1390,6 +1391,76 @@ emit_msb_shift_vector_constant (MonoCompile *cfg, MonoClass *arg_class, MonoType
}
#endif

static MonoInst*
emit_dot (MonoCompile *cfg, MonoClass *klass, MonoType *vector_type, MonoTypeEnum arg0_type, int sreg1, int sreg2) {
if (!is_element_type_primitive (vector_type))
return NULL;
#if defined(TARGET_WASM)
if (!COMPILE_LLVM (cfg) && (arg0_type == MONO_TYPE_I8 || arg0_type == MONO_TYPE_U8))
return NULL;
#elif defined(TARGET_ARM64)
if (!COMPILE_LLVM (cfg) && (arg0_type == MONO_TYPE_I8 || arg0_type == MONO_TYPE_U8 || arg0_type == MONO_TYPE_I || arg0_type == MONO_TYPE_U))
return NULL;
#endif

#if defined(TARGET_ARM64) || defined(TARGET_WASM)
MonoInst *pairwise_multiply = emit_simd_ins (cfg, klass, OP_XBINOP, sreg1, sreg2);
pairwise_multiply->inst_c0 = type_enum_is_float (arg0_type) ? OP_FMUL : OP_IMUL;
pairwise_multiply->inst_c1 = arg0_type;
return emit_sum_vector (cfg, vector_type, arg0_type, pairwise_multiply);
#elif defined(TARGET_AMD64)
int instc =-1;
if (type_enum_is_float (arg0_type)) {
if (is_SIMD_feature_supported (cfg, MONO_CPU_X86_SSE41)) {
int mask_val = -1;
switch (arg0_type) {
case MONO_TYPE_R4:
instc = COMPILE_LLVM (cfg) ? OP_SSE41_DPPS : OP_SSE41_DPPS_IMM;
mask_val = 0xf1; // 0xf1 ... 0b11110001
break;
case MONO_TYPE_R8:
instc = COMPILE_LLVM (cfg) ? OP_SSE41_DPPD : OP_SSE41_DPPD_IMM;
mask_val = 0x31; // 0x31 ... 0b00110001
break;
default:
return NULL;
}

MonoInst *dot;
if (COMPILE_LLVM (cfg)) {
int mask_reg = alloc_ireg (cfg);
MONO_EMIT_NEW_ICONST (cfg, mask_reg, mask_val);

dot = emit_simd_ins (cfg, klass, instc, sreg1, sreg2);
dot->sreg3 = mask_reg;
} else {
dot = emit_simd_ins (cfg, klass, instc, sreg1, sreg2);
dot->inst_c0 = mask_val;
}
return extract_first_element (cfg, klass, arg0_type, dot->dreg);
} else {
instc = OP_FMUL;
}
} else {
if (arg0_type == MONO_TYPE_I1 || arg0_type == MONO_TYPE_U1)
return NULL; // We don't support sum vector for byte, sbyte types yet

// FIXME:
if (!COMPILE_LLVM (cfg))
return NULL;

instc = OP_IMUL;
}
MonoInst *pairwise_multiply = emit_simd_ins (cfg, klass, OP_XBINOP, sreg1, sreg2);
pairwise_multiply->inst_c0 = type_enum_is_float (arg0_type) ? OP_FMUL : OP_IMUL;
pairwise_multiply->inst_c1 = arg0_type;

return emit_sum_vector (cfg, vector_type, arg0_type, pairwise_multiply);
#else
return NULL;
#endif
}

/*
* Emit intrinsics in System.Numerics.Vector and System.Runtime.Intrinsics.Vector64/128/256/512.
* If the intrinsic is not supported for some reasons, return NULL, and fall back to the c#
Expand Down Expand Up @@ -1768,70 +1839,7 @@ emit_sri_vector (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *fsi
}
}
case SN_Dot: {
if (!is_element_type_primitive (fsig->params [0]))
return NULL;
#if defined(TARGET_WASM)
if (!COMPILE_LLVM (cfg) && (arg0_type == MONO_TYPE_I8 || arg0_type == MONO_TYPE_U8))
return NULL;
#elif defined(TARGET_ARM64)
if (!COMPILE_LLVM (cfg) && (arg0_type == MONO_TYPE_I8 || arg0_type == MONO_TYPE_U8 || arg0_type == MONO_TYPE_I || arg0_type == MONO_TYPE_U))
return NULL;
#endif

#if defined(TARGET_ARM64) || defined(TARGET_WASM)
int instc0 = type_enum_is_float (arg0_type) ? OP_FMUL : OP_IMUL;
MonoInst *pairwise_multiply = emit_simd_ins_for_sig (cfg, klass, OP_XBINOP, instc0, arg0_type, fsig, args);
return emit_sum_vector (cfg, fsig->params [0], arg0_type, pairwise_multiply);
#elif defined(TARGET_AMD64)
int instc =-1;
if (type_enum_is_float (arg0_type)) {
if (is_SIMD_feature_supported (cfg, MONO_CPU_X86_SSE41)) {
int mask_val = -1;
switch (arg0_type) {
case MONO_TYPE_R4:
instc = COMPILE_LLVM (cfg) ? OP_SSE41_DPPS : OP_SSE41_DPPS_IMM;
mask_val = 0xf1; // 0xf1 ... 0b11110001
break;
case MONO_TYPE_R8:
instc = COMPILE_LLVM (cfg) ? OP_SSE41_DPPD : OP_SSE41_DPPD_IMM;
mask_val = 0x31; // 0x31 ... 0b00110001
break;
default:
return NULL;
}

MonoInst *dot;
if (COMPILE_LLVM (cfg)) {
int mask_reg = alloc_ireg (cfg);
MONO_EMIT_NEW_ICONST (cfg, mask_reg, mask_val);

dot = emit_simd_ins (cfg, klass, instc, args [0]->dreg, args [1]->dreg);
dot->sreg3 = mask_reg;
} else {
dot = emit_simd_ins (cfg, klass, instc, args [0]->dreg, args [1]->dreg);
dot->inst_c0 = mask_val;
}

return extract_first_element (cfg, klass, arg0_type, dot->dreg);
} else {
instc = OP_FMUL;
}
} else {
if (arg0_type == MONO_TYPE_I1 || arg0_type == MONO_TYPE_U1)
return NULL; // We don't support sum vector for byte, sbyte types yet

// FIXME:
if (!COMPILE_LLVM (cfg))
return NULL;

instc = OP_IMUL;
}
MonoInst *pairwise_multiply = emit_simd_ins_for_sig (cfg, klass, OP_XBINOP, instc, arg0_type, fsig, args);

return emit_sum_vector (cfg, fsig->params [0], arg0_type, pairwise_multiply);
#else
return NULL;
#endif
return emit_dot (cfg, klass, fsig->params [0], arg0_type, args [0]->dreg, args [1]->dreg);
}
case SN_Equals:
case SN_EqualsAll:
Expand Down Expand Up @@ -2910,6 +2918,8 @@ emit_vector_2_3_4 (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *f
value [1] = 1.0f;
value [2] = 1.0f;
value [3] = 1.0f;
if (len == 3)
value [3] = 0.0f;
return emit_xconst_v128 (cfg, klass, (guint8*)value);
}
case SN_set_Item: {
Expand Down Expand Up @@ -2988,28 +2998,7 @@ emit_vector_2_3_4 (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *f
return emit_simd_ins_for_binary_op (cfg, klass, fsig, args, MONO_TYPE_R4, id);
}
case SN_Dot: {
#if defined(TARGET_ARM64) || defined(TARGET_WASM)
MonoInst *pairwise_multiply = emit_simd_ins_for_sig (cfg, klass, OP_XBINOP, OP_FMUL, MONO_TYPE_R4, fsig, args);
return emit_sum_vector (cfg, fsig->params [0], MONO_TYPE_R4, pairwise_multiply);
#elif defined(TARGET_AMD64)
if (!(mini_get_cpu_features (cfg) & MONO_CPU_X86_SSE41))
return NULL;

int mask_reg = alloc_ireg (cfg);
MONO_EMIT_NEW_ICONST (cfg, mask_reg, 0xf1);
MonoInst *dot = emit_simd_ins (cfg, klass, OP_SSE41_DPPS, args [0]->dreg, args [1]->dreg);
dot->sreg3 = mask_reg;

MONO_INST_NEW (cfg, ins, OP_EXTRACT_R4);
ins->dreg = alloc_freg (cfg);
ins->sreg1 = dot->dreg;
ins->inst_c0 = 0;
ins->inst_c1 = MONO_TYPE_R4;
MONO_ADD_INS (cfg->cbb, ins);
return ins;
#else
return NULL;
#endif
return emit_dot (cfg, klass, fsig->params [0], MONO_TYPE_R4, args [0]->dreg, args [1]->dreg);
}
case SN_Negate:
case SN_op_UnaryNegation: {
Expand Down Expand Up @@ -3061,7 +3050,6 @@ emit_vector_2_3_4 (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *f
#endif
}
case SN_CopyTo:
// FIXME: https://github.com/dotnet/runtime/issues/91394
return NULL;
case SN_Clamp: {
if (!(!fsig->hasthis && fsig->param_count == 3 && mono_metadata_type_equal (fsig->ret, type) && mono_metadata_type_equal (fsig->params [0], type) && mono_metadata_type_equal (fsig->params [1], type) && mono_metadata_type_equal (fsig->params [2], type)))
Expand All @@ -3077,15 +3065,133 @@ emit_vector_2_3_4 (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *f

return min;
}
case SN_Conjugate:
case SN_Distance:
case SN_DistanceSquared:
case SN_Distance:
case SN_DistanceSquared: {
#if defined(TARGET_ARM64)
MonoInst *diffs = emit_simd_ins_for_sig (cfg, klass, OP_XBINOP, OP_FSUB, MONO_TYPE_R4, fsig, args);
MonoInst *dot = emit_dot(cfg, klass, fsig->params [0], MONO_TYPE_R4, diffs->dreg, diffs->dreg);

switch (id) {
case SN_Distance: {
dot = emit_simd_ins (cfg, klass, OP_EXPAND_R4, dot->dreg, -1);
dot->inst_c1 = MONO_TYPE_R4;

MonoInst *sqrt = emit_simd_ins (cfg, klass, OP_XOP_OVR_X_X, dot->dreg, -1);
sqrt->inst_c0 = INTRINS_AARCH64_ADV_SIMD_FSQRT;
sqrt->inst_c1 = MONO_TYPE_R4;

MonoInst *distance = emit_simd_ins (cfg, klass, OP_EXTRACT_R4, sqrt->dreg, -1);
distance->inst_c0 = 0;
distance->inst_c1 = MONO_TYPE_R4;
return distance;
}
case SN_DistanceSquared:
return dot;
default:
g_assert_not_reached ();
}
#else
return NULL;
#endif
}
case SN_Length:
case SN_LengthSquared:
case SN_Lerp:
case SN_LengthSquared: {
#if defined (TARGET_ARM64)
int src1 = load_simd_vreg (cfg, cmethod, args [0], NULL);
MonoInst *dot = emit_dot(cfg, klass, type, MONO_TYPE_R4, src1, src1);

switch (id) {
case SN_Length: {
dot = emit_simd_ins (cfg, klass, OP_EXPAND_R4, dot->dreg, -1);
dot->inst_c1 = MONO_TYPE_R4;

MonoInst *sqrt = emit_simd_ins (cfg, klass, OP_XOP_OVR_X_X, dot->dreg, -1);
sqrt->inst_c0 = INTRINS_AARCH64_ADV_SIMD_FSQRT;
sqrt->inst_c1 = MONO_TYPE_R4;

MonoInst *length = emit_simd_ins (cfg, klass, OP_EXTRACT_R4, sqrt->dreg, -1);
length->inst_c0 = 0;
length->inst_c1 = MONO_TYPE_R4;
return length;
}
case SN_LengthSquared:
return dot;
default:
g_assert_not_reached ();
}
#else
return NULL;
#endif
}
case SN_Lerp: {
#if defined (TARGET_ARM64)
MonoInst* v1 = args [1];
if (!strcmp ("Quaternion", m_class_get_name (klass)))
return NULL;


MonoInst *diffs = emit_simd_ins (cfg, klass, OP_XBINOP, v1->dreg, args [0]->dreg);
diffs->inst_c0 = OP_FSUB;
diffs->inst_c1 = MONO_TYPE_R4;

MonoInst *scaled_diffs = handle_mul_div_by_scalar (cfg, klass, MONO_TYPE_R4, args [2]->dreg, diffs->dreg, OP_FMUL);

MonoInst *result = emit_simd_ins (cfg, klass, OP_XBINOP, args [0]->dreg, scaled_diffs->dreg);
result->inst_c0 = OP_FADD;
result->inst_c1 = MONO_TYPE_R4;

return result;
#else
return NULL;
#endif
}
case SN_Normalize: {
// FIXME: https://github.com/dotnet/runtime/issues/91394
#if defined (TARGET_ARM64)
MonoInst* vec = args[0];
const char *class_name = m_class_get_name (klass);
if (!strcmp ("Plane", class_name)) {
static float r4_0 = 0;
MonoInst *zero;
int zero_dreg = alloc_freg (cfg);
MONO_INST_NEW (cfg, zero, OP_R4CONST);
zero->inst_p0 = (void*)&r4_0;
zero->dreg = zero_dreg;
MONO_ADD_INS (cfg->cbb, zero);
vec = emit_vector_insert_element (cfg, klass, vec, MONO_TYPE_R4, zero, 3, FALSE);
}

MonoInst *dot = emit_dot(cfg, klass, type, MONO_TYPE_R4, vec->dreg, vec->dreg);
dot = emit_simd_ins (cfg, klass, OP_EXPAND_R4, dot->dreg, -1);
dot->inst_c1 = MONO_TYPE_R4;

MonoInst *sqrt_vec = emit_simd_ins (cfg, klass, OP_XOP_OVR_X_X, dot->dreg, -1);
sqrt_vec->inst_c0 = INTRINS_AARCH64_ADV_SIMD_FSQRT;
sqrt_vec->inst_c1 = MONO_TYPE_R4;

MonoInst *normalized_vec = emit_simd_ins (cfg, klass, OP_XBINOP, args [0]->dreg, sqrt_vec->dreg);
normalized_vec->inst_c0 = OP_FDIV;
normalized_vec->inst_c1 = MONO_TYPE_R4;

return normalized_vec;
#else
return NULL;
#endif
}
case SN_Conjugate: {
#if defined (TARGET_ARM64)
float value[4];
value [0] = -1.0f;
value [1] = -1.0f;
value [2] = -1.0f;
value [3] = 1.0f;
MonoInst* r = emit_xconst_v128 (cfg, klass, (guint8*)value);
MonoInst* result = emit_simd_ins (cfg, klass, OP_XBINOP, args [0]->dreg, r->dreg);
result->inst_c0 = OP_FMUL;
result->inst_c1 = MONO_TYPE_R4;
return result;
#else
return NULL;
#endif
}
default:
g_assert_not_reached ();
Expand Down

0 comments on commit 516f5c4

Please sign in to comment.