Skip to content

[mono] Fix vector class retrieval and type checks for binary operand APIs #107388

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Oct 8, 2024
209 changes: 109 additions & 100 deletions src/mono/mono/mini/simd-intrinsics.c
Original file line number Diff line number Diff line change
Expand Up @@ -339,111 +339,73 @@ emit_simd_ins_for_binary_op (MonoCompile *cfg, MonoClass *klass, MonoMethodSigna
int instc0 = -1;
int op = OP_XBINOP;

if (id == SN_BitwiseAnd || id == SN_BitwiseOr || id == SN_Xor ||
id == SN_op_BitwiseAnd || id == SN_op_BitwiseOr || id == SN_op_ExclusiveOr) {
op = OP_XBINOP_FORCEINT;

switch (id) {
switch (id) {
case SN_Add:
case SN_op_Addition: {
if (type_enum_is_float (arg_type)) {
instc0 = OP_FADD;
} else {
instc0 = OP_IADD;
}
break;
}
case SN_BitwiseAnd:
case SN_op_BitwiseAnd:
case SN_op_BitwiseAnd: {
op = OP_XBINOP_FORCEINT;
instc0 = XBINOP_FORCEINT_AND;
break;
}
case SN_BitwiseOr:
case SN_op_BitwiseOr:
case SN_op_BitwiseOr: {
op = OP_XBINOP_FORCEINT;
instc0 = XBINOP_FORCEINT_OR;
break;
case SN_op_ExclusiveOr:
case SN_Xor:
instc0 = XBINOP_FORCEINT_XOR;
break;
}
} else {
if (type_enum_is_float (arg_type)) {
switch (id) {
case SN_Add:
case SN_op_Addition:
instc0 = OP_FADD;
break;
case SN_Divide:
case SN_op_Division: {
const char *class_name = m_class_get_name (klass);
if (strcmp ("Quaternion", class_name) && strcmp ("Plane", class_name)) {
if (!type_is_simd_vector (fsig->params [1]))
return handle_mul_div_by_scalar (cfg, klass, arg_type, args [1]->dreg, args [0]->dreg, OP_FDIV);
else if (type_is_simd_vector (fsig->params [0]) && type_is_simd_vector (fsig->params [1])) {
instc0 = OP_FDIV;
break;
} else {
return NULL;
}
}
case SN_Divide:
case SN_op_Division: {
if (type_enum_is_float (arg_type)) {
instc0 = OP_FDIV;
break;
}
#ifdef TARGET_ARM64
case SN_Max:
#endif
case SN_MaxNative:
instc0 = OP_FMAX;
break;
#ifdef TARGET_ARM64
case SN_Min:
#endif
case SN_MinNative:
instc0 = OP_FMIN;
break;
case SN_Multiply:
case SN_op_Multiply: {
const char *class_name = m_class_get_name (klass);
if (strcmp ("Quaternion", class_name) && strcmp ("Plane", class_name)) {
if (!type_is_simd_vector (fsig->params [1]))
return handle_mul_div_by_scalar (cfg, klass, arg_type, args [1]->dreg, args [0]->dreg, OP_FMUL);
else if (!type_is_simd_vector (fsig->params [0]))
return handle_mul_div_by_scalar (cfg, klass, arg_type, args [0]->dreg, args [1]->dreg, OP_FMUL);
else if (type_is_simd_vector (fsig->params [0]) && type_is_simd_vector (fsig->params [1])) {
instc0 = OP_FMUL;
break;
} else {
return NULL;
}
if (MONO_TYPE_IS_VECTOR_PRIMITIVE (fsig->params [1])) { // vector / scalar
return handle_mul_div_by_scalar (cfg, klass, arg_type, args [1]->dreg, args [0]->dreg, instc0);
}
instc0 = OP_FMUL;
break;
}
case SN_Subtract:
case SN_op_Subtraction:
instc0 = OP_FSUB;
break;
default:
g_assert_not_reached ();
}
} else {
switch (id) {
case SN_Add:
case SN_op_Addition:
instc0 = OP_IADD;
break;
case SN_Divide:
case SN_op_Division:
} else {
return NULL;
case SN_Max:
case SN_MaxNative:
}
break;
}
case SN_Max:
case SN_MaxNative: {
if (type_enum_is_float (arg_type)) {
instc0 = OP_FMAX;
} else {
instc0 = type_enum_is_unsigned (arg_type) ? OP_IMAX_UN : OP_IMAX;

#ifdef TARGET_AMD64
if (!COMPILE_LLVM (cfg) && instc0 == OP_IMAX_UN)
return NULL;
#endif
break;
case SN_Min:
case SN_MinNative:
}
break;
}
case SN_Min:
case SN_MinNative: {
if (type_enum_is_float (arg_type)) {
instc0 = OP_FMIN;
} else {
instc0 = type_enum_is_unsigned (arg_type) ? OP_IMIN_UN : OP_IMIN;

#ifdef TARGET_AMD64
if (!COMPILE_LLVM (cfg) && instc0 == OP_IMIN_UN)
return NULL;
#endif
break;
case SN_Multiply:
case SN_op_Multiply: {
}
break;
}
case SN_Multiply:
case SN_op_Multiply: {
if (type_enum_is_float (arg_type)) {
instc0 = OP_FMUL;
} else {
#ifdef TARGET_ARM64
if (!COMPILE_LLVM (cfg) && (arg_type == MONO_TYPE_I8 || arg_type == MONO_TYPE_U8 || arg_type == MONO_TYPE_I || arg_type == MONO_TYPE_U))
return NULL;
Expand All @@ -452,22 +414,34 @@ emit_simd_ins_for_binary_op (MonoCompile *cfg, MonoClass *klass, MonoMethodSigna
if (!COMPILE_LLVM (cfg))
return NULL;
#endif
if (fsig->params [1]->type != MONO_TYPE_GENERICINST)
return handle_mul_div_by_scalar (cfg, klass, arg_type, args [1]->dreg, args [0]->dreg, OP_IMUL);
else if (fsig->params [0]->type != MONO_TYPE_GENERICINST)
return handle_mul_div_by_scalar (cfg, klass, arg_type, args [0]->dreg, args [1]->dreg, OP_IMUL);
instc0 = OP_IMUL;
break;
}
case SN_Subtract:
case SN_op_Subtraction:
if (MONO_TYPE_IS_VECTOR_PRIMITIVE(fsig->params [1])) { // vector * scalar
return handle_mul_div_by_scalar (cfg, klass, arg_type, args [1]->dreg, args [0]->dreg, instc0);
} else if (MONO_TYPE_IS_VECTOR_PRIMITIVE (fsig->params [0])) { // scalar * vector
return handle_mul_div_by_scalar (cfg, klass, arg_type, args [0]->dreg, args [1]->dreg, instc0);
}
break;
}
case SN_Subtract:
case SN_op_Subtraction: {
if (type_enum_is_float (arg_type)) {
instc0 = OP_FSUB;
} else {
instc0 = OP_ISUB;
break;
default:
g_assert_not_reached ();
}
break;
}
case SN_Xor:
case SN_op_ExclusiveOr: {
op = OP_XBINOP_FORCEINT;
instc0 = XBINOP_FORCEINT_XOR;
break;
}
default:
g_assert_not_reached ();
}

return emit_simd_ins_for_sig (cfg, klass, op, instc0, arg_type, fsig, args);
}

Expand Down Expand Up @@ -1992,7 +1966,7 @@ emit_sri_vector (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *fsi
return NULL;
#endif

MonoClass* klass = fsig->param_count > 0 ? args[0]->klass : cmethod->klass;
MonoClass *klass = fsig->param_count > 0 ? args [0]->klass : cmethod->klass;
MonoTypeEnum arg0_type = fsig->param_count > 0 ? get_underlying_type (fsig->params [0]) : MONO_TYPE_VOID;

if (cfg->verbose_level > 1) {
Expand Down Expand Up @@ -2057,21 +2031,56 @@ emit_sri_vector (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *fsi
case SN_Add:
case SN_BitwiseAnd:
case SN_BitwiseOr:
case SN_Divide:
case SN_Max:
case SN_MaxNative:
case SN_Min:
case SN_MinNative:
case SN_Multiply:
case SN_Subtract:
case SN_Xor:
if (!is_element_type_primitive (fsig->params [0]))
case SN_Xor: {
if (fsig->param_count != 2)
return NULL;

if (!is_element_type_primitive (fsig->params [0]) || !is_element_type_primitive (fsig->params [1]))
return NULL;

#ifndef TARGET_ARM64
if (((id == SN_Max) || (id == SN_Min)) && type_enum_is_float(arg0_type))
return NULL;
#endif

return emit_simd_ins_for_binary_op (cfg, klass, fsig, args, arg0_type, id);
}
case SN_Divide: {
if (fsig->param_count != 2)
return NULL;

if (!is_element_type_primitive (fsig->params [0]) ||
!(MONO_TYPE_IS_VECTOR_PRIMITIVE (fsig->params [1]) || is_element_type_primitive (fsig->params [1])))
return NULL;

return emit_simd_ins_for_binary_op (cfg, klass, fsig, args, arg0_type, id);
}
case SN_Multiply: {
if (fsig->param_count != 2)
return NULL;

MonoTypeEnum vector_inner_type = arg0_type;
if (MONO_TYPE_IS_VECTOR_PRIMITIVE (fsig->params [0])) {
if (MONO_TYPE_IS_VECTOR_PRIMITIVE (fsig->params [1]) || !is_element_type_primitive (fsig->params [1]))
return NULL;
// By default, we expect the first argument to be the vector type
// however, for Multiply, the first argument can be scalar. In this case, we need to
// get the vector type from the second argument.
klass = args [1]->klass;
vector_inner_type = get_underlying_type (fsig->params [1]);
} else if (MONO_TYPE_IS_VECTOR_PRIMITIVE (fsig->params [1])) {
if (!is_element_type_primitive (fsig->params [0]))
return NULL;
} else if (!(is_element_type_primitive (fsig->params [0]) && is_element_type_primitive (fsig->params [1])))
return NULL;

return emit_simd_ins_for_binary_op (cfg, klass, fsig, args, vector_inner_type, id);
}
case SN_AndNot: {
if (!is_element_type_primitive (fsig->params [0]))
return NULL;
Expand Down
Loading