Skip to content

Commit 8000000

Browse files
matouskozakpull[bot]
authored andcommitted
[mono] Fix vector class retrieval and type checks for binary operand APIs (#107388)
- change the function to be split by the OP code rather than the type of the operands - add type checks to the callsite to ensure that the operands are of the correct type
1 parent bbe99fd commit 8000000

File tree

1 file changed

+109
-100
lines changed

1 file changed

+109
-100
lines changed

src/mono/mono/mini/simd-intrinsics.c

Lines changed: 109 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -339,111 +339,73 @@ emit_simd_ins_for_binary_op (MonoCompile *cfg, MonoClass *klass, MonoMethodSigna
339339
int instc0 = -1;
340340
int op = OP_XBINOP;
341341

342-
if (id == SN_BitwiseAnd || id == SN_BitwiseOr || id == SN_Xor ||
343-
id == SN_op_BitwiseAnd || id == SN_op_BitwiseOr || id == SN_op_ExclusiveOr) {
344-
op = OP_XBINOP_FORCEINT;
345-
346-
switch (id) {
342+
switch (id) {
343+
case SN_Add:
344+
case SN_op_Addition: {
345+
if (type_enum_is_float (arg_type)) {
346+
instc0 = OP_FADD;
347+
} else {
348+
instc0 = OP_IADD;
349+
}
350+
break;
351+
}
347352
case SN_BitwiseAnd:
348-
case SN_op_BitwiseAnd:
353+
case SN_op_BitwiseAnd: {
354+
op = OP_XBINOP_FORCEINT;
349355
instc0 = XBINOP_FORCEINT_AND;
350356
break;
357+
}
351358
case SN_BitwiseOr:
352-
case SN_op_BitwiseOr:
359+
case SN_op_BitwiseOr: {
360+
op = OP_XBINOP_FORCEINT;
353361
instc0 = XBINOP_FORCEINT_OR;
354362
break;
355-
case SN_op_ExclusiveOr:
356-
case SN_Xor:
357-
instc0 = XBINOP_FORCEINT_XOR;
358-
break;
359363
}
360-
} else {
361-
if (type_enum_is_float (arg_type)) {
362-
switch (id) {
363-
case SN_Add:
364-
case SN_op_Addition:
365-
instc0 = OP_FADD;
366-
break;
367-
case SN_Divide:
368-
case SN_op_Division: {
369-
const char *class_name = m_class_get_name (klass);
370-
if (strcmp ("Quaternion", class_name) && strcmp ("Plane", class_name)) {
371-
if (!type_is_simd_vector (fsig->params [1]))
372-
return handle_mul_div_by_scalar (cfg, klass, arg_type, args [1]->dreg, args [0]->dreg, OP_FDIV);
373-
else if (type_is_simd_vector (fsig->params [0]) && type_is_simd_vector (fsig->params [1])) {
374-
instc0 = OP_FDIV;
375-
break;
376-
} else {
377-
return NULL;
378-
}
379-
}
364+
case SN_Divide:
365+
case SN_op_Division: {
366+
if (type_enum_is_float (arg_type)) {
380367
instc0 = OP_FDIV;
381-
break;
382-
}
383-
#ifdef TARGET_ARM64
384-
case SN_Max:
385-
#endif
386-
case SN_MaxNative:
387-
instc0 = OP_FMAX;
388-
break;
389-
#ifdef TARGET_ARM64
390-
case SN_Min:
391-
#endif
392-
case SN_MinNative:
393-
instc0 = OP_FMIN;
394-
break;
395-
case SN_Multiply:
396-
case SN_op_Multiply: {
397-
const char *class_name = m_class_get_name (klass);
398-
if (strcmp ("Quaternion", class_name) && strcmp ("Plane", class_name)) {
399-
if (!type_is_simd_vector (fsig->params [1]))
400-
return handle_mul_div_by_scalar (cfg, klass, arg_type, args [1]->dreg, args [0]->dreg, OP_FMUL);
401-
else if (!type_is_simd_vector (fsig->params [0]))
402-
return handle_mul_div_by_scalar (cfg, klass, arg_type, args [0]->dreg, args [1]->dreg, OP_FMUL);
403-
else if (type_is_simd_vector (fsig->params [0]) && type_is_simd_vector (fsig->params [1])) {
404-
instc0 = OP_FMUL;
405-
break;
406-
} else {
407-
return NULL;
408-
}
368+
if (MONO_TYPE_IS_VECTOR_PRIMITIVE (fsig->params [1])) { // vector / scalar
369+
return handle_mul_div_by_scalar (cfg, klass, arg_type, args [1]->dreg, args [0]->dreg, instc0);
409370
}
410-
instc0 = OP_FMUL;
411-
break;
412-
}
413-
case SN_Subtract:
414-
case SN_op_Subtraction:
415-
instc0 = OP_FSUB;
416-
break;
417-
default:
418-
g_assert_not_reached ();
419-
}
420-
} else {
421-
switch (id) {
422-
case SN_Add:
423-
case SN_op_Addition:
424-
instc0 = OP_IADD;
425-
break;
426-
case SN_Divide:
427-
case SN_op_Division:
371+
} else {
428372
return NULL;
429-
case SN_Max:
430-
case SN_MaxNative:
373+
}
374+
break;
375+
}
376+
case SN_Max:
377+
case SN_MaxNative: {
378+
if (type_enum_is_float (arg_type)) {
379+
instc0 = OP_FMAX;
380+
} else {
431381
instc0 = type_enum_is_unsigned (arg_type) ? OP_IMAX_UN : OP_IMAX;
382+
432383
#ifdef TARGET_AMD64
433384
if (!COMPILE_LLVM (cfg) && instc0 == OP_IMAX_UN)
434385
return NULL;
435386
#endif
436-
break;
437-
case SN_Min:
438-
case SN_MinNative:
387+
}
388+
break;
389+
}
390+
case SN_Min:
391+
case SN_MinNative: {
392+
if (type_enum_is_float (arg_type)) {
393+
instc0 = OP_FMIN;
394+
} else {
439395
instc0 = type_enum_is_unsigned (arg_type) ? OP_IMIN_UN : OP_IMIN;
396+
440397
#ifdef TARGET_AMD64
441398
if (!COMPILE_LLVM (cfg) && instc0 == OP_IMIN_UN)
442399
return NULL;
443400
#endif
444-
break;
445-
case SN_Multiply:
446-
case SN_op_Multiply: {
401+
}
402+
break;
403+
}
404+
case SN_Multiply:
405+
case SN_op_Multiply: {
406+
if (type_enum_is_float (arg_type)) {
407+
instc0 = OP_FMUL;
408+
} else {
447409
#ifdef TARGET_ARM64
448410
if (!COMPILE_LLVM (cfg) && (arg_type == MONO_TYPE_I8 || arg_type == MONO_TYPE_U8 || arg_type == MONO_TYPE_I || arg_type == MONO_TYPE_U))
449411
return NULL;
@@ -452,22 +414,34 @@ emit_simd_ins_for_binary_op (MonoCompile *cfg, MonoClass *klass, MonoMethodSigna
452414
if (!COMPILE_LLVM (cfg))
453415
return NULL;
454416
#endif
455-
if (fsig->params [1]->type != MONO_TYPE_GENERICINST)
456-
return handle_mul_div_by_scalar (cfg, klass, arg_type, args [1]->dreg, args [0]->dreg, OP_IMUL);
457-
else if (fsig->params [0]->type != MONO_TYPE_GENERICINST)
458-
return handle_mul_div_by_scalar (cfg, klass, arg_type, args [0]->dreg, args [1]->dreg, OP_IMUL);
459417
instc0 = OP_IMUL;
460-
break;
461418
}
462-
case SN_Subtract:
463-
case SN_op_Subtraction:
419+
if (MONO_TYPE_IS_VECTOR_PRIMITIVE(fsig->params [1])) { // vector * scalar
420+
return handle_mul_div_by_scalar (cfg, klass, arg_type, args [1]->dreg, args [0]->dreg, instc0);
421+
} else if (MONO_TYPE_IS_VECTOR_PRIMITIVE (fsig->params [0])) { // scalar * vector
422+
return handle_mul_div_by_scalar (cfg, klass, arg_type, args [0]->dreg, args [1]->dreg, instc0);
423+
}
424+
break;
425+
}
426+
case SN_Subtract:
427+
case SN_op_Subtraction: {
428+
if (type_enum_is_float (arg_type)) {
429+
instc0 = OP_FSUB;
430+
} else {
464431
instc0 = OP_ISUB;
465-
break;
466-
default:
467-
g_assert_not_reached ();
468432
}
433+
break;
434+
}
435+
case SN_Xor:
436+
case SN_op_ExclusiveOr: {
437+
op = OP_XBINOP_FORCEINT;
438+
instc0 = XBINOP_FORCEINT_XOR;
439+
break;
469440
}
441+
default:
442+
g_assert_not_reached ();
470443
}
444+
471445
return emit_simd_ins_for_sig (cfg, klass, op, instc0, arg_type, fsig, args);
472446
}
473447

@@ -1992,7 +1966,7 @@ emit_sri_vector (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *fsi
19921966
return NULL;
19931967
#endif
19941968

1995-
MonoClass* klass = fsig->param_count > 0 ? args[0]->klass : cmethod->klass;
1969+
MonoClass *klass = fsig->param_count > 0 ? args [0]->klass : cmethod->klass;
19961970
MonoTypeEnum arg0_type = fsig->param_count > 0 ? get_underlying_type (fsig->params [0]) : MONO_TYPE_VOID;
19971971

19981972
if (cfg->verbose_level > 1) {
@@ -2057,21 +2031,56 @@ emit_sri_vector (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *fsi
20572031
case SN_Add:
20582032
case SN_BitwiseAnd:
20592033
case SN_BitwiseOr:
2060-
case SN_Divide:
20612034
case SN_Max:
20622035
case SN_MaxNative:
20632036
case SN_Min:
20642037
case SN_MinNative:
2065-
case SN_Multiply:
20662038
case SN_Subtract:
2067-
case SN_Xor:
2068-
if (!is_element_type_primitive (fsig->params [0]))
2039+
case SN_Xor: {
2040+
if (fsig->param_count != 2)
20692041
return NULL;
2042+
2043+
if (!is_element_type_primitive (fsig->params [0]) || !is_element_type_primitive (fsig->params [1]))
2044+
return NULL;
2045+
20702046
#ifndef TARGET_ARM64
20712047
if (((id == SN_Max) || (id == SN_Min)) && type_enum_is_float(arg0_type))
20722048
return NULL;
20732049
#endif
2050+
20742051
return emit_simd_ins_for_binary_op (cfg, klass, fsig, args, arg0_type, id);
2052+
}
2053+
case SN_Divide: {
2054+
if (fsig->param_count != 2)
2055+
return NULL;
2056+
2057+
if (!is_element_type_primitive (fsig->params [0]) ||
2058+
!(MONO_TYPE_IS_VECTOR_PRIMITIVE (fsig->params [1]) || is_element_type_primitive (fsig->params [1])))
2059+
return NULL;
2060+
2061+
return emit_simd_ins_for_binary_op (cfg, klass, fsig, args, arg0_type, id);
2062+
}
2063+
case SN_Multiply: {
2064+
if (fsig->param_count != 2)
2065+
return NULL;
2066+
2067+
MonoTypeEnum vector_inner_type = arg0_type;
2068+
if (MONO_TYPE_IS_VECTOR_PRIMITIVE (fsig->params [0])) {
2069+
if (MONO_TYPE_IS_VECTOR_PRIMITIVE (fsig->params [1]) || !is_element_type_primitive (fsig->params [1]))
2070+
return NULL;
2071+
// By default, we expect the first argument to be the vector type
2072+
// however, for Multiply, the first argument can be scalar. In this case, we need to
2073+
// get the vector type from the second argument.
2074+
klass = args [1]->klass;
2075+
vector_inner_type = get_underlying_type (fsig->params [1]);
2076+
} else if (MONO_TYPE_IS_VECTOR_PRIMITIVE (fsig->params [1])) {
2077+
if (!is_element_type_primitive (fsig->params [0]))
2078+
return NULL;
2079+
} else if (!(is_element_type_primitive (fsig->params [0]) && is_element_type_primitive (fsig->params [1])))
2080+
return NULL;
2081+
2082+
return emit_simd_ins_for_binary_op (cfg, klass, fsig, args, vector_inner_type, id);
2083+
}
20752084
case SN_AndNot: {
20762085
if (!is_element_type_primitive (fsig->params [0]))
20772086
return NULL;

0 commit comments

Comments
 (0)