Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Mono] [Arm64] Added SIMD support for vector 2/3/4 methods #98761

Merged
merged 16 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/mono/mono/arch/arm64/arm64-codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -1270,6 +1270,7 @@ arm_encode_arith_imm (int imm, guint32 *shift)
#define arm_neon_fabs(p, width, type, rd, rn) arm_neon_2mvec_opcode ((p), (width), 0b0, 0b10 | (type), 0b01111, (rd), (rn))
#define arm_neon_fneg(p, width, type, rd, rn) arm_neon_2mvec_opcode ((p), (width), 0b1, 0b10 | (type), 0b01111, (rd), (rn))
#define arm_neon_fsqrt(p, width, type, rd, rn) arm_neon_2mvec_opcode ((p), (width), 0b1, 0b10 | (type), 0b11111, (rd), (rn))
#define arm_neon_frsqrte(p, width, type, rd, rn) arm_neon_2mvec_opcode ((p), (width), 0b1, 0b10 | (type), 0b11101, (rd), (rn))
#define arm_neon_fcvtn(p, rd, rn) arm_neon_2mvec_opcode ((p), VREG_LOW, 0b0, SIZE_2, 0b10110, (rd), (rn))
#define arm_neon_fcvtn2(p, rd, rn) arm_neon_2mvec_opcode ((p), VREG_FULL, 0b0, SIZE_2, 0b10110, (rd), (rn))
#define arm_neon_fcvtl(p, rd, rn) arm_neon_2mvec_opcode ((p), VREG_LOW, 0b0, SIZE_2, 0b10111, (rd), (rn))
Expand Down Expand Up @@ -1845,6 +1846,7 @@ arm_encode_arith_imm (int imm, guint32 *shift)
#define arm_neon_fcmge(p, width, type, rd, rn, rm) arm_neon_3svec_opcode ((p), (width), 0b1, (type), 0b11100, (rd), (rn), (rm))
#define arm_neon_fcmgt(p, width, type, rd, rn, rm) arm_neon_3svec_opcode ((p), (width), 0b1, 0b10 | (type), 0b11100, (rd), (rn), (rm))
#define arm_neon_faddp(p, width, type, rd, rn, rm) arm_neon_3svec_opcode ((p), (width), 0b1, (type), 0b11010, (rd), (rn), (rm))
#define arm_neon_frsqrts(p, width, type, rd, rn, rm) arm_neon_3svec_opcode ((p), (width), 0b0, 0b10 | (type), 0b11111, (rd), (rn), (rm))

// Generalized macros for bitwise ops:
// width - determines if full register or its lower half is used one of {VREG_LOW, VREG_FULL}
Expand Down
3 changes: 3 additions & 0 deletions src/mono/mono/mini/mini-arm64.c
Original file line number Diff line number Diff line change
Expand Up @@ -4112,6 +4112,9 @@ mono_arch_output_basic_block (MonoCompile *cfg, MonoBasicBlock *bb)
case INTRINS_AARCH64_ADV_SIMD_USHL:
arm_neon_ushl (code, get_vector_size_macro (ins), get_type_size_macro (ins->inst_c1), dreg, sreg1, sreg2);
break;
case INTRINS_AARCH64_ADV_SIMD_FRSQRTS:
arm_neon_frsqrts (code, get_vector_size_macro (ins), get_type_size_macro (ins->inst_c1), dreg, sreg1, sreg2);
break;
default:
g_assert_not_reached ();
break;
Expand Down
1 change: 0 additions & 1 deletion src/mono/mono/mini/mini-llvm.c
Original file line number Diff line number Diff line change
Expand Up @@ -8486,7 +8486,6 @@ MONO_RESTORE_WARNING
#endif
break;
}

jkurdek marked this conversation as resolved.
Show resolved Hide resolved
default:
g_assert_not_reached ();
}
Expand Down
2 changes: 2 additions & 0 deletions src/mono/mono/mini/simd-arm64.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ SIMD_OP (64, OP_XBINOP, OP_FDIV, WTDSS, _UNDEF,
SIMD_OP (64, OP_ARM64_XADDV, INTRINS_AARCH64_ADV_SIMD_FADDV, WTDS, _UNDEF, _UNDEF, _UNDEF, _UNDEF, _SKIP, _UNDEF)
SIMD_OP (64, OP_XOP_OVR_X_X, INTRINS_AARCH64_ADV_SIMD_FSQRT, WTDS, _UNDEF, _UNDEF, _UNDEF, _UNDEF, arm_neon_fsqrt, _UNDEF)
SIMD_OP (64, OP_XOP_OVR_X_X, INTRINS_AARCH64_ADV_SIMD_FABS, WTDS, _UNDEF, _UNDEF, _UNDEF, _UNDEF, arm_neon_fabs, _UNDEF)
SIMD_OP (64, OP_XOP_OVR_X_X, INTRINS_AARCH64_ADV_SIMD_FRSQRTE,WTDS, _UNDEF, _UNDEF, _UNDEF, _UNDEF, arm_neon_frsqrte, _UNDEF)

/* 128-bit vectors */
/* Width Opcode Function Operand config I8 I16 I32 I64 F32 F64 */
Expand Down Expand Up @@ -91,3 +92,4 @@ SIMD_OP (128, OP_XOP_OVR_X_X, INTRINS_SIMD_FLOOR, WTDS, _UNDEF, _U
SIMD_OP (128, OP_XOP_OVR_X_X, INTRINS_AARCH64_ADV_SIMD_FSQRT, WTDS, _UNDEF, _UNDEF, _UNDEF, _UNDEF, arm_neon_fsqrt, arm_neon_fsqrt)
SIMD_OP (128, OP_XOP_OVR_X_X, INTRINS_AARCH64_ADV_SIMD_ABS, WTDS, arm_neon_abs, arm_neon_abs, arm_neon_abs, arm_neon_abs, _UNDEF, _UNDEF)
SIMD_OP (128, OP_XOP_OVR_X_X, INTRINS_AARCH64_ADV_SIMD_FABS, WTDS, _UNDEF, _UNDEF, _UNDEF, _UNDEF, arm_neon_fabs, arm_neon_fabs)
SIMD_OP (128, OP_XOP_OVR_X_X, INTRINS_AARCH64_ADV_SIMD_FRSQRTE,WTDS, _UNDEF, _UNDEF, _UNDEF, _UNDEF, arm_neon_frsqrte, arm_neon_frsqrte)
233 changes: 223 additions & 10 deletions src/mono/mono/mini/simd-intrinsics.c
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,7 @@ emit_xconst_v128 (MonoCompile *cfg, MonoClass *klass, guint8 value[16])
ins->type = STACK_VTYPE;
ins->dreg = alloc_xreg (cfg);
ins->inst_p0 = mono_mem_manager_alloc (cfg->mem_manager, size);
ins->klass = klass;
MONO_ADD_INS (cfg->cbb, ins);

memcpy (ins->inst_p0, &value[0], size);
Expand Down Expand Up @@ -688,6 +689,27 @@ emit_sum_vector (MonoCompile *cfg, MonoType *vector_type, MonoTypeEnum element_t
return ins;
}
}

static MonoInst*
emit_sum_sqrt_vector_2_3_4 (MonoCompile *cfg, MonoClass *klass, MonoInst *arg) {
MonoInst *sum = emit_simd_ins (cfg, klass, OP_ARM64_XADDV, arg->dreg, -1);
sum->inst_c0 = INTRINS_AARCH64_ADV_SIMD_FADDV;
sum->inst_c1 = MONO_TYPE_R4;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worth noting that Arm has typically pushed us away from using FADDV as it does not perform well on some hardware.

Rather instead they had us use a sequence of FADDP (AddPairwise) instructions which tend to have better perf/throughput: https://github.com/dotnet/runtime/blob/main/src/coreclr/jit/gentree.cpp#L25190-L25210

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for sharing the information, @tannergooding. @jkurdek Feel free to create an issue to address it in a future PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


if (COMPILE_LLVM (cfg)) {
sum = emit_simd_ins (cfg, klass, OP_EXPAND_R4, sum->dreg, -1);
sum->inst_c1 = MONO_TYPE_R4;
}

MonoInst* sum_sqrt = emit_simd_ins (cfg, klass, OP_XOP_OVR_X_X, sum->dreg, -1);
sum_sqrt->inst_c0 = INTRINS_AARCH64_ADV_SIMD_FSQRT;
sum_sqrt->inst_c1 = MONO_TYPE_R4;

MonoInst *ins = emit_simd_ins (cfg, klass, OP_EXTRACT_R4, sum_sqrt->dreg, -1);
ins->inst_c0 = 0;
ins->inst_c1 = MONO_TYPE_R4;
return ins;
}
#endif
#ifdef TARGET_WASM
static MonoInst* emit_sum_vector (MonoCompile *cfg, MonoType *vector_type, MonoTypeEnum element_type, MonoInst *arg);
Expand Down Expand Up @@ -1087,6 +1109,63 @@ emit_vector_insert_element (
return ins;
}

#if defined(TARGET_ARM64)
static MonoInst*
emit_normalize_vector_2_3_4 (MonoCompile *cfg, MonoClass *klass, MonoInst *arg){
jkurdek marked this conversation as resolved.
Show resolved Hide resolved
MonoInst *vec_squared = emit_simd_ins (cfg, klass, OP_XBINOP, arg->dreg, arg->dreg);
vec_squared->inst_c0 = OP_FMUL;
vec_squared->inst_c1 = MONO_TYPE_R4;

const char *class_name = m_class_get_name (klass);
if (!strcmp ("Plane", class_name)) {
fanyang-mono marked this conversation as resolved.
Show resolved Hide resolved
static float r4_0 = 0;
MonoInst *zero;
int zero_dreg = alloc_freg (cfg);
MONO_INST_NEW (cfg, zero, OP_R4CONST);
zero->inst_p0 = (void*)&r4_0;
zero->dreg = zero_dreg;
MONO_ADD_INS (cfg->cbb, zero);
vec_squared = emit_vector_insert_element (cfg, klass, vec_squared, MONO_TYPE_R4, zero, 3, FALSE);
}

MonoInst *sum = emit_simd_ins (cfg, klass, OP_ARM64_XADDV, vec_squared->dreg, -1);
sum->inst_c0 = INTRINS_AARCH64_ADV_SIMD_FADDV;
sum->inst_c1 = MONO_TYPE_R4;

if (COMPILE_LLVM (cfg)) {
sum = emit_simd_ins (cfg, klass, OP_EXPAND_R4, sum->dreg, -1);
sum->inst_c1 = MONO_TYPE_R4;
}

MonoInst *recip_sqrt = emit_simd_ins (cfg, klass, OP_XOP_OVR_X_X, sum->dreg, -1);
recip_sqrt->inst_c0 = INTRINS_AARCH64_ADV_SIMD_FRSQRTE;
jkurdek marked this conversation as resolved.
Show resolved Hide resolved
recip_sqrt->inst_c1 = MONO_TYPE_R4;


MonoInst *recip_sqrt_2, *corr;

for (int i = 0; i < 2; i++) {
recip_sqrt_2 = emit_simd_ins (cfg, klass, OP_XBINOP, recip_sqrt->dreg, recip_sqrt->dreg);
recip_sqrt_2->inst_c0 = OP_FMUL;
recip_sqrt_2->inst_c1 = MONO_TYPE_R4;

corr = emit_simd_ins (cfg, klass, OP_XOP_OVR_X_X_X, sum->dreg, recip_sqrt_2->dreg);
corr->inst_c0 = INTRINS_AARCH64_ADV_SIMD_FRSQRTS;
corr->inst_c1 = MONO_TYPE_R4;

recip_sqrt = emit_simd_ins (cfg, klass, OP_XBINOP, recip_sqrt->dreg, corr->dreg);
recip_sqrt->inst_c0 = OP_FMUL;
recip_sqrt->inst_c1 = MONO_TYPE_R4;
}

MonoInst *normalized_vec = emit_simd_ins (cfg, klass, OP_XBINOP, arg->dreg, recip_sqrt->dreg);
normalized_vec->inst_c0 = OP_FMUL;
normalized_vec->inst_c1 = MONO_TYPE_R4;

return normalized_vec;
}
#endif

static MonoInst *
emit_vector_create_elementwise (
MonoCompile *cfg, MonoMethodSignature *fsig, MonoType *vtype,
Expand Down Expand Up @@ -2926,6 +3005,8 @@ emit_vector_2_3_4 (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *f
value [1] = 1.0f;
value [2] = 1.0f;
value [3] = 1.0f;
if (len == 3)
value [3] = 0.0f;
return emit_xconst_v128 (cfg, klass, (guint8*)value);
}
case SN_set_Item: {
Expand Down Expand Up @@ -3076,9 +3157,44 @@ emit_vector_2_3_4 (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *f
return NULL;
#endif
}
case SN_CopyTo:
// FIXME: https://github.com/dotnet/runtime/issues/91394
return NULL;
case SN_CopyTo: {
jkurdek marked this conversation as resolved.
Show resolved Hide resolved
#if defined(TARGET_ARM64)
if ((fsig->param_count == 1 || fsig->param_count == 2) && (fsig->params [0]->type == MONO_TYPE_SZARRAY)) {
MonoInst *index_ins;
int val_vreg, end_index_reg;
val_vreg = load_simd_vreg (cfg, cmethod, args [0], NULL);

if (fsig->param_count == 2) {
index_ins = args [2];
} else {
EMIT_NEW_ICONST (cfg, index_ins, 0);
}

MonoInst *ldelema_ins;
MonoInst *array_ins = args [1];

/* CopyTo () does complicated argument checks */
mini_emit_bounds_check_offset (cfg, array_ins->dreg, MONO_STRUCT_OFFSET (MonoArray, max_length), index_ins->dreg, "ArgumentOutOfRangeException", FALSE);
end_index_reg = alloc_ireg (cfg);
int len_reg = alloc_ireg (cfg);
MONO_EMIT_NEW_LOAD_MEMBASE_OP_FLAGS (cfg, OP_LOADI4_MEMBASE, len_reg, array_ins->dreg, MONO_STRUCT_OFFSET (MonoArray, max_length), MONO_INST_INVARIANT_LOAD);
EMIT_NEW_BIALU (cfg, ins, OP_ISUB, end_index_reg, len_reg, index_ins->dreg);
MONO_EMIT_NEW_BIALU_IMM (cfg, OP_COMPARE_IMM, -1, end_index_reg, len);
MONO_EMIT_NEW_COND_EXC (cfg, LT, "ArgumentException");

/* Load the array slice into the simd reg */
ldelema_ins = mini_emit_ldelema_1_ins (cfg, mono_class_from_mono_type_internal (etype), array_ins, index_ins, FALSE, FALSE);
EMIT_NEW_STORE_MEMBASE (cfg, ins, OP_STOREX_MEMBASE, ldelema_ins->dreg, 0, val_vreg);
ins->klass = cmethod->klass;
return ins;
} else {
// CopyTo(Span<Single>)
// Not intrinsified on coreclr
return NULL;
}
#endif
}
break;
case SN_Clamp: {
if (!(!fsig->hasthis && fsig->param_count == 3 && mono_metadata_type_equal (fsig->ret, type) && mono_metadata_type_equal (fsig->params [0], type) && mono_metadata_type_equal (fsig->params [1], type) && mono_metadata_type_equal (fsig->params [2], type)))
return NULL;
Expand All @@ -3093,16 +3209,113 @@ emit_vector_2_3_4 (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *f

return min;
}
case SN_Conjugate:
case SN_Distance:
case SN_DistanceSquared:
case SN_Distance:
case SN_DistanceSquared: {
#if defined(TARGET_ARM64)
jkurdek marked this conversation as resolved.
Show resolved Hide resolved
MonoInst *diffs = emit_simd_ins_for_sig (cfg, klass, OP_XBINOP, OP_FSUB, MONO_TYPE_R4, fsig, args);
MonoInst *diffs_squared = emit_simd_ins (cfg, klass, OP_XBINOP, diffs->dreg, diffs->dreg);
diffs_squared->inst_c0 = OP_FMUL;
diffs_squared->inst_c1 = MONO_TYPE_R4;

switch (id) {
jkurdek marked this conversation as resolved.
Show resolved Hide resolved
case SN_Distance:
return emit_sum_sqrt_vector_2_3_4 (cfg, klass, diffs_squared);
case SN_DistanceSquared:
return emit_sum_vector (cfg, fsig->params [0], MONO_TYPE_R4, diffs_squared);
default:
g_assert_not_reached ();
}
#endif
}
break;
jkurdek marked this conversation as resolved.
Show resolved Hide resolved
case SN_Length:
case SN_LengthSquared:
case SN_Lerp:
case SN_LengthSquared: {
#if defined (TARGET_ARM64)
int src1 = load_simd_vreg (cfg, cmethod, args [0], NULL);

MonoInst *vec_squared = emit_simd_ins (cfg, klass, OP_XBINOP, src1, src1);
vec_squared->inst_c0 = OP_FMUL;
vec_squared->inst_c1 = MONO_TYPE_R4;

switch (id) {
jkurdek marked this conversation as resolved.
Show resolved Hide resolved
case SN_Length:
return emit_sum_sqrt_vector_2_3_4 (cfg, klass, vec_squared);
case SN_LengthSquared:
return emit_sum_vector (cfg, type, MONO_TYPE_R4, vec_squared);
default:
g_assert_not_reached ();
}
#endif
}
break;
case SN_Lerp: {
#if defined (TARGET_ARM64)
MonoInst* v1 = args [1];
if (!strcmp ("Quaternion", m_class_get_name (klass))) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quaternion.Lerp is not marked as intrinsic in the libraries:

/// <summary>Performs a linear interpolation between two quaternions based on a value that specifies the weighting of the second quaternion.</summary>
/// <param name="quaternion1">The first quaternion.</param>
/// <param name="quaternion2">The second quaternion.</param>
/// <param name="amount">The relative weight of <paramref name="quaternion2" /> in the interpolation.</param>
/// <returns>The interpolated quaternion.</returns>
public static Quaternion Lerp(Quaternion quaternion1, Quaternion quaternion2, float amount)

However, if this implementation generates better codegen we should probably keep it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The intrinsified version is around 40% faster on my machine

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the intrinsified version faster here? Is it fundamentally doing something differently from the managed implementation or is there potentially a missing JIT optimization?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or perhaps there is simply missing a change on the managed side and so its using scalar logic rather than any actual vectorization and a better fix is to update the managed impl?


We've typically tried to keep a clear separation between intrinsic functionality and more complex methods.

APIs like operator + or Sqrt are generally mapped to exactly 1 hardware instruction and this is the case for most platforms.

APIs like DotProduct or even Create may be mapped to exactly 1 hardware instruction on some platforms and are fairly "core" to the general throughput considerations of many platforms.

APIs like Quaternion.Lerp or CopyTo are more complex functions which use multiple instructions on all platforms and which may even require branching or masking logic. So, we've typically tried to keep them in managed and have them use the intrinsic APIs instead.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that we should align Mono's behavior with CoreCLR, not intrinsifying Quaternion.Lerp or CopyTo for mono either.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the intrinsified version faster here? Is it fundamentally doing something differently from the managed implementation or is there potentially a missing JIT optimization?

In general, Mono's mini JIT doesn't have as comprehensive optimizations as CoreCLR's RyuJIT.

MonoInst *pairwise_multiply = emit_simd_ins_for_sig (cfg, klass, OP_XBINOP, OP_FMUL, MONO_TYPE_R4, fsig, args);
pairwise_multiply->sreg3 = -1;
MonoInst *dot = emit_simd_ins (cfg, klass, OP_ARM64_XADDV, pairwise_multiply->dreg, -1);
dot->inst_c0 = INTRINS_AARCH64_ADV_SIMD_FADDV;
dot->inst_c1 = MONO_TYPE_R4;

if (COMPILE_LLVM (cfg)) {
dot = emit_simd_ins (cfg, klass, OP_EXPAND_R4, dot->dreg, -1);
dot->inst_c1 = MONO_TYPE_R4;
}

MonoInst* zeros = emit_xzero (cfg, klass);

MonoInst* ge_0 = emit_simd_ins (cfg, klass, OP_XCOMPARE_FP, dot->dreg, zeros->dreg);
ge_0->inst_c0 = CMP_GE;
ge_0->inst_c1 = MONO_TYPE_R4;

MonoInst* negated_v1 = emit_simd_ins (cfg, klass, OP_NEGATION, args [1]->dreg, -1);
negated_v1->inst_c1 = MONO_TYPE_R4;

v1 = emit_simd_ins (cfg, klass, OP_BSL, ge_0->dreg, args [1]->dreg);
v1->sreg3 = negated_v1->dreg;
v1->inst_c1 = MONO_TYPE_R4;
}

MonoInst *diffs = emit_simd_ins (cfg, klass, OP_XBINOP, v1->dreg, args [0]->dreg);
diffs->inst_c0 = OP_FSUB;
diffs->inst_c1 = MONO_TYPE_R4;

MonoInst *scaled_diffs = handle_mul_div_by_scalar (cfg, klass, MONO_TYPE_R4, args [2]->dreg, diffs->dreg, OP_FMUL);

MonoInst *result = emit_simd_ins (cfg, klass, OP_XBINOP, args [0]->dreg, scaled_diffs->dreg);
result->inst_c0 = OP_FADD;
result->inst_c1 = MONO_TYPE_R4;

if (!strcmp ("Quaternion", m_class_get_name (klass))) {
return emit_normalize_vector_2_3_4 (cfg, klass, result);
}

return result;
#endif
}
break;
case SN_Normalize: {
// FIXME: https://github.com/dotnet/runtime/issues/91394
return NULL;
#if defined (TARGET_ARM64)
return emit_normalize_vector_2_3_4 (cfg, klass, args[0]);
#endif
}
break;
case SN_Conjugate: {
#if defined (TARGET_ARM64)
float value[4];
value [0] = -1.0f;
value [1] = -1.0f;
value [2] = -1.0f;
value [3] = 1.0f;
MonoInst* r = emit_xconst_v128 (cfg, klass, (guint8*)value);
MonoInst* result = emit_simd_ins (cfg, klass, OP_XBINOP, args [0]->dreg, r->dreg);
result->inst_c0 = OP_FMUL;
result->inst_c1 = MONO_TYPE_R4;
return result;
#endif
}
break;
default:
g_assert_not_reached ();
}
Expand Down