diff --git a/NuGet.config b/NuGet.config index 80dd215a4bf01b..d8166846f78489 100644 --- a/NuGet.config +++ b/NuGet.config @@ -7,6 +7,11 @@ + + + + + + +## Key Features + + + +## How to Use + + + +## Main Types + + + +## Additional Documentation + +* [Conceptual documentation](...) +* [API documentation](...) + +## Related Packages + + + +## Feedback & Contributing + + + +ExamplePackage is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). +``` For a list of supported Markdown features, see [NuGet documentation](https://learn.microsoft.com/nuget/nuget-org/package-readme-on-nuget-org#supported-markdown-features). diff --git a/docs/project/list-of-diagnostics.md b/docs/project/list-of-diagnostics.md index 4f78e9e711653d..aed0f89c89ced4 100644 --- a/docs/project/list-of-diagnostics.md +++ b/docs/project/list-of-diagnostics.md @@ -142,10 +142,10 @@ The diagnostic id values reserved for .NET Libraries analyzer warnings are `SYSL | __`SYSLIB1023`__ | Generating more than 6 arguments is not supported | | __`SYSLIB1024`__ | Argument is using the unsupported out parameter modifier | | __`SYSLIB1025`__ | Multiple logging methods cannot use the same event name within a class | -| __`SYSLIB1026`__ | _`SYSLIB1026`-`SYSLIB1029` reserved for logging._ | -| __`SYSLIB1027`__ | _`SYSLIB1026`-`SYSLIB1029` reserved for logging._ | -| __`SYSLIB1028`__ | _`SYSLIB1026`-`SYSLIB1029` reserved for logging._ | -| __`SYSLIB1029`__ | _`SYSLIB1026`-`SYSLIB1029` reserved for logging._ | +| __`SYSLIB1026`__ | C# language version not supported by the logging source generator. | +| __`SYSLIB1027`__ | _`SYSLIB1001`-`SYSLIB1029` reserved for logging._ | +| __`SYSLIB1028`__ | _`SYSLIB1001`-`SYSLIB1029` reserved for logging._ | +| __`SYSLIB1029`__ | _`SYSLIB1001`-`SYSLIB1029` reserved for logging._ | | __`SYSLIB1030`__ | JsonSourceGenerator did not generate serialization metadata for type | | __`SYSLIB1031`__ | JsonSourceGenerator encountered a duplicate JsonTypeInfo property name | | __`SYSLIB1032`__ | JsonSourceGenerator encountered a context class that is not partial | @@ -208,7 +208,7 @@ The diagnostic id values reserved for .NET Libraries analyzer warnings are `SYSL | __`SYSLIB1089`__ | _`SYSLIB1070`-`SYSLIB1089` reserved for System.Runtime.InteropServices.JavaScript.JSImportGenerator._ | | __`SYSLIB1090`__ | Invalid 'GeneratedComInterfaceAttribute' usage | | __`SYSLIB1091`__ | Method is declared in different partial declaration than the 'GeneratedComInterface' attribute. | -| __`SYSLIB1092`__ | 'GenerateComInterfaceAttribute' usage not recommended. See aka.ms/GeneratedComInterfaceUsage for recommended usage. | +| __`SYSLIB1092`__ | Usage of '[LibraryImport|GeneratedComInterface]' does not follow recommendation. See aka.ms/[LibraryImport|GeneratedComInterface]Usage for best practices. | | __`SYSLIB1093`__ | Analysis for COM interface generation has failed | | __`SYSLIB1094`__ | The base COM interface failed to generate source. Code will not be generated for this interface. | | __`SYSLIB1095`__ | Invalid 'GeneratedComClassAttribute' usage | @@ -250,8 +250,8 @@ The diagnostic id values reserved for .NET Libraries analyzer warnings are `SYSL | __`SYSLIB1213`__ | Options validation generator: Member potentially missing enumerable validation. | | __`SYSLIB1214`__ | Options validation generator: Can't validate constants, static fields or properties. | | __`SYSLIB1215`__ | Options validation generator: Validation attribute on the member is inaccessible from the validator type. | -| __`SYSLIB1216`__ | *_`SYSLIB1201`-`SYSLIB1219` reserved for Microsoft.Extensions.Options.SourceGeneration.* | -| __`SYSLIB1217`__ | *_`SYSLIB1201`-`SYSLIB1219` reserved for Microsoft.Extensions.Options.SourceGeneration.* | +| __`SYSLIB1216`__ | C# language version not supported by the options validation source generator. | +| __`SYSLIB1217`__ | The validation attribute is only applicable to properties of type string, array, or ICollection; it cannot be used with other types. | | __`SYSLIB1218`__ | *_`SYSLIB1201`-`SYSLIB1219` reserved for Microsoft.Extensions.Options.SourceGeneration.* | | __`SYSLIB1219`__ | *_`SYSLIB1201`-`SYSLIB1219` reserved for Microsoft.Extensions.Options.SourceGeneration.* | | __`SYSLIB1220`__ | JsonSourceGenerator encountered a [JsonConverterAttribute] with an invalid type argument. | @@ -270,3 +270,5 @@ The diagnostic id values reserved for .NET Libraries analyzer warnings are `SYSL | Suppression ID | Suppressed Diagnostic ID | Description | | :----------------------- | :----------------------- | :---------- | | __`SYSLIBSUPPRESS0001`__ | CA1822 | Do not offer to make methods static when the methods need to be instance methods for a custom marshaller shape. | +| __`SYSLIBSUPPRESS0002`__ | IL2026 | ConfigurationBindingGenerator: suppress RequiresUnreferencedCode diagnostic for binding call that has been intercepted by a generated static variant. | +| __`SYSLIBSUPPRESS0003`__ | IL3050 | ConfigurationBindingGenerator: suppress RequiresDynamicCode diagnostic for binding call that has been intercepted by a generated static variant. | diff --git a/docs/workflow/trimming/feature-switches.md b/docs/workflow/trimming/feature-switches.md index 87d8fa4c5ec425..635187684116b5 100644 --- a/docs/workflow/trimming/feature-switches.md +++ b/docs/workflow/trimming/feature-switches.md @@ -13,6 +13,7 @@ configurations but their defaults might vary as any SDK can set the defaults dif | EnableUnsafeBinaryFormatterSerialization | System.Runtime.Serialization.EnableUnsafeBinaryFormatterSerialization | BinaryFormatter serialization support is trimmed when set to false | | EventSourceSupport | System.Diagnostics.Tracing.EventSource.IsSupported | Any EventSource related code or logic is trimmed when set to false | | InvariantGlobalization | System.Globalization.Invariant | All globalization specific code and data is trimmed when set to true | +| MetricsSupport | System.Diagnostics.Metrics.Meter.IsSupported | Any Metrics related code or logic is trimmed when set to false | | PredefinedCulturesOnly | System.Globalization.PredefinedCulturesOnly | Don't allow creating a culture for which the platform does not have data | | HybridGlobalization | System.Globalization.Hybrid | Properties connected with the mixed: platform-specific + icu-based globalization will be trimmed | | UseSystemResourceKeys | System.Resources.UseSystemResourceKeys | Any localizable resources for system assemblies is trimmed when set to true | @@ -27,6 +28,7 @@ configurations but their defaults might vary as any SDK can set the defaults dif | MetadataUpdaterSupport | System.Reflection.Metadata.MetadataUpdater.IsSupported | Metadata update related code to be trimmed when set to false | | _EnableConsumingManagedCodeFromNativeHosting | System.Runtime.InteropServices.EnableConsumingManagedCodeFromNativeHosting | Getting a managed function from native hosting is disabled when set to false and related functionality can be trimmed. | | VerifyDependencyInjectionOpenGenericServiceTrimmability | Microsoft.Extensions.DependencyInjection.VerifyOpenGenericServiceTrimmability | When set to true, DependencyInjection will verify trimming annotations applied to open generic services are correct | +| DisableDependencyInjectionDynamicEngine | Microsoft.Extensions.DependencyInjection.DisableDynamicEngine | When set to true, DependencyInjection will avoid using System.Reflection.Emit when realizing services | | NullabilityInfoContextSupport | System.Reflection.NullabilityInfoContext.IsSupported | Nullable attributes can be trimmed when set to false | | DynamicCodeSupport | System.Runtime.CompilerServices.RuntimeFeature.IsDynamicCodeSupported | Changes RuntimeFeature.IsDynamicCodeSupported to false to allow testing AOT-safe fallback code without publishing for Native AOT. | | _AggressiveAttributeTrimming | System.AggressiveAttributeTrimming | When set to true, aggressively trims attributes to allow for the most size savings possible, even if it could result in runtime behavior changes | diff --git a/eng/Analyzers.targets b/eng/Analyzers.targets index 4ca3df7737280a..a7955d74e9bdc4 100644 --- a/eng/Analyzers.targets +++ b/eng/Analyzers.targets @@ -1,4 +1,11 @@ + + + false + false diff --git a/eng/SourceBuildPrebuiltBaseline.xml b/eng/SourceBuildPrebuiltBaseline.xml index 74f6be96543a5e..458b2d756cba9a 100644 --- a/eng/SourceBuildPrebuiltBaseline.xml +++ b/eng/SourceBuildPrebuiltBaseline.xml @@ -10,12 +10,18 @@ + - + + + + diff --git a/eng/Subsets.props b/eng/Subsets.props index 77268ffa7b5d09..eb4151f3a185d6 100644 --- a/eng/Subsets.props +++ b/eng/Subsets.props @@ -502,7 +502,7 @@ - + @@ -512,7 +512,7 @@ - + diff --git a/eng/Version.Details.xml b/eng/Version.Details.xml index 76bf5019cd8ed2..11e66be9b72e5e 100644 --- a/eng/Version.Details.xml +++ b/eng/Version.Details.xml @@ -1,80 +1,80 @@ - + https://github.com/dotnet/icu - 92124838d3f0efde3ac483a904691a611babb9a0 + feea7b8dcee39fd35ee6c415197e47d19102bb0b - + https://github.com/dotnet/msquic - a880e93af4e50d19110d228e698900c110e2b0e9 + bbb1252b31e3a194be3163982d972e4583c75476 https://github.com/dotnet/wcf 7f504aabb1988e9a093c1e74d8040bd52feb2f01 - + https://github.com/dotnet/llvm-project - 9b77c16a6061fb1160ec12bd307badb4c58dff98 + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 9b77c16a6061fb1160ec12bd307badb4c58dff98 + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 9b77c16a6061fb1160ec12bd307badb4c58dff98 + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 9b77c16a6061fb1160ec12bd307badb4c58dff98 + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 9b77c16a6061fb1160ec12bd307badb4c58dff98 + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 9b77c16a6061fb1160ec12bd307badb4c58dff98 + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 9b77c16a6061fb1160ec12bd307badb4c58dff98 + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 9b77c16a6061fb1160ec12bd307badb4c58dff98 + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 9b77c16a6061fb1160ec12bd307badb4c58dff98 + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 9b77c16a6061fb1160ec12bd307badb4c58dff98 + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 9b77c16a6061fb1160ec12bd307badb4c58dff98 + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 9b77c16a6061fb1160ec12bd307badb4c58dff98 + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 9b77c16a6061fb1160ec12bd307badb4c58dff98 + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 9b77c16a6061fb1160ec12bd307badb4c58dff98 + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 9b77c16a6061fb1160ec12bd307badb4c58dff98 + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 9b77c16a6061fb1160ec12bd307badb4c58dff98 + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd https://github.com/dotnet/command-line-api @@ -85,209 +85,213 @@ 02fe27cd6a9b001c8feb7938e6ef4b3799745759b - + https://github.com/dotnet/cecil - 2f4ef297939628143389ddeea569874ded0b1c1b + 45dd3a73dd5b64b010c4251303b3664bb30df029 - + https://github.com/dotnet/emsdk - abfa03c97f4175d4d209435cd0e71f558e36c3fd + 2406616d0e3a31d80b326e27c156955bfa41c791 + + + https://github.com/dotnet/emsdk + 2406616d0e3a31d80b326e27c156955bfa41c791 - + https://github.com/dotnet/source-build-reference-packages - 5a1492557c8717b428b69fd4b7ca8c91d5d18cd3 + b4fa7f2e1e65ef49881be2ab2df27624280a8c55 - + https://github.com/dotnet/source-build-externals - de4dda48d0cf31e13182bc24107b2246c61ed483 + 3dc05150cf234f76f6936dcb2853d31a0da1f60e - + https://github.com/dotnet/arcade - 9b2af35a6702526dc8a7c5fcadcc44efd0dca170 + 39042b4048580366d35a7c1c4f4ce8fc0dbea4b4 - + https://github.com/dotnet/xliff-tasks - 493329204079519072f0241ed26f692bdee0d60c + 73f0850939d96131c28cf6ea6ee5aacb4da0083a - + https://github.com/dotnet/arcade - 9b2af35a6702526dc8a7c5fcadcc44efd0dca170 + 39042b4048580366d35a7c1c4f4ce8fc0dbea4b4 - + https://github.com/dotnet/arcade - 9b2af35a6702526dc8a7c5fcadcc44efd0dca170 + 39042b4048580366d35a7c1c4f4ce8fc0dbea4b4 - + https://github.com/dotnet/arcade - 9b2af35a6702526dc8a7c5fcadcc44efd0dca170 + 39042b4048580366d35a7c1c4f4ce8fc0dbea4b4 - + https://github.com/dotnet/arcade - 9b2af35a6702526dc8a7c5fcadcc44efd0dca170 + 39042b4048580366d35a7c1c4f4ce8fc0dbea4b4 - + https://github.com/dotnet/arcade - 9b2af35a6702526dc8a7c5fcadcc44efd0dca170 + 39042b4048580366d35a7c1c4f4ce8fc0dbea4b4 - + https://github.com/dotnet/arcade - 9b2af35a6702526dc8a7c5fcadcc44efd0dca170 + 39042b4048580366d35a7c1c4f4ce8fc0dbea4b4 - + https://github.com/dotnet/arcade - 9b2af35a6702526dc8a7c5fcadcc44efd0dca170 + 39042b4048580366d35a7c1c4f4ce8fc0dbea4b4 - + https://github.com/dotnet/arcade - 9b2af35a6702526dc8a7c5fcadcc44efd0dca170 + 39042b4048580366d35a7c1c4f4ce8fc0dbea4b4 - + https://github.com/dotnet/arcade - 9b2af35a6702526dc8a7c5fcadcc44efd0dca170 + 39042b4048580366d35a7c1c4f4ce8fc0dbea4b4 - + https://github.com/dotnet/arcade - 9b2af35a6702526dc8a7c5fcadcc44efd0dca170 + 39042b4048580366d35a7c1c4f4ce8fc0dbea4b4 - + https://github.com/dotnet/arcade - 9b2af35a6702526dc8a7c5fcadcc44efd0dca170 + 39042b4048580366d35a7c1c4f4ce8fc0dbea4b4 - + https://github.com/dotnet/arcade - 9b2af35a6702526dc8a7c5fcadcc44efd0dca170 + 39042b4048580366d35a7c1c4f4ce8fc0dbea4b4 - + https://github.com/dotnet/arcade - 9b2af35a6702526dc8a7c5fcadcc44efd0dca170 + 39042b4048580366d35a7c1c4f4ce8fc0dbea4b4 - + https://github.com/dotnet/arcade - 9b2af35a6702526dc8a7c5fcadcc44efd0dca170 + 39042b4048580366d35a7c1c4f4ce8fc0dbea4b4 - + https://github.com/dotnet/arcade - 9b2af35a6702526dc8a7c5fcadcc44efd0dca170 + 39042b4048580366d35a7c1c4f4ce8fc0dbea4b4 - + https://github.com/dotnet/arcade - 9b2af35a6702526dc8a7c5fcadcc44efd0dca170 + 39042b4048580366d35a7c1c4f4ce8fc0dbea4b4 - + https://github.com/dotnet/runtime-assets - 48270e734aa881c737b80c4fe0459e68aaf08ad6 + 99168dcff56809205e7ef8530d1256f3a07fab1f - + https://github.com/dotnet/runtime-assets - 48270e734aa881c737b80c4fe0459e68aaf08ad6 + 99168dcff56809205e7ef8530d1256f3a07fab1f - + https://github.com/dotnet/runtime-assets - 48270e734aa881c737b80c4fe0459e68aaf08ad6 + 99168dcff56809205e7ef8530d1256f3a07fab1f - + https://github.com/dotnet/runtime-assets - 48270e734aa881c737b80c4fe0459e68aaf08ad6 + 99168dcff56809205e7ef8530d1256f3a07fab1f - + https://github.com/dotnet/runtime-assets - 48270e734aa881c737b80c4fe0459e68aaf08ad6 + 99168dcff56809205e7ef8530d1256f3a07fab1f - + https://github.com/dotnet/runtime-assets - 48270e734aa881c737b80c4fe0459e68aaf08ad6 + 99168dcff56809205e7ef8530d1256f3a07fab1f - + https://github.com/dotnet/runtime-assets - 48270e734aa881c737b80c4fe0459e68aaf08ad6 + 99168dcff56809205e7ef8530d1256f3a07fab1f - + https://github.com/dotnet/runtime-assets - 48270e734aa881c737b80c4fe0459e68aaf08ad6 + 99168dcff56809205e7ef8530d1256f3a07fab1f - + https://github.com/dotnet/runtime-assets - 48270e734aa881c737b80c4fe0459e68aaf08ad6 + 99168dcff56809205e7ef8530d1256f3a07fab1f - + https://github.com/dotnet/runtime-assets - 48270e734aa881c737b80c4fe0459e68aaf08ad6 + 99168dcff56809205e7ef8530d1256f3a07fab1f - + https://github.com/dotnet/runtime-assets - 48270e734aa881c737b80c4fe0459e68aaf08ad6 + 99168dcff56809205e7ef8530d1256f3a07fab1f - + https://github.com/dotnet/runtime-assets - 48270e734aa881c737b80c4fe0459e68aaf08ad6 + 99168dcff56809205e7ef8530d1256f3a07fab1f - + https://github.com/dotnet/runtime-assets - 48270e734aa881c737b80c4fe0459e68aaf08ad6 + 99168dcff56809205e7ef8530d1256f3a07fab1f - + https://github.com/dotnet/llvm-project - 9b77c16a6061fb1160ec12bd307badb4c58dff98 + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 9b77c16a6061fb1160ec12bd307badb4c58dff98 + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 9b77c16a6061fb1160ec12bd307badb4c58dff98 + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 9b77c16a6061fb1160ec12bd307badb4c58dff98 + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 9b77c16a6061fb1160ec12bd307badb4c58dff98 + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 9b77c16a6061fb1160ec12bd307badb4c58dff98 + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 9b77c16a6061fb1160ec12bd307badb4c58dff98 + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 9b77c16a6061fb1160ec12bd307badb4c58dff98 + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 9b77c16a6061fb1160ec12bd307badb4c58dff98 + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 9b77c16a6061fb1160ec12bd307badb4c58dff98 + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 9b77c16a6061fb1160ec12bd307badb4c58dff98 + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 9b77c16a6061fb1160ec12bd307badb4c58dff98 + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 9b77c16a6061fb1160ec12bd307badb4c58dff98 + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 9b77c16a6061fb1160ec12bd307badb4c58dff98 + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd https://github.com/dotnet/runtime @@ -330,67 +334,67 @@ https://github.com/dotnet/xharness 480b9159eb7e69b182a87581d5a336e97e0b6dae - + https://github.com/dotnet/arcade - 9b2af35a6702526dc8a7c5fcadcc44efd0dca170 + 39042b4048580366d35a7c1c4f4ce8fc0dbea4b4 - + https://dev.azure.com/dnceng/internal/_git/dotnet-optimization - 068998a5d91f55a619d1d072ab3094dacd5d6a4f + e5d9d61ccb43b9135a7471429b338aa7332e2eb5 - + https://dev.azure.com/dnceng/internal/_git/dotnet-optimization - 068998a5d91f55a619d1d072ab3094dacd5d6a4f + e5d9d61ccb43b9135a7471429b338aa7332e2eb5 - + https://dev.azure.com/dnceng/internal/_git/dotnet-optimization - 068998a5d91f55a619d1d072ab3094dacd5d6a4f + e5d9d61ccb43b9135a7471429b338aa7332e2eb5 - + https://dev.azure.com/dnceng/internal/_git/dotnet-optimization - 068998a5d91f55a619d1d072ab3094dacd5d6a4f + e5d9d61ccb43b9135a7471429b338aa7332e2eb5 - + https://github.com/dotnet/hotreload-utils - 696312fd2a60671797b12311a4cf387d3cd14dd0 + 7e01dcd64329d25070ad66af5eddd02410e80111 - + https://github.com/dotnet/runtime-assets - 48270e734aa881c737b80c4fe0459e68aaf08ad6 + 99168dcff56809205e7ef8530d1256f3a07fab1f - + https://github.com/dotnet/roslyn - 1fd4ff9d594b227baa3fc0962e2251323311ec19 + 12b11685a551e0a6a203dcecb584f610f4df1157 - + https://github.com/dotnet/roslyn - 1fd4ff9d594b227baa3fc0962e2251323311ec19 + 12b11685a551e0a6a203dcecb584f610f4df1157 - + https://github.com/dotnet/roslyn - 1fd4ff9d594b227baa3fc0962e2251323311ec19 + 12b11685a551e0a6a203dcecb584f610f4df1157 - + https://github.com/dotnet/roslyn-analyzers - 755a4f888d64fc7c0f2802adca731f301a53283d + 4ff28092cdb2006c30869fb35b2fd6b7b11382b1 - + https://github.com/dotnet/roslyn-analyzers - 755a4f888d64fc7c0f2802adca731f301a53283d + 4ff28092cdb2006c30869fb35b2fd6b7b11382b1 - + https://github.com/dotnet/sdk - d10b02ae5cc670609d920a672985ed4456bdd6b6 + 7e33fd449381b337c290a801057fdcd68c4b7220 - + https://dev.azure.com/dnceng/internal/_git/dotnet-optimization - 068998a5d91f55a619d1d072ab3094dacd5d6a4f + e5d9d61ccb43b9135a7471429b338aa7332e2eb5 - + https://dev.azure.com/dnceng/internal/_git/dotnet-optimization - 068998a5d91f55a619d1d072ab3094dacd5d6a4f + e5d9d61ccb43b9135a7471429b338aa7332e2eb5 @@ -398,5 +402,9 @@ https://github.com/NuGet/NuGet.Client 8fef55f5a55a3b4f2c96cd1a9b5ddc51d4b927f8 + + https://github.com/dotnet/installer + 46a7370763921ded24dcb70c585ee97883c615d4 + diff --git a/eng/Versions.props b/eng/Versions.props index 3cccff5761371c..98e3f3b5b496e8 100644 --- a/eng/Versions.props +++ b/eng/Versions.props @@ -7,18 +7,22 @@ 0 0 8.0.100 - 7.0.8 + 7.0.14 6.0.$([MSBuild]::Add($([System.Version]::Parse('$(PackageVersionNet7)').Build),11)) - rc - 1 - -$(PreReleaseVersionLabel).$(PreReleaseVersionIteration) + rtm + + + + true + release + -$(PreReleaseVersionLabel) + -$(PreReleaseVersionLabel).$(PreReleaseVersionIteration) $(SdkBandVersion)$(WorkloadVersionSuffix) + + false $(MajorVersion).$(MinorVersion).0.0 - - false - release true false @@ -32,17 +36,17 @@ - 3.11.0-beta1.23412.1 - 8.0.0-preview.23412.1 + 3.11.0-beta1.23516.2 + 8.0.0-preview.23516.2 - 4.8.0-1.23408.8 - 4.8.0-1.23408.8 - 4.8.0-1.23408.8 + 4.8.0-3.23518.7 + 4.8.0-3.23518.7 + 4.8.0-3.23518.7 - 8.0.100-preview.7.23329.3 + 8.0.100-rtm.23520.8 - 8.0.0-beta.23411.1 - 8.0.0-beta.23411.1 - 8.0.0-beta.23411.1 - 8.0.0-beta.23411.1 - 8.0.0-beta.23411.1 - 2.5.1-beta.23411.1 - 8.0.0-beta.23411.1 - 8.0.0-beta.23411.1 - 8.0.0-beta.23411.1 - 8.0.0-beta.23411.1 - 8.0.0-beta.23411.1 - 8.0.0-beta.23411.1 - 8.0.0-beta.23411.1 - 8.0.0-beta.23411.1 - 8.0.0-beta.23411.1 + 8.0.0-beta.23516.4 + 8.0.0-beta.23516.4 + 8.0.0-beta.23516.4 + 8.0.0-beta.23516.4 + 8.0.0-beta.23516.4 + 2.5.1-beta.23516.4 + 8.0.0-beta.23516.4 + 8.0.0-beta.23516.4 + 8.0.0-beta.23516.4 + 8.0.0-beta.23516.4 + 8.0.0-beta.23516.4 + 8.0.0-beta.23516.4 + 8.0.0-beta.23516.4 + 8.0.0-beta.23516.4 + 8.0.0-beta.23516.4 6.0.0-preview.1.102 @@ -106,14 +110,14 @@ 8.0.0-rc.1.23406.6 8.0.0-preview.7.23325.2 - 16.0.5-alpha.1.23408.1 - 16.0.5-alpha.1.23408.1 - 16.0.5-alpha.1.23408.1 - 16.0.5-alpha.1.23408.1 - 16.0.5-alpha.1.23408.1 - 16.0.5-alpha.1.23408.1 - 16.0.5-alpha.1.23408.1 - 16.0.5-alpha.1.23408.1 + 16.0.5-alpha.1.23478.1 + 16.0.5-alpha.1.23478.1 + 16.0.5-alpha.1.23478.1 + 16.0.5-alpha.1.23478.1 + 16.0.5-alpha.1.23478.1 + 16.0.5-alpha.1.23478.1 + 16.0.5-alpha.1.23478.1 + 16.0.5-alpha.1.23478.1 6.0.0 1.1.1 @@ -139,29 +143,29 @@ 4.5.0 8.0.0-rc.1.23406.6 - 8.0.0-beta.23408.1 - 8.0.0-beta.23408.1 - 8.0.0-beta.23408.1 - 8.0.0-beta.23408.1 - 8.0.0-beta.23408.1 - 8.0.0-beta.23408.1 - 8.0.0-beta.23408.1 - 8.0.0-beta.23408.1 - 8.0.0-beta.23408.1 - 8.0.0-beta.23408.1 - 8.0.0-beta.23408.1 - 8.0.0-beta.23408.1 - 8.0.0-beta.23408.1 - 8.0.0-beta.23408.1 + 8.0.0-beta.23421.1 + 8.0.0-beta.23421.1 + 8.0.0-beta.23421.1 + 8.0.0-beta.23421.1 + 8.0.0-beta.23421.1 + 8.0.0-beta.23421.1 + 8.0.0-beta.23421.1 + 8.0.0-beta.23421.1 + 8.0.0-beta.23421.1 + 8.0.0-beta.23421.1 + 8.0.0-beta.23421.1 + 8.0.0-beta.23421.1 + 8.0.0-beta.23421.1 + 8.0.0-beta.23421.1 - 1.0.0-prerelease.23362.5 - 1.0.0-prerelease.23362.5 - 1.0.0-prerelease.23362.5 - 1.0.0-prerelease.23362.5 - 1.0.0-prerelease.23362.5 - 1.0.0-prerelease.23362.5 + 1.0.0-prerelease.23521.3 + 1.0.0-prerelease.23521.3 + 1.0.0-prerelease.23521.3 + 1.0.0-prerelease.23521.3 + 1.0.0-prerelease.23521.3 + 1.0.0-prerelease.23521.3 - 16.11.27-beta1.23180.1 + 16.11.29-beta1.23404.4 2.0.0-beta4.23307.1 3.0.3 2.1.0 @@ -182,14 +186,14 @@ 8.0.0-prerelease.23407.2 8.0.0-prerelease.23407.2 8.0.0-prerelease.23407.2 - 8.0.0-alpha.0.23407.2 + 8.0.0-alpha.0.23518.2 2.4.2 1.0.0 2.4.5 3.12.0 4.1.0 6.0.0 - 13.0.1 + 13.0.3 1.0.2 2.0.4 4.18.4 @@ -203,57 +207,58 @@ 2.46.3 2.45.0 2.45.0 - - 1.1.2-beta1.23323.1 - 7.0.0-preview-20221010.1 + 8.0.0-preview-20230918.1 8.0.0-rc.1.23406.6 - 0.11.4-alpha.23407.2 + 0.11.4-alpha.23509.2 8.0.0-rc.1.23406.6 - 8.0.0-rc.1.23407.2 + 8.0.0-rtm.23511.1 - 2.2.2 - 8.0.0-alpha.1.23180.2 + 2.2.3 + 8.0.0-alpha.1.23468.1 - 16.0.5-alpha.1.23408.1 - 16.0.5-alpha.1.23408.1 - 16.0.5-alpha.1.23408.1 - 16.0.5-alpha.1.23408.1 - 16.0.5-alpha.1.23408.1 - 16.0.5-alpha.1.23408.1 - 16.0.5-alpha.1.23408.1 - 16.0.5-alpha.1.23408.1 - 16.0.5-alpha.1.23408.1 - 16.0.5-alpha.1.23408.1 - 16.0.5-alpha.1.23408.1 - 16.0.5-alpha.1.23408.1 - 16.0.5-alpha.1.23408.1 - 16.0.5-alpha.1.23408.1 + 16.0.5-alpha.1.23478.1 + 16.0.5-alpha.1.23478.1 + 16.0.5-alpha.1.23478.1 + 16.0.5-alpha.1.23478.1 + 16.0.5-alpha.1.23478.1 + 16.0.5-alpha.1.23478.1 + 16.0.5-alpha.1.23478.1 + 16.0.5-alpha.1.23478.1 + 16.0.5-alpha.1.23478.1 + 16.0.5-alpha.1.23478.1 + 16.0.5-alpha.1.23478.1 + 16.0.5-alpha.1.23478.1 + 16.0.5-alpha.1.23478.1 + 16.0.5-alpha.1.23478.1 - 8.0.0-rc.1.23411.2 - $(MicrosoftNETWorkloadEmscriptenCurrentManifest80100TransportVersion) + 8.0.0 + $(MicrosoftNETWorkloadEmscriptenCurrentManifest80100Version) 1.1.87-gba258badda 1.0.0-v3.14.0.5722 - 16.0.5-alpha.1.23408.1 - 16.0.5-alpha.1.23408.1 - 16.0.5-alpha.1.23408.1 - 16.0.5-alpha.1.23408.1 - 16.0.5-alpha.1.23408.1 - 16.0.5-alpha.1.23408.1 - 16.0.5-alpha.1.23408.1 - 16.0.5-alpha.1.23408.1 + 16.0.5-alpha.1.23478.1 + 16.0.5-alpha.1.23478.1 + 16.0.5-alpha.1.23478.1 + 16.0.5-alpha.1.23478.1 + 16.0.5-alpha.1.23478.1 + 16.0.5-alpha.1.23478.1 + 16.0.5-alpha.1.23478.1 + 16.0.5-alpha.1.23478.1 3.1.7 1.0.406601 + + 8.0.100-rtm.23506.1 + diff --git a/eng/common/SetupNugetSources.ps1 b/eng/common/SetupNugetSources.ps1 index 6e99723945183e..6c65e81925f2a3 100644 --- a/eng/common/SetupNugetSources.ps1 +++ b/eng/common/SetupNugetSources.ps1 @@ -153,7 +153,7 @@ if ($dotnet31Source -ne $null) { AddPackageSource -Sources $sources -SourceName "dotnet3.1-internal-transport" -SourceEndPoint "https://pkgs.dev.azure.com/dnceng/_packaging/dotnet3.1-internal-transport/nuget/v2" -Creds $creds -Username $userName -Password $Password } -$dotnetVersions = @('5','6','7') +$dotnetVersions = @('5','6','7','8') foreach ($dotnetVersion in $dotnetVersions) { $feedPrefix = "dotnet" + $dotnetVersion; diff --git a/eng/common/SetupNugetSources.sh b/eng/common/SetupNugetSources.sh index 8af7d899db1212..d387c7eac95e54 100644 --- a/eng/common/SetupNugetSources.sh +++ b/eng/common/SetupNugetSources.sh @@ -105,7 +105,7 @@ if [ "$?" == "0" ]; then PackageSources+=('dotnet3.1-internal-transport') fi -DotNetVersions=('5' '6' '7') +DotNetVersions=('5' '6' '7' '8') for DotNetVersion in ${DotNetVersions[@]} ; do FeedPrefix="dotnet${DotNetVersion}"; diff --git a/eng/common/cross/toolchain.cmake b/eng/common/cross/toolchain.cmake index a88d643c8a765e..0998e875e5f78d 100644 --- a/eng/common/cross/toolchain.cmake +++ b/eng/common/cross/toolchain.cmake @@ -207,6 +207,7 @@ elseif(ILLUMOS) set(CMAKE_CXX_STANDARD_LIBRARIES "${CMAKE_CXX_STANDARD_LIBRARIES} -lssp") elseif(HAIKU) set(CMAKE_SYSROOT "${CROSS_ROOTFS}") + set(CMAKE_PROGRAM_PATH "${CMAKE_PROGRAM_PATH};${CROSS_ROOTFS}/cross-tools-x86_64/bin") set(TOOLSET_PREFIX ${TOOLCHAIN}-) function(locate_toolchain_exec exec var) @@ -217,7 +218,6 @@ elseif(HAIKU) endif() find_program(EXEC_LOCATION_${exec} - PATHS "${CROSS_ROOTFS}/cross-tools-x86_64/bin" NAMES "${TOOLSET_PREFIX}${exec}${CLR_CMAKE_COMPILER_FILE_NAME_VERSION}" "${TOOLSET_PREFIX}${exec}") diff --git a/eng/common/loc/P22DotNetHtmlLocalization.lss b/eng/common/loc/P22DotNetHtmlLocalization.lss index 858a0b237c62ce..5d892d619398f9 100644 Binary files a/eng/common/loc/P22DotNetHtmlLocalization.lss and b/eng/common/loc/P22DotNetHtmlLocalization.lss differ diff --git a/eng/common/native/init-compiler.sh b/eng/common/native/init-compiler.sh index 517401b688bf76..f5c1ec7eafeb28 100644 --- a/eng/common/native/init-compiler.sh +++ b/eng/common/native/init-compiler.sh @@ -63,7 +63,7 @@ if [ -z "$CLR_CC" ]; then # Set default versions if [ -z "$majorVersion" ]; then # note: gcc (all versions) and clang versions higher than 6 do not have minor version in file name, if it is zero. - if [ "$compiler" = "clang" ]; then versions="16 15 14 13 12 11 10 9 8 7 6.0 5.0 4.0 3.9 3.8 3.7 3.6 3.5" + if [ "$compiler" = "clang" ]; then versions="17 16 15 14 13 12 11 10 9 8 7 6.0 5.0 4.0 3.9 3.8 3.7 3.6 3.5" elif [ "$compiler" = "gcc" ]; then versions="13 12 11 10 9 8 7 6 5 4.9"; fi for version in $versions; do diff --git a/eng/common/native/init-distro-rid.sh b/eng/common/native/init-distro-rid.sh index aba9fe24028b0f..de1687b2ccbe79 100644 --- a/eng/common/native/init-distro-rid.sh +++ b/eng/common/native/init-distro-rid.sh @@ -79,7 +79,6 @@ getNonPortableDistroRid() # Input: # os: (str) # arch: (str) -# isPortable: (int) # rootfsDir?: (nullable:string) # # Return: @@ -97,10 +96,9 @@ initDistroRidGlobal() { local targetOs="$1" local targetArch="$2" - local isPortable="$3" local rootfsDir="" - if [ "$#" -ge 4 ]; then - rootfsDir="$4" + if [ "$#" -ge 3 ]; then + rootfsDir="$3" fi if [ -n "${rootfsDir}" ]; then diff --git a/eng/common/sdk-task.ps1 b/eng/common/sdk-task.ps1 index 6c4ac6fec1a99a..73828dd30d3179 100644 --- a/eng/common/sdk-task.ps1 +++ b/eng/common/sdk-task.ps1 @@ -64,7 +64,7 @@ try { $GlobalJson.tools | Add-Member -Name "vs" -Value (ConvertFrom-Json "{ `"version`": `"16.5`" }") -MemberType NoteProperty } if( -not ($GlobalJson.tools.PSObject.Properties.Name -match "xcopy-msbuild" )) { - $GlobalJson.tools | Add-Member -Name "xcopy-msbuild" -Value "17.6.0-2" -MemberType NoteProperty + $GlobalJson.tools | Add-Member -Name "xcopy-msbuild" -Value "17.8.1-2" -MemberType NoteProperty } if ($GlobalJson.tools."xcopy-msbuild".Trim() -ine "none") { $xcopyMSBuildToolsFolder = InitializeXCopyMSBuild $GlobalJson.tools."xcopy-msbuild" -install $true diff --git a/eng/common/sdl/trim-assets-version.ps1 b/eng/common/sdl/trim-assets-version.ps1 new file mode 100644 index 00000000000000..a2e0048770452f --- /dev/null +++ b/eng/common/sdl/trim-assets-version.ps1 @@ -0,0 +1,75 @@ +<# +.SYNOPSIS +Install and run the 'Microsoft.DotNet.VersionTools.Cli' tool with the 'trim-artifacts-version' command to trim the version from the NuGet assets file name. + +.PARAMETER InputPath +Full path to directory where artifact packages are stored + +.PARAMETER Recursive +Search for NuGet packages recursively + +#> + +Param( + [string] $InputPath, + [bool] $Recursive = $true +) + +$CliToolName = "Microsoft.DotNet.VersionTools.Cli" + +function Install-VersionTools-Cli { + param( + [Parameter(Mandatory=$true)][string]$Version + ) + + Write-Host "Installing the package '$CliToolName' with a version of '$version' ..." + $feed = "https://pkgs.dev.azure.com/dnceng/public/_packaging/dotnet-eng/nuget/v3/index.json" + + $argumentList = @("tool", "install", "--local", "$CliToolName", "--add-source $feed", "--no-cache", "--version $Version", "--create-manifest-if-needed") + Start-Process "$dotnet" -Verbose -ArgumentList $argumentList -NoNewWindow -Wait +} + +# ------------------------------------------------------------------- + +if (!(Test-Path $InputPath)) { + Write-Host "Input Path '$InputPath' does not exist" + ExitWithExitCode 1 +} + +$ErrorActionPreference = 'Stop' +Set-StrictMode -Version 2.0 + +$disableConfigureToolsetImport = $true +$global:LASTEXITCODE = 0 + +# `tools.ps1` checks $ci to perform some actions. Since the SDL +# scripts don't necessarily execute in the same agent that run the +# build.ps1/sh script this variable isn't automatically set. +$ci = $true +. $PSScriptRoot\..\tools.ps1 + +try { + $dotnetRoot = InitializeDotNetCli -install:$true + $dotnet = "$dotnetRoot\dotnet.exe" + + $toolsetVersion = Read-ArcadeSdkVersion + Install-VersionTools-Cli -Version $toolsetVersion + + $cliToolFound = (& "$dotnet" tool list --local | Where-Object {$_.Split(' ')[0] -eq $CliToolName}) + if ($null -eq $cliToolFound) { + Write-PipelineTelemetryError -Force -Category 'Sdl' -Message "The '$CliToolName' tool is not installed." + ExitWithExitCode 1 + } + + Exec-BlockVerbosely { + & "$dotnet" $CliToolName trim-assets-version ` + --assets-path $InputPath ` + --recursive $Recursive + Exit-IfNZEC "Sdl" + } +} +catch { + Write-Host $_ + Write-PipelineTelemetryError -Force -Category 'Sdl' -Message $_ + ExitWithExitCode 1 +} \ No newline at end of file diff --git a/eng/common/templates/job/execute-sdl.yml b/eng/common/templates/job/execute-sdl.yml index 7aabaa18017bf6..7870f93bc17652 100644 --- a/eng/common/templates/job/execute-sdl.yml +++ b/eng/common/templates/job/execute-sdl.yml @@ -105,6 +105,11 @@ jobs: downloadPath: $(Build.ArtifactStagingDirectory)\artifacts checkDownloadedFiles: true + - powershell: eng/common/sdl/trim-assets-version.ps1 + -InputPath $(Build.ArtifactStagingDirectory)\artifacts + displayName: Trim the version from the NuGet packages + continueOnError: ${{ parameters.sdlContinueOnError }} + - powershell: eng/common/sdl/extract-artifact-packages.ps1 -InputPath $(Build.ArtifactStagingDirectory)\artifacts\BlobArtifacts -ExtractPath $(Build.ArtifactStagingDirectory)\artifacts\BlobArtifacts diff --git a/eng/common/tools.ps1 b/eng/common/tools.ps1 index c9eced9f7df4c6..fdd0cbb91f8596 100644 --- a/eng/common/tools.ps1 +++ b/eng/common/tools.ps1 @@ -379,13 +379,13 @@ function InitializeVisualStudioMSBuild([bool]$install, [object]$vsRequirements = } # Minimum VS version to require. - $vsMinVersionReqdStr = '17.6' + $vsMinVersionReqdStr = '17.7' $vsMinVersionReqd = [Version]::new($vsMinVersionReqdStr) # If the version of msbuild is going to be xcopied, # use this version. Version matches a package here: - # https://dev.azure.com/dnceng/public/_artifacts/feed/dotnet-eng/NuGet/RoslynTools.MSBuild/versions/17.6.0-2 - $defaultXCopyMSBuildVersion = '17.6.0-2' + # https://dev.azure.com/dnceng/public/_artifacts/feed/dotnet-eng/NuGet/RoslynTools.MSBuild/versions/17.8.1-2 + $defaultXCopyMSBuildVersion = '17.8.1-2' if (!$vsRequirements) { if (Get-Member -InputObject $GlobalJson.tools -Name 'vs') { @@ -671,6 +671,10 @@ function InitializeNativeTools() { } } +function Read-ArcadeSdkVersion() { + return $GlobalJson.'msbuild-sdks'.'Microsoft.DotNet.Arcade.Sdk' +} + function InitializeToolset() { if (Test-Path variable:global:_ToolsetBuildProj) { return $global:_ToolsetBuildProj @@ -678,7 +682,7 @@ function InitializeToolset() { $nugetCache = GetNuGetPackageCachePath - $toolsetVersion = $GlobalJson.'msbuild-sdks'.'Microsoft.DotNet.Arcade.Sdk' + $toolsetVersion = Read-ArcadeSdkVersion $toolsetLocationFile = Join-Path $ToolsetDir "$toolsetVersion.txt" if (Test-Path $toolsetLocationFile) { diff --git a/eng/liveBuilds.targets b/eng/liveBuilds.targets index 118601229cf8b7..370e19805cc3a0 100644 --- a/eng/liveBuilds.targets +++ b/eng/liveBuilds.targets @@ -260,10 +260,4 @@ DependsOnTargets=" ResolveLibrariesRefAssembliesFromLocalBuild; ResolveLibrariesRuntimeFilesFromLocalBuild" /> - - - - $([MSBuild]::NormalizePath('$(ArtifactsBinDir)', 'Microsoft.NETCore.Platforms', 'runtime.json')) - $([MSBuild]::NormalizePath('$(LibrariesProjectRoot)', 'Microsoft.NETCore.Platforms', 'src', 'runtime.json')) - diff --git a/eng/native/configureplatform.cmake b/eng/native/configureplatform.cmake index 2f6ca03db863fb..e6e0273bc75c46 100644 --- a/eng/native/configureplatform.cmake +++ b/eng/native/configureplatform.cmake @@ -2,7 +2,7 @@ include(${CMAKE_CURRENT_LIST_DIR}/functions.cmake) # If set, indicates that this is not an officially supported release. # Release branches should set this to false. -set(PRERELEASE 1) +set(PRERELEASE 0) #---------------------------------------- # Detect and set platform variable names diff --git a/eng/native/ijw/IJW.cmake b/eng/native/ijw/IJW.cmake index 9ef90525dda8ba..33f047d54fca74 100644 --- a/eng/native/ijw/IJW.cmake +++ b/eng/native/ijw/IJW.cmake @@ -51,7 +51,7 @@ if (CLR_CMAKE_HOST_WIN32) # 4365 - signed/unsigned mismatch # 4679 - Could not import member. This is an issue with IJW and static abstract methods in interfaces. - add_compile_options(/wd4365 /wd4679) + add_compile_options(/wd4365 /wd4679 /wd5271) # IJW add_compile_options(/clr:netcore) diff --git a/eng/pipelines/common/evaluate-default-paths.yml b/eng/pipelines/common/evaluate-default-paths.yml index 5fb74a3741f413..0e4279a9697b94 100644 --- a/eng/pipelines/common/evaluate-default-paths.yml +++ b/eng/pipelines/common/evaluate-default-paths.yml @@ -241,6 +241,7 @@ jobs: - src/mono/tools/* - src/mono/wasi/* - src/mono/wasm/debugger/* + - src/mono/wasm/host/* - src/mono/wasm/Wasm.Build.Tests/* - ${{ parameters._const_paths._wasm_pipelines }} - ${{ parameters._const_paths._always_exclude }} @@ -258,6 +259,7 @@ jobs: - eng/testing/workloads-testing.targets - src/mono/mono/component/mini-wasm-debugger.c - src/mono/wasm/debugger/* + - src/mono/wasm/host/* - src/mono/wasm/Wasm.Build.Tests/* - src/mono/nuget/Microsoft.NET.Runtime* src/mono/nuget/Microsoft.NET.Sdk.WebAssembly.Pack/* diff --git a/eng/pipelines/common/global-build-job.yml b/eng/pipelines/common/global-build-job.yml index 41cce9e1534f94..39e7d9b5c53ced 100644 --- a/eng/pipelines/common/global-build-job.yml +++ b/eng/pipelines/common/global-build-job.yml @@ -68,6 +68,7 @@ jobs: variables: - ${{ if eq(variables['System.TeamProject'], 'internal') }}: - group: DotNet-HelixApi-Access + - group: AzureDevOps-Artifact-Feeds-Pats - name: _osParameter value: -os ${{ parameters.osGroup }} @@ -144,13 +145,37 @@ jobs: - ${{ each variable in parameters.variables }}: - ${{ variable }} steps: + - ${{ if eq(parameters.osGroup, 'windows') }}: + - template: /eng/pipelines/common/templates/disable-vsupdate-or-failfast.yml + - checkout: self clean: true - fetchDepth: $(checkoutFetchDepth) + # If running in source build mode, a git stash will be used for the inner clone. Avoid setting a fetch depth, + # as a stash of a shallow cloned repo is not currently supported. + ${{ if ne(parameters.isSourceBuild, true) }}: + fetchDepth: $(checkoutFetchDepth) - ${{ if and(eq(parameters.isOfficialBuild, true), notin(parameters.osGroup, 'osx', 'maccatalyst', 'ios', 'iossimulator', 'tvos', 'tvossimulator')) }}: - template: /eng/pipelines/common/restore-internal-tools.yml + - ${{ if ne(variables['System.TeamProject'], 'public') }}: + - ${{ if and(ne(parameters.osGroup, 'windows'), ne(parameters.hostedOs, 'windows')) }}: + - task: Bash@3 + displayName: Setup Private Feeds Credentials + inputs: + filePath: $(Build.SourcesDirectory)/eng/common/SetupNugetSources.sh + arguments: $(Build.SourcesDirectory)/NuGet.config $Token + env: + Token: $(dn-bot-dnceng-artifact-feeds-rw) + - ${{ else }}: + - task: PowerShell@2 + displayName: Setup Private Feeds Credentials + inputs: + filePath: $(Build.SourcesDirectory)/eng/common/SetupNugetSources.ps1 + arguments: -ConfigFile $(Build.SourcesDirectory)/NuGet.config -Password $Env:Token + env: + Token: $(dn-bot-dnceng-artifact-feeds-rw) + - ${{ each monoCrossAOTTargetOS in parameters.monoCrossAOTTargetOS }}: - task: DownloadPipelineArtifact@2 displayName: Download ${{monoCrossAOTTargetOS}} AOT offset files diff --git a/eng/pipelines/common/templates/disable-vsupdate-or-failfast.yml b/eng/pipelines/common/templates/disable-vsupdate-or-failfast.yml new file mode 100644 index 00000000000000..7b9eab0bafdb50 --- /dev/null +++ b/eng/pipelines/common/templates/disable-vsupdate-or-failfast.yml @@ -0,0 +1,38 @@ +# This script tries to disable VSIXAutoUpdate. In case an update is seen as already running, +# it will exit with an error. +steps: + - powershell: | + schtasks /change /tn "\Microsoft\VisualStudio\VSIX Auto Update" /disable + + $vswhere = "C:\Program Files (x86)\Microsoft Visual Studio\Installer\vswhere.exe" + if (-not (Test-Path -Path "$vswhere" -PathType Leaf)) + { + Write-Error "Couldn't locate vswhere at $vswhere" + exit 1 + } + + $vsdir = &"$vswhere" -latest -prerelease -products * -requires Microsoft.VisualStudio.Component.VC.Tools.x86.x64 -property installationPath + $vsregedit = "$vsdir\Common7\IDE\VsRegEdit.exe" + + if (-not (Test-Path -Path "$vsregedit" )) + { + Write-Error "VSWhere returned path: $vsdir, but regedit $vsregedit doesn't exist." + exit 1 + } + + Write-Output "VSWhere returned path: $vsdir, using regedit $vsregedit" + Write-Output "Disabling updates through VS Registry:" + + &"$vsdir\Common7\IDE\VsRegEdit.exe" set local HKCU ExtensionManager AutomaticallyCheckForUpdates2Override dword 0 + &"$vsdir\Common7\IDE\VsRegEdit.exe" read local HKCU ExtensionManager AutomaticallyCheckForUpdates2Override dword + + $processes = Get-Process -Name VSIXAutoUpdate -ErrorAction SilentlyContinue + + if ($processes -ne $null -and $processes.Count -gt 0) + { + Write-Error "VSIXAutoUpdate has already spawned. Failfast to allow retry" + exit 1 + } + + displayName: Disable VSIX updates or fail-fast + condition: always() diff --git a/eng/pipelines/common/templates/pipeline-with-resources.yml b/eng/pipelines/common/templates/pipeline-with-resources.yml index c30ce8597808d8..02242394fba671 100644 --- a/eng/pipelines/common/templates/pipeline-with-resources.yml +++ b/eng/pipelines/common/templates/pipeline-with-resources.yml @@ -85,12 +85,12 @@ resources: image: mcr.microsoft.com/dotnet-buildtools/prereqs:centos-stream8 - container: browser_wasm - image: mcr.microsoft.com/dotnet-buildtools/prereqs:cbl-mariner-2.0-webassembly + image: mcr.microsoft.com/dotnet-buildtools/prereqs:cbl-mariner-2.0-webassembly-20230913040940-1edc1c6 env: ROOTFS_DIR: /crossrootfs/x64 - container: wasi_wasm - image: mcr.microsoft.com/dotnet-buildtools/prereqs:cbl-mariner-2.0-webassembly + image: mcr.microsoft.com/dotnet-buildtools/prereqs:cbl-mariner-2.0-webassembly-20230913040940-1edc1c6 env: ROOTFS_DIR: /crossrootfs/x64 diff --git a/eng/pipelines/common/templates/runtimes/xplat-job.yml b/eng/pipelines/common/templates/runtimes/xplat-job.yml index 7249125648cf32..f4ac7e82957123 100644 --- a/eng/pipelines/common/templates/runtimes/xplat-job.yml +++ b/eng/pipelines/common/templates/runtimes/xplat-job.yml @@ -106,6 +106,9 @@ jobs: - ${{insert}}: ${{ variable }} steps: + - ${{ if eq(parameters.osGroup, 'windows') }}: + - template: /eng/pipelines/common/templates/disable-vsupdate-or-failfast.yml + - checkout: self clean: true fetchDepth: $(checkoutFetchDepth) diff --git a/eng/pipelines/common/xplat-setup.yml b/eng/pipelines/common/xplat-setup.yml index 28257b05265ba0..eb19570aeecac2 100644 --- a/eng/pipelines/common/xplat-setup.yml +++ b/eng/pipelines/common/xplat-setup.yml @@ -108,7 +108,7 @@ jobs: - ${{ if eq(parameters.archType, 'wasm') }}: - name: wasmDarcDependenciesChanged value: $[ or( - eq(dependencies.evaluate_paths.outputs['DarcDependenciesChanged.Microsoft_NET_Workload_Emscripten_Current_Manifest-8_0_100_Transport'], true), + eq(dependencies.evaluate_paths.outputs['DarcDependenciesChanged.Microsoft_NET_Workload_Emscripten_Current_Manifest-8_0_100'], true), eq(dependencies.evaluate_paths.outputs['DarcDependenciesChanged.Microsoft_DotNet_Build_Tasks_Workloads'], true), eq(dependencies.evaluate_paths.outputs['DarcDependenciesChanged.System_Runtime_TimeZoneData'], true), eq(dependencies.evaluate_paths.outputs['DarcDependenciesChanged.Microsoft_Net_Compilers_Toolset'], true), diff --git a/eng/pipelines/coreclr/perf.yml b/eng/pipelines/coreclr/perf.yml index edaadde3e511ee..65d29662504364 100644 --- a/eng/pipelines/coreclr/perf.yml +++ b/eng/pipelines/coreclr/perf.yml @@ -3,7 +3,7 @@ trigger: branches: include: - main - - release/8.0-rc1 + - release/8.0 paths: include: - '*' diff --git a/eng/pipelines/coreclr/templates/build-job.yml b/eng/pipelines/coreclr/templates/build-job.yml index 99379f80a5d9ce..365c1432aa41a1 100644 --- a/eng/pipelines/coreclr/templates/build-job.yml +++ b/eng/pipelines/coreclr/templates/build-job.yml @@ -79,6 +79,8 @@ jobs: # Variables used by arcade to gather asset manifests - name: _DotNetPublishToBlobFeed value: true + - ${{ if eq(variables['System.TeamProject'], 'internal') }}: + - group: AzureDevOps-Artifact-Feeds-Pats - name: officialBuildIdArg value: '' - ${{ if eq(parameters.isOfficialBuild, true) }}: @@ -162,6 +164,24 @@ jobs: continueOnError: false condition: and(succeeded(), in(variables['SignType'], 'real', 'test')) + - ${{ if ne(variables['System.TeamProject'], 'public') }}: + - ${{ if ne(parameters.osGroup, 'windows') }}: + - task: Bash@3 + displayName: Setup Private Feeds Credentials + inputs: + filePath: $(Build.SourcesDirectory)/eng/common/SetupNugetSources.sh + arguments: $(Build.SourcesDirectory)/NuGet.config $Token + env: + Token: $(dn-bot-dnceng-artifact-feeds-rw) + - ${{ if eq(parameters.osGroup, 'windows') }}: + - task: PowerShell@2 + displayName: Setup Private Feeds Credentials + inputs: + filePath: $(Build.SourcesDirectory)/eng/common/SetupNugetSources.ps1 + arguments: -ConfigFile $(Build.SourcesDirectory)/NuGet.config -Password $Env:Token + env: + Token: $(dn-bot-dnceng-artifact-feeds-rw) + - ${{ if in(parameters.osGroup, 'osx', 'ios', 'tvos') }}: - script: | du -sh $(Build.SourcesDirectory)/* diff --git a/eng/pipelines/extra-platforms/runtime-extra-platforms-other.yml b/eng/pipelines/extra-platforms/runtime-extra-platforms-other.yml index 13d3352393fa7f..c279c318e34d59 100644 --- a/eng/pipelines/extra-platforms/runtime-extra-platforms-other.yml +++ b/eng/pipelines/extra-platforms/runtime-extra-platforms-other.yml @@ -193,6 +193,44 @@ jobs: eq(dependencies.evaluate_paths.outputs['SetPathVars_coreclr.containsChange'], true), eq(variables['isRollingBuild'], true)) +# +# CoreCLR NativeAOT checked build and Pri0 tests +# Only when CoreCLR is changed +# +- template: /eng/pipelines/common/platform-matrix.yml + parameters: + jobTemplate: /eng/pipelines/common/global-build-job.yml + helixQueuesTemplate: /eng/pipelines/coreclr/templates/helix-queues-setup.yml + buildConfig: Checked + platforms: + - windows_x64 + - linux_x64 + variables: + - name: timeoutPerTestInMinutes + value: 60 + - name: timeoutPerTestCollectionInMinutes + value: 180 + jobParameters: + timeoutInMinutes: 240 + nameSuffix: NativeAOT_Pri0 + buildArgs: -s clr.aot+host.native+libs -rc $(_BuildConfig) -lc Release -hc Release + extraStepsTemplate: /eng/pipelines/coreclr/nativeaot-post-build-steps.yml + extraStepsParameters: + creator: dotnet-bot + testBuildArgs: 'nativeaot /p:IlcUseServerGc=false' + liveLibrariesBuildConfig: Release + testRunNamePrefixSuffix: NativeAOT_Pri0_$(_BuildConfig) + extraVariablesTemplates: + - template: /eng/pipelines/common/templates/runtimes/test-variables.yml + parameters: + testGroup: innerloop + liveLibrariesBuildConfig: Release + condition: >- + or( + eq(dependencies.evaluate_paths.outputs['SetPathVars_libraries.containsChange'], true), + eq(dependencies.evaluate_paths.outputs['SetPathVars_coreclr.containsChange'], true), + eq(variables['isRollingBuild'], true)) + # Run net48 tests on win-x64 - template: /eng/pipelines/common/platform-matrix.yml parameters: diff --git a/eng/pipelines/installer/jobs/build-job.yml b/eng/pipelines/installer/jobs/build-job.yml index 5a0e37157e45c5..1d89cfad70eb20 100644 --- a/eng/pipelines/installer/jobs/build-job.yml +++ b/eng/pipelines/installer/jobs/build-job.yml @@ -293,9 +293,30 @@ jobs: parameters.archType, parameters.liveLibrariesBuildConfig) }} steps: + - ${{ if eq(parameters.osGroup, 'windows') }}: + - template: /eng/pipelines/common/templates/disable-vsupdate-or-failfast.yml - checkout: self clean: true fetchDepth: $(checkoutFetchDepth) + + - ${{ if ne(variables['System.TeamProject'], 'public') }}: + - ${{ if ne(parameters.osGroup, 'windows') }}: + - task: Bash@3 + displayName: Setup Private Feeds Credentials + inputs: + filePath: $(Build.SourcesDirectory)/eng/common/SetupNugetSources.sh + arguments: $(Build.SourcesDirectory)/NuGet.config $Token + env: + Token: $(dn-bot-dnceng-artifact-feeds-rw) + - ${{ else }}: + - task: PowerShell@2 + displayName: Setup Private Feeds Credentials + inputs: + filePath: $(Build.SourcesDirectory)/eng/common/SetupNugetSources.ps1 + arguments: -ConfigFile $(Build.SourcesDirectory)/NuGet.config -Password $Env:Token + env: + Token: $(dn-bot-dnceng-artifact-feeds-rw) + - ${{ if ne(parameters.liveRuntimeBuildConfig, '') }}: - template: /eng/pipelines/common/download-artifact-step.yml parameters: diff --git a/eng/pipelines/libraries/base-job.yml b/eng/pipelines/libraries/base-job.yml index 9dea30f61c455d..2448124a7bc62d 100644 --- a/eng/pipelines/libraries/base-job.yml +++ b/eng/pipelines/libraries/base-job.yml @@ -48,6 +48,7 @@ jobs: variables: - ${{ if eq(variables['System.TeamProject'], 'internal') }}: - group: DotNet-HelixApi-Access + - group: AzureDevOps-Artifact-Feeds-Pats - _buildScriptFileName: build @@ -136,4 +137,22 @@ jobs: artifactName: '$(_runtimeArtifactName)' displayName: '$(runtimeFlavorName) build drop' + - ${{ if ne(variables['System.TeamProject'], 'public') }}: + - ${{ if ne(parameters.osGroup, 'windows') }}: + - task: Bash@3 + displayName: Setup Private Feeds Credentials + inputs: + filePath: $(Build.SourcesDirectory)/eng/common/SetupNugetSources.sh + arguments: $(Build.SourcesDirectory)/NuGet.config $Token + env: + Token: $(dn-bot-dnceng-artifact-feeds-rw) + - ${{ if eq(parameters.osGroup, 'windows') }}: + - task: PowerShell@2 + displayName: Setup Private Feeds Credentials + inputs: + filePath: $(Build.SourcesDirectory)/eng/common/SetupNugetSources.ps1 + arguments: -ConfigFile $(Build.SourcesDirectory)/NuGet.config -Password $Env:Token + env: + Token: $(dn-bot-dnceng-artifact-feeds-rw) + - ${{ parameters.steps }} diff --git a/eng/pipelines/libraries/helix-queues-setup.yml b/eng/pipelines/libraries/helix-queues-setup.yml index 72d8d53cd94ddd..987d7f99c41f4a 100644 --- a/eng/pipelines/libraries/helix-queues-setup.yml +++ b/eng/pipelines/libraries/helix-queues-setup.yml @@ -62,13 +62,13 @@ jobs: - ${{ if and(eq(parameters.jobParameters.testScope, 'outerloop'), eq(parameters.jobParameters.runtimeFlavor, 'mono')) }}: - SLES.15.Amd64.Open - (Centos.8.Amd64.Open)Ubuntu.1804.Amd64.Open@mcr.microsoft.com/dotnet-buildtools/prereqs:centos-stream8-helix - - (Fedora.36.Amd64.Open)Ubuntu.1804.Amd64.Open@mcr.microsoft.com/dotnet-buildtools/prereqs:fedora-36-helix + - (Fedora.38.Amd64.Open)Ubuntu.1804.Amd64.Open@mcr.microsoft.com/dotnet-buildtools/prereqs:fedora-38-helix - (Ubuntu.2204.Amd64.Open)Ubuntu.1804.Amd64.Open@mcr.microsoft.com/dotnet-buildtools/prereqs:ubuntu-22.04-helix-amd64 - (Debian.11.Amd64.Open)Ubuntu.2204.Amd64.Open@mcr.microsoft.com/dotnet-buildtools/prereqs:debian-11-helix-amd64 - ${{ if or(ne(parameters.jobParameters.testScope, 'outerloop'), ne(parameters.jobParameters.runtimeFlavor, 'mono')) }}: - ${{ if or(eq(parameters.jobParameters.isExtraPlatforms, true), eq(parameters.jobParameters.includeAllPlatforms, true)) }}: - SLES.15.Amd64.Open - - (Fedora.36.Amd64.Open)ubuntu.1804.amd64.open@mcr.microsoft.com/dotnet-buildtools/prereqs:fedora-36-helix + - (Fedora.38.Amd64.Open)ubuntu.1804.amd64.open@mcr.microsoft.com/dotnet-buildtools/prereqs:fedora-38-helix - Ubuntu.2204.Amd64.Open - (Debian.11.Amd64.Open)Ubuntu.1804.Amd64.Open@mcr.microsoft.com/dotnet-buildtools/prereqs:debian-11-helix-amd64 - (Mariner.2.0.Amd64.Open)ubuntu.1804.amd64.open@mcr.microsoft.com/dotnet-buildtools/prereqs:cbl-mariner-2.0-helix-amd64 diff --git a/eng/pipelines/libraries/stress/http.yml b/eng/pipelines/libraries/stress/http.yml index 6c740e49d04d47..f4f9c45de36e48 100644 --- a/eng/pipelines/libraries/stress/http.yml +++ b/eng/pipelines/libraries/stress/http.yml @@ -13,6 +13,7 @@ schedules: - main - release/6.0 - release/7.0 + - release/8.0 variables: - template: ../variables.yml diff --git a/eng/pipelines/libraries/stress/ssl.yml b/eng/pipelines/libraries/stress/ssl.yml index 791251030f5753..ab93994400d346 100644 --- a/eng/pipelines/libraries/stress/ssl.yml +++ b/eng/pipelines/libraries/stress/ssl.yml @@ -13,6 +13,7 @@ schedules: - main - release/6.0 - release/7.0 + - release/8.0 variables: - template: ../variables.yml diff --git a/eng/pipelines/runtime-llvm.yml b/eng/pipelines/runtime-llvm.yml index e31e623a0353c8..9d358e5f793086 100644 --- a/eng/pipelines/runtime-llvm.yml +++ b/eng/pipelines/runtime-llvm.yml @@ -119,7 +119,7 @@ extends: testGroup: innerloop nameSuffix: AllSubsets_Mono_LLVMAOT buildArgs: -s mono+libs+host+packs -c $(_BuildConfig) - /p:MonoEnableLLVM=true /p:MonoBundleLLVMOptimizer=true + /p:MonoEnableLLVM=true /p:MonoAOTEnableLLVM=true /p:MonoBundleLLVMOptimizer=true condition: >- or( eq(dependencies.evaluate_paths.outputs['SetPathVars_libraries.containsChange'], true), @@ -138,7 +138,7 @@ extends: testGroup: innerloop nameSuffix: AllSubsets_Mono_LLVMAOT buildArgs: -s mono+libs+host+packs -c $(_BuildConfig) - /p:MonoEnableLLVM=true /p:MonoBundleLLVMOptimizer=true + /p:MonoEnableLLVM=true /p:MonoAOTEnableLLVM=true /p:MonoBundleLLVMOptimizer=true condition: >- or( eq(dependencies.evaluate_paths.outputs['SetPathVars_libraries.containsChange'], true), diff --git a/eng/pipelines/runtime-official.yml b/eng/pipelines/runtime-official.yml index 172a40e24d169f..3a9fd8d89ac4b0 100644 --- a/eng/pipelines/runtime-official.yml +++ b/eng/pipelines/runtime-official.yml @@ -41,13 +41,13 @@ extends: # Localization build # - # disabled due to https://github.com/dotnet/runtime/issues/90466 - #- ${{ if eq(variables['Build.SourceBranch'], 'refs/heads/main') }}: - # - template: /eng/common/templates/job/onelocbuild.yml - # parameters: - # MirrorRepo: runtime - # LclSource: lclFilesfromPackage - # LclPackageId: 'LCL-JUNO-PROD-RUNTIME' + - ${{ if eq(variables['Build.SourceBranch'], 'refs/heads/release/8.0') }}: + - template: /eng/common/templates/job/onelocbuild.yml + parameters: + MirrorRepo: runtime + MirrorBranch: release/8.0 + LclSource: lclFilesfromPackage + LclPackageId: 'LCL-JUNO-PROD-RUNTIME' # # Source Index Build @@ -334,7 +334,7 @@ extends: runtimeFlavor: mono jobParameters: buildArgs: -s mono+libs+host+packs -c $(_BuildConfig) - /p:MonoEnableLLVM=true /p:MonoBundleLLVMOptimizer=true + /p:MonoEnableLLVM=true /p:MonoAOTEnableLLVM=true /p:MonoBundleLLVMOptimizer=true nameSuffix: AllSubsets_Mono_LLVMAOT runtimeVariant: LLVMAOT isOfficialBuild: ${{ variables.isOfficialBuild }} diff --git a/eng/pipelines/runtime-wasm-perf.yml b/eng/pipelines/runtime-wasm-perf.yml index bd6a6d979e3e40..69039fb3e2a473 100644 --- a/eng/pipelines/runtime-wasm-perf.yml +++ b/eng/pipelines/runtime-wasm-perf.yml @@ -3,6 +3,7 @@ # UI to this, and thus avoid any scheduled triggers trigger: none +pr: none variables: - template: /eng/pipelines/common/variables.yml diff --git a/eng/pipelines/runtime.yml b/eng/pipelines/runtime.yml index 0f1f9610c60349..3aa0b6504819a7 100644 --- a/eng/pipelines/runtime.yml +++ b/eng/pipelines/runtime.yml @@ -556,6 +556,47 @@ extends: extraBuildArgs: /p:AotHostArchitecture=x64 /p:AotHostOS=$(_hostedOS) alwaysRun: ${{ variables.isRollingBuild }} + # + # Android devices + # Build the whole product using Mono and run libraries tests + # + - template: /eng/pipelines/common/platform-matrix.yml + parameters: + jobTemplate: /eng/pipelines/common/global-build-job.yml + helixQueuesTemplate: /eng/pipelines/libraries/helix-queues-setup.yml + buildConfig: Release + runtimeFlavor: mono + platforms: + - android_arm + - android_arm64 + variables: + # map dependencies variables to local variables + - name: librariesContainsChange + value: $[ dependencies.evaluate_paths.outputs['SetPathVars_libraries.containsChange'] ] + - name: monoContainsChange + value: $[ dependencies.evaluate_paths.outputs['SetPathVars_mono_excluding_wasm.containsChange'] ] + jobParameters: + testGroup: innerloop + nameSuffix: AllSubsets_Mono + buildArgs: -s mono+libs+libs.tests+host+packs -c $(_BuildConfig) /p:ArchiveTests=true /p:RunSmokeTestsOnly=true /p:EnableAdditionalTimezoneChecks=true + timeoutInMinutes: 480 + condition: >- + or( + eq(dependencies.evaluate_paths.outputs['SetPathVars_libraries.containsChange'], true), + eq(dependencies.evaluate_paths.outputs['SetPathVars_mono_excluding_wasm.containsChange'], true), + eq(dependencies.evaluate_paths.outputs['SetPathVars_installer.containsChange'], true), + eq(variables['isRollingBuild'], true)) + # extra steps, run tests + extraStepsTemplate: /eng/pipelines/libraries/helix.yml + extraStepsParameters: + creator: dotnet-bot + testRunNamePrefixSuffix: Mono_$(_BuildConfig) + condition: >- + or( + eq(variables['librariesContainsChange'], true), + eq(variables['monoContainsChange'], true), + eq(variables['isRollingBuild'], true)) + # # iOS/tvOS devices - Full AOT + AggressiveTrimming to reduce size # Build the whole product using Mono and run libraries tests @@ -739,7 +780,7 @@ extends: testGroup: innerloop nameSuffix: AllSubsets_Mono_LLVMAOT buildArgs: -s mono+libs+host+packs -c $(_BuildConfig) - /p:MonoEnableLLVM=true /p:MonoBundleLLVMOptimizer=true + /p:MonoEnableLLVM=true /p:MonoAOTEnableLLVM=true /p:MonoBundleLLVMOptimizer=true condition: >- or( eq(dependencies.evaluate_paths.outputs['SetPathVars_libraries.containsChange'], true), @@ -758,7 +799,7 @@ extends: testGroup: innerloop nameSuffix: AllSubsets_Mono_LLVMAOT buildArgs: -s mono+libs+host+packs -c $(_BuildConfig) - /p:MonoEnableLLVM=true /p:MonoBundleLLVMOptimizer=true + /p:MonoEnableLLVM=true /p:MonoAOTEnableLLVM=true /p:MonoBundleLLVMOptimizer=true condition: >- or( eq(dependencies.evaluate_paths.outputs['SetPathVars_libraries.containsChange'], true), @@ -1277,7 +1318,7 @@ extends: testGroup: innerloop nameSuffix: AllSubsets_Mono_LLVMAot_RuntimeTests runtimeVariant: llvmaot - buildArgs: -s mono+libs+clr.hosts+clr.iltools -c Release /p:MonoEnableLLVM=true /p:MonoBundleLLVMOptimizer=true + buildArgs: -s mono+libs+clr.hosts+clr.iltools -c Release /p:MonoEnableLLVM=true /p:MonoAOTEnableLLVM=true /p:MonoBundleLLVMOptimizer=true timeoutInMinutes: 180 condition: >- diff --git a/eng/resolveContract.targets b/eng/resolveContract.targets index 6d414f46f93e6b..3454d7064739a8 100644 --- a/eng/resolveContract.targets +++ b/eng/resolveContract.targets @@ -73,8 +73,9 @@ That is necessary as APICompat is invoked twice, once for the ref <-> src comparision and then again for the package validation (which doesn't include reference assemblies). As both operations don't have all the inputs available, some suppressions might only apply to one or the other and hence unnecessary - suppressions can't be determined. --> - + suppressions can't be determined. + Disable the validation under source build as that might use an out-of-date SDK and not the ApiCompat.Task package. --> + true true diff --git a/eng/testing/ProvisioningVersions.props b/eng/testing/ProvisioningVersions.props index 251fa2d85d3165..3105078bdc3b59 100644 --- a/eng/testing/ProvisioningVersions.props +++ b/eng/testing/ProvisioningVersions.props @@ -44,20 +44,20 @@ - + false true - 113.0.5672.63 - 1121455 - <_ChromeBaseSnapshotUrl>https://storage.googleapis.com/chromium-browser-snapshots/Linux_x64/1121461 + 115.0.5790.170 + 1148114 + <_ChromeBaseSnapshotUrl>https://storage.googleapis.com/chromium-browser-snapshots/Linux_x64/1148123 - 113.0.5672.64 - 1121455 - <_ChromeBaseSnapshotUrl>https://storage.googleapis.com/chromium-browser-snapshots/Win_x64/1121477 + 115.0.5790.171 + 1148114 + <_ChromeBaseSnapshotUrl>https://storage.googleapis.com/chromium-browser-snapshots/Win_x64/1148119 diff --git a/eng/testing/performance/performance-setup.ps1 b/eng/testing/performance/performance-setup.ps1 index 8caea345a893dc..8a8cd269dbe454 100644 --- a/eng/testing/performance/performance-setup.ps1 +++ b/eng/testing/performance/performance-setup.ps1 @@ -101,7 +101,7 @@ if ($iOSNativeAOT) { } # FIX ME: This is a workaround until we get this from the actual pipeline -$CleanedBranchName = "main" +$CleanedBranchName = "release/8.0" if($Branch.Contains("refs/heads/release")) { $CleanedBranchName = $Branch.replace('refs/heads/', '') diff --git a/eng/testing/performance/performance-setup.sh b/eng/testing/performance/performance-setup.sh index 9a1c95ec730820..c53ca6924b97b4 100755 --- a/eng/testing/performance/performance-setup.sh +++ b/eng/testing/performance/performance-setup.sh @@ -358,9 +358,7 @@ if [[ "$physicalpromotion" == "true" ]]; then configurations="$configurations PhysicalPromotionType=physicalpromotion" fi - - -cleaned_branch_name="main" +cleaned_branch_name="release/8.0" if [[ $branch == *"refs/heads/release"* ]]; then cleaned_branch_name=${branch/refs\/heads\//} fi @@ -404,15 +402,14 @@ if [[ -n "$wasm_bundle_directory" ]]; then using_wasm=true wasm_bundle_directory_path=$payload_directory mv $wasm_bundle_directory/* $wasm_bundle_directory_path - find $wasm_bundle_directory_path -type d - wasm_args="--experimental-wasm-eh --expose_wasm" + wasm_args="--expose_wasm" if [ "$javascript_engine" == "v8" ]; then # for es6 module support wasm_args="$wasm_args --module" fi # Workaround: escaping the quotes around `--wasmArgs=..` so they get retained for the actual command line - extra_benchmark_dotnet_arguments="$extra_benchmark_dotnet_arguments --wasmEngine /home/helixbot/.jsvu/bin/$javascript_engine --wasmArgs \\\"$wasm_args\\\" --cli \$HELIX_CORRELATION_PAYLOAD/dotnet/dotnet --wasmDataDir \$HELIX_CORRELATION_PAYLOAD/wasm-data" + extra_benchmark_dotnet_arguments="$extra_benchmark_dotnet_arguments --wasmEngine /home/helixbot/.jsvu/bin/$javascript_engine \\\"--wasmArgs=$wasm_args\\\" --cli \$HELIX_CORRELATION_PAYLOAD/dotnet/dotnet --wasmDataDir \$HELIX_CORRELATION_PAYLOAD/wasm-data" if [[ "$wasmaot" == "true" ]]; then extra_benchmark_dotnet_arguments="$extra_benchmark_dotnet_arguments --aotcompilermode wasm --buildTimeout 3600" fi diff --git a/eng/testing/tests.ioslike.targets b/eng/testing/tests.ioslike.targets index 9151c7c7db05e5..f93afcf1dfb31b 100644 --- a/eng/testing/tests.ioslike.targets +++ b/eng/testing/tests.ioslike.targets @@ -15,8 +15,8 @@ <_AOTBuildCommand Condition="'$(ContinuousIntegrationBuild)' != 'true'">$(_AOTBuildCommand) /p:RuntimeSrcDir=$(RepoRoot) /p:RuntimeConfig=$(Configuration) - - <_AOTBuildCommand>$(_AOTBuildCommand) /p:XHARNESS_EXECUTION_DIR="$XHARNESS_EXECUTION_DIR" /p:RunAOTCompilation=$(RunAOTCompilation) /p:UseNativeAOTRuntime=$(UseNativeAOTRuntime) /p:TargetOS=$(TargetOS) /p:TargetArchitecture=$(TargetArchitecture) /p:MonoForceInterpreter=$(MonoForceInterpreter) /p:DevTeamProvisioning=$(DevTeamProvisioning) /p:UsePortableRuntimePack=true /p:Configuration=$(Configuration) + + <_AOTBuildCommand>$(_AOTBuildCommand) /p:XHARNESS_EXECUTION_DIR="$XHARNESS_EXECUTION_DIR" /p:RunAOTCompilation=$(RunAOTCompilation) /p:UseNativeAOTRuntime=$(UseNativeAOTRuntime) /p:TargetOS=$(TargetOS) /p:TargetArchitecture=$(TargetArchitecture) /p:MonoForceInterpreter=$(MonoForceInterpreter) /p:MonoEnableLLVM=true /p:DevTeamProvisioning=$(DevTeamProvisioning) /p:UsePortableRuntimePack=true /p:Configuration=$(Configuration) <_AOTBuildCommand>$(_AOTBuildCommand) <_ResetSimulatorSwitch Condition="'$(TargetOS)' == 'iossimulator' or '$(TargetOS)' == 'tvossimulator'">--reset-simulator diff --git a/eng/testing/workloads-testing.targets b/eng/testing/workloads-testing.targets index 2961313c84973d..df5526e3f11583 100644 --- a/eng/testing/workloads-testing.targets +++ b/eng/testing/workloads-testing.targets @@ -76,6 +76,7 @@ Command="chmod +x $(_DotNetInstallScriptPath); $(_DotNetInstallCommand)" /> diff --git a/global.json b/global.json index b4ca83d356c283..38ec23b6193a2f 100644 --- a/global.json +++ b/global.json @@ -1,16 +1,16 @@ { "sdk": { - "version": "8.0.100-preview.7.23376.3", + "version": "8.0.100-rtm.23506.1", "allowPrerelease": true, "rollForward": "major" }, "tools": { - "dotnet": "8.0.100-preview.7.23376.3" + "dotnet": "8.0.100-rtm.23506.1" }, "msbuild-sdks": { - "Microsoft.DotNet.Arcade.Sdk": "8.0.0-beta.23411.1", - "Microsoft.DotNet.Helix.Sdk": "8.0.0-beta.23411.1", - "Microsoft.DotNet.SharedFramework.Sdk": "8.0.0-beta.23411.1", + "Microsoft.DotNet.Arcade.Sdk": "8.0.0-beta.23516.4", + "Microsoft.DotNet.Helix.Sdk": "8.0.0-beta.23516.4", + "Microsoft.DotNet.SharedFramework.Sdk": "8.0.0-beta.23516.4", "Microsoft.Build.NoTargets": "3.7.0", "Microsoft.Build.Traversal": "3.4.0", "Microsoft.NET.Sdk.IL": "8.0.0-rc.1.23406.6" diff --git a/src/coreclr/System.Private.CoreLib/src/System/GC.CoreCLR.cs b/src/coreclr/System.Private.CoreLib/src/System/GC.CoreCLR.cs index dbbb6758593b0c..e3e091bb872a2b 100644 --- a/src/coreclr/System.Private.CoreLib/src/System/GC.CoreCLR.cs +++ b/src/coreclr/System.Private.CoreLib/src/System/GC.CoreCLR.cs @@ -899,13 +899,10 @@ internal enum RefreshMemoryStatus /// /// This API will only handle configs that could be handled when the runtime is loaded, for example, for configs that don't have any effects on 32-bit systems (like the GCHeapHardLimit* ones), this API will not handle it. /// - /// As of now, this API is feature preview only and subject to changes as necessary. - /// /// If the hard limit is too low. This can happen if the heap hard limit that the refresh will set, either because of new AppData settings or implied by the container memory limit changes, is lower than what is already committed. /// If the hard limit is invalid. This can happen, for example, with negative heap hard limit percentages. /// /// - [RequiresPreviewFeatures("RefreshMemoryLimit is in preview.")] public static void RefreshMemoryLimit() { ulong heapHardLimit = (AppContext.GetData("GCHeapHardLimit") as ulong?) ?? ulong.MaxValue; diff --git a/src/coreclr/debug/createdump/crashinfo.cpp b/src/coreclr/debug/createdump/crashinfo.cpp index ef903767ba0279..8af6ec4a54f5bd 100644 --- a/src/coreclr/debug/createdump/crashinfo.cpp +++ b/src/coreclr/debug/createdump/crashinfo.cpp @@ -195,7 +195,7 @@ CrashInfo::GatherCrashInfo(DumpType dumpType) return false; } // Add the special (fake) memory region for the special diagnostics info - MemoryRegion special(PF_R, SpecialDiagInfoAddress, SpecialDiagInfoAddress + PAGE_SIZE); + MemoryRegion special(PF_R, SpecialDiagInfoAddress, SpecialDiagInfoAddress + SpecialDiagInfoSize); m_memoryRegions.insert(special); #ifdef __APPLE__ InitializeOtherMappings(); diff --git a/src/coreclr/debug/createdump/specialdiaginfo.h b/src/coreclr/debug/createdump/specialdiaginfo.h index 3a04a9f551e6d7..a857129c9c91ff 100644 --- a/src/coreclr/debug/createdump/specialdiaginfo.h +++ b/src/coreclr/debug/createdump/specialdiaginfo.h @@ -24,6 +24,8 @@ const uint64_t SpecialDiagInfoAddress = 0x7fff1000; #endif #endif +const uint64_t SpecialDiagInfoSize = 0x1000; + struct SpecialDiagInfoHeader { char Signature[16]; diff --git a/src/coreclr/debug/daccess/dacdbiimpl.cpp b/src/coreclr/debug/daccess/dacdbiimpl.cpp index 67d5b1e60d948c..07208001b0c3b8 100644 --- a/src/coreclr/debug/daccess/dacdbiimpl.cpp +++ b/src/coreclr/debug/daccess/dacdbiimpl.cpp @@ -7788,8 +7788,9 @@ HRESULT DacStackReferenceWalker::Next(ULONG count, DacGcReference stackRefs[], U stackRefs[i].i64ExtraData = 0; const SOSStackRefData &sosStackRef = mList.Get(i); - if (sosStackRef.Flags & GC_CALL_INTERIOR) + if (sosStackRef.Flags & GC_CALL_INTERIOR || sosStackRef.Address == 0) { + // Direct pointer case - interior pointer, Frame ref, or enregistered var. stackRefs[i].pObject = CLRDATA_ADDRESS_TO_TADDR(sosStackRef.Object) | 1; } else diff --git a/src/coreclr/debug/daccess/request.cpp b/src/coreclr/debug/daccess/request.cpp index 868593fae4651e..e5cc22d8c708a5 100644 --- a/src/coreclr/debug/daccess/request.cpp +++ b/src/coreclr/debug/daccess/request.cpp @@ -135,11 +135,17 @@ BOOL DacValidateEEClass(PTR_EEClass pEEClass) BOOL DacValidateMethodTable(PTR_MethodTable pMT, BOOL &bIsFree) { + bIsFree = FALSE; + + if ((pMT == NULL) || dac_cast(pMT) == (TADDR)-1) + { + return FALSE; + } + // Verify things are right. BOOL retval = FALSE; EX_TRY { - bIsFree = FALSE; if (HOST_CDADDR(pMT) == HOST_CDADDR(g_pFreeObjectMethodTable)) { bIsFree = TRUE; @@ -182,7 +188,7 @@ BadMethodTable: ; BOOL DacValidateMD(PTR_MethodDesc pMD) { - if (pMD == NULL) + if ((pMD == NULL) || dac_cast(pMD) == (TADDR)-1) { return FALSE; } @@ -2642,8 +2648,7 @@ ClrDataAccess::GetAssemblyLocation(CLRDATA_ADDRESS assembly, int count, _Inout_u // Turn from bytes to wide characters if (!pAssembly->GetPEAssembly()->GetPath().IsEmpty()) { - if (!pAssembly->GetPEAssembly()->GetPath(). - DacGetUnicode(count, location, pNeeded)) + if (!pAssembly->GetPEAssembly()->GetPath().DacGetUnicode(count, location, pNeeded)) { hr = E_FAIL; } diff --git a/src/coreclr/debug/daccess/stack.cpp b/src/coreclr/debug/daccess/stack.cpp index 9402d529eb8ea3..6b9f1a491c291c 100644 --- a/src/coreclr/debug/daccess/stack.cpp +++ b/src/coreclr/debug/daccess/stack.cpp @@ -1253,14 +1253,19 @@ ClrDataFrame::GetLocalSig(MetaSig** sig, { // It turns out we cannot really get rid of this check. Dynamic methods // (including IL stubs) do not have their local sig's available after JIT time. - if (!m_methodDesc->IsIL()) + // IL methods with dynamically generated IL (for example, UnsafeAccessors) may + // not have an IL header. + COR_ILMETHOD* ilHeader = m_methodDesc->IsIL() + ? m_methodDesc->GetILHeader() + : NULL; + if (ilHeader == NULL) { *sig = NULL; *count = 0; return E_FAIL; } - COR_ILMETHOD_DECODER methodDecoder(m_methodDesc->GetILHeader()); + COR_ILMETHOD_DECODER methodDecoder(ilHeader); mdSignature localSig = methodDecoder.GetLocalVarSigTok() ? methodDecoder.GetLocalVarSigTok() : mdSignatureNil; if (localSig == mdSignatureNil) diff --git a/src/coreclr/debug/di/process.cpp b/src/coreclr/debug/di/process.cpp index eb0f4ad5f1c262..db8f2a4badd67f 100644 --- a/src/coreclr/debug/di/process.cpp +++ b/src/coreclr/debug/di/process.cpp @@ -180,7 +180,11 @@ STDAPI DLLEXPORT OpenVirtualProcessImpl2( IUnknown ** ppInstance, CLR_DEBUGGING_PROCESS_FLAGS* pFlagsOut) { +#ifdef TARGET_WINDOWS + HMODULE hDac = LoadLibraryExW(pDacModulePath, NULL, LOAD_WITH_ALTERED_SEARCH_PATH); +#else HMODULE hDac = LoadLibraryW(pDacModulePath); +#endif // !TARGET_WINDOWS if (hDac == NULL) { return HRESULT_FROM_WIN32(GetLastError()); diff --git a/src/coreclr/debug/di/rsclass.cpp b/src/coreclr/debug/di/rsclass.cpp index ec52823c07af5f..55f83b48a6d211 100644 --- a/src/coreclr/debug/di/rsclass.cpp +++ b/src/coreclr/debug/di/rsclass.cpp @@ -132,6 +132,7 @@ HRESULT CordbClass::GetStaticFieldValue(mdFieldDef fieldDef, IMetaDataImport * pImport = NULL; EX_TRY { + RSLockHolder lockHolder(GetProcess()->GetProcessLock()); pImport = GetModule()->GetMetaDataImporter(); // throws // Validate the token. @@ -1191,4 +1192,3 @@ HRESULT CordbClass::SearchFieldInfo( // Well, the field doesn't even belong to this class... ThrowHR(E_INVALIDARG); } - diff --git a/src/coreclr/debug/ee/functioninfo.cpp b/src/coreclr/debug/ee/functioninfo.cpp index 19910c6429a9c6..6eaa02d2c6de6f 100644 --- a/src/coreclr/debug/ee/functioninfo.cpp +++ b/src/coreclr/debug/ee/functioninfo.cpp @@ -1575,7 +1575,11 @@ DebuggerJitInfo *DebuggerMethodInfo::FindOrCreateInitAndAddJitInfo(MethodDesc* f if (startAddr == NULL) { startAddr = g_pEEInterface->GetFunctionAddress(fd); - _ASSERTE(startAddr != NULL); + if (startAddr == NULL) + { + //The only case this should happen is if we are trying to get the DJI for a method that has not been jitted yet. + return NULL; + } } else { diff --git a/src/coreclr/gc/gc.cpp b/src/coreclr/gc/gc.cpp index 02a9b8f26c2f56..7351954070725e 100644 --- a/src/coreclr/gc/gc.cpp +++ b/src/coreclr/gc/gc.cpp @@ -823,6 +823,11 @@ class t_join join_struct.r_join_lock = n_th; } + int get_num_threads() + { + return join_struct.n_threads; + } + void destroy () { dprintf (JOIN_LOG, ("Destroying join structure")); @@ -887,6 +892,8 @@ class t_join // avoid race due to the thread about to reset the event (occasionally) being preempted before ResetEvent() if (color == join_struct.lock_color.LoadWithoutBarrier()) { + dprintf (9999, ("---h%d %d j%d %d - respin!!! (c:%d-%d)", + gch->heap_number, join_id, join_struct.n_threads, color, join_struct.lock_color.LoadWithoutBarrier())); goto respin; } @@ -1117,6 +1124,25 @@ t_join bgc_t_join; } \ } +#define spin_and_wait(count_to_spin, expr) \ +{ \ + while (!expr) \ + { \ + for (int j = 0; j < count_to_spin; j++) \ + { \ + if (expr) \ + { \ + break; \ + } \ + YieldProcessor (); \ + } \ + if (!(expr)) \ + { \ + GCToOSInterface::YieldThread (0); \ + } \ + } \ +} + #ifdef BACKGROUND_GC #define max_pending_allocs 64 @@ -1429,8 +1455,6 @@ enter_msl_status gc_heap::enter_spin_lock_msl_helper (GCSpinLock* msl) { #ifdef DYNAMIC_HEAP_COUNT uint64_t start = GetHighPrecisionTimeStamp(); - - msl->msl_wait_count++; #endif //DYNAMIC_HEAP_COUNT unsigned int i = 0; @@ -1485,7 +1509,7 @@ enter_msl_status gc_heap::enter_spin_lock_msl_helper (GCSpinLock* msl) #ifdef DYNAMIC_HEAP_COUNT uint64_t end = GetHighPrecisionTimeStamp(); Interlocked::ExchangeAdd64 (&msl->msl_wait_time, end - start); - dprintf (6666, ("wait for msl lock total time: %zd, total count: %zd, this time: %zd, this count: %u", msl->msl_wait_time, msl->msl_wait_count, end - start, i)); + dprintf (3, ("h%d wait for msl lock wait time %zd, total wait time: %zd", heap_number, (end - start), msl->msl_wait_time)); #endif //DYNAMIC_HEAP_COUNT } while (Interlocked::CompareExchange (&msl->lock, lock_taken, lock_free) != lock_free); @@ -2318,9 +2342,6 @@ sorted_table* gc_heap::seg_table; #ifdef MULTIPLE_HEAPS GCEvent gc_heap::ee_suspend_event; -#ifdef DYNAMIC_HEAP_COUNT -GCEvent gc_heap::gc_idle_thread_event; -#endif //DYNAMIC_HEAP_COUNT size_t gc_heap::min_gen0_balance_delta = 0; size_t gc_heap::min_balance_threshold = 0; #endif //MULTIPLE_HEAPS @@ -2919,6 +2940,12 @@ BOOL gc_heap::should_expand_in_full_gc = FALSE; #ifdef DYNAMIC_HEAP_COUNT int gc_heap::dynamic_adaptation_mode = dynamic_adaptation_default; gc_heap::dynamic_heap_count_data_t SVR::gc_heap::dynamic_heap_count_data; +uint64_t gc_heap::last_suspended_end_time = 0; +size_t gc_heap::gc_index_full_gc_end = 0; + +#ifdef STRESS_DYNAMIC_HEAP_COUNT +int gc_heap::heaps_in_this_gc = 0; +#endif //STRESS_DYNAMIC_HEAP_COUNT #endif // DYNAMIC_HEAP_COUNT // Provisional mode related stuff. @@ -6967,12 +6994,6 @@ BOOL gc_heap::create_thread_support (int number_of_heaps) { goto cleanup; } -#ifdef DYNAMIC_HEAP_COUNT - if (!gc_idle_thread_event.CreateOSManualEventNoThrow (FALSE)) - { - goto cleanup; - } -#endif //DYNAMIC_HEAP_COUNT if (!ee_suspend_event.CreateOSAutoEventNoThrow (FALSE)) { goto cleanup; @@ -7020,10 +7041,6 @@ bool gc_heap::create_gc_thread () return GCToEEInterface::CreateThread(gc_thread_stub, this, false, ".NET Server GC"); } -#ifdef DYNAMIC_HEAP_COUNT -static size_t prev_change_heap_count_gc_index; -#endif //DYNAMIC_HEAP_COUNT - #ifdef _MSC_VER #pragma warning(disable:4715) //IA64 xcompiler recognizes that without the 'break;' the while(1) will never end and therefore not return a value for that code path #endif //_MSC_VER @@ -7042,18 +7059,87 @@ void gc_heap::gc_thread_function () if (heap_number == 0) { - uint32_t wait_result = gc_heap::ee_suspend_event.Wait(gradual_decommit_in_progress_p ? DECOMMIT_TIME_STEP_MILLISECONDS : INFINITE, FALSE); + bool wait_on_time_out_p = gradual_decommit_in_progress_p; + uint32_t wait_time = DECOMMIT_TIME_STEP_MILLISECONDS; +#ifdef DYNAMIC_HEAP_COUNT + // background_running_p can only change from false to true during suspension. + if (!gc_heap::background_running_p () && dynamic_heap_count_data.should_change_heap_count) + { + assert (dynamic_adaptation_mode == dynamic_adaptation_to_application_sizes); + + dynamic_heap_count_data_t::sample& sample = dynamic_heap_count_data.samples[dynamic_heap_count_data.sample_index]; + wait_time = min (wait_time, (uint32_t)(sample.elapsed_between_gcs / 1000 / 3)); + wait_time = max (wait_time, 1); + + dprintf (6666, ("gc#0 thread waiting for %d ms (betwen GCs %I64d)", wait_time, sample.elapsed_between_gcs)); + } +#endif //DYNAMIC_HEAP_COUNT + uint32_t wait_result = gc_heap::ee_suspend_event.Wait(wait_on_time_out_p ? wait_time : INFINITE, FALSE); + dprintf (9999, ("waiting for ee done res %d (timeout %d, %I64d ms since last suspend end)(should_change_heap_count is %d) (gradual_decommit_in_progress_p %d)", + wait_result, wait_time, ((GetHighPrecisionTimeStamp() - last_suspended_end_time) / 1000), + dynamic_heap_count_data.should_change_heap_count, gradual_decommit_in_progress_p)); if (wait_result == WAIT_TIMEOUT) { - decommit_lock.Enter(); - gradual_decommit_in_progress_p = decommit_step (DECOMMIT_TIME_STEP_MILLISECONDS); - decommit_lock.Leave(); +#ifdef DYNAMIC_HEAP_COUNT + if (dynamic_heap_count_data.should_change_heap_count) + { +#ifdef BACKGROUND_GC + if (!gc_heap::background_running_p ()) +#endif //BACKGROUND_GC + { + dprintf (6666, ("changing heap count due to timeout")); + check_heap_count(); + } + } +#endif //DYNAMIC_HEAP_COUNT + + if (gradual_decommit_in_progress_p) + { + decommit_lock.Enter (); + gradual_decommit_in_progress_p = decommit_step (DECOMMIT_TIME_STEP_MILLISECONDS); + decommit_lock.Leave (); + } continue; } +#ifdef DYNAMIC_HEAP_COUNT + // We might want to consider also doing this when a BGC finishes. + if (dynamic_heap_count_data.should_change_heap_count) + { +#ifdef BACKGROUND_GC + if (!gc_heap::background_running_p ()) +#endif //BACKGROUND_GC + { + // this was a request to do a GC so make sure we follow through with one. + dprintf (6666, ("changing heap count at a GC start")); + check_heap_count (); + } + } + + // wait till the threads that should have gone idle at least reached the place where they are about to wait on the idle event. + if ((gc_heap::dynamic_adaptation_mode == dynamic_adaptation_to_application_sizes) && + (n_heaps != dynamic_heap_count_data.last_n_heaps)) + { + int spin_count = 1024; + int idle_thread_count = n_max_heaps - n_heaps; + dprintf (9999, ("heap count changed %d->%d, idle should be %d and is %d", dynamic_heap_count_data.last_n_heaps, n_heaps, + idle_thread_count, VolatileLoadWithoutBarrier (&dynamic_heap_count_data.idle_thread_count))); + if (idle_thread_count != dynamic_heap_count_data.idle_thread_count) + { + spin_and_wait (spin_count, (idle_thread_count == dynamic_heap_count_data.idle_thread_count)); + dprintf (9999, ("heap count changed %d->%d, now idle is %d", dynamic_heap_count_data.last_n_heaps, n_heaps, + VolatileLoadWithoutBarrier (&dynamic_heap_count_data.idle_thread_count))); + } + + dynamic_heap_count_data.last_n_heaps = n_heaps; + } +#endif //DYNAMIC_HEAP_COUNT + suspended_start_time = GetHighPrecisionTimeStamp(); BEGIN_TIMING(suspend_ee_during_log); + dprintf (9999, ("h0 suspending EE in GC!")); GCToEEInterface::SuspendEE(SUSPEND_FOR_GC); + dprintf (9999, ("h0 suspended EE in GC!")); END_TIMING(suspend_ee_during_log); proceed_with_gc_p = TRUE; @@ -7067,46 +7153,74 @@ void gc_heap::gc_thread_function () { settings.init_mechanisms(); #ifdef DYNAMIC_HEAP_COUNT - // make sure the other gc threads cannot see this as a request to change heap count - // see explanation below about the cases when we return from gc_start_event.Wait - assert (dynamic_heap_count_data.new_n_heaps == n_heaps); + if (gc_heap::dynamic_adaptation_mode == dynamic_adaptation_to_application_sizes) + { + // make sure the other gc threads cannot see this as a request to change heap count + // see explanation below about the cases when we return from gc_start_event.Wait + assert (dynamic_heap_count_data.new_n_heaps == n_heaps); + } #endif //DYNAMIC_HEAP_COUNT + dprintf (9999, ("GC thread %d setting_gc_start_in_gc(h%d)", heap_number, n_heaps)); gc_start_event.Set(); } dprintf (3, (ThreadStressLog::gcServerThread0StartMsg(), heap_number)); } else { + dprintf (9999, ("GC thread %d waiting_for_gc_start(%d)(gc%Id)", heap_number, n_heaps, VolatileLoadWithoutBarrier(&settings.gc_index))); gc_start_event.Wait(INFINITE, FALSE); #ifdef DYNAMIC_HEAP_COUNT - // we have a couple different cases to handle here when we come back from the wait: - // 1. We are starting a GC. Signaled by dynamic_heap_count_data.new_n_heaps == n_heaps - // a) We are starting a GC, but this thread is idle. Signaled by n_heaps <= heap_number - // b) We are starting a GC, and this thread is participating. Signaled by heap_number < n_heaps - // 2. We are changing heap count. Signaled by dynamic_heap_count_data.new_n_heaps != n_heaps - // a) We are changing heap count, but this thread is idle. Signaled by n_heaps <= heap_number. - // b) We are changing heap count, and this thread is participating. Signaled by heap_number < n_heaps. - - // check for 1.a) and 2.a) cases above - if (n_heaps <= heap_number) - { - dprintf (2, ("GC thread %d idle", heap_number)); - - // make sure GC is complete so we know the gc_idle_thread_event has been reset - g_theGCHeap->WaitUntilGCComplete(); + dprintf (9999, ("GC thread %d waiting_done_gc_start(%d-%d)(i: %d)(gc%Id)", + heap_number, n_heaps, dynamic_heap_count_data.new_n_heaps, dynamic_heap_count_data.init_only_p, VolatileLoadWithoutBarrier (&settings.gc_index))); + + if ((gc_heap::dynamic_adaptation_mode == dynamic_adaptation_to_application_sizes) && + (dynamic_heap_count_data.new_n_heaps != n_heaps)) + { + // The reason why we need to do this is - + // + for threads that were participating, we need them to do work for change_heap_count + // + for threads that were not participating but will need to participate, we need to make sure they are woken now instead of + // randomly sometime later. + int old_n_heaps = n_heaps; + int new_n_heaps = dynamic_heap_count_data.new_n_heaps; + int num_threads_to_wake = max (new_n_heaps, old_n_heaps); + if (heap_number < num_threads_to_wake) + { + dprintf (9999, ("h%d < %d, calling change", heap_number, num_threads_to_wake)); + change_heap_count (dynamic_heap_count_data.new_n_heaps); + if (new_n_heaps < old_n_heaps) + { + dprintf (9999, ("h%d after change", heap_number)); + // at the end of change_heap_count we've changed join's heap count to the new one if it's smaller. So we need to make sure + // only that many threads will participate in the following GCs. + if (heap_number < new_n_heaps) + { + dprintf (9999, ("h%d < %d participating (dec)", heap_number, new_n_heaps)); + } + else + { + Interlocked::Increment (&dynamic_heap_count_data.idle_thread_count); + dprintf (9999, ("GC thread %d wait_on_idle(%d < %d)(gc%Id), total idle %d", heap_number, old_n_heaps, new_n_heaps, + VolatileLoadWithoutBarrier (&settings.gc_index), VolatileLoadWithoutBarrier (&dynamic_heap_count_data.idle_thread_count))); + gc_idle_thread_event.Wait (INFINITE, FALSE); + dprintf (9999, ("GC thread %d waking_from_idle(%d)(gc%Id) after doing change", heap_number, n_heaps, VolatileLoadWithoutBarrier (&settings.gc_index))); + } + } + else + { + dprintf (9999, ("h%d < %d participating (inc)", heap_number, new_n_heaps)); + } + } + else + { + Interlocked::Increment (&dynamic_heap_count_data.idle_thread_count); + dprintf (9999, ("GC thread %d wait_on_idle(< max %d)(gc%Id), total idle %d", heap_number, num_threads_to_wake, + VolatileLoadWithoutBarrier (&settings.gc_index), VolatileLoadWithoutBarrier (&dynamic_heap_count_data.idle_thread_count))); + gc_idle_thread_event.Wait (INFINITE, FALSE); + dprintf (9999, ("GC thread %d waking_from_idle(%d)(gc%Id)", heap_number, n_heaps, VolatileLoadWithoutBarrier (&settings.gc_index))); + } - // now wait on the gc_idle_thread_event - gc_idle_thread_event.Wait(INFINITE, FALSE); - dprintf (2, ("GC thread %d waking from idle", heap_number)); - continue; - } - // case 2.b) above: is this a request to change heap count? - if (dynamic_heap_count_data.new_n_heaps != n_heaps) - { - change_heap_count (dynamic_heap_count_data.new_n_heaps); continue; } - // case 1.b) above: we're starting a GC. #endif //DYNAMIC_HEAP_COUNT dprintf (3, (ThreadStressLog::gcServerThreadNStartMsg(), heap_number)); } @@ -7191,10 +7305,6 @@ void gc_heap::gc_thread_function () { gradual_decommit_in_progress_p = decommit_step (DECOMMIT_TIME_STEP_MILLISECONDS); } -#ifdef DYNAMIC_HEAP_COUNT - // check if we should adjust the number of heaps - check_heap_count(); -#endif //DYNAMIC_HEAP_COUNT } else { @@ -9955,6 +10065,20 @@ BOOL gc_heap::insert_ro_segment (heap_segment* seg) return TRUE; } +void gc_heap::update_ro_segment (heap_segment* seg, uint8_t* allocated, uint8_t* committed) +{ + enter_spin_lock (&gc_heap::gc_lock); + + assert (use_frozen_segments_p); + assert (heap_segment_read_only_p (seg)); + assert (allocated <= committed); + assert (committed <= heap_segment_reserved (seg)); + heap_segment_allocated (seg) = allocated; + heap_segment_committed (seg) = committed; + + leave_spin_lock (&gc_heap::gc_lock); +} + // No one is calling this function right now. If this is getting called we need // to take care of decommitting the mark array for it - we will need to remember // which portion of the mark array was committed and only decommit that. @@ -12513,6 +12637,16 @@ void gc_heap::rearrange_uoh_segments() freeable_uoh_segment = 0; } +void gc_heap::delay_free_segments() +{ + rearrange_uoh_segments(); +#ifdef BACKGROUND_GC + background_delay_delete_uoh_segments(); + if (!gc_heap::background_running_p()) + rearrange_small_heap_segments(); +#endif //BACKGROUND_GC +} + #ifndef USE_REGIONS void gc_heap::rearrange_heap_segments(BOOL compacting) { @@ -14846,6 +14980,25 @@ gc_heap::init_gc_heap (int h_number) gc_done_event_lock = -1; gc_done_event_set = false; +#ifdef DYNAMIC_HEAP_COUNT + if (h_number != 0) + { + if (!gc_idle_thread_event.CreateAutoEventNoThrow (FALSE)) + { + return 0; + } + +#ifdef BACKGROUND_GC + if (!bgc_idle_thread_event.CreateAutoEventNoThrow (FALSE)) + { + return 0; + } +#endif //BACKGROUND_GC + + dprintf (9999, ("creating idle events for h%d", h_number)); + } +#endif //DYNAMIC_HEAP_COUNT + if (!init_dynamic_data()) { return 0; @@ -16024,7 +16177,6 @@ void min_fl_list_info::thread_item_no_prev (uint8_t* item) tail = item; } -// This is only implemented for gen2 right now!!!! // the min_fl_list array is arranged as chunks of n_heaps min_fl_list_info, the 1st chunk corresponds to the 1st bucket, // and so on. void allocator::rethread_items (size_t* num_total_fl_items, size_t* num_total_fl_items_rethreaded, gc_heap* current_heap, @@ -17392,6 +17544,7 @@ BOOL gc_heap::a_fit_free_list_uoh_p (size_t size, gen_number, align_const); dd_new_allocation (dynamic_data_of (gen_number)) -= limit; + size_t saved_free_list_size = free_list_size; #ifdef FEATURE_LOH_COMPACTION if (loh_pad) { @@ -17420,7 +17573,7 @@ BOOL gc_heap::a_fit_free_list_uoh_p (size_t size, { generation_free_obj_space (gen) += remain_size; } - generation_free_list_space (gen) -= free_list_size; + generation_free_list_space (gen) -= saved_free_list_size; assert ((ptrdiff_t)generation_free_list_space (gen) >= 0); generation_free_list_allocated (gen) += limit; @@ -21986,11 +22139,70 @@ BOOL gc_heap::should_proceed_with_gc() void gc_heap::update_end_gc_time_per_heap() { +#ifdef DYNAMIC_HEAP_COUNT + size_t prev_gen2_end_time = 0; + if ((heap_number == 0) && (dynamic_adaptation_mode == dynamic_adaptation_to_application_sizes) && (settings.condemned_generation == max_generation)) + { + dynamic_data* dd = dynamic_data_of (max_generation); + prev_gen2_end_time = dd_previous_time_clock (dd) + dd_gc_elapsed_time (dd);; + } +#endif //DYNAMIC_HEAP_COUNT + for (int gen_number = 0; gen_number <= settings.condemned_generation; gen_number++) { dynamic_data* dd = dynamic_data_of (gen_number); + + if (heap_number == 0) + { + dprintf (6666, ("prev gen%d GC end time: prev start %I64d + prev gc elapsed %Id = %I64d", + gen_number, dd_previous_time_clock (dd), dd_gc_elapsed_time (dd), (dd_previous_time_clock (dd) + dd_gc_elapsed_time (dd)))); + } + dd_gc_elapsed_time (dd) = (size_t)(end_gc_time - dd_time_clock (dd)); + + if (heap_number == 0) + { + dprintf (6666, ("updated NGC%d %Id elapsed time to %I64d - %I64d = %I64d", gen_number, dd_gc_clock (dd), end_gc_time, dd_time_clock (dd), dd_gc_elapsed_time (dd))); + } } + +#ifdef DYNAMIC_HEAP_COUNT + if ((heap_number == 0) && (dynamic_adaptation_mode == dynamic_adaptation_to_application_sizes)) + { + dynamic_heap_count_data_t::sample& sample = dynamic_heap_count_data.samples[dynamic_heap_count_data.sample_index]; + sample.elapsed_between_gcs = end_gc_time - last_suspended_end_time; + sample.gc_pause_time = dd_gc_elapsed_time (dynamic_data_of (0)); + sample.msl_wait_time = get_msl_wait_time(); + + dprintf (6666, ("sample#%d: this GC end %I64d - last sus end %I64d = %I64d, this GC pause %I64d, msl wait %I64d", + dynamic_heap_count_data.sample_index, end_gc_time, last_suspended_end_time, sample.elapsed_between_gcs, sample.gc_pause_time, sample.msl_wait_time)); + + last_suspended_end_time = end_gc_time; + + GCEventFireHeapCountSample_V1 ( + (uint64_t)VolatileLoadWithoutBarrier (&settings.gc_index), + sample.elapsed_between_gcs, + sample.gc_pause_time, + sample.msl_wait_time); + + dynamic_heap_count_data.sample_index = (dynamic_heap_count_data.sample_index + 1) % dynamic_heap_count_data_t::sample_size; + + if (settings.condemned_generation == max_generation) + { + gc_index_full_gc_end = dd_gc_clock (dynamic_data_of (0)); + size_t elapsed_between_gen2_gcs = end_gc_time - prev_gen2_end_time; + size_t gen2_elapsed_time = sample.gc_pause_time; + dynamic_heap_count_data.gen2_gc_percents[dynamic_heap_count_data.gen2_sample_index] = (float)gen2_elapsed_time * 100.0f / elapsed_between_gen2_gcs; + + dprintf (6666, ("gen2 sample#%d: this GC end %I64d - last gen2 end %I64d = %I64d, GC elapsed %I64d, percent %.3f", + dynamic_heap_count_data.gen2_sample_index, end_gc_time, prev_gen2_end_time, elapsed_between_gen2_gcs, + gen2_elapsed_time, dynamic_heap_count_data.gen2_gc_percents[dynamic_heap_count_data.gen2_sample_index])); + dynamic_heap_count_data.gen2_sample_index = (dynamic_heap_count_data.gen2_sample_index + 1) % dynamic_heap_count_data_t::sample_size; + } + + calculate_new_heap_count (); + } +#endif //DYNAMIC_HEAP_COUNT } void gc_heap::update_end_ngc_time() @@ -22137,7 +22349,31 @@ void gc_heap::gc1() { dynamic_data* dd = dynamic_data_of (n); end_gc_time = GetHighPrecisionTimeStamp(); + size_t time_since_last_gen2 = 0; + +#ifdef DYNAMIC_HEAP_COUNT + if ((heap_number == 0) && (dynamic_adaptation_mode == dynamic_adaptation_to_application_sizes)) + { + time_since_last_gen2 = (size_t)(end_gc_time - (dd_previous_time_clock (dd) + dd_gc_elapsed_time (dd))); + dprintf (6666, ("BGC %Id end %I64d - (prev gen2 start %I64d + elapsed %Id = %I64d) = time inbewteen gen2 %Id", + dd_gc_clock (dd), end_gc_time, dd_previous_time_clock (dd), dd_gc_elapsed_time (dd), (dd_previous_time_clock (dd) + dd_gc_elapsed_time (dd)), time_since_last_gen2)); + } +#endif //DYNAMIC_HEAP_COUNT + dd_gc_elapsed_time (dd) = (size_t)(end_gc_time - dd_time_clock (dd)); +#ifdef DYNAMIC_HEAP_COUNT + if ((heap_number == 0) && (dynamic_adaptation_mode == dynamic_adaptation_to_application_sizes)) + { + dprintf (6666, ("updating BGC %Id elapsed time to %I64d - %I64d = %I64d", dd_gc_clock (dd), end_gc_time, dd_time_clock (dd), dd_gc_elapsed_time (dd))); + + float bgc_percent = (float)dd_gc_elapsed_time (dd) * 100.0f / (float)time_since_last_gen2; + dynamic_heap_count_data.gen2_gc_percents[dynamic_heap_count_data.gen2_sample_index] = bgc_percent; + dprintf (6666, ("gen2 sample %d elapsed %Id * 100 / time inbetween gen2 %Id = %.3f", + dynamic_heap_count_data.gen2_sample_index, dd_gc_elapsed_time (dd), time_since_last_gen2, bgc_percent)); + dynamic_heap_count_data.gen2_sample_index = (dynamic_heap_count_data.gen2_sample_index + 1) % dynamic_heap_count_data_t::sample_size; + gc_index_full_gc_end = dd_gc_clock (dynamic_data_of (0)); + } +#endif //DYNAMIC_HEAP_COUNT #ifdef HEAP_BALANCE_INSTRUMENTATION if (heap_number == 0) @@ -22744,7 +22980,12 @@ void gc_heap::merge_fl_from_other_heaps (int gen_idx, int to_n_heaps, int from_n assert (free_list_space_decrease <= generation_free_list_space (gen)); generation_free_list_space (gen) -= free_list_space_decrease; - assert (free_list_space_decrease <= dd_fragmentation (dd)); + // TODO - I'm seeing for gen2 this is free_list_space_decrease can be a bit larger than frag. + // Need to fix this later. + if (gen_idx != max_generation) + { + assert (free_list_space_decrease <= dd_fragmentation (dd)); + } size_t free_list_space_increase = 0; for (int from_hn = 0; from_hn < from_n_heaps; from_hn++) @@ -23719,9 +23960,6 @@ void gc_heap::garbage_collect (int n) #ifdef MULTIPLE_HEAPS gc_start_event.Reset(); -#ifdef DYNAMIC_HEAP_COUNT - gc_idle_thread_event.Reset(); -#endif //DYNAMIC_HEAP_COUNT gc_t_join.restart(); #endif //MULTIPLE_HEAPS } @@ -23743,6 +23981,9 @@ void gc_heap::garbage_collect (int n) #endif // STRESS_HEAP #ifdef MULTIPLE_HEAPS +#ifdef STRESS_DYNAMIC_HEAP_COUNT + Interlocked::Increment (&heaps_in_this_gc); +#endif //STRESS_DYNAMIC_HEAP_COUNT //align all heaps on the max generation to condemn dprintf (3, ("Joining for max generation to condemn")); condemned_generation_num = generation_to_condemn (n, @@ -23758,30 +23999,31 @@ void gc_heap::garbage_collect (int n) #endif //FEATURE_BASICFREEZE #ifdef MULTIPLE_HEAPS +#ifdef STRESS_DYNAMIC_HEAP_COUNT + dprintf (9999, ("%d heaps, join sees %d, actually joined %d, %d idle threads (%d)", + n_heaps, gc_t_join.get_num_threads (), heaps_in_this_gc, + VolatileLoadWithoutBarrier(&dynamic_heap_count_data.idle_thread_count), (n_max_heaps - n_heaps))); + if (heaps_in_this_gc != n_heaps) + { + dprintf (9999, ("should have %d heaps but actually have %d!!", n_heaps, heaps_in_this_gc)); + GCToOSInterface::DebugBreak (); + } + + heaps_in_this_gc = 0; +#endif //STRESS_DYNAMIC_HEAP_COUNT + for (int i = 0; i < n_heaps; i++) { gc_heap* hp = g_heaps[i]; // check for card table growth if (g_gc_card_table != hp->card_table) hp->copy_brick_card_table(); - - hp->rearrange_uoh_segments(); -#ifdef BACKGROUND_GC - hp->background_delay_delete_uoh_segments(); - if (!gc_heap::background_running_p()) - hp->rearrange_small_heap_segments(); -#endif //BACKGROUND_GC + hp->delay_free_segments(); } #else //MULTIPLE_HEAPS if (g_gc_card_table != card_table) copy_brick_card_table(); - - rearrange_uoh_segments(); -#ifdef BACKGROUND_GC - background_delay_delete_uoh_segments(); - if (!gc_heap::background_running_p()) - rearrange_small_heap_segments(); -#endif //BACKGROUND_GC + delay_free_segments(); #endif //MULTIPLE_HEAPS BOOL should_evaluate_elevation = TRUE; @@ -23868,10 +24110,8 @@ void gc_heap::garbage_collect (int n) do_pre_gc(); #ifdef MULTIPLE_HEAPS + dprintf (9999, ("in GC, resetting gc_start")); gc_start_event.Reset(); -#ifdef DYNAMIC_HEAP_COUNT - gc_idle_thread_event.Reset(); -#endif //DYNAMIC_HEAP_COUNT dprintf(3, ("Starting all gc threads for gc")); gc_t_join.restart(); #endif //MULTIPLE_HEAPS @@ -24327,7 +24567,7 @@ void gc_heap::equalize_promoted_bytes(int condemned_gen_number) // hope is to achieve better work balancing in relocate and compact phases // this is also used when the heap count changes to balance regions between heaps int highest_gen_number = ((condemned_gen_number == max_generation) ? - (total_generation_count - 1) : condemned_gen_number); + (total_generation_count - 1) : condemned_gen_number); int stop_gen_idx = get_stop_generation_index (condemned_gen_number); for (int gen_idx = highest_gen_number; gen_idx >= stop_gen_idx; gen_idx--) @@ -25036,285 +25276,332 @@ void gc_heap::recommission_heap() #endif //RECORD_LOH_STATE } -void gc_heap::check_heap_count () +float median_of_3 (float a, float b, float c) +{ +#define compare_and_swap(i, j) \ + { \ + if (i < j) \ + { \ + float t = i; \ + i = j; \ + j = t; \ + } \ + } + compare_and_swap (b, a); + compare_and_swap (c, a); + compare_and_swap (c, b); +#undef compare_and_swap + return b; +} + +size_t gc_heap::get_num_completed_gcs () { - dynamic_heap_count_data.new_n_heaps = n_heaps; + size_t num_completed_gcs = settings.gc_index; +#ifdef BACKGROUND_GC + if (g_heaps[0]->is_bgc_in_progress ()) + { + num_completed_gcs--; + dprintf (6666, ("BGC in prog, completed GCs -> %Id", num_completed_gcs)); + } +#endif //BACKGROUND_GC + + return num_completed_gcs; +} + +int gc_heap::calculate_new_heap_count () +{ + assert (dynamic_adaptation_mode == dynamic_adaptation_to_application_sizes); + + size_t num_completed_gcs = get_num_completed_gcs (); - if (dynamic_adaptation_mode != dynamic_adaptation_to_application_sizes) + dprintf (6666, ("current GC %Id(completed: %Id), prev completed GCs %Id, last full GC happened at index %Id", + VolatileLoadWithoutBarrier (&settings.gc_index), num_completed_gcs, dynamic_heap_count_data.prev_num_completed_gcs, gc_index_full_gc_end)); + + if (num_completed_gcs < (dynamic_heap_count_data.prev_num_completed_gcs + dynamic_heap_count_data_t::sample_size)) { - return; + dprintf (6666, ("not enough GCs, skipping")); + return n_heaps; } - // we should be calling this only on the main GC thread - assert (heap_number == 0); + float median_gen2_tcp_percent = 0.0f; + if (gc_index_full_gc_end >= (settings.gc_index - dynamic_heap_count_data_t::sample_size)) + { + median_gen2_tcp_percent = dynamic_heap_count_data.get_median_gen2_gc_percent (); + } - // acquire data for the current sample - uint64_t soh_msl_wait_time = 0; - uint64_t uoh_msl_wait_time = 0; - size_t allocating_thread_count = 0; - size_t heap_size = 0; - for (int i = 0; i < n_heaps; i++) + // If there was a blocking gen2 GC, the overhead would be very large and most likely we would not pick it. So we + // rely on the gen2 sample's overhead calculated above. + float throughput_cost_percents[dynamic_heap_count_data_t::sample_size]; + for (int i = 0; i < dynamic_heap_count_data_t::sample_size; i++) { - gc_heap* hp = g_heaps[i]; + dynamic_heap_count_data_t::sample& sample = dynamic_heap_count_data.samples[i]; + throughput_cost_percents[i] = (sample.elapsed_between_gcs ? (((float)sample.msl_wait_time / n_heaps + sample.gc_pause_time) * 100.0f / (float)sample.elapsed_between_gcs) : 0.0f); + assert (throughput_cost_percents[i] >= 0.0); + if (throughput_cost_percents[i] > 100.0) + throughput_cost_percents[i] = 100.0; + dprintf (6666, ("sample %d: msl %I64d / %d + pause %I64d / elapsed %I64d = throughput_cost_percent: %.3f", i, + sample.msl_wait_time, n_heaps, sample.gc_pause_time, sample.elapsed_between_gcs, throughput_cost_percents[i])); + } - allocating_thread_count += hp->alloc_contexts_used; + float median_throughput_cost_percent = median_of_3 (throughput_cost_percents[0], throughput_cost_percents[1], throughput_cost_percents[2]); - soh_msl_wait_time += hp->more_space_lock_soh.msl_wait_time; - hp->more_space_lock_soh.msl_wait_time = 0; - hp->more_space_lock_soh.msl_wait_count = 0; + // apply exponential smoothing and use 1/3 for the smoothing factor + const float smoothing = 3; + float smoothed_median_throughput_cost_percent = dynamic_heap_count_data.smoothed_median_throughput_cost_percent; + if (smoothed_median_throughput_cost_percent != 0.0f) + { + // average it with the previous value + smoothed_median_throughput_cost_percent = median_throughput_cost_percent / smoothing + (smoothed_median_throughput_cost_percent / smoothing) * (smoothing - 1); + } + else + { + smoothed_median_throughput_cost_percent = median_throughput_cost_percent; + } - uoh_msl_wait_time += hp->more_space_lock_uoh.msl_wait_time; - hp->more_space_lock_uoh.msl_wait_time = 0; - hp->more_space_lock_uoh.msl_wait_count = 0; + dprintf (6666, ("median tcp: %.3f, smoothed tcp: %.3f, gen2 tcp %.3f(%.3f, %.3f, %.3f)", + median_throughput_cost_percent, smoothed_median_throughput_cost_percent, median_gen2_tcp_percent, + dynamic_heap_count_data.gen2_gc_percents[0], dynamic_heap_count_data.gen2_gc_percents[1], dynamic_heap_count_data.gen2_gc_percents[2])); + + size_t heap_size = 0; + for (int i = 0; i < n_heaps; i++) + { + gc_heap* hp = g_heaps[i]; for (int gen_idx = 0; gen_idx < total_generation_count; gen_idx++) { dynamic_data* dd = hp->dynamic_data_of (gen_idx); // estimate the size of each generation as the live data size plus the budget - heap_size += dd_promoted_size (dd) + dd_desired_allocation (dd); - dprintf (6666, ("h%d g%d promoted: %zd desired allocation: %zd", i, gen_idx, dd_promoted_size (dd), dd_desired_allocation (dd))); + heap_size += dd_current_size (dd) + dd_desired_allocation (dd); + dprintf (3, ("h%d g%d current: %zd desired allocation: %zd", i, gen_idx, dd_promoted_size (dd), dd_desired_allocation (dd))); } } - dynamic_data* hp0_dd0 = g_heaps[0]->dynamic_data_of (0); + // estimate the space cost of adding a heap as the min gen0 budget + size_t heap_space_cost_per_heap = dd_min_size (g_heaps[0]->dynamic_data_of (0)); - // persist data for the current sample - dynamic_heap_count_data_t::sample& sample = dynamic_heap_count_data.samples[dynamic_heap_count_data.sample_index]; + // compute the % space cost of adding a heap + float percent_heap_space_cost_per_heap = heap_space_cost_per_heap * 100.0f / heap_size; - sample.soh_msl_wait_time = soh_msl_wait_time / n_heaps; - sample.uoh_msl_wait_time = uoh_msl_wait_time / n_heaps; - sample.elapsed_between_gcs = dd_time_clock (hp0_dd0) - dd_previous_time_clock (hp0_dd0); - sample.gc_elapsed_time = dd_gc_elapsed_time (hp0_dd0); - sample.allocating_thread_count = allocating_thread_count; - sample.heap_size = heap_size; + // compute reasonable step sizes for the heap count + // + // on the way up, we essentially multiply the heap count by 1.5, so we go 1, 2, 3, 5, 8 ... + // we don't go all the way to the number of CPUs, but stay 1 or 2 short + int step_up = (n_heaps + 1) / 2; + int extra_heaps = 1 + (n_max_heaps >= 32); + step_up = min (step_up, n_max_heaps - extra_heaps - n_heaps); - dprintf (6666, ("sample %d: soh_msl_wait_time: %zd, uoh_msl_wait_time: %zd, elapsed_between_gcs: %zd, gc_elapsed_time: %d, heap_size: %zd MB", - dynamic_heap_count_data.sample_index, - sample.soh_msl_wait_time, - sample.uoh_msl_wait_time, - sample.elapsed_between_gcs, - sample.gc_elapsed_time, - sample.heap_size/(1024*1024))); + // on the way down, we essentially divide the heap count by 1.5 + int step_down = (n_heaps + 1) / 3; - dynamic_heap_count_data.sample_index = (dynamic_heap_count_data.sample_index + 1) % dynamic_heap_count_data_t::sample_size; + // estimate the potential time benefit of going up a step + float tcp_reduction_per_step_up = smoothed_median_throughput_cost_percent * step_up / (n_heaps + step_up); - GCEventFireHeapCountSample_V1( - sample.gc_elapsed_time, - sample.soh_msl_wait_time, - sample.uoh_msl_wait_time, - sample.elapsed_between_gcs - ); + // estimate the potential time cost of going down a step + float tcp_increase_per_step_down = smoothed_median_throughput_cost_percent * step_down / (n_heaps - step_down); + + // estimate the potential space cost of going up a step + float scp_increase_per_step_up = percent_heap_space_cost_per_heap * step_up; + + // estimate the potential space saving of going down a step + float scp_decrease_per_step_down = percent_heap_space_cost_per_heap * step_down; - if (settings.gc_index < prev_change_heap_count_gc_index + 3) + dprintf (6666, ("[CHP] u %d, d %d | space cost %Id / heap %Id(%.2fmb) = scp %.3f (u: %.3f, d: %.3f) | stcp %.3f, u * %.1f = %.3f, d * %.1f = %.3f", + step_up, step_down, + heap_space_cost_per_heap, heap_size, ((float)heap_size / (float)1000 / (float)1000), percent_heap_space_cost_per_heap, + scp_increase_per_step_up, scp_decrease_per_step_down, + smoothed_median_throughput_cost_percent, + ((float)step_up / (float)(n_heaps + step_up)), tcp_reduction_per_step_up, + ((float)step_down / (float)(n_heaps - step_down)), tcp_increase_per_step_down)); + +#ifdef STRESS_DYNAMIC_HEAP_COUNT + // quick hack for initial testing + int new_n_heaps = (int)gc_rand::get_rand (n_max_heaps - 1) + 1; + + // if we are adjusting down, make sure we adjust lower than the lowest uoh msl heap + if ((new_n_heaps < n_heaps) && (dynamic_heap_count_data.lowest_heap_with_msl_uoh != -1)) { - // reconsider the decision every few gcs - return; + new_n_heaps = min (dynamic_heap_count_data.lowest_heap_with_msl_uoh, new_n_heaps); + new_n_heaps = max (new_n_heaps, 1); } - - if (gc_heap::background_running_p()) + dprintf (6666, ("stress %d -> %d", n_heaps, new_n_heaps)); +#else //STRESS_DYNAMIC_HEAP_COUNT + int new_n_heaps = n_heaps; + if (median_throughput_cost_percent > 10.0f) { - // can't have background gc running while we change the number of heaps - // so it's useless to compute a new number of heaps here + // ramp up more agressively - use as many heaps as it would take to bring + // the tcp down to 5% + new_n_heaps = (int)(n_heaps * (median_throughput_cost_percent / 5.0)); + dprintf (6666, ("[CHP0] tcp %.3f -> %d * %.3f = %d", median_throughput_cost_percent, n_heaps, (median_throughput_cost_percent / 5.0), new_n_heaps)); + new_n_heaps = min (new_n_heaps, n_max_heaps - extra_heaps); } - else + // if the median tcp is 10% or less, react slower + else if ((smoothed_median_throughput_cost_percent > 5.0f) || (median_gen2_tcp_percent > 10.0f)) { - // compute the % overhead from msl waiting time and gc time for each of the samples - float percent_overhead[dynamic_heap_count_data_t::sample_size]; - for (int i = 0; i < dynamic_heap_count_data_t::sample_size; i++) - { - dynamic_heap_count_data_t::sample& sample = dynamic_heap_count_data.samples[i]; - uint64_t overhead_time = sample.soh_msl_wait_time + sample.uoh_msl_wait_time + sample.gc_elapsed_time; - percent_overhead[i] = overhead_time * 100.0f / sample.elapsed_between_gcs; - if (percent_overhead[i] < 0) - percent_overhead[i] = 0; - else if (percent_overhead[i] > 100) - percent_overhead[i] = 100; - dprintf (6666, ("sample %d: percent_overhead: %d%%", i, (int)percent_overhead[i])); - } - // compute the median of the percent overhead samples - #define compare_and_swap(i, j) \ - { \ - if (percent_overhead[i] < percent_overhead[j]) \ - { \ - float t = percent_overhead[i]; \ - percent_overhead[i] = percent_overhead[j]; \ - percent_overhead[j] = t; \ - } \ - } - compare_and_swap (1, 0); - compare_and_swap (2, 0); - compare_and_swap (2, 1); - #undef compare_and_swap - - // the middle element is the median overhead percentage - float median_percent_overhead = percent_overhead[1]; - - // apply exponential smoothing and use 1/3 for the smoothing factor - const float smoothing = 3; - float smoothed_median_percent_overhead = dynamic_heap_count_data.smoothed_median_percent_overhead; - if (smoothed_median_percent_overhead != 0.0f) - { - // average it with the previous value - smoothed_median_percent_overhead = median_percent_overhead / smoothing + (smoothed_median_percent_overhead / smoothing) * (smoothing - 1); + if (smoothed_median_throughput_cost_percent > 5.0f) + { + dprintf (6666, ("[CHP1] stcp %.3f > 5, %d + %d = %d", smoothed_median_throughput_cost_percent, n_heaps, step_up, (n_heaps + step_up))); } else { - // first time? initialize to the median - smoothed_median_percent_overhead = median_percent_overhead; + dprintf (6666, ("[CHP2] tcp %.3f > 10, %d + %d = %d", median_gen2_tcp_percent, n_heaps, step_up, (n_heaps + step_up))); } + new_n_heaps += step_up; + } + // if we can save at least 1% more in time than we spend in space, increase number of heaps + else if ((tcp_reduction_per_step_up - scp_increase_per_step_up) >= 1.0f) + { + dprintf (6666, ("[CHP3] % .3f - % .3f = % .3f, % d + % d = % d", + tcp_reduction_per_step_up, scp_increase_per_step_up, (tcp_reduction_per_step_up - scp_increase_per_step_up), + n_heaps, step_up, (n_heaps + step_up))); + new_n_heaps += step_up; + } + // if we can save at least 1% more in space than we spend in time, decrease number of heaps + else if ((smoothed_median_throughput_cost_percent < 1.0f) && + (median_gen2_tcp_percent < 5.0f) && + ((scp_decrease_per_step_down - tcp_increase_per_step_down) >= 1.0f)) + { + dprintf (6666, ("[CHP4] stcp %.3f tcp %.3f, %.3f - %.3f = %.3f, %d + %d = %d", + smoothed_median_throughput_cost_percent, median_gen2_tcp_percent, + scp_decrease_per_step_down, tcp_increase_per_step_down, (scp_decrease_per_step_down - tcp_increase_per_step_down), + n_heaps, step_up, (n_heaps + step_up))); + new_n_heaps -= step_down; + } - dprintf (6666, ("median overhead: %d%% smoothed median overhead: %d%%", (int)(median_percent_overhead*1000), (int)(smoothed_median_percent_overhead*1000))); - - // estimate the space cost of adding a heap as the min gen0 size - size_t heap_space_cost_per_heap = dd_min_size (hp0_dd0); - - // compute the % space cost of adding a heap - float percent_heap_space_cost_per_heap = heap_space_cost_per_heap * 100.0f / heap_size; - - // compute reasonable step sizes for the heap count + assert (new_n_heaps >= 1); + assert (new_n_heaps <= n_max_heaps); +#endif //STRESS_DYNAMIC_HEAP_COUNT - // on the way up, we essentially multiply the heap count by 1.5, so we go 1, 2, 3, 5, 8 ... - // we don't go all the way to the number of CPUs, but stay 1 or 2 short - int step_up = (n_heaps + 1) / 2; - int extra_heaps = 1 + (n_max_heaps >= 32); - step_up = min (step_up, n_max_heaps - extra_heaps - n_heaps); + // store data used for decision to emit in ETW event + dynamic_heap_count_data.median_throughput_cost_percent = median_throughput_cost_percent; + dynamic_heap_count_data.smoothed_median_throughput_cost_percent = smoothed_median_throughput_cost_percent; + dynamic_heap_count_data.percent_heap_space_cost_per_heap = percent_heap_space_cost_per_heap; + dynamic_heap_count_data.tcp_reduction_per_step_up = tcp_reduction_per_step_up; + dynamic_heap_count_data.tcp_increase_per_step_down = tcp_increase_per_step_down; + dynamic_heap_count_data.scp_increase_per_step_up = scp_increase_per_step_up; + dynamic_heap_count_data.scp_decrease_per_step_down = scp_decrease_per_step_down; + + GCEventFireHeapCountTuning_V1 ( + (uint16_t)dynamic_heap_count_data.new_n_heaps, + (uint64_t)VolatileLoadWithoutBarrier (&settings.gc_index), + dynamic_heap_count_data.median_throughput_cost_percent, + dynamic_heap_count_data.smoothed_median_throughput_cost_percent, + dynamic_heap_count_data.tcp_reduction_per_step_up, + dynamic_heap_count_data.tcp_increase_per_step_down, + dynamic_heap_count_data.scp_increase_per_step_up, + dynamic_heap_count_data.scp_decrease_per_step_down + ); - // on the way down, we essentially divide the heap count by 1.5 - int step_down = (n_heaps + 1) / 3; + dynamic_heap_count_data.prev_num_completed_gcs = num_completed_gcs; - // estimate the potential time benefit of going up a step - float overhead_reduction_per_step_up = smoothed_median_percent_overhead * step_up / (n_heaps + step_up); + if (new_n_heaps != n_heaps) + { + dprintf (6666, ("should change! %d->%d", n_heaps, new_n_heaps)); + dynamic_heap_count_data.heap_count_to_change_to = new_n_heaps; + dynamic_heap_count_data.should_change_heap_count = true; + } - // estimate the potential time cost of going down a step - float overhead_increase_per_step_down = smoothed_median_percent_overhead * step_down / (n_heaps - step_down); + return new_n_heaps; +} - // estimate the potential space cost of going up a step - float space_cost_increase_per_step_up = percent_heap_space_cost_per_heap * step_up; +void gc_heap::check_heap_count () +{ + dynamic_heap_count_data.new_n_heaps = dynamic_heap_count_data.heap_count_to_change_to; - // estimate the potential space saving of going down a step - float space_cost_decrease_per_step_down = percent_heap_space_cost_per_heap * step_down; + assert (dynamic_heap_count_data.new_n_heaps != n_heaps); -#ifdef STRESS_DYNAMIC_HEAP_COUNT - // quick hack for initial testing - int new_n_heaps = (int)gc_rand::get_rand (n_max_heaps - 1) + 1; + if (dynamic_heap_count_data.new_n_heaps != n_heaps) + { + dprintf (9999, ("h0 suspending EE in check")); + // can't have threads allocating while we change the number of heaps + GCToEEInterface::SuspendEE(SUSPEND_FOR_GC_PREP); + dprintf (9999, ("h0 suspended EE in check")); - // if we are adjusting down, make sure we adjust lower than the lowest uoh msl heap - if ((new_n_heaps < n_heaps) && (dynamic_heap_count_data.lowest_heap_with_msl_uoh != -1)) +#ifdef BACKGROUND_GC + if (gc_heap::background_running_p()) { - new_n_heaps = min (dynamic_heap_count_data.lowest_heap_with_msl_uoh, new_n_heaps); + // background GC is running - reset the new heap count + dynamic_heap_count_data.new_n_heaps = n_heaps; + dprintf (6666, ("can't change heap count! BGC in progress")); - // but not down to zero, obviously... - new_n_heaps = max (new_n_heaps, 1); - } -#else //STRESS_DYNAMIC_HEAP_COUNT - int new_n_heaps = n_heaps; - if (median_percent_overhead > 10.0f) - { - // ramp up more agressively - use as many heaps as it would take to bring - // the overhead down to 5% - new_n_heaps = (int)(n_heaps * (median_percent_overhead / 5.0)); - new_n_heaps = min (new_n_heaps, n_max_heaps - extra_heaps); - } - // if the median overhead is 10% or less, react slower - else if (smoothed_median_percent_overhead > 5.0f) - { - new_n_heaps += step_up; - } - // if we can save at least 1% more in time than we spend in space, increase number of heaps - else if (overhead_reduction_per_step_up - space_cost_increase_per_step_up >= 1.0f) - { - new_n_heaps += step_up; - } - // if we can save at least 1% more in space than we spend in time, decrease number of heaps - else if (smoothed_median_percent_overhead < 1.0f && space_cost_decrease_per_step_down - overhead_increase_per_step_down >= 1.0f) - { - new_n_heaps -= step_down; + GCToEEInterface::RestartEE(TRUE); } +#endif //BACKGROUND_GC + } - dprintf (6666, ("or: %d, si: %d, sd: %d, oi: %d => %d -> %d", - (int)overhead_reduction_per_step_up, - (int)space_cost_increase_per_step_up, - (int)space_cost_decrease_per_step_down, - (int)overhead_increase_per_step_down, - n_heaps, - new_n_heaps)); - - assert (1 <= new_n_heaps); - assert (new_n_heaps <= n_max_heaps); -#endif //STRESS_DYNAMIC_HEAP_COUNT - - dynamic_heap_count_data.new_n_heaps = new_n_heaps; - - // store data used for decision to emit in ETW event - dynamic_heap_count_data.median_percent_overhead = median_percent_overhead; - dynamic_heap_count_data.smoothed_median_percent_overhead = smoothed_median_percent_overhead; - dynamic_heap_count_data.percent_heap_space_cost_per_heap = percent_heap_space_cost_per_heap; - dynamic_heap_count_data.overhead_reduction_per_step_up = overhead_reduction_per_step_up; - dynamic_heap_count_data.overhead_increase_per_step_down = overhead_increase_per_step_down; - dynamic_heap_count_data.space_cost_increase_per_step_up = space_cost_increase_per_step_up; - dynamic_heap_count_data.space_cost_decrease_per_step_down = space_cost_decrease_per_step_down; - - GCEventFireHeapCountTuning_V1( - (uint16_t)dynamic_heap_count_data.new_n_heaps, - (uint64_t)VolatileLoad(&settings.gc_index), - dynamic_heap_count_data.median_percent_overhead, - dynamic_heap_count_data.smoothed_median_percent_overhead, - dynamic_heap_count_data.overhead_reduction_per_step_up, - dynamic_heap_count_data.overhead_increase_per_step_down, - dynamic_heap_count_data.space_cost_increase_per_step_up, - dynamic_heap_count_data.space_cost_decrease_per_step_down - ); - - if (new_n_heaps != n_heaps) + if (dynamic_heap_count_data.new_n_heaps != n_heaps) + { + dprintf (6666, ("prep to change from %d to %d", n_heaps, dynamic_heap_count_data.new_n_heaps)); + if (!prepare_to_change_heap_count (dynamic_heap_count_data.new_n_heaps)) { - // can't have threads allocating while we change the number of heaps - GCToEEInterface::SuspendEE(SUSPEND_FOR_GC_PREP); - - if (gc_heap::background_running_p()) - { - // background GC is running - reset the new heap count - dynamic_heap_count_data.new_n_heaps = n_heaps; - - GCToEEInterface::RestartEE(TRUE); - } + // we don't have sufficient resources - reset the new heap count + dynamic_heap_count_data.new_n_heaps = n_heaps; } } if (dynamic_heap_count_data.new_n_heaps == n_heaps) { // heap count stays the same, no work to do - dprintf (6666, ("heap count stays the same, no work to do %d == %d", dynamic_heap_count_data.new_n_heaps, n_heaps)); + dynamic_heap_count_data.prev_num_completed_gcs = get_num_completed_gcs (); + dynamic_heap_count_data.should_change_heap_count = false; - // come back after 3 GCs to reconsider - prev_change_heap_count_gc_index = settings.gc_index; + dprintf (6666, ("heap count stays the same %d, no work to do, set prev completed to %Id", dynamic_heap_count_data.new_n_heaps, dynamic_heap_count_data.prev_num_completed_gcs)); return; } - if (GCScan::GetGcRuntimeStructuresValid()) + int new_n_heaps = dynamic_heap_count_data.new_n_heaps; + + assert (!(dynamic_heap_count_data.init_only_p)); + { + // At this point we are guaranteed to be able to change the heap count to the new one. + // Change the heap count for joins here because we will need to join new_n_heaps threads together. + dprintf (9999, ("changing join hp %d->%d", n_heaps, new_n_heaps)); + int max_threads_to_wake = max (n_heaps, new_n_heaps); + gc_t_join.update_n_threads (max_threads_to_wake); + // make sure the other gc threads cannot see this as a request to GC assert (dynamic_heap_count_data.new_n_heaps != n_heaps); + + if (n_heaps < new_n_heaps) + { + int saved_idle_thread_count = dynamic_heap_count_data.idle_thread_count; + Interlocked::ExchangeAdd (&dynamic_heap_count_data.idle_thread_count, (n_heaps - new_n_heaps)); + dprintf (9999, ("GC thread %d setting idle events for h%d-h%d, total idle %d -> %d", heap_number, n_heaps, (new_n_heaps - 1), + saved_idle_thread_count, VolatileLoadWithoutBarrier (&dynamic_heap_count_data.idle_thread_count))); + + for (int heap_idx = n_heaps; heap_idx < new_n_heaps; heap_idx++) + { + g_heaps[heap_idx]->gc_idle_thread_event.Set(); +#ifdef BACKGROUND_GC + g_heaps[heap_idx]->bgc_idle_thread_event.Set(); +#endif //BACKGROUND_GC + } + } + gc_start_event.Set(); } int old_n_heaps = n_heaps; + (dynamic_heap_count_data.heap_count_change_count)++; change_heap_count (dynamic_heap_count_data.new_n_heaps); GCToEEInterface::RestartEE(TRUE); - prev_change_heap_count_gc_index = settings.gc_index; + dprintf (9999, ("h0 restarted EE")); // we made changes to the heap count that will change the overhead, // so change the smoothed overhead to reflect that - int new_n_heaps = n_heaps; - dynamic_heap_count_data.smoothed_median_percent_overhead = dynamic_heap_count_data.smoothed_median_percent_overhead/new_n_heaps*old_n_heaps; + dynamic_heap_count_data.smoothed_median_throughput_cost_percent = dynamic_heap_count_data.smoothed_median_throughput_cost_percent / n_heaps * old_n_heaps; + + dprintf (6666, ("h0 finished changing, set should change to false!")); + dynamic_heap_count_data.should_change_heap_count = false; } bool gc_heap::prepare_to_change_heap_count (int new_n_heaps) { - dprintf (6666, ("trying to change heap count %d -> %d", n_heaps, new_n_heaps)); + dprintf (9999, ("trying to change heap count %d -> %d", n_heaps, new_n_heaps)); // use this variable for clarity - n_heaps will change during the transition int old_n_heaps = n_heaps; @@ -25357,6 +25644,17 @@ bool gc_heap::prepare_to_change_heap_count (int new_n_heaps) } } + // Before we look at whether we have sufficient regions we should return regions that should be deleted to free + // so we don't lose them when we decommission heaps. We could do this for only heaps that we are about + // to decomission. But it's better to do this for all heaps because we don't need to worry about adding them to the + // heaps remain (freeable uoh/soh regions) and we get rid of regions with the heap_segment_flags_uoh_delete flag + // because background_delay_delete_uoh_segments makes the assumption it can't be the start region. + for (int i = 0; i < old_n_heaps; i++) + { + gc_heap* hp = g_heaps[i]; + hp->delay_free_segments (); + } + // if we want to increase the number of heaps, we have to make sure we can give // each heap a region for each generation. If we cannot do that, we have to give up ptrdiff_t region_count_in_gen[total_generation_count]; @@ -25437,39 +25735,34 @@ bool gc_heap::prepare_to_change_heap_count (int new_n_heaps) bool gc_heap::change_heap_count (int new_n_heaps) { + dprintf (9999, ("BEG heap%d changing %d->%d", heap_number, n_heaps, new_n_heaps)); + // use this variable for clarity - n_heaps will change during the transition int old_n_heaps = n_heaps; + bool init_only_p = dynamic_heap_count_data.init_only_p; - if (heap_number == 0) - { - if (!prepare_to_change_heap_count (new_n_heaps)) - { - // we don't have sufficient resources - reset the new heap count - dynamic_heap_count_data.new_n_heaps = n_heaps; - } - } - - if (GCScan::GetGcRuntimeStructuresValid()) { - // join for sufficient resources decision gc_t_join.join (this, gc_join_merge_temp_fl); if (gc_t_join.joined ()) { + // BGC is not running, we can safely change its join's heap count. +#ifdef BACKGROUND_GC + bgc_t_join.update_n_threads (new_n_heaps); +#endif //BACKGROUND_GC + + dynamic_heap_count_data.init_only_p = false; + dprintf (9999, ("in change h%d resetting gc_start, update bgc join to %d heaps", heap_number, new_n_heaps)); gc_start_event.Reset(); gc_t_join.restart (); } } - // gc_heap::n_heaps may have changed by now, compare to the snapshot *before* the join - if (dynamic_heap_count_data.new_n_heaps == old_n_heaps) - { - dprintf (6666, ("failed to change heap count, no work to do %d == %d", dynamic_heap_count_data.new_n_heaps, old_n_heaps)); - return false; - } + assert (dynamic_heap_count_data.new_n_heaps != old_n_heaps); + + dprintf (9999, ("Waiting h0 heap%d changing %d->%d", heap_number, n_heaps, new_n_heaps)); if (heap_number == 0) { - // after having checked for sufficient resources, we are now committed to actually change the heap count dprintf (3, ("switching heap count from %d to %d heaps", old_n_heaps, new_n_heaps)); // spread finalization data out to heaps coming into service @@ -25490,17 +25783,23 @@ bool gc_heap::change_heap_count (int new_n_heaps) from_heap_number = (from_heap_number + 1) % old_n_heaps; } - // prepare for the switch by fixing the allocation contexts on the old heaps, + // prepare for the switch by fixing the allocation contexts on the old heaps, unify the gen0_bricks_cleared flag, // and setting the survived size for the existing regions to their allocated size + BOOL unified_gen0_bricks_cleared = TRUE; for (int i = 0; i < old_n_heaps; i++) { gc_heap* hp = g_heaps[i]; - if (GCScan::GetGcRuntimeStructuresValid()) + if (!init_only_p) { hp->fix_allocation_contexts (TRUE); } + if (unified_gen0_bricks_cleared && (hp->gen0_bricks_cleared == FALSE)) + { + unified_gen0_bricks_cleared = FALSE; + } + for (int gen_idx = 0; gen_idx < total_generation_count; gen_idx++) { generation* gen = hp->generation_of (gen_idx); @@ -25600,7 +25899,7 @@ bool gc_heap::change_heap_count (int new_n_heaps) hpd->free_regions[kind].transfer_regions(&hp->free_regions[kind]); } } - // update number of heaps + dprintf (9999, ("h%d changing %d->%d", heap_number, n_heaps, new_n_heaps)); n_heaps = new_n_heaps; // even out the regions over the current number of heaps @@ -25611,6 +25910,8 @@ bool gc_heap::change_heap_count (int new_n_heaps) { gc_heap* hp = g_heaps[i]; + hp->gen0_bricks_cleared = unified_gen0_bricks_cleared; + // establish invariants regarding the ephemeral segment generation* gen0 = hp->generation_of (0); if ((hp->ephemeral_heap_segment == nullptr) || @@ -25639,7 +25940,9 @@ bool gc_heap::change_heap_count (int new_n_heaps) } } - if (GCScan::GetGcRuntimeStructuresValid()) + dprintf (3, ("individual heap%d changing %d->%d", heap_number, n_heaps, new_n_heaps)); + + if (!init_only_p) { // join for rethreading the free lists gc_t_join.join (this, gc_join_merge_temp_fl); @@ -25651,7 +25954,11 @@ bool gc_heap::change_heap_count (int new_n_heaps) // rethread the free lists for (int gen_idx = 0; gen_idx < total_generation_count; gen_idx++) { - rethread_fl_items (gen_idx); + if (heap_number < old_n_heaps) + { + dprintf (3, ("h%d calling per heap work!", heap_number)); + rethread_fl_items (gen_idx); + } // join for merging the free lists gc_t_join.join (this, gc_join_merge_temp_fl); @@ -25662,18 +25969,14 @@ bool gc_heap::change_heap_count (int new_n_heaps) gc_t_join.restart (); } } +#ifdef BACKGROUND_GC // there should be no items in the bgc_alloc_lock bgc_alloc_lock->check(); +#endif //BACKGROUND_GC } if (heap_number == 0) { - // udate the number of heaps in the joins - gc_t_join.update_n_threads(new_n_heaps); - #ifdef BACKGROUND_GC - bgc_t_join.update_n_threads(new_n_heaps); - #endif //BACKGROUND_GC - // compute the total budget per generation over the old heaps // and figure out what the new budget per heap is ptrdiff_t budget_per_heap[total_generation_count]; @@ -25733,21 +26036,50 @@ bool gc_heap::change_heap_count (int new_n_heaps) hp->decommission_heap(); } - if (GCScan::GetGcRuntimeStructuresValid()) + if (!init_only_p) { // make sure no allocation contexts point to idle heaps fix_allocation_contexts_heaps(); } - if (old_n_heaps < new_n_heaps) + dynamic_heap_count_data.last_n_heaps = old_n_heaps; + } + + // join the last time to change the heap count again if needed. + if (new_n_heaps < old_n_heaps) + { + gc_t_join.join (this, gc_join_merge_temp_fl); + if (gc_t_join.joined ()) { - // wake up threads for the new heaps - gc_idle_thread_event.Set(); + dprintf (9999, ("now changing the join heap count to the smaller one %d", new_n_heaps)); + gc_t_join.update_n_threads (new_n_heaps); + + gc_t_join.restart (); } } return true; } + +size_t gc_heap::get_msl_wait_time() +{ + assert (dynamic_adaptation_mode == dynamic_adaptation_to_application_sizes); + + size_t msl_wait_since_pause = 0; + + for (int i = 0; i < n_heaps; i++) + { + gc_heap* hp = g_heaps[i]; + + msl_wait_since_pause += hp->more_space_lock_soh.msl_wait_time; + hp->more_space_lock_soh.msl_wait_time = 0; + + msl_wait_since_pause += hp->more_space_lock_uoh.msl_wait_time; + hp->more_space_lock_uoh.msl_wait_time = 0; + } + + return msl_wait_since_pause; +} #endif //DYNAMIC_HEAP_COUNT #endif //USE_REGIONS @@ -32791,17 +33123,17 @@ void gc_heap::plan_phase (int condemned_gen_number) } else { - dprintf (2, ("gen2 didn't grow (end seg alloc: %zd, , condemned alloc: %zd, gen1 c alloc: %zd", + dprintf (1, ("gen2 didn't grow (end seg alloc: %zd, , condemned alloc: %zd, gen1 c alloc: %zd", end_seg_allocated, condemned_allocated, generation_condemned_allocated (generation_of (max_generation - 1)))); } - dprintf (1, ("older gen's free alloc: %zd->%zd, seg alloc: %zd->%zd, condemned alloc: %zd->%zd", + dprintf (2, ("older gen's free alloc: %zd->%zd, seg alloc: %zd->%zd, condemned alloc: %zd->%zd", r_older_gen_free_list_allocated, generation_free_list_allocated (older_gen), r_older_gen_end_seg_allocated, generation_end_seg_allocated (older_gen), r_older_gen_condemned_allocated, generation_condemned_allocated (older_gen))); - dprintf (1, ("this GC did %zd free list alloc(%zd bytes free space rejected)", + dprintf (2, ("this GC did %zd free list alloc(%zd bytes free space rejected)", free_list_allocated, rejected_free_space)); maxgen_size_increase* maxgen_size_info = &(get_gc_data_per_heap()->maxgen_size_info); @@ -38894,9 +39226,9 @@ void gc_heap::bgc_thread_function() { // this is the case where we have more background GC threads than heaps // - wait until we're told to continue... - dprintf (3, ("BGC thread %d idle", heap_number)); - gc_idle_thread_event.Wait(INFINITE, FALSE); - dprintf (3, ("BGC thread %d waking from idle", heap_number)); + dprintf (9999, ("BGC thread %d idle (%d heaps) (gc%Id)", heap_number, n_heaps, VolatileLoadWithoutBarrier (&settings.gc_index))); + bgc_idle_thread_event.Wait(INFINITE, FALSE); + dprintf (9999, ("BGC thread %d waking from idle (%d heaps) (gc%Id)", heap_number, n_heaps, VolatileLoadWithoutBarrier (&settings.gc_index))); continue; } #endif //DYNAMIC_HEAP_COUNT @@ -38968,7 +39300,7 @@ void gc_heap::bgc_thread_function() dprintf (SPINLOCK_LOG, ("bgc Lgc")); leave_spin_lock (&gc_lock); #ifdef MULTIPLE_HEAPS - dprintf(1, ("End of BGC - starting all BGC threads")); + dprintf(1, ("End of BGC")); bgc_t_join.restart(); #endif //MULTIPLE_HEAPS } @@ -42845,6 +43177,9 @@ bool gc_heap::init_dynamic_data() { process_start_time = now; smoothed_desired_total[0] = dynamic_data_of (0)->min_size * n_heaps; +#ifdef DYNAMIC_HEAP_COUNT + last_suspended_end_time = now; +#endif //DYNAMIC_HEAP_COUNT #ifdef HEAP_BALANCE_INSTRUMENTATION last_gc_end_time_us = now; dprintf (HEAP_BALANCE_LOG, ("qpf=%zd, start: %zd(%d)", qpf, start_raw_ts, now)); @@ -47943,6 +48278,7 @@ HRESULT GCHeap::Initialize() uint32_t nhp = 1; uint32_t nhp_from_config = 0; + uint32_t max_nhp_from_config = (uint32_t)GCConfig::GetMaxHeapCount(); #ifndef MULTIPLE_HEAPS GCConfig::SetServerGC(false); @@ -48137,6 +48473,10 @@ HRESULT GCHeap::Initialize() #ifdef MULTIPLE_HEAPS assert (nhp <= g_num_processors); + if (max_nhp_from_config) + { + nhp = min (nhp, max_nhp_from_config); + } gc_heap::n_max_heaps = nhp; gc_heap::n_heaps = nhp; hr = gc_heap::initialize_gc (seg_size, large_seg_size, pin_seg_size, nhp); @@ -48287,9 +48627,32 @@ HRESULT GCHeap::Initialize() { // start with only 1 heap gc_heap::smoothed_desired_total[0] /= gc_heap::n_heaps; - gc_heap::g_heaps[0]->change_heap_count (1); + int initial_n_heaps = 1; + dprintf (9999, ("gc_heap::n_heaps is %d, initial %d", gc_heap::n_heaps, initial_n_heaps)); + + { + if (!gc_heap::prepare_to_change_heap_count (initial_n_heaps)) + { + // we don't have sufficient resources. + return E_FAIL; + } + + gc_heap::dynamic_heap_count_data.new_n_heaps = initial_n_heaps; + gc_heap::dynamic_heap_count_data.idle_thread_count = 0; + gc_heap::dynamic_heap_count_data.init_only_p = true; + + int max_threads_to_wake = max (gc_heap::n_heaps, initial_n_heaps); + gc_t_join.update_n_threads (max_threads_to_wake); + gc_heap::gc_start_event.Set (); + } + + gc_heap::g_heaps[0]->change_heap_count (initial_n_heaps); + gc_heap::gc_start_event.Reset (); + + // This needs to be different from our initial heap count so we can make sure we wait for + // the idle threads correctly in gc_thread_function. + gc_heap::dynamic_heap_count_data.last_n_heaps = 0; } - gc_heap::dynamic_heap_count_data.new_n_heaps = gc_heap::n_heaps; #endif //DYNAMIC_HEAP_COUNT GCScan::GcRuntimeStructuresValid (TRUE); @@ -49861,10 +50224,16 @@ void gc_heap::do_post_gc() } #endif //BGC_SERVO_TUNING +#ifdef BACKGROUND_GC + const char* str_gc_type = (settings.concurrent ? "BGC" : (gc_heap::background_running_p () ? "FGC" : "NGC")); +#else + const char* str_gc_type = "NGC"; +#endif //BACKGROUND_GC + dprintf (1, (ThreadStressLog::gcDetailedEndMsg(), - VolatileLoad(&settings.gc_index), - dd_collection_count(hp->dynamic_data_of(0)), - (size_t)(GetHighPrecisionTimeStamp() / 1000), + VolatileLoad (&settings.gc_index), + dd_collection_count (hp->dynamic_data_of (0)), + (size_t)(GetHighPrecisionTimeStamp () / 1000), settings.condemned_generation, (settings.concurrent ? "BGC" : (gc_heap::background_running_p() ? "FGC" : "NGC")), (settings.compaction ? "C" : "S"), diff --git a/src/coreclr/gc/gcconfig.h b/src/coreclr/gc/gcconfig.h index 72786778d5a978..aeded6bc97f17f 100644 --- a/src/coreclr/gc/gcconfig.h +++ b/src/coreclr/gc/gcconfig.h @@ -83,6 +83,7 @@ class GCConfigStringHolder INT_CONFIG (BGCSpinCount, "BGCSpinCount", NULL, 140, "Specifies the bgc spin count") \ INT_CONFIG (BGCSpin, "BGCSpin", NULL, 2, "Specifies the bgc spin time") \ INT_CONFIG (HeapCount, "GCHeapCount", "System.GC.HeapCount", 0, "Specifies the number of server GC heaps") \ + INT_CONFIG (MaxHeapCount, "GCMaxHeapCount", "System.GC.MaxHeapCount", 0, "Specifies the max number of server GC heaps to adjust to") \ INT_CONFIG (Gen0Size, "GCgen0size", NULL, 0, "Specifies the smallest gen0 budget") \ INT_CONFIG (SegmentSize, "GCSegmentSize", NULL, 0, "Specifies the managed heap segment size") \ INT_CONFIG (LatencyMode, "GCLatencyMode", NULL, -1, "Specifies the GC latency mode - batch, interactive or low latency (note that the same " \ diff --git a/src/coreclr/gc/gcee.cpp b/src/coreclr/gc/gcee.cpp index 6dbbfd64a7a514..32738da9b603ab 100644 --- a/src/coreclr/gc/gcee.cpp +++ b/src/coreclr/gc/gcee.cpp @@ -510,9 +510,12 @@ bool GCHeap::IsInFrozenSegment(Object *object) void GCHeap::UpdateFrozenSegment(segment_handle seg, uint8_t* allocated, uint8_t* committed) { #ifdef FEATURE_BASICFREEZE - heap_segment* heap_seg = reinterpret_cast(seg); - heap_segment_committed(heap_seg) = committed; - heap_segment_allocated(heap_seg) = allocated; +#ifdef MULTIPLE_HEAPS + gc_heap* heap = gc_heap::g_heaps[0]; +#else + gc_heap* heap = pGenGCHeap; +#endif //MULTIPLE_HEAPS + heap->update_ro_segment (reinterpret_cast(seg), allocated, committed); #endif // FEATURE_BASICFREEZE } diff --git a/src/coreclr/gc/gcpriv.h b/src/coreclr/gc/gcpriv.h index da0085ce19610d..cce6c5ee28adf0 100644 --- a/src/coreclr/gc/gcpriv.h +++ b/src/coreclr/gc/gcpriv.h @@ -402,8 +402,6 @@ struct GCDebugSpinLock { #if defined(DYNAMIC_HEAP_COUNT) // time in microseconds we wait for the more space lock uint64_t msl_wait_time; - // number of times we wait for the more space lock - uint64_t msl_wait_count; #endif //DYNAMIC_HEAP_COUNT GCDebugSpinLock() @@ -415,7 +413,7 @@ struct GCDebugSpinLock { , num_switch_thread(0), num_wait_longer(0), num_switch_thread_w(0), num_disable_preemptive_w(0) #endif #if defined(DYNAMIC_HEAP_COUNT) - , msl_wait_time(0), msl_wait_count(0) + , msl_wait_time(0) #endif //DYNAMIC_HEAP_COUNT { } @@ -1148,15 +1146,12 @@ class dynamic_data // // The following 3 fields are updated at the beginning of each GC, if that GC condemns this generation. // - // The number of GC that condemned this generation. The only difference between this - // and collection_count is just that collection_count is maintained for all physical generations - // (currently there are 5) whereas this is only updated for logical generations (there are 3). - size_t gc_clock; - uint64_t time_clock; //time when this gc started + size_t gc_clock; // the gc index + uint64_t time_clock; // time when this gc started uint64_t previous_time_clock; // time when previous gc started // Updated at the end of a GC, if that GC condemns this generation. - size_t gc_elapsed_time; // Time it took for the gc to complete + size_t gc_elapsed_time; // time it took for the gc to complete // // The following fields (and fields in sdata) are initialized during GC init time and do not change. @@ -1495,6 +1490,8 @@ class mark_queue_t void verify_empty(); }; +float median_of_3 (float a, float b, float c); + //class definition of the internal class class gc_heap { @@ -2380,6 +2377,7 @@ class gc_heap #ifdef FEATURE_BASICFREEZE PER_HEAP_METHOD BOOL insert_ro_segment (heap_segment* seg); PER_HEAP_METHOD void remove_ro_segment (heap_segment* seg); + PER_HEAP_METHOD void update_ro_segment (heap_segment* seg, uint8_t* allocated, uint8_t* committed); #endif //FEATURE_BASICFREEZE PER_HEAP_METHOD BOOL set_ro_segment_in_range (heap_segment* seg); #ifndef USE_REGIONS @@ -2421,6 +2419,7 @@ class gc_heap #ifndef USE_REGIONS PER_HEAP_METHOD void rearrange_heap_segments(BOOL compacting); #endif //!USE_REGIONS + PER_HEAP_METHOD void delay_free_segments(); PER_HEAP_ISOLATED_METHOD void distribute_free_regions(); #ifdef BACKGROUND_GC PER_HEAP_ISOLATED_METHOD void reset_write_watch_for_gc_heap(void* base_address, size_t region_size); @@ -2596,11 +2595,17 @@ class gc_heap // re-initialize a heap in preparation to putting it back into service PER_HEAP_METHOD void recommission_heap(); + PER_HEAP_ISOLATED_METHOD size_t get_num_completed_gcs(); + + PER_HEAP_ISOLATED_METHOD int calculate_new_heap_count(); + // check if we should change the heap count PER_HEAP_METHOD void check_heap_count(); - PER_HEAP_METHOD bool prepare_to_change_heap_count (int new_n_heaps); + PER_HEAP_ISOLATED_METHOD bool prepare_to_change_heap_count (int new_n_heaps); PER_HEAP_METHOD bool change_heap_count (int new_n_heaps); + + PER_HEAP_ISOLATED_METHOD size_t get_msl_wait_time(); #endif //DYNAMIC_HEAP_COUNT #endif //USE_REGIONS @@ -3777,6 +3782,13 @@ class gc_heap PER_HEAP_FIELD_MAINTAINED mark* loh_pinned_queue; #endif //FEATURE_LOH_COMPACTION +#ifdef DYNAMIC_HEAP_COUNT + PER_HEAP_FIELD_MAINTAINED GCEvent gc_idle_thread_event; +#ifdef BACKGROUND_GC + PER_HEAP_FIELD_MAINTAINED GCEvent bgc_idle_thread_event; +#endif //BACKGROUND_GC +#endif //DYNAMIC_HEAP_COUNT + /******************************************/ // PER_HEAP_FIELD_MAINTAINED_ALLOC fields // /******************************************/ @@ -4083,7 +4095,6 @@ class gc_heap // These 2 fields' values do not change but are set/unset per GC PER_HEAP_ISOLATED_FIELD_SINGLE_GC GCEvent gc_start_event; PER_HEAP_ISOLATED_FIELD_SINGLE_GC GCEvent ee_suspend_event; - PER_HEAP_ISOLATED_FIELD_SINGLE_GC GCEvent gc_idle_thread_event; // Also updated on the heap#0 GC thread because that's where we are actually doing the decommit. PER_HEAP_ISOLATED_FIELD_SINGLE_GC BOOL gradual_decommit_in_progress_p; @@ -4162,6 +4173,10 @@ class gc_heap PER_HEAP_ISOLATED_FIELD_SINGLE_GC uint8_t* gc_high; // high end of the highest region being condemned #endif //USE_REGIONS +#ifdef STRESS_DYNAMIC_HEAP_COUNT + PER_HEAP_ISOLATED_FIELD_SINGLE_GC int heaps_in_this_gc; +#endif //STRESS_DYNAMIC_HEAP_COUNT + /**************************************************/ // PER_HEAP_ISOLATED_FIELD_SINGLE_GC_ALLOC fields // /**************************************************/ @@ -4260,37 +4275,65 @@ class gc_heap #endif //USE_REGIONS #ifdef DYNAMIC_HEAP_COUNT + // Sample collection - + // + // For every GC, we collect the msl wait time + GC pause duration info and use both to calculate the + // throughput cost percentage. We will also be using the wait time and the GC pause duration separately + // for other purposes in the future. + // + // For all gen2 GCs we also keep a separate array currently just for the GC cost. This serves as a backstop + // to smooth out the situation when we rarely pick the gen2 GCs in the first array. struct dynamic_heap_count_data_t { static const int sample_size = 3; struct sample { - uint64_t elapsed_between_gcs; // time between gcs in microseconds - uint64_t gc_elapsed_time; // time the gc took - uint64_t soh_msl_wait_time; // time the allocator spent waiting for the soh msl lock - uint64_t uoh_msl_wait_time; // time the allocator spent waiting for the uoh msl lock - size_t allocating_thread_count;// number of allocating threads - size_t heap_size; + uint64_t elapsed_between_gcs; // time between gcs in microseconds (this should really be between_pauses) + uint64_t gc_pause_time; // pause time for this GC + uint64_t msl_wait_time; }; - unsigned sample_index; + uint32_t sample_index; sample samples[sample_size]; + size_t prev_num_completed_gcs; + + uint32_t gen2_sample_index; + // This is (gc_elapsed_time / time inbetween this and the last gen2 GC) + float gen2_gc_percents[sample_size]; - float median_percent_overhead; // estimated overhead of allocator + gc - float smoothed_median_percent_overhead; // exponentially smoothed version - float percent_heap_space_cost_per_heap; // percent space cost of adding a heap - float overhead_reduction_per_step_up; // percentage effect on overhead of increasing heap count - float overhead_increase_per_step_down; // percentage effect on overhead of decreasing heap count - float space_cost_increase_per_step_up; // percentage effect on space of increasing heap count - float space_cost_decrease_per_step_down;// percentage effect on space of decreasing heap count + float median_throughput_cost_percent; // estimated overhead of allocator + gc + float smoothed_median_throughput_cost_percent; // exponentially smoothed version + float percent_heap_space_cost_per_heap; // percent space cost of adding a heap + float tcp_reduction_per_step_up; // throughput cost percent effect of increasing heap count + float tcp_increase_per_step_down; // throughput cost percent effect of decreasing heap count + float scp_increase_per_step_up; // space cost percent effect of increasing heap count + float scp_decrease_per_step_down; // space cost percent effect of decreasing heap count int new_n_heaps; + // the heap count we changed from + int last_n_heaps; + // don't start a GC till we see (n_max_heaps - new_n_heaps) number of threads idling + VOLATILE(int32_t) idle_thread_count; + bool init_only_p; + + bool should_change_heap_count; + int heap_count_to_change_to; + int heap_count_change_count; #ifdef STRESS_DYNAMIC_HEAP_COUNT int lowest_heap_with_msl_uoh; #endif //STRESS_DYNAMIC_HEAP_COUNT + + float get_median_gen2_gc_percent() + { + return median_of_3 (gen2_gc_percents[0], gen2_gc_percents[1], gen2_gc_percents[2]); + } }; PER_HEAP_ISOLATED_FIELD_MAINTAINED dynamic_heap_count_data_t dynamic_heap_count_data; + PER_HEAP_ISOLATED_FIELD_MAINTAINED uint64_t last_suspended_end_time; + // If the last full GC is blocking, this is that GC's index; for BGC, this is the settings.gc_index + // when the BGC ended. + PER_HEAP_ISOLATED_FIELD_MAINTAINED size_t gc_index_full_gc_end; #endif //DYNAMIC_HEAP_COUNT /****************************************************/ @@ -4866,7 +4909,6 @@ uint64_t& dd_previous_time_clock (dynamic_data* inst) return inst->previous_time_clock; } - inline size_t& dd_gc_clock_interval (dynamic_data* inst) { diff --git a/src/coreclr/gc/unix/gcenv.unix.cpp b/src/coreclr/gc/unix/gcenv.unix.cpp index 285b783485802a..b45cd40d8073fe 100644 --- a/src/coreclr/gc/unix/gcenv.unix.cpp +++ b/src/coreclr/gc/unix/gcenv.unix.cpp @@ -168,6 +168,17 @@ enum membarrier_cmd bool CanFlushUsingMembarrier() { + +#ifdef TARGET_ANDROID + // Avoid calling membarrier on older Android versions where membarrier + // may be barred by seccomp causing the process to be killed. + int apiLevel = android_get_device_api_level(); + if (apiLevel < __ANDROID_API_Q__) + { + return false; + } +#endif + // Starting with Linux kernel 4.14, process memory barriers can be generated // using MEMBARRIER_CMD_PRIVATE_EXPEDITED. diff --git a/src/coreclr/ilasm/assem.cpp b/src/coreclr/ilasm/assem.cpp index dd2c91ac093acb..2bd90fadb8f916 100644 --- a/src/coreclr/ilasm/assem.cpp +++ b/src/coreclr/ilasm/assem.cpp @@ -1324,7 +1324,12 @@ OPCODE Assembler::DecodeOpcode(const BYTE *pCode, DWORD *pdwLen) char* Assembler::ReflectionNotation(mdToken tk) { + // We break the global static `wzUniBuf` into 2 equal parts: the first part is used for a Unicode + // string, the second part is used for a converted-into-multibyte (MB) string. Note that the MB string + // length is in bytes. char *sz = (char*)&wzUniBuf[dwUniBuf>>1], *pc; + const size_t szSizeBytes = (dwUniBuf * sizeof(WCHAR)) / 2; // sizeof(WCHAR) is 2, so this is just `dwUniBuf` + const size_t cchUniBuf = dwUniBuf / 2; // only use the first 1/2 of wzUniBuf *sz=0; switch(TypeFromToken(tk)) { @@ -1333,7 +1338,7 @@ char* Assembler::ReflectionNotation(mdToken tk) Class *pClass = m_lstClass.PEEK(RidFromToken(tk)-1); if(pClass) { - strcpy_s(sz,dwUniBuf>>1,pClass->m_szFQN); + strcpy_s(sz,szSizeBytes,pClass->m_szFQN); pc = sz; while((pc = strchr(pc,NESTING_SEP)) != NULL) { @@ -1348,31 +1353,80 @@ char* Assembler::ReflectionNotation(mdToken tk) { ULONG N; mdToken tkResScope; - if(SUCCEEDED(m_pImporter->GetTypeRefProps(tk,&tkResScope,wzUniBuf,dwUniBuf>>1,&N))) + if(SUCCEEDED(m_pImporter->GetTypeRefProps(tk,&tkResScope,wzUniBuf,cchUniBuf,&N))) { - WszWideCharToMultiByte(CP_UTF8,0,wzUniBuf,-1,sz,dwUniBuf>>1,NULL,NULL); + int ret = WszWideCharToMultiByte(CP_UTF8,0,wzUniBuf,-1,sz,szSizeBytes,NULL,NULL); if(TypeFromToken(tkResScope)==mdtAssemblyRef) { AsmManAssembly *pAsmRef = m_pManifest->m_AsmRefLst.PEEK(RidFromToken(tkResScope)-1); if(pAsmRef) { - pc = &sz[strlen(sz)]; - pc+=sprintf_s(pc,(dwUniBuf >> 1),", %s, Version=%d.%d.%d.%d, Culture=",pAsmRef->szName, + // We assume below that if sprintf_s fails due to buffer overrun, + // execution fails fast and sprintf_s doesn't return. + int sprintf_ret; + const size_t szLen = strlen(sz); + pc = &sz[szLen]; + size_t szRemainingSizeBytes = szSizeBytes - szLen; + + sprintf_ret = sprintf_s(pc,szRemainingSizeBytes,", %s, Version=%d.%d.%d.%d, Culture=",pAsmRef->szName, pAsmRef->usVerMajor,pAsmRef->usVerMinor,pAsmRef->usBuild,pAsmRef->usRevision); - ULONG L=0; - if(pAsmRef->pLocale && (L=pAsmRef->pLocale->length())) + pc += sprintf_ret; + szRemainingSizeBytes -= (size_t)sprintf_ret; + + unsigned L=0; + if(pAsmRef->pLocale && ((L=pAsmRef->pLocale->length()) > 0)) + { + // L is in bytes and doesn't include the terminating null. + if (L > (cchUniBuf - 1) * sizeof(WCHAR)) + { + report->error("Locale too long (%d characters, %d allowed).\n",L / sizeof(WCHAR), cchUniBuf - 1); + *sz=0; + break; + } + else if (szRemainingSizeBytes == 0) + { + report->error("TypeRef too long.\n"); + *sz=0; + break; + } + + if (szRemainingSizeBytes > 0) + { + memcpy(wzUniBuf,pAsmRef->pLocale->ptr(),L); + wzUniBuf[L>>1] = 0; + ret = WszWideCharToMultiByte(CP_UTF8,0,wzUniBuf,-1,pc,(int)szRemainingSizeBytes,NULL,NULL); + if (ret <= 0) + { + report->error("Locale too long.\n"); + *sz=0; + break; + } + else + { + pc += ret; + szRemainingSizeBytes -= (size_t)ret; + } + } + } + else { - memcpy(wzUniBuf,pAsmRef->pLocale->ptr(),L); - wzUniBuf[L>>1] = 0; - WszWideCharToMultiByte(CP_UTF8,0,wzUniBuf,-1,pc,dwUniBuf>>1,NULL,NULL); + sprintf_ret = sprintf_s(pc,szRemainingSizeBytes,"neutral"); + pc += sprintf_ret; + szRemainingSizeBytes -= (size_t)sprintf_ret; } - else pc+=sprintf_s(pc,(dwUniBuf >> 1),"neutral"); - pc = &sz[strlen(sz)]; - if(pAsmRef->pPublicKeyToken && (L=pAsmRef->pPublicKeyToken->length())) + if(pAsmRef->pPublicKeyToken && ((L=pAsmRef->pPublicKeyToken->length()) > 0)) { - pc+=sprintf_s(pc,(dwUniBuf >> 1),", Publickeytoken="); + sprintf_ret = sprintf_s(pc,szRemainingSizeBytes,", Publickeytoken="); + pc += sprintf_ret; + szRemainingSizeBytes -= (size_t)sprintf_ret; + BYTE* pb = (BYTE*)(pAsmRef->pPublicKeyToken->ptr()); - for(N=0; N> 1),"%2.2x",*pb); + for(unsigned i=0; i class ClrSafeInt INDEBUG( mutable bool m_checkedOverflow; ) }; +#if defined(_MSC_VER) && defined(HOST_ARM64) // Workaround for https://github.com/dotnet/runtime/issues/93442 +#pragma optimize("", off) +#endif + template <> inline bool ClrSafeInt::multiply(int64_t lhs, int64_t rhs, int64_t &result) { @@ -874,6 +878,10 @@ inline bool ClrSafeInt::multiply(uint8_t lhs, uint8_t rhs, uint8_t &res return true; } +#if defined(_MSC_VER) && defined(HOST_ARM64) // Workaround for https://github.com/dotnet/runtime/issues/93442 +#pragma optimize("", on) +#endif + // Allows creation of a ClrSafeInt corresponding to the type of the argument. template ClrSafeInt AsClrSafeInt(T t) diff --git a/src/coreclr/inc/sstring.h b/src/coreclr/inc/sstring.h index 00b826b23c3c54..6557ca0d43edb1 100644 --- a/src/coreclr/inc/sstring.h +++ b/src/coreclr/inc/sstring.h @@ -683,7 +683,9 @@ class EMPTY_BASES_DECL SString : private SBuffer BOOL IsASCIIScanned() const; void SetASCIIScanned() const; void SetNormalized() const; +public: BOOL IsNormalized() const; +private: void ClearNormalized() const; void EnsureWritable() const; diff --git a/src/coreclr/jit/codegenarm64.cpp b/src/coreclr/jit/codegenarm64.cpp index f99602d21d953b..d937cd67747e25 100644 --- a/src/coreclr/jit/codegenarm64.cpp +++ b/src/coreclr/jit/codegenarm64.cpp @@ -327,19 +327,38 @@ bool CodeGen::genInstrWithConstant(instruction ins, break; case INS_strb: + assert(size == EA_1BYTE); + immFitsInIns = emitter::emitIns_valid_imm_for_ldst_offset(imm, EA_1BYTE); + break; + case INS_strh: + assert(size == EA_2BYTE); + immFitsInIns = emitter::emitIns_valid_imm_for_ldst_offset(imm, EA_2BYTE); + break; + case INS_str: // reg1 is a source register for store instructions assert(tmpReg != reg1); // regTmp can not match any source register immFitsInIns = emitter::emitIns_valid_imm_for_ldst_offset(imm, size); break; + case INS_ldrb: case INS_ldrsb: + immFitsInIns = emitter::emitIns_valid_imm_for_ldst_offset(imm, EA_1BYTE); + break; + + case INS_ldrh: case INS_ldrsh: + immFitsInIns = emitter::emitIns_valid_imm_for_ldst_offset(imm, EA_2BYTE); + break; + case INS_ldrsw: - case INS_ldrb: - case INS_ldrh: + assert(size == EA_4BYTE); + immFitsInIns = emitter::emitIns_valid_imm_for_ldst_offset(imm, EA_4BYTE); + break; + case INS_ldr: + assert((size == EA_4BYTE) || (size == EA_8BYTE) || (size == EA_16BYTE)); immFitsInIns = emitter::emitIns_valid_imm_for_ldst_offset(imm, size); break; diff --git a/src/coreclr/jit/codegencommon.cpp b/src/coreclr/jit/codegencommon.cpp index 07e03d3e0c2334..72e31c388154ab 100644 --- a/src/coreclr/jit/codegencommon.cpp +++ b/src/coreclr/jit/codegencommon.cpp @@ -2011,7 +2011,7 @@ void CodeGen::genEmitMachineCode() #if TRACK_LSRA_STATS if (JitConfig.DisplayLsraStats() == 3) { - compiler->m_pLinearScan->dumpLsraStatsSummary(jitstdout); + compiler->m_pLinearScan->dumpLsraStatsSummary(jitstdout()); } #endif // TRACK_LSRA_STATS @@ -2104,7 +2104,7 @@ void CodeGen::genEmitUnwindDebugGCandEH() genCreateAndStoreGCInfo(codeSize, prologSize, epilogSize DEBUGARG(codePtr)); #ifdef DEBUG - FILE* dmpf = jitstdout; + FILE* dmpf = jitstdout(); compiler->opts.dmpHex = false; if (!strcmp(compiler->info.compMethodName, "= 0.5) { - fprintf(fout, " GT_%-17s %7u (%4.1lf%%) %3u bytes each\n", GenTree::OpName(oper.Oper), count, - percentage, size); + jitprintf(" GT_%-17s %7u (%4.1lf%%) %3u bytes each\n", GenTree::OpName(oper.Oper), count, + percentage, size); remainingCount -= count; } else @@ -1484,14 +1482,14 @@ void Compiler::compShutdown() if (remainingCount > 0) { - fprintf(fout, " All other GT_xxx ... %7u (%4.1lf%%) ... %4.1lf%% small + %4.1lf%% large\n", - remainingCount, 100.0 * remainingCount / totalCount, 100.0 * remainingCountSmall / totalCount, - 100.0 * remainingCountLarge / totalCount); + jitprintf(" All other GT_xxx ... %7u (%4.1lf%%) ... %4.1lf%% small + %4.1lf%% large\n", remainingCount, + 100.0 * remainingCount / totalCount, 100.0 * remainingCountSmall / totalCount, + 100.0 * remainingCountLarge / totalCount); } - fprintf(fout, " -----------------------------------------------------\n"); - fprintf(fout, " Total ....... %11u --ALL-- ... %4.1lf%% small + %4.1lf%% large\n", totalCount, - 100.0 * countSmall / totalCount, 100.0 * countLarge / totalCount); - fprintf(fout, "\n"); + jitprintf(" -----------------------------------------------------\n"); + jitprintf(" Total ....... %11u --ALL-- ... %4.1lf%% small + %4.1lf%% large\n", totalCount, + 100.0 * countSmall / totalCount, 100.0 * countLarge / totalCount); + jitprintf("\n"); } #endif // COUNT_AST_OPERS @@ -1500,49 +1498,49 @@ void Compiler::compShutdown() if (grossVMsize && grossNCsize) { - fprintf(fout, "\n"); - fprintf(fout, "--------------------------------------\n"); - fprintf(fout, "Function and GC info size stats\n"); - fprintf(fout, "--------------------------------------\n"); + jitprintf("\n"); + jitprintf("--------------------------------------\n"); + jitprintf("Function and GC info size stats\n"); + jitprintf("--------------------------------------\n"); - fprintf(fout, "[%7u VM, %8u %6s %4u%%] %s\n", grossVMsize, grossNCsize, Target::g_tgtCPUName, - 100 * grossNCsize / grossVMsize, "Total (excluding GC info)"); + jitprintf("[%7u VM, %8u %6s %4u%%] %s\n", grossVMsize, grossNCsize, Target::g_tgtCPUName, + 100 * grossNCsize / grossVMsize, "Total (excluding GC info)"); - fprintf(fout, "[%7u VM, %8u %6s %4u%%] %s\n", grossVMsize, totalNCsize, Target::g_tgtCPUName, - 100 * totalNCsize / grossVMsize, "Total (including GC info)"); + jitprintf("[%7u VM, %8u %6s %4u%%] %s\n", grossVMsize, totalNCsize, Target::g_tgtCPUName, + 100 * totalNCsize / grossVMsize, "Total (including GC info)"); if (gcHeaderISize || gcHeaderNSize) { - fprintf(fout, "\n"); + jitprintf("\n"); - fprintf(fout, "GC tables : [%7uI,%7uN] %7u byt (%u%% of IL, %u%% of %s).\n", - gcHeaderISize + gcPtrMapISize, gcHeaderNSize + gcPtrMapNSize, totalNCsize - grossNCsize, - 100 * (totalNCsize - grossNCsize) / grossVMsize, 100 * (totalNCsize - grossNCsize) / grossNCsize, - Target::g_tgtCPUName); + jitprintf("GC tables : [%7uI,%7uN] %7u byt (%u%% of IL, %u%% of %s).\n", gcHeaderISize + gcPtrMapISize, + gcHeaderNSize + gcPtrMapNSize, totalNCsize - grossNCsize, + 100 * (totalNCsize - grossNCsize) / grossVMsize, 100 * (totalNCsize - grossNCsize) / grossNCsize, + Target::g_tgtCPUName); - fprintf(fout, "GC headers : [%7uI,%7uN] %7u byt, [%4.1fI,%4.1fN] %4.1f byt/meth\n", gcHeaderISize, - gcHeaderNSize, gcHeaderISize + gcHeaderNSize, (float)gcHeaderISize / (genMethodICnt + 0.001), - (float)gcHeaderNSize / (genMethodNCnt + 0.001), - (float)(gcHeaderISize + gcHeaderNSize) / genMethodCnt); + jitprintf("GC headers : [%7uI,%7uN] %7u byt, [%4.1fI,%4.1fN] %4.1f byt/meth\n", gcHeaderISize, + gcHeaderNSize, gcHeaderISize + gcHeaderNSize, (float)gcHeaderISize / (genMethodICnt + 0.001), + (float)gcHeaderNSize / (genMethodNCnt + 0.001), + (float)(gcHeaderISize + gcHeaderNSize) / genMethodCnt); - fprintf(fout, "GC ptr maps : [%7uI,%7uN] %7u byt, [%4.1fI,%4.1fN] %4.1f byt/meth\n", gcPtrMapISize, - gcPtrMapNSize, gcPtrMapISize + gcPtrMapNSize, (float)gcPtrMapISize / (genMethodICnt + 0.001), - (float)gcPtrMapNSize / (genMethodNCnt + 0.001), - (float)(gcPtrMapISize + gcPtrMapNSize) / genMethodCnt); + jitprintf("GC ptr maps : [%7uI,%7uN] %7u byt, [%4.1fI,%4.1fN] %4.1f byt/meth\n", gcPtrMapISize, + gcPtrMapNSize, gcPtrMapISize + gcPtrMapNSize, (float)gcPtrMapISize / (genMethodICnt + 0.001), + (float)gcPtrMapNSize / (genMethodNCnt + 0.001), + (float)(gcPtrMapISize + gcPtrMapNSize) / genMethodCnt); } else { - fprintf(fout, "\n"); + jitprintf("\n"); - fprintf(fout, "GC tables take up %u bytes (%u%% of instr, %u%% of %6s code).\n", - totalNCsize - grossNCsize, 100 * (totalNCsize - grossNCsize) / grossVMsize, - 100 * (totalNCsize - grossNCsize) / grossNCsize, Target::g_tgtCPUName); + jitprintf("GC tables take up %u bytes (%u%% of instr, %u%% of %6s code).\n", totalNCsize - grossNCsize, + 100 * (totalNCsize - grossNCsize) / grossVMsize, 100 * (totalNCsize - grossNCsize) / grossNCsize, + Target::g_tgtCPUName); } #ifdef DEBUG #if DOUBLE_ALIGN - fprintf(fout, "%u out of %u methods generated with double-aligned stack\n", - Compiler::s_lvaDoubleAlignedProcsCount, genMethodCnt); + jitprintf("%u out of %u methods generated with double-aligned stack\n", Compiler::s_lvaDoubleAlignedProcsCount, + genMethodCnt); #endif #endif } @@ -1550,110 +1548,110 @@ void Compiler::compShutdown() #endif // DISPLAY_SIZES #if CALL_ARG_STATS - compDispCallArgStats(fout); + compDispCallArgStats(jitstdout()); #endif #if COUNT_BASIC_BLOCKS - fprintf(fout, "--------------------------------------------------\n"); - fprintf(fout, "Basic block count frequency table:\n"); - fprintf(fout, "--------------------------------------------------\n"); - bbCntTable.dump(fout); - fprintf(fout, "--------------------------------------------------\n"); - - fprintf(fout, "\n"); - - fprintf(fout, "--------------------------------------------------\n"); - fprintf(fout, "IL method size frequency table for methods with a single basic block:\n"); - fprintf(fout, "--------------------------------------------------\n"); - bbOneBBSizeTable.dump(fout); - fprintf(fout, "--------------------------------------------------\n"); - - fprintf(fout, "--------------------------------------------------\n"); - fprintf(fout, "fgComputeDoms `while (change)` iterations:\n"); - fprintf(fout, "--------------------------------------------------\n"); - domsChangedIterationTable.dump(fout); - fprintf(fout, "--------------------------------------------------\n"); - - fprintf(fout, "--------------------------------------------------\n"); - fprintf(fout, "fgComputeReachabilitySets `while (change)` iterations:\n"); - fprintf(fout, "--------------------------------------------------\n"); - computeReachabilitySetsIterationTable.dump(fout); - fprintf(fout, "--------------------------------------------------\n"); - - fprintf(fout, "--------------------------------------------------\n"); - fprintf(fout, "fgComputeReachability `while (change)` iterations:\n"); - fprintf(fout, "--------------------------------------------------\n"); - computeReachabilityIterationTable.dump(fout); - fprintf(fout, "--------------------------------------------------\n"); + jitprintf("--------------------------------------------------\n"); + jitprintf("Basic block count frequency table:\n"); + jitprintf("--------------------------------------------------\n"); + bbCntTable.dump(jitstdout()); + jitprintf("--------------------------------------------------\n"); + + jitprintf("\n"); + + jitprintf("--------------------------------------------------\n"); + jitprintf("IL method size frequency table for methods with a single basic block:\n"); + jitprintf("--------------------------------------------------\n"); + bbOneBBSizeTable.dump(jitstdout()); + jitprintf("--------------------------------------------------\n"); + + jitprintf("--------------------------------------------------\n"); + jitprintf("fgComputeDoms `while (change)` iterations:\n"); + jitprintf("--------------------------------------------------\n"); + domsChangedIterationTable.dump(jitstdout()); + jitprintf("--------------------------------------------------\n"); + + jitprintf("--------------------------------------------------\n"); + jitprintf("fgComputeReachabilitySets `while (change)` iterations:\n"); + jitprintf("--------------------------------------------------\n"); + computeReachabilitySetsIterationTable.dump(jitstdout()); + jitprintf("--------------------------------------------------\n"); + + jitprintf("--------------------------------------------------\n"); + jitprintf("fgComputeReachability `while (change)` iterations:\n"); + jitprintf("--------------------------------------------------\n"); + computeReachabilityIterationTable.dump(jitstdout()); + jitprintf("--------------------------------------------------\n"); #endif // COUNT_BASIC_BLOCKS #if COUNT_LOOPS - fprintf(fout, "\n"); - fprintf(fout, "---------------------------------------------------\n"); - fprintf(fout, "Loop stats\n"); - fprintf(fout, "---------------------------------------------------\n"); - fprintf(fout, "Total number of methods with loops is %5u\n", totalLoopMethods); - fprintf(fout, "Total number of loops is %5u\n", totalLoopCount); - fprintf(fout, "Maximum number of loops per method is %5u\n", maxLoopsPerMethod); - fprintf(fout, "# of methods overflowing nat loop table is %5u\n", totalLoopOverflows); - fprintf(fout, "Total number of 'unnatural' loops is %5u\n", totalUnnatLoopCount); - fprintf(fout, "# of methods overflowing unnat loop limit is %5u\n", totalUnnatLoopOverflows); - fprintf(fout, "Total number of loops with an iterator is %5u\n", iterLoopCount); - fprintf(fout, "Total number of loops with a constant iterator is %5u\n", constIterLoopCount); - - fprintf(fout, "--------------------------------------------------\n"); - fprintf(fout, "Loop count frequency table:\n"); - fprintf(fout, "--------------------------------------------------\n"); - loopCountTable.dump(fout); - fprintf(fout, "--------------------------------------------------\n"); - fprintf(fout, "Loop exit count frequency table:\n"); - fprintf(fout, "--------------------------------------------------\n"); - loopExitCountTable.dump(fout); - fprintf(fout, "--------------------------------------------------\n"); + jitprintf("\n"); + jitprintf("---------------------------------------------------\n"); + jitprintf("Loop stats\n"); + jitprintf("---------------------------------------------------\n"); + jitprintf("Total number of methods with loops is %5u\n", totalLoopMethods); + jitprintf("Total number of loops is %5u\n", totalLoopCount); + jitprintf("Maximum number of loops per method is %5u\n", maxLoopsPerMethod); + jitprintf("# of methods overflowing nat loop table is %5u\n", totalLoopOverflows); + jitprintf("Total number of 'unnatural' loops is %5u\n", totalUnnatLoopCount); + jitprintf("# of methods overflowing unnat loop limit is %5u\n", totalUnnatLoopOverflows); + jitprintf("Total number of loops with an iterator is %5u\n", iterLoopCount); + jitprintf("Total number of loops with a constant iterator is %5u\n", constIterLoopCount); + + jitprintf("--------------------------------------------------\n"); + jitprintf("Loop count frequency table:\n"); + jitprintf("--------------------------------------------------\n"); + loopCountTable.dump(jitstdout()); + jitprintf("--------------------------------------------------\n"); + jitprintf("Loop exit count frequency table:\n"); + jitprintf("--------------------------------------------------\n"); + loopExitCountTable.dump(jitstdout()); + jitprintf("--------------------------------------------------\n"); #endif // COUNT_LOOPS #if MEASURE_NODE_SIZE - fprintf(fout, "\n"); - fprintf(fout, "---------------------------------------------------\n"); - fprintf(fout, "GenTree node allocation stats\n"); - fprintf(fout, "---------------------------------------------------\n"); + jitprintf("\n"); + jitprintf("---------------------------------------------------\n"); + jitprintf("GenTree node allocation stats\n"); + jitprintf("---------------------------------------------------\n"); - fprintf(fout, "Allocated %6I64u tree nodes (%7I64u bytes total, avg %4I64u bytes per method)\n", - genNodeSizeStats.genTreeNodeCnt, genNodeSizeStats.genTreeNodeSize, - genNodeSizeStats.genTreeNodeSize / genMethodCnt); + jitprintf("Allocated %6I64u tree nodes (%7I64u bytes total, avg %4I64u bytes per method)\n", + genNodeSizeStats.genTreeNodeCnt, genNodeSizeStats.genTreeNodeSize, + genNodeSizeStats.genTreeNodeSize / genMethodCnt); - fprintf(fout, "Allocated %7I64u bytes of unused tree node space (%3.2f%%)\n", - genNodeSizeStats.genTreeNodeSize - genNodeSizeStats.genTreeNodeActualSize, - (float)(100 * (genNodeSizeStats.genTreeNodeSize - genNodeSizeStats.genTreeNodeActualSize)) / - genNodeSizeStats.genTreeNodeSize); + jitprintf("Allocated %7I64u bytes of unused tree node space (%3.2f%%)\n", + genNodeSizeStats.genTreeNodeSize - genNodeSizeStats.genTreeNodeActualSize, + (float)(100 * (genNodeSizeStats.genTreeNodeSize - genNodeSizeStats.genTreeNodeActualSize)) / + genNodeSizeStats.genTreeNodeSize); - fprintf(fout, "\n"); - fprintf(fout, "---------------------------------------------------\n"); - fprintf(fout, "Distribution of per-method GenTree node counts:\n"); - genTreeNcntHist.dump(fout); + jitprintf("\n"); + jitprintf("---------------------------------------------------\n"); + jitprintf("Distribution of per-method GenTree node counts:\n"); + genTreeNcntHist.dump(jitstdout()); - fprintf(fout, "\n"); - fprintf(fout, "---------------------------------------------------\n"); - fprintf(fout, "Distribution of per-method GenTree node allocations (in bytes):\n"); - genTreeNsizHist.dump(fout); + jitprintf("\n"); + jitprintf("---------------------------------------------------\n"); + jitprintf("Distribution of per-method GenTree node allocations (in bytes):\n"); + genTreeNsizHist.dump(jitstdout()); #endif // MEASURE_NODE_SIZE #if MEASURE_BLOCK_SIZE - fprintf(fout, "\n"); - fprintf(fout, "---------------------------------------------------\n"); - fprintf(fout, "BasicBlock and FlowEdge/BasicBlockList allocation stats\n"); - fprintf(fout, "---------------------------------------------------\n"); + jitprintf("\n"); + jitprintf("---------------------------------------------------\n"); + jitprintf("BasicBlock and FlowEdge/BasicBlockList allocation stats\n"); + jitprintf("---------------------------------------------------\n"); - fprintf(fout, "Allocated %6u basic blocks (%7u bytes total, avg %4u bytes per method)\n", BasicBlock::s_Count, - BasicBlock::s_Size, BasicBlock::s_Size / genMethodCnt); - fprintf(fout, "Allocated %6u flow nodes (%7u bytes total, avg %4u bytes per method)\n", genFlowNodeCnt, - genFlowNodeSize, genFlowNodeSize / genMethodCnt); + jitprintf("Allocated %6u basic blocks (%7u bytes total, avg %4u bytes per method)\n", BasicBlock::s_Count, + BasicBlock::s_Size, BasicBlock::s_Size / genMethodCnt); + jitprintf("Allocated %6u flow nodes (%7u bytes total, avg %4u bytes per method)\n", genFlowNodeCnt, genFlowNodeSize, + genFlowNodeSize / genMethodCnt); #endif // MEASURE_BLOCK_SIZE @@ -1661,21 +1659,21 @@ void Compiler::compShutdown() if (s_dspMemStats) { - fprintf(fout, "\nAll allocations:\n"); - ArenaAllocator::dumpAggregateMemStats(jitstdout); + jitprintf("\nAll allocations:\n"); + ArenaAllocator::dumpAggregateMemStats(jitstdout()); - fprintf(fout, "\nLargest method:\n"); - ArenaAllocator::dumpMaxMemStats(jitstdout); + jitprintf("\nLargest method:\n"); + ArenaAllocator::dumpMaxMemStats(jitstdout()); - fprintf(fout, "\n"); - fprintf(fout, "---------------------------------------------------\n"); - fprintf(fout, "Distribution of total memory allocated per method (in KB):\n"); - memAllocHist.dump(fout); + jitprintf("\n"); + jitprintf("---------------------------------------------------\n"); + jitprintf("Distribution of total memory allocated per method (in KB):\n"); + memAllocHist.dump(jitstdout()); - fprintf(fout, "\n"); - fprintf(fout, "---------------------------------------------------\n"); - fprintf(fout, "Distribution of total memory used per method (in KB):\n"); - memUsedHist.dump(fout); + jitprintf("\n"); + jitprintf("---------------------------------------------------\n"); + jitprintf("Distribution of total memory used per method (in KB):\n"); + memUsedHist.dump(jitstdout()); } #endif // MEASURE_MEM_ALLOC @@ -1685,29 +1683,29 @@ void Compiler::compShutdown() if (JitConfig.DisplayLoopHoistStats() != 0) #endif // DEBUG { - PrintAggregateLoopHoistStats(jitstdout); + PrintAggregateLoopHoistStats(jitstdout()); } #endif // LOOP_HOIST_STATS #if TRACK_ENREG_STATS if (JitConfig.JitEnregStats() != 0) { - s_enregisterStats.Dump(fout); + s_enregisterStats.Dump(jitstdout()); } #endif // TRACK_ENREG_STATS #if MEASURE_PTRTAB_SIZE - fprintf(fout, "\n"); - fprintf(fout, "---------------------------------------------------\n"); - fprintf(fout, "GC pointer table stats\n"); - fprintf(fout, "---------------------------------------------------\n"); + jitprintf("\n"); + jitprintf("---------------------------------------------------\n"); + jitprintf("GC pointer table stats\n"); + jitprintf("---------------------------------------------------\n"); - fprintf(fout, "Reg pointer descriptor size (internal): %8u (avg %4u per method)\n", GCInfo::s_gcRegPtrDscSize, - GCInfo::s_gcRegPtrDscSize / genMethodCnt); + jitprintf("Reg pointer descriptor size (internal): %8u (avg %4u per method)\n", GCInfo::s_gcRegPtrDscSize, + GCInfo::s_gcRegPtrDscSize / genMethodCnt); - fprintf(fout, "Total pointer table size: %8u (avg %4u per method)\n", GCInfo::s_gcTotalPtrTabSize, - GCInfo::s_gcTotalPtrTabSize / genMethodCnt); + jitprintf("Total pointer table size: %8u (avg %4u per method)\n", GCInfo::s_gcTotalPtrTabSize, + GCInfo::s_gcTotalPtrTabSize / genMethodCnt); #endif // MEASURE_PTRTAB_SIZE @@ -1715,37 +1713,37 @@ void Compiler::compShutdown() if (genMethodCnt != 0) { - fprintf(fout, "\n"); - fprintf(fout, "A total of %6u methods compiled", genMethodCnt); + jitprintf("\n"); + jitprintf("A total of %6u methods compiled", genMethodCnt); #if DISPLAY_SIZES if (genMethodICnt || genMethodNCnt) { - fprintf(fout, " (%u interruptible, %u non-interruptible)", genMethodICnt, genMethodNCnt); + jitprintf(" (%u interruptible, %u non-interruptible)", genMethodICnt, genMethodNCnt); } #endif // DISPLAY_SIZES - fprintf(fout, ".\n"); + jitprintf(".\n"); } #endif // MEASURE_NODE_SIZE || MEASURE_BLOCK_SIZE || MEASURE_PTRTAB_SIZE || DISPLAY_SIZES #if EMITTER_STATS - emitterStats(fout); + emitterStats(jitstdout()); #endif #if MEASURE_FATAL - fprintf(fout, "\n"); - fprintf(fout, "---------------------------------------------------\n"); - fprintf(fout, "Fatal errors stats\n"); - fprintf(fout, "---------------------------------------------------\n"); - fprintf(fout, " badCode: %u\n", fatal_badCode); - fprintf(fout, " noWay: %u\n", fatal_noWay); - fprintf(fout, " implLimitation: %u\n", fatal_implLimitation); - fprintf(fout, " NOMEM: %u\n", fatal_NOMEM); - fprintf(fout, " noWayAssertBody: %u\n", fatal_noWayAssertBody); + jitprintf("\n"); + jitprintf("---------------------------------------------------\n"); + jitprintf("Fatal errors stats\n"); + jitprintf("---------------------------------------------------\n"); + jitprintf(" badCode: %u\n", fatal_badCode); + jitprintf(" noWay: %u\n", fatal_noWay); + jitprintf(" implLimitation: %u\n", fatal_implLimitation); + jitprintf(" NOMEM: %u\n", fatal_NOMEM); + jitprintf(" noWayAssertBody: %u\n", fatal_noWayAssertBody); #ifdef DEBUG - fprintf(fout, " noWayAssertBodyArgs: %u\n", fatal_noWayAssertBodyArgs); + jitprintf(" noWayAssertBodyArgs: %u\n", fatal_noWayAssertBodyArgs); #endif // DEBUG - fprintf(fout, " NYI: %u\n", fatal_NYI); + jitprintf(" NYI: %u\n", fatal_NYI); #endif // MEASURE_FATAL } @@ -1754,14 +1752,14 @@ void Compiler::compShutdown() */ /* static */ -void Compiler::compDisplayStaticSizes(FILE* fout) +void Compiler::compDisplayStaticSizes() { #if MEASURE_NODE_SIZE - GenTree::DumpNodeSizes(fout); + GenTree::DumpNodeSizes(); #endif #if EMITTER_STATS - emitterStaticStats(fout); + emitterStaticStats(); #endif } @@ -5177,7 +5175,7 @@ void Compiler::compCompile(void** methodCodePtr, uint32_t* methodCodeSize, JitFl #if TRACK_LSRA_STATS if (JitConfig.DisplayLsraStats() == 2) { - m_pLinearScan->dumpLsraStatsCsv(jitstdout); + m_pLinearScan->dumpLsraStatsCsv(jitstdout()); } #endif // TRACK_LSRA_STATS @@ -5282,6 +5280,13 @@ PhaseStatus Compiler::placeLoopAlignInstructions() weight_t minBlockSoFar = BB_MAX_WEIGHT; BasicBlock* bbHavingAlign = nullptr; BasicBlock::loopNumber currentAlignedLoopNum = BasicBlock::NOT_IN_LOOP; + bool visitedLoopNum[BasicBlock::MAX_LOOP_NUM]; + memset(visitedLoopNum, false, sizeof(visitedLoopNum)); + +#ifdef DEBUG + unsigned visitedBlockForLoopNum[BasicBlock::MAX_LOOP_NUM]; + memset(visitedBlockForLoopNum, 0, sizeof(visitedBlockForLoopNum)); +#endif if ((fgFirstBB != nullptr) && fgFirstBB->isLoopAlign()) { @@ -5304,7 +5309,7 @@ PhaseStatus Compiler::placeLoopAlignInstructions() } } - // If there is a unconditional jump (which is not part of callf/always pair) + // If there is an unconditional jump (which is not part of callf/always pair) if (opts.compJitHideAlignBehindJmp && (block->bbJumpKind == BBJ_ALWAYS) && !block->isBBCallAlwaysPairTail()) { // Track the lower weight blocks @@ -5358,12 +5363,19 @@ PhaseStatus Compiler::placeLoopAlignInstructions() madeChanges = true; unmarkedLoopAlign = true; } - else if ((block->bbNatLoopNum != BasicBlock::NOT_IN_LOOP) && (block->bbNatLoopNum == loopTop->bbNatLoopNum)) + else if ((loopTop->bbNatLoopNum != BasicBlock::NOT_IN_LOOP) && visitedLoopNum[loopTop->bbNatLoopNum]) { +#ifdef DEBUG + char buffer[100]; + sprintf_s(buffer, 100, "loop block " FMT_BB " appears before top of loop", + visitedBlockForLoopNum[loopTop->bbNatLoopNum]); +#endif + // In some odd cases we may see blocks within the loop before we see the // top block of the loop. Just bail on aligning such loops. // - loopTop->unmarkLoopAlign(this DEBUG_ARG("loop block appears before top of loop")); + + loopTop->unmarkLoopAlign(this DEBUG_ARG(buffer)); madeChanges = true; unmarkedLoopAlign = true; } @@ -5398,6 +5410,20 @@ PhaseStatus Compiler::placeLoopAlignInstructions() break; } } + + if (block->bbNatLoopNum != BasicBlock::NOT_IN_LOOP) + { +#ifdef DEBUG + if (!visitedLoopNum[block->bbNatLoopNum]) + { + // Record the first block for which bbNatLoopNum was seen for + // debugging purpose. + visitedBlockForLoopNum[block->bbNatLoopNum] = block->bbNum; + } +#endif + // If this block is part of loop, mark the loopNum as visited. + visitedLoopNum[block->bbNatLoopNum] = true; + } } assert(loopsToProcess == 0); @@ -5879,7 +5905,7 @@ int Compiler::compCompile(CORINFO_MODULE_HANDLE classPtr, } #endif // FUNC_INFO_LOGGING - // if (s_compMethodsCount==0) setvbuf(jitstdout, NULL, _IONBF, 0); + // if (s_compMethodsCount==0) setvbuf(jitstdout(), NULL, _IONBF, 0); if (compIsForInlining()) { @@ -6363,7 +6389,7 @@ void Compiler::compCompileFinish() if (s_dspMemStats || verbose) { printf("\nAllocations for %s (MethodHash=%08x)\n", info.compFullName, info.compMethodHash()); - compArenaAllocator->dumpMemStats(jitstdout); + compArenaAllocator->dumpMemStats(jitstdout()); } #endif // DEBUG #endif // MEASURE_MEM_ALLOC diff --git a/src/coreclr/jit/compiler.h b/src/coreclr/jit/compiler.h index 645e54258a637a..8339d6d274f4f5 100644 --- a/src/coreclr/jit/compiler.h +++ b/src/coreclr/jit/compiler.h @@ -10345,7 +10345,7 @@ XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX InlineInfo* inlineInfo); void compDone(); - static void compDisplayStaticSizes(FILE* fout); + static void compDisplayStaticSizes(); //------------ Some utility functions -------------- diff --git a/src/coreclr/jit/disasm.cpp b/src/coreclr/jit/disasm.cpp index e1926a3f640b7b..fd5c98eb068810 100644 --- a/src/coreclr/jit/disasm.cpp +++ b/src/coreclr/jit/disasm.cpp @@ -1478,12 +1478,12 @@ void DisAssembler::disAsmCode(BYTE* hotCodePtr, size_t hotCodeSize, BYTE* coldCo } #else // !DEBUG // NOTE: non-DEBUG builds always use jitstdout currently! - disAsmFile = jitstdout; + disAsmFile = jitstdout(); #endif // !DEBUG if (disAsmFile == nullptr) { - disAsmFile = jitstdout; + disAsmFile = jitstdout(); } // As this writes to a common file, this is not reentrant. @@ -1519,7 +1519,7 @@ void DisAssembler::disAsmCode(BYTE* hotCodePtr, size_t hotCodeSize, BYTE* coldCo DisasmBuffer(disAsmFile, /* printIt */ true); fprintf(disAsmFile, "\n"); - if (disAsmFile != jitstdout) + if (disAsmFile != jitstdout()) { fclose(disAsmFile); } diff --git a/src/coreclr/jit/ee_il_dll.cpp b/src/coreclr/jit/ee_il_dll.cpp index bcf7c7be401c21..57c52855ea8393 100644 --- a/src/coreclr/jit/ee_il_dll.cpp +++ b/src/coreclr/jit/ee_il_dll.cpp @@ -31,8 +31,6 @@ XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX /*****************************************************************************/ -FILE* jitstdout = nullptr; - ICorJitHost* g_jitHost = nullptr; bool g_jitInitialized = false; @@ -72,15 +70,28 @@ extern "C" DLLEXPORT void jitStartup(ICorJitHost* jitHost) assert(!JitConfig.isInitialized()); JitConfig.initialize(jitHost); +#ifdef FEATURE_TRACELOGGING + JitTelemetry::NotifyDllProcessAttach(); +#endif + Compiler::compStartup(); + + g_jitInitialized = true; +} + +static FILE* volatile s_jitstdout; + +static FILE* jitstdoutInit() +{ const WCHAR* jitStdOutFile = JitConfig.JitStdOutFile(); + FILE* file = nullptr; if (jitStdOutFile != nullptr) { - jitstdout = _wfopen(jitStdOutFile, W("a")); - assert(jitstdout != nullptr); + file = _wfopen(jitStdOutFile, W("a")); + assert(file != nullptr); } #if !defined(HOST_UNIX) - if (jitstdout == nullptr) + if (file == nullptr) { int stdoutFd = _fileno(procstdout()); // Check fileno error output(s) -1 may overlap with errno result @@ -89,46 +100,61 @@ extern "C" DLLEXPORT void jitStartup(ICorJitHost* jitHost) // or bogus and avoid making further calls. if ((stdoutFd != -1) && (stdoutFd != -2) && (errno != EINVAL)) { - int jitstdoutFd = _dup(_fileno(procstdout())); + int jitstdoutFd = _dup(stdoutFd); // Check the error status returned by dup. if (jitstdoutFd != -1) { _setmode(jitstdoutFd, _O_TEXT); - jitstdout = _fdopen(jitstdoutFd, "w"); - assert(jitstdout != nullptr); + file = _fdopen(jitstdoutFd, "w"); + assert(file != nullptr); // Prevent the FILE* from buffering its output in order to avoid calls to // `fflush()` throughout the code. - setvbuf(jitstdout, nullptr, _IONBF, 0); + setvbuf(file, nullptr, _IONBF, 0); } } } #endif // !HOST_UNIX - // If jitstdout is still null, fallback to whatever procstdout() was - // initially set to. - if (jitstdout == nullptr) + if (file == nullptr) { - jitstdout = procstdout(); + file = procstdout(); } -#ifdef FEATURE_TRACELOGGING - JitTelemetry::NotifyDllProcessAttach(); -#endif - Compiler::compStartup(); + FILE* observed = InterlockedCompareExchangeT(&s_jitstdout, file, nullptr); - g_jitInitialized = true; + if (observed != nullptr) + { + if (file != procstdout()) + { + fclose(file); + } + + return observed; + } + + return file; } -#ifndef DEBUG +FILE* jitstdout() +{ + FILE* file = s_jitstdout; + if (file != nullptr) + { + return file; + } + + return jitstdoutInit(); +} + +// Like printf/logf, but only outputs to jitstdout -- skips call back into EE. void jitprintf(const char* fmt, ...) { va_list vl; va_start(vl, fmt); - vfprintf(jitstdout, fmt, vl); + vfprintf(jitstdout(), fmt, vl); va_end(vl); } -#endif void jitShutdown(bool processIsTerminating) { @@ -139,14 +165,15 @@ void jitShutdown(bool processIsTerminating) Compiler::compShutdown(); - if (jitstdout != procstdout()) + FILE* file = s_jitstdout; + if ((file != nullptr) && (file != procstdout())) { // When the process is terminating, the fclose call is unnecessary and is also prone to // crashing since the UCRT itself often frees the backing memory earlier on in the // termination sequence. if (!processIsTerminating) { - fclose(jitstdout); + fclose(file); } } diff --git a/src/coreclr/jit/emit.cpp b/src/coreclr/jit/emit.cpp index c85724f9240b98..8a671b7a757b88 100644 --- a/src/coreclr/jit/emit.cpp +++ b/src/coreclr/jit/emit.cpp @@ -215,7 +215,7 @@ unsigned emitter::emitInt32CnsCnt; unsigned emitter::emitNegCnsCnt; unsigned emitter::emitPow2CnsCnt; -void emitterStaticStats(FILE* fout) +void emitterStaticStats() { // The IG buffer size depends on whether we are storing a debug info pointer or not. For our purposes // here, do not include that. @@ -227,6 +227,8 @@ void emitterStaticStats(FILE* fout) insGroup* igDummy = nullptr; + FILE* fout = jitstdout(); + fprintf(fout, "\n"); fprintf(fout, "insGroup:\n"); fprintf(fout, "Offset / size of igNext = %3zu / %2zu\n", offsetof(insGroup, igNext), diff --git a/src/coreclr/jit/emitarm.cpp b/src/coreclr/jit/emitarm.cpp index a0dc786782b0bc..784f797bc5efed 100644 --- a/src/coreclr/jit/emitarm.cpp +++ b/src/coreclr/jit/emitarm.cpp @@ -4850,6 +4850,7 @@ void emitter::emitIns_Call(EmitCallType callType, dispIns(id); appendToCurIG(id); + emitLastMemBarrier = nullptr; // Cannot optimize away future memory barriers } /***************************************************************************** diff --git a/src/coreclr/jit/emitarm64.cpp b/src/coreclr/jit/emitarm64.cpp index b4e81322e2a696..ef1220e325e47c 100644 --- a/src/coreclr/jit/emitarm64.cpp +++ b/src/coreclr/jit/emitarm64.cpp @@ -8886,6 +8886,7 @@ void emitter::emitIns_Call(EmitCallType callType, dispIns(id); appendToCurIG(id); + emitLastMemBarrier = nullptr; // Cannot optimize away future memory barriers } /***************************************************************************** @@ -16615,6 +16616,15 @@ emitter::RegisterOrder emitter::IsOptimizableLdrStrWithPair( emitAttr prevSize = emitLastIns->idOpSize(); ssize_t prevImm = emitGetInsSC(emitLastIns); + // If we have this format, the 'imm' and/or 'prevImm' are not scaled(encoded), + // therefore we cannot proceed. + // TODO: In this context, 'imm' and 'prevImm' are assumed to be scaled(encoded). + // They should never be scaled(encoded) until its about to be written to the buffer. + if (fmt == IF_LS_2C || lastInsFmt == IF_LS_2C) + { + return eRO_none; + } + // Signed, *raw* immediate value fits in 7 bits, so for LDP/ STP the raw value is from -64 to +63. // For LDR/ STR, there are 9 bits, so we need to limit the range explicitly in software. if ((imm < -64) || (imm > 63) || (prevImm < -64) || (prevImm > 63)) diff --git a/src/coreclr/jit/emitxarch.cpp b/src/coreclr/jit/emitxarch.cpp index 980d40a47ac318..1c48d1c52f0bb2 100644 --- a/src/coreclr/jit/emitxarch.cpp +++ b/src/coreclr/jit/emitxarch.cpp @@ -5485,6 +5485,13 @@ void emitter::emitInsRMW(instruction ins, emitAttr attr, GenTreeStoreInd* storeI { assert(!src->isContained()); // there must be one non-contained src + if (addr->isContained() && addr->OperIs(GT_LCL_ADDR)) + { + GenTreeLclVarCommon* lclVar = addr->AsLclVarCommon(); + emitIns_S_R(ins, attr, src->GetRegNum(), lclVar->GetLclNum(), lclVar->GetLclOffs()); + return; + } + // ind, reg id = emitNewInstrAmd(attr, offset); emitHandleMemOp(storeInd, id, emitInsModeFormat(ins, IF_ARD_RRD), ins); diff --git a/src/coreclr/jit/error.cpp b/src/coreclr/jit/error.cpp index 01e6f734b89f89..06635f5d582a1d 100644 --- a/src/coreclr/jit/error.cpp +++ b/src/coreclr/jit/error.cpp @@ -387,7 +387,7 @@ int logf(const char* fmt, ...) { // if the EE refuses to log it, we try to send it to stdout va_start(args, fmt); - written = vflogf(jitstdout, fmt, args); + written = vflogf(jitstdout(), fmt, args); va_end(args); } #if 0 // Enable this only when you need it @@ -448,7 +448,7 @@ void gcDump_logf(const char* fmt, ...) { // if the EE refuses to log it, we try to send it to stdout va_start(args, fmt); - vflogf(jitstdout, fmt, args); + vflogf(jitstdout(), fmt, args); va_end(args); } #if 0 // Enable this only when you need it diff --git a/src/coreclr/jit/fgdiagnostic.cpp b/src/coreclr/jit/fgdiagnostic.cpp index e79bdb0e46368a..6dbc5e9a654845 100644 --- a/src/coreclr/jit/fgdiagnostic.cpp +++ b/src/coreclr/jit/fgdiagnostic.cpp @@ -674,7 +674,7 @@ FILE* Compiler::fgOpenFlowGraphFile(bool* wbDontClose, Phases phase, PhasePositi } else if (strcmp(filename, "stdout") == 0) { - fgxFile = jitstdout; + fgxFile = jitstdout(); *wbDontClose = true; } else if (strcmp(filename, "stderr") == 0) diff --git a/src/coreclr/jit/gentree.cpp b/src/coreclr/jit/gentree.cpp index d02ce4714d8a2e..7c34c51571e66d 100644 --- a/src/coreclr/jit/gentree.cpp +++ b/src/coreclr/jit/gentree.cpp @@ -481,10 +481,12 @@ void GenTree::ReportOperBashing(FILE* f) #if MEASURE_NODE_SIZE -void GenTree::DumpNodeSizes(FILE* fp) +void GenTree::DumpNodeSizes() { // Dump the sizes of the various GenTree flavors + FILE* fp = jitstdout(); + fprintf(fp, "Small tree node size = %zu bytes\n", TREE_NODE_SZ_SMALL); fprintf(fp, "Large tree node size = %zu bytes\n", TREE_NODE_SZ_LARGE); fprintf(fp, "\n"); @@ -18532,10 +18534,12 @@ CORINFO_CLASS_HANDLE Compiler::gtGetFieldClassHandle(CORINFO_FIELD_HANDLE fieldH { JITDUMP("Field's current class not available\n"); } + + return fieldClass; } } - return fieldClass; + return NO_CLASS_HANDLE; } //------------------------------------------------------------------------ @@ -19637,8 +19641,8 @@ GenTree* Compiler::gtNewSimdBinOpNode( } else { - assert(op2->TypeIs(type, simdBaseType, genActualType(simdBaseType)) || - (op2->TypeIs(TYP_SIMD12) && type == TYP_SIMD16)); + assert((genActualType(op2) == genActualType(type)) || (genActualType(op2) == genActualType(simdBaseType)) || + (op2->TypeIs(TYP_SIMD12) && (type == TYP_SIMD16))); } NamedIntrinsic intrinsic = NI_Illegal; diff --git a/src/coreclr/jit/gentree.h b/src/coreclr/jit/gentree.h index 27688c3d41790d..109da6a15c30d8 100644 --- a/src/coreclr/jit/gentree.h +++ b/src/coreclr/jit/gentree.h @@ -2311,7 +2311,7 @@ struct GenTree void SetIndirExceptionFlags(Compiler* comp); #if MEASURE_NODE_SIZE - static void DumpNodeSizes(FILE* fp); + static void DumpNodeSizes(); #endif #ifdef DEBUG diff --git a/src/coreclr/jit/host.h b/src/coreclr/jit/host.h index c99a0601e499b0..0ccefae924e637 100644 --- a/src/coreclr/jit/host.h +++ b/src/coreclr/jit/host.h @@ -3,6 +3,8 @@ /*****************************************************************************/ +void jitprintf(const char* fmt, ...); + #ifdef DEBUG #undef printf @@ -44,7 +46,6 @@ extern "C" void ANALYZER_NORETURN __cdecl assertAbort(const char* why, const cha // Re-define printf in Release to use jitstdout (can be overwritten with DOTNET_JitStdOutFile=file) #undef printf #define printf jitprintf -void jitprintf(const char* fmt, ...); #undef assert #define assert(p) (void)0 @@ -55,7 +56,7 @@ void jitprintf(const char* fmt, ...); #define _HOST_H_ /*****************************************************************************/ -extern FILE* jitstdout; +FILE* jitstdout(); inline FILE* procstdout() { diff --git a/src/coreclr/jit/hwintrinsicxarch.cpp b/src/coreclr/jit/hwintrinsicxarch.cpp index 488f65b5ac008d..065999982a87a9 100644 --- a/src/coreclr/jit/hwintrinsicxarch.cpp +++ b/src/coreclr/jit/hwintrinsicxarch.cpp @@ -3602,17 +3602,19 @@ GenTree* Compiler::impSpecialIntrinsic(NamedIntrinsic intrinsic, op2 = impSIMDPopStack(); op1 = impSIMDPopStack(); - if (unusedVal1) + // Consume operands we won't use, in case they have side effects. + // + if (unusedVal1 && !(*val1)->IsVectorZero()) { impAppendTree(gtUnusedValNode(*val1), CHECK_SPILL_ALL, impCurStmtDI); } - if (unusedVal2) + if (unusedVal2 && !(*val2)->IsVectorZero()) { impAppendTree(gtUnusedValNode(*val2), CHECK_SPILL_ALL, impCurStmtDI); } - if (unusedVal3) + if (unusedVal3 && !(*val3)->IsVectorZero()) { impAppendTree(gtUnusedValNode(*val3), CHECK_SPILL_ALL, impCurStmtDI); } diff --git a/src/coreclr/jit/inline.cpp b/src/coreclr/jit/inline.cpp index bae5755707594d..0ded3ef3482a70 100644 --- a/src/coreclr/jit/inline.cpp +++ b/src/coreclr/jit/inline.cpp @@ -480,7 +480,7 @@ void InlineContext::DumpData(unsigned indent) { const char* inlineReason = InlGetObservationString(m_Observation); printf("%*s%u,\"%s\",\"%s\",", indent, "", GetOrdinal(), inlineReason, calleeName); - m_Policy->DumpData(jitstdout); + m_Policy->DumpData(jitstdout()); printf("\n"); } diff --git a/src/coreclr/jit/layout.cpp b/src/coreclr/jit/layout.cpp index 113414ddfd7f7a..918fd4ab6521d4 100644 --- a/src/coreclr/jit/layout.cpp +++ b/src/coreclr/jit/layout.cpp @@ -421,6 +421,7 @@ void ClassLayout::InitializeGCPtrs(Compiler* compiler) // // Return value: // true if at least one GC ByRef, false otherwise. +// bool ClassLayout::HasGCByRef() const { unsigned slots = GetSlotCount(); @@ -435,6 +436,39 @@ bool ClassLayout::HasGCByRef() const return false; } +//------------------------------------------------------------------------ +// IntersectsGCPtr: check if the specified interval intersects with a GC +// pointer. +// +// Parameters: +// offset - The start offset of the interval +// size - The size of the interval +// +// Return value: +// True if it does. +// +bool ClassLayout::IntersectsGCPtr(unsigned offset, unsigned size) const +{ + if (!HasGCPtr()) + { + return false; + } + + unsigned startSlot = offset / TARGET_POINTER_SIZE; + unsigned endSlot = (offset + size - 1) / TARGET_POINTER_SIZE; + assert((startSlot < GetSlotCount()) && (endSlot < GetSlotCount())); + + for (unsigned i = startSlot; i <= endSlot; i++) + { + if (IsGCPtr(i)) + { + return true; + } + } + + return false; +} + //------------------------------------------------------------------------ // AreCompatible: check if 2 layouts are the same for copying. // diff --git a/src/coreclr/jit/layout.h b/src/coreclr/jit/layout.h index 0e9d6ed65d03d3..59ecaa9405485d 100644 --- a/src/coreclr/jit/layout.h +++ b/src/coreclr/jit/layout.h @@ -216,6 +216,8 @@ class ClassLayout } } + bool IntersectsGCPtr(unsigned offset, unsigned size) const; + static bool AreCompatible(const ClassLayout* layout1, const ClassLayout* layout2); private: diff --git a/src/coreclr/jit/lower.cpp b/src/coreclr/jit/lower.cpp index 81df694f05e1ff..2e454e64c14eb1 100644 --- a/src/coreclr/jit/lower.cpp +++ b/src/coreclr/jit/lower.cpp @@ -7484,6 +7484,28 @@ bool Lowering::CheckMultiRegLclVar(GenTreeLclVar* lclNode, int registerCount) if (registerCount == varDsc->lvFieldCnt) { canEnregisterAsMultiReg = true; + +#ifdef FEATURE_SIMD + // TYP_SIMD12 breaks the above invariant that "we won't have + // matching reg and field counts"; for example, consider + // + // * STORE_LCL_VAR(CALL) + // * RETURN(LCL_VAR) + // + // These return in two GPR registers, while the fields of the + // local are stored in SIMD and GPR register, so registerCount + // == varDsc->lvFieldCnt == 2. But the backend cannot handle + // this. + + for (int i = 0; i < varDsc->lvFieldCnt; i++) + { + if (comp->lvaGetDesc(varDsc->lvFieldLclStart + i)->TypeGet() == TYP_SIMD12) + { + canEnregisterAsMultiReg = false; + break; + } + } +#endif } } } diff --git a/src/coreclr/jit/lowerarmarch.cpp b/src/coreclr/jit/lowerarmarch.cpp index 86a5b2b1ab4cb4..2536d44aa00c56 100644 --- a/src/coreclr/jit/lowerarmarch.cpp +++ b/src/coreclr/jit/lowerarmarch.cpp @@ -1900,6 +1900,10 @@ GenTree* Lowering::LowerHWIntrinsicDot(GenTreeHWIntrinsic* node) { use.ReplaceWith(tmp2); } + else + { + tmp2->SetUnusedValue(); + } BlockRange().Remove(node); return tmp2->gtNext; diff --git a/src/coreclr/jit/lowerxarch.cpp b/src/coreclr/jit/lowerxarch.cpp index eba7bdb93e20fb..150ad04a55d99f 100644 --- a/src/coreclr/jit/lowerxarch.cpp +++ b/src/coreclr/jit/lowerxarch.cpp @@ -3754,17 +3754,14 @@ GenTree* Lowering::LowerHWIntrinsicGetElement(GenTreeHWIntrinsic* node) // We want to optimize GetElement down to an Indir where possible as // this unlocks additional containment opportunities for various nodes - var_types newType; - GenTree* newBase; - GenTree* newIndex; - uint32_t newScale; - int32_t newOffset; + GenTree* newBase; + GenTree* newIndex; + uint32_t newScale; + int32_t newOffset; GenTreeIndir* indir = op1->AsIndir(); GenTree* addr = indir->Addr(); - newType = simdBaseType; - if (addr->OperIsAddrMode()) { // We have an existing addressing mode, so we want to try and @@ -3860,7 +3857,8 @@ GenTree* Lowering::LowerHWIntrinsicGetElement(GenTreeHWIntrinsic* node) new (comp, GT_LEA) GenTreeAddrMode(addr->TypeGet(), newBase, newIndex, newScale, newOffset); BlockRange().InsertBefore(node, newAddr); - GenTreeIndir* newIndir = comp->gtNewIndir(newType, newAddr, (indir->gtFlags & GTF_IND_FLAGS)); + GenTreeIndir* newIndir = + comp->gtNewIndir(JITtype2varType(simdBaseJitType), newAddr, (indir->gtFlags & GTF_IND_FLAGS)); BlockRange().InsertBefore(node, newIndir); LIR::Use use; @@ -3868,6 +3866,10 @@ GenTree* Lowering::LowerHWIntrinsicGetElement(GenTreeHWIntrinsic* node) { use.ReplaceWith(newIndir); } + else + { + newIndir->SetUnusedValue(); + } BlockRange().Remove(op1); BlockRange().Remove(node); @@ -3907,8 +3909,8 @@ GenTree* Lowering::LowerHWIntrinsicGetElement(GenTreeHWIntrinsic* node) if (lclDsc->lvDoNotEnregister && (lclOffs <= 0xFFFF) && ((lclOffs + elemSize) <= lclDsc->lvExactSize())) { - GenTree* lclFld = - comp->gtNewLclFldNode(lclVar->GetLclNum(), simdBaseType, static_cast(lclOffs)); + GenTree* lclFld = comp->gtNewLclFldNode(lclVar->GetLclNum(), JITtype2varType(simdBaseJitType), + static_cast(lclOffs)); BlockRange().InsertBefore(node, lclFld); LIR::Use use; @@ -3916,6 +3918,10 @@ GenTree* Lowering::LowerHWIntrinsicGetElement(GenTreeHWIntrinsic* node) { use.ReplaceWith(lclFld); } + else + { + lclFld->SetUnusedValue(); + } BlockRange().Remove(op1); BlockRange().Remove(op2); @@ -4158,6 +4164,11 @@ GenTree* Lowering::LowerHWIntrinsicGetElement(GenTreeHWIntrinsic* node) { use.ReplaceWith(cast); } + else + { + node->ClearUnusedValue(); + cast->SetUnusedValue(); + } next = LowerNode(cast); } @@ -4737,6 +4748,10 @@ GenTree* Lowering::LowerHWIntrinsicDot(GenTreeHWIntrinsic* node) { use.ReplaceWith(tmp1); } + else + { + tmp1->SetUnusedValue(); + } BlockRange().Remove(node); return LowerNode(tmp1); @@ -5267,6 +5282,10 @@ GenTree* Lowering::LowerHWIntrinsicDot(GenTreeHWIntrinsic* node) { use.ReplaceWith(tmp1); } + else + { + tmp1->SetUnusedValue(); + } BlockRange().Remove(node); return tmp1->gtNext; @@ -5306,7 +5325,8 @@ GenTree* Lowering::LowerHWIntrinsicToScalar(GenTreeHWIntrinsic* node) GenTreeIndir* indir = op1->AsIndir(); - GenTreeIndir* newIndir = comp->gtNewIndir(simdBaseType, indir->Addr(), (indir->gtFlags & GTF_IND_FLAGS)); + GenTreeIndir* newIndir = + comp->gtNewIndir(JITtype2varType(simdBaseJitType), indir->Addr(), (indir->gtFlags & GTF_IND_FLAGS)); BlockRange().InsertBefore(node, newIndir); LIR::Use use; @@ -5314,6 +5334,10 @@ GenTree* Lowering::LowerHWIntrinsicToScalar(GenTreeHWIntrinsic* node) { use.ReplaceWith(newIndir); } + else + { + newIndir->SetUnusedValue(); + } BlockRange().Remove(op1); BlockRange().Remove(node); @@ -5334,7 +5358,8 @@ GenTree* Lowering::LowerHWIntrinsicToScalar(GenTreeHWIntrinsic* node) if (lclDsc->lvDoNotEnregister && (lclOffs <= 0xFFFF) && ((lclOffs + elemSize) <= lclDsc->lvExactSize())) { - GenTree* lclFld = comp->gtNewLclFldNode(lclVar->GetLclNum(), simdBaseType, lclVar->GetLclOffs()); + GenTree* lclFld = + comp->gtNewLclFldNode(lclVar->GetLclNum(), JITtype2varType(simdBaseJitType), lclVar->GetLclOffs()); BlockRange().InsertBefore(node, lclFld); LIR::Use use; @@ -5342,6 +5367,10 @@ GenTree* Lowering::LowerHWIntrinsicToScalar(GenTreeHWIntrinsic* node) { use.ReplaceWith(lclFld); } + else + { + lclFld->SetUnusedValue(); + } BlockRange().Remove(op1); BlockRange().Remove(node); @@ -5426,6 +5455,11 @@ GenTree* Lowering::LowerHWIntrinsicToScalar(GenTreeHWIntrinsic* node) { use.ReplaceWith(cast); } + else + { + node->ClearUnusedValue(); + cast->SetUnusedValue(); + } next = LowerNode(cast); } @@ -6470,7 +6504,7 @@ void Lowering::ContainCheckStoreIndir(GenTreeStoreInd* node) case NI_AVX2_ConvertToUInt32: { // These intrinsics are "ins reg/mem, xmm" - isContainable = varTypeIsIntegral(simdBaseType); + isContainable = varTypeIsIntegral(simdBaseType) && (genTypeSize(src) == genTypeSize(node)); break; } @@ -6534,7 +6568,8 @@ void Lowering::ContainCheckStoreIndir(GenTreeStoreInd* node) size_t numArgs = hwintrinsic->GetOperandCount(); GenTree* lastOp = hwintrinsic->Op(numArgs); - isContainable = HWIntrinsicInfo::isImmOp(intrinsicId, lastOp) && lastOp->IsCnsIntOrI(); + isContainable = HWIntrinsicInfo::isImmOp(intrinsicId, lastOp) && lastOp->IsCnsIntOrI() && + (genTypeSize(simdBaseType) == genTypeSize(node)); if (isContainable && (intrinsicId == NI_SSE2_Extract)) { @@ -7922,6 +7957,9 @@ bool Lowering::IsContainableHWIntrinsicOp(GenTreeHWIntrinsic* parentNode, GenTre // The memory form of this already takes a pointer and should be treated like a MemoryLoad supportsGeneralLoads = !childNode->OperIsHWIntrinsic(); } + + supportsGeneralLoads = + supportsGeneralLoads && (genTypeSize(childNode) >= genTypeSize(parentNode->GetSimdBaseType())); break; } @@ -8101,7 +8139,16 @@ bool Lowering::IsContainableHWIntrinsicOp(GenTreeHWIntrinsic* parentNode, GenTre case NI_Vector128_ToScalar: case NI_Vector256_ToScalar: case NI_Vector512_ToScalar: + case NI_SSE2_ConvertToInt32: + case NI_SSE2_ConvertToUInt32: + case NI_SSE2_X64_ConvertToInt64: + case NI_SSE2_X64_ConvertToUInt64: + case NI_SSE2_Extract: + case NI_SSE41_Extract: + case NI_SSE41_X64_Extract: case NI_AVX_ExtractVector128: + case NI_AVX2_ConvertToInt32: + case NI_AVX2_ConvertToUInt32: case NI_AVX2_ExtractVector128: case NI_AVX512F_ExtractVector128: case NI_AVX512F_ExtractVector256: @@ -8144,15 +8191,24 @@ bool Lowering::IsContainableHWIntrinsicOp(GenTreeHWIntrinsic* parentNode, GenTre return false; } + case NI_Vector128_get_Zero: + case NI_Vector256_get_Zero: + { + // These are only containable as part of Sse41.Insert + return false; + } + case NI_SSE3_MoveAndDuplicate: case NI_AVX2_BroadcastScalarToVector128: case NI_AVX2_BroadcastScalarToVector256: case NI_AVX512F_BroadcastScalarToVector512: { - var_types baseType = hwintrinsic->GetSimdBaseType(); - if (varTypeIsSmall(baseType)) + var_types parentBaseType = parentNode->GetSimdBaseType(); + var_types childBaseType = hwintrinsic->GetSimdBaseType(); + + if (varTypeIsSmall(parentBaseType) || (genTypeSize(parentBaseType) != genTypeSize(childBaseType))) { - // early return if the base type is not embedded broadcast compatible. + // early return if either base type is not embedded broadcast compatible. return false; } @@ -8160,7 +8216,7 @@ bool Lowering::IsContainableHWIntrinsicOp(GenTreeHWIntrinsic* parentNode, GenTre if (intrinsicId == NI_SSE3_MoveAndDuplicate) { // NI_SSE3_MoveAndDuplicate is for Vector128 only. - assert(baseType == TYP_DOUBLE); + assert(childBaseType == TYP_DOUBLE); } if (comp->compOpportunisticallyDependsOn(InstructionSet_AVX512F_VL) && @@ -8193,6 +8249,15 @@ bool Lowering::IsContainableHWIntrinsicOp(GenTreeHWIntrinsic* parentNode, GenTre case NI_AVX_BroadcastScalarToVector128: case NI_AVX_BroadcastScalarToVector256: { + var_types parentBaseType = parentNode->GetSimdBaseType(); + var_types childBaseType = hwintrinsic->GetSimdBaseType(); + + if (varTypeIsSmall(parentBaseType) || (genTypeSize(parentBaseType) != genTypeSize(childBaseType))) + { + // early return if either base type is not embedded broadcast compatible. + return false; + } + return parentNode->OperIsEmbBroadcastCompatible(); } @@ -8332,8 +8397,15 @@ void Lowering::TryFoldCnsVecForEmbeddedBroadcast(GenTreeHWIntrinsic* parentNode, BlockRange().InsertBefore(broadcastNode, createScalar); BlockRange().InsertBefore(createScalar, constScalar); LIR::Use use; - BlockRange().TryGetUse(childNode, &use); - use.ReplaceWith(broadcastNode); + if (BlockRange().TryGetUse(childNode, &use)) + { + use.ReplaceWith(broadcastNode); + } + else + { + broadcastNode->SetUnusedValue(); + } + BlockRange().Remove(childNode); LowerNode(createScalar); LowerNode(broadcastNode); diff --git a/src/coreclr/jit/lsra.cpp b/src/coreclr/jit/lsra.cpp index 5c8fe4aae77889..0a33ba3faba9bd 100644 --- a/src/coreclr/jit/lsra.cpp +++ b/src/coreclr/jit/lsra.cpp @@ -1421,7 +1421,7 @@ PhaseStatus LinearScan::doLinearScan() #endif ) { - dumpLsraStats(jitstdout); + dumpLsraStats(jitstdout()); } #endif // TRACK_LSRA_STATS diff --git a/src/coreclr/jit/morph.cpp b/src/coreclr/jit/morph.cpp index f29d6a57148136..3deada8eec085b 100644 --- a/src/coreclr/jit/morph.cpp +++ b/src/coreclr/jit/morph.cpp @@ -8919,6 +8919,14 @@ GenTree* Compiler::fgMorphSmpOp(GenTree* tree, MorphAddrContext* mac, bool* optA break; #endif + case GT_COMMA: + if (op2->OperIsStore() || (op2->OperGet() == GT_COMMA && op2->TypeGet() == TYP_VOID) || fgIsThrow(op2)) + { + typ = tree->gtType = TYP_VOID; + } + + break; + default: break; } @@ -10770,6 +10778,12 @@ GenTree* Compiler::fgOptimizeHWIntrinsic(GenTreeHWIntrinsic* node) break; } + // Must be working with the same types of vectors. + if (hwop1->TypeGet() != node->TypeGet()) + { + break; + } + if (toScalar != nullptr) { DEBUG_DESTROY_NODE(toScalar); @@ -10793,8 +10807,6 @@ GenTree* Compiler::fgOptimizeHWIntrinsic(GenTreeHWIntrinsic* node) } #if defined(TARGET_XARCH) - case NI_AVX512F_Add: - case NI_AVX512BW_Add: case NI_AVX512F_And: case NI_AVX512DQ_And: case NI_AVX512F_AndNot: @@ -10836,13 +10848,6 @@ GenTree* Compiler::fgOptimizeHWIntrinsic(GenTreeHWIntrinsic* node) switch (intrinsicId) { - case NI_AVX512F_Add: - case NI_AVX512BW_Add: - { - maskIntrinsicId = NI_AVX512F_AddMask; - break; - } - case NI_AVX512F_And: case NI_AVX512DQ_And: { diff --git a/src/coreclr/jit/optimizer.cpp b/src/coreclr/jit/optimizer.cpp index 684f71b898f313..2ca5526476b82d 100644 --- a/src/coreclr/jit/optimizer.cpp +++ b/src/coreclr/jit/optimizer.cpp @@ -8875,9 +8875,15 @@ GenTree* Compiler::optRemoveRangeCheck(GenTreeBoundsChk* check, GenTree* comma, } #endif - // Extract side effects + // TODO-Bug: We really should be extracting all side effects from the + // length and index here, but the length typically involves a GT_ARR_LENGTH + // that we would preserve. Usually, as part of proving that the range check + // passes, we have also proven that the ARR_LENGTH is non-faulting. We need + // a good way to communicate to this function that it is ok to ignore side + // effects of the ARR_LENGTH. GenTree* sideEffList = nullptr; - gtExtractSideEffList(check, &sideEffList, GTF_ASG); + gtExtractSideEffList(check->GetArrayLength(), &sideEffList, GTF_ASG); + gtExtractSideEffList(check->GetIndex(), &sideEffList); if (sideEffList != nullptr) { @@ -9031,9 +9037,9 @@ void Compiler::optRemoveRedundantZeroInits() CompAllocator allocator(getAllocator(CMK_ZeroInit)); LclVarRefCounts refCounts(allocator); BitVecTraits bitVecTraits(lvaCount, this); - BitVec zeroInitLocals = BitVecOps::MakeEmpty(&bitVecTraits); - bool hasGCSafePoint = false; - bool canThrow = false; + BitVec zeroInitLocals = BitVecOps::MakeEmpty(&bitVecTraits); + bool hasGCSafePoint = false; + bool hasImplicitControlFlow = false; assert(fgNodeThreading == NodeThreading::AllTrees); @@ -9044,6 +9050,8 @@ void Compiler::optRemoveRedundantZeroInits() CompAllocator allocator(getAllocator(CMK_ZeroInit)); LclVarRefCounts defsInBlock(allocator); bool removedTrackedDefs = false; + bool hasEHSuccs = block->HasPotentialEHSuccs(this); + for (Statement* stmt = block->FirstNonPhiDef(); stmt != nullptr;) { Statement* next = stmt->GetNextStmt(); @@ -9054,10 +9062,7 @@ void Compiler::optRemoveRedundantZeroInits() hasGCSafePoint = true; } - if ((tree->gtFlags & GTF_EXCEPT) != 0) - { - canThrow = true; - } + hasImplicitControlFlow |= hasEHSuccs && ((tree->gtFlags & GTF_EXCEPT) != 0); switch (tree->gtOper) { @@ -9203,7 +9208,8 @@ void Compiler::optRemoveRedundantZeroInits() } } - if (!removedExplicitZeroInit && isEntire && (!canThrow || !lclDsc->lvLiveInOutOfHndlr)) + if (!removedExplicitZeroInit && isEntire && + (!hasImplicitControlFlow || (lclDsc->lvTracked && !lclDsc->lvLiveInOutOfHndlr))) { // If compMethodRequiresPInvokeFrame() returns true, lower may later // insert a call to CORINFO_HELP_INIT_PINVOKE_FRAME which is a gc-safe point. diff --git a/src/coreclr/jit/promotion.cpp b/src/coreclr/jit/promotion.cpp index e2c4e797a3c9d8..52163f4db0cceb 100644 --- a/src/coreclr/jit/promotion.cpp +++ b/src/coreclr/jit/promotion.cpp @@ -621,6 +621,38 @@ class LocalUses bool EvaluateReplacement( Compiler* comp, unsigned lclNum, const Access& access, unsigned inducedCount, weight_t inducedCountWtd) { + // Verify that this replacement has proper GC ness compared to the + // layout. While reinterpreting GC fields to integers can be considered + // UB, there are scenarios where it can happen safely: + // + // * The user code could have guarded the access with a dynamic check + // that it doesn't contain a GC pointer, so that the access is actually + // in dead code. This happens e.g. in span functions in SPC. + // + // * For byrefs, reinterpreting as an integer could be ok in a + // restricted scope due to pinning. + // + // In theory we could allow these promotions in the restricted scope, + // but currently physical promotion works on a function-wide basis. + + LclVarDsc* lcl = comp->lvaGetDesc(lclNum); + ClassLayout* layout = lcl->GetLayout(); + if (layout->IntersectsGCPtr(access.Offset, genTypeSize(access.AccessType))) + { + if (((access.Offset % TARGET_POINTER_SIZE) != 0) || + (layout->GetGCPtrType(access.Offset / TARGET_POINTER_SIZE) != access.AccessType)) + { + return false; + } + } + else + { + if (varTypeIsGC(access.AccessType)) + { + return false; + } + } + unsigned countOverlappedCallArg = 0; unsigned countOverlappedStoredFromCall = 0; @@ -678,9 +710,8 @@ class LocalUses // Now look at the overlapping struct uses that promotion will make more expensive. - unsigned countReadBacks = 0; - weight_t countReadBacksWtd = 0; - LclVarDsc* lcl = comp->lvaGetDesc(lclNum); + unsigned countReadBacks = 0; + weight_t countReadBacksWtd = 0; // For parameters or OSR locals we always need one read back. if (lcl->lvIsParam || lcl->lvIsOSRLocal) { @@ -2309,8 +2340,36 @@ void ReplaceVisitor::ReadBackAfterCall(GenTreeCall* call, GenTree* user) // // If the remainder of the struct local is dying, then we expect that this // entire struct local is now dying, since all field accesses are going to be -// replaced with other locals. The exception is if there is a queued read -// back for any of the fields. +// replaced with other locals. +// +// There are two exceptions to the above: +// +// 1) If there is a queued readback for any of the fields, then there is +// live state in the struct local, so it is not dying. +// +// 2) If there are further uses of the local in the same statement then we cannot +// actually act on the last-use information we would provide here. That's because +// uses of locals occur at the user and we do not model that here. In the real model +// there are cases where we do not have any place to insert any IR between the two uses. +// For example, consider: +// +// ▌ CALL void Program:Foo(Program+S,Program+S) +// ├──▌ LCL_VAR struct V01 loc0 +// └──▌ LCL_VAR struct V01 loc0 +// +// If V01 is promoted fully then both uses of V01 are last uses here; but +// replacing the IR with +// +// ▌ CALL void Program:Foo(Program+S,Program+S) +// ├──▌ LCL_VAR struct V01 loc0 (last use) +// └──▌ COMMA struct +// ├──▌ STORE_LCL_FLD int V01 loc0 [+0] +// │ └──▌ LCL_VAR int V02 tmp0 +// └──▌ LCL_VAR struct V01 loc0 (last use) +// +// would be illegal since the created store overlaps with the first local, +// and does not take into account that both uses occur simultaneously at +// the position of the CALL node. // bool ReplaceVisitor::IsPromotedStructLocalDying(GenTreeLclVarCommon* lcl) { @@ -2331,6 +2390,15 @@ bool ReplaceVisitor::IsPromotedStructLocalDying(GenTreeLclVarCommon* lcl) } } + for (GenTree* cur = lcl->gtNext; cur != nullptr; cur = cur->gtNext) + { + assert(cur->OperIsAnyLocal()); + if (cur->TypeIs(TYP_STRUCT) && (cur->AsLclVarCommon()->GetLclNum() == lcl->GetLclNum())) + { + return false; + } + } + return true; } @@ -2546,7 +2614,7 @@ void ReplaceVisitor::WriteBackBeforeCurrentStatement(unsigned lcl, unsigned offs GenTree* readBack = Promotion::CreateWriteBack(m_compiler, lcl, rep); Statement* stmt = m_compiler->fgNewStmtFromTree(readBack); - JITDUMP("Writing back %s before " FMT_STMT "\n", rep.Description, stmt->GetID()); + JITDUMP("Writing back %s before " FMT_STMT "\n", rep.Description, m_currentStmt->GetID()); DISPSTMT(stmt); m_compiler->fgInsertStmtBefore(m_currentBlock, m_currentStmt, stmt); ClearNeedsWriteBack(rep); diff --git a/src/coreclr/jit/utils.cpp b/src/coreclr/jit/utils.cpp index 3ec8e71c159e46..2e1c0a52a3d81b 100644 --- a/src/coreclr/jit/utils.cpp +++ b/src/coreclr/jit/utils.cpp @@ -2734,7 +2734,7 @@ float FloatingPointUtils::maximumNumber(float x, float y) // // It propagates NaN inputs back to the caller and // otherwise returns the lesser of the inputs. It -// treats +0 as lesser than -0 as per the specification. +// treats +0 as greater than -0 as per the specification. // // Arguments: // val1 - left operand @@ -2763,7 +2763,7 @@ double FloatingPointUtils::minimum(double val1, double val2) // // It propagates NaN inputs back to the caller and // otherwise returns the input with a lesser magnitude. -// It treats +0 as lesser than -0 as per the specification. +// It treats +0 as greater than -0 as per the specification. // // Arguments: // x - left operand @@ -2856,7 +2856,7 @@ double FloatingPointUtils::minimumNumber(double x, double y) // // It propagates NaN inputs back to the caller and // otherwise returns the lesser of the inputs. It -// treats +0 as lesser than -0 as per the specification. +// treats +0 as greater than -0 as per the specification. // // Arguments: // val1 - left operand @@ -2885,7 +2885,7 @@ float FloatingPointUtils::minimum(float val1, float val2) // // It propagates NaN inputs back to the caller and // otherwise returns the input with a lesser magnitude. -// It treats +0 as lesser than -0 as per the specification. +// It treats +0 as greater than -0 as per the specification. // // Arguments: // x - left operand diff --git a/src/coreclr/jit/valuenum.cpp b/src/coreclr/jit/valuenum.cpp index 2ea2ee7845b8b2..6943c3c5e07e26 100644 --- a/src/coreclr/jit/valuenum.cpp +++ b/src/coreclr/jit/valuenum.cpp @@ -2860,6 +2860,81 @@ ValueNum ValueNumStore::VNForMapPhysicalSelect( return result; } +typedef JitHashTable, bool> ValueNumSet; + +class SmallValueNumSet +{ + union { + ValueNum m_inlineElements[4]; + ValueNumSet* m_set; + }; + unsigned m_numElements = 0; + +public: + unsigned Count() + { + return m_numElements; + } + + template + void ForEach(Func func) + { + if (m_numElements <= ArrLen(m_inlineElements)) + { + for (unsigned i = 0; i < m_numElements; i++) + { + func(m_inlineElements[i]); + } + } + else + { + for (ValueNum vn : ValueNumSet::KeyIteration(m_set)) + { + func(vn); + } + } + } + + void Add(Compiler* comp, ValueNum vn) + { + if (m_numElements <= ArrLen(m_inlineElements)) + { + for (unsigned i = 0; i < m_numElements; i++) + { + if (m_inlineElements[i] == vn) + { + return; + } + } + + if (m_numElements < ArrLen(m_inlineElements)) + { + m_inlineElements[m_numElements] = vn; + m_numElements++; + } + else + { + ValueNumSet* set = new (comp, CMK_ValueNumber) ValueNumSet(comp->getAllocator(CMK_ValueNumber)); + for (ValueNum oldVn : m_inlineElements) + { + set->Set(oldVn, true); + } + + set->Set(vn, true); + + m_set = set; + m_numElements++; + assert(m_numElements == set->GetCount()); + } + } + else + { + m_set->Set(vn, true, ValueNumSet::SetKind::Overwrite); + m_numElements = m_set->GetCount(); + } + } +}; + //------------------------------------------------------------------------------ // VNForMapSelectInner: Select value from a map and record loop memory dependencies. // @@ -2874,10 +2949,10 @@ ValueNum ValueNumStore::VNForMapPhysicalSelect( // ValueNum ValueNumStore::VNForMapSelectInner(ValueNumKind vnk, var_types type, ValueNum map, ValueNum index) { - int budget = m_mapSelectBudget; - bool usedRecursiveVN = false; - ArrayStack memoryDependencies(m_alloc); - ValueNum result = VNForMapSelectWork(vnk, type, map, index, &budget, &usedRecursiveVN, &memoryDependencies); + int budget = m_mapSelectBudget; + bool usedRecursiveVN = false; + SmallValueNumSet memoryDependencies; + ValueNum result = VNForMapSelectWork(vnk, type, map, index, &budget, &usedRecursiveVN, memoryDependencies); // The remaining budget should always be between [0..m_mapSelectBudget] assert((budget >= 0) && (budget <= m_mapSelectBudget)); @@ -2888,11 +2963,9 @@ ValueNum ValueNumStore::VNForMapSelectInner(ValueNumKind vnk, var_types type, Va if ((m_pComp->compCurBB != nullptr) && (m_pComp->compCurTree != nullptr) && m_pComp->compCurBB->bbNatLoopNum != BasicBlock::NOT_IN_LOOP) { - for (int i = 0; i < memoryDependencies.Height(); i++) - { - m_pComp->optRecordLoopMemoryDependence(m_pComp->compCurTree, m_pComp->compCurBB, - memoryDependencies.Bottom(i)); - } + memoryDependencies.ForEach([this](ValueNum vn) { + m_pComp->optRecordLoopMemoryDependence(m_pComp->compCurTree, m_pComp->compCurBB, vn); + }); } return result; @@ -2903,19 +2976,16 @@ ValueNum ValueNumStore::VNForMapSelectInner(ValueNumKind vnk, var_types type, Va // cache entry. // // Arguments: -// alloc - Allocator to use if memory is required. -// deps - Array stack containing the memory dependencies. -// startIndex - Start index into 'deps' of memory dependencies. +// comp - Compiler instance +// set - Set of memory dependencies to store in the entry. // -void ValueNumStore::MapSelectWorkCacheEntry::SetMemoryDependencies(CompAllocator alloc, - ArrayStack& deps, - unsigned startIndex) +void ValueNumStore::MapSelectWorkCacheEntry::SetMemoryDependencies(Compiler* comp, SmallValueNumSet& set) { - m_numMemoryDependencies = deps.Height() - startIndex; + m_numMemoryDependencies = set.Count(); ValueNum* arr; if (m_numMemoryDependencies > ArrLen(m_inlineMemoryDependencies)) { - m_memoryDependencies = new (alloc) ValueNum[m_numMemoryDependencies]; + m_memoryDependencies = new (comp, CMK_ValueNumber) ValueNum[m_numMemoryDependencies]; arr = m_memoryDependencies; } @@ -2924,27 +2994,29 @@ void ValueNumStore::MapSelectWorkCacheEntry::SetMemoryDependencies(CompAllocator arr = m_inlineMemoryDependencies; } - for (unsigned i = 0; i < m_numMemoryDependencies; i++) - { - arr[i] = deps.Bottom(startIndex + i); - } + size_t i = 0; + set.ForEach([&i, arr](ValueNum vn) { + arr[i] = vn; + i++; + }); } //------------------------------------------------------------------------------ // GetMemoryDependencies: Push all of the memory dependencies cached in this -// entry into the specified array stack. +// entry into the specified set. // // Arguments: -// result - Array stack to push memory dependencies into. +// comp - Compiler instance +// result - Set to add memory dependencies to. // -void ValueNumStore::MapSelectWorkCacheEntry::GetMemoryDependencies(ArrayStack& result) +void ValueNumStore::MapSelectWorkCacheEntry::GetMemoryDependencies(Compiler* comp, SmallValueNumSet& result) { ValueNum* arr = m_numMemoryDependencies <= ArrLen(m_inlineMemoryDependencies) ? m_inlineMemoryDependencies : m_memoryDependencies; for (unsigned i = 0; i < m_numMemoryDependencies; i++) { - result.Push(arr[i]); + result.Add(comp, arr[i]); } } @@ -2959,7 +3031,7 @@ void ValueNumStore::MapSelectWorkCacheEntry::GetMemoryDependencies(ArrayStack* memoryDependencies) +ValueNum ValueNumStore::VNForMapSelectWork(ValueNumKind vnk, + var_types type, + ValueNum map, + ValueNum index, + int* pBudget, + bool* pUsedRecursiveVN, + SmallValueNumSet& memoryDependencies) { TailCall: // This label allows us to directly implement a tail call by setting up the arguments, and doing a goto to here. @@ -2997,13 +3069,12 @@ ValueNum ValueNumStore::VNForMapSelectWork(ValueNumKind vnk, assert(selLim == 0 || m_numMapSels < selLim); #endif - int firstMemoryDependency = memoryDependencies->Height(); MapSelectWorkCacheEntry entry; VNDefFuncApp<2> fstruct(VNF_MapSelect, map, index); if (GetMapSelectWorkCache()->Lookup(fstruct, &entry)) { - entry.GetMemoryDependencies(*memoryDependencies); + entry.GetMemoryDependencies(m_pComp, memoryDependencies); return entry.Result; } @@ -3029,6 +3100,8 @@ ValueNum ValueNumStore::VNForMapSelectWork(ValueNumKind vnk, return RecursiveVN; } + SmallValueNumSet recMemoryDependencies; + VNFuncApp funcApp; if (GetVNFunc(map, &funcApp)) { @@ -3047,7 +3120,7 @@ ValueNum ValueNumStore::VNForMapSelectWork(ValueNumKind vnk, funcApp.m_args[0], map, funcApp.m_args[1], funcApp.m_args[2], index, funcApp.m_args[2]); #endif - memoryDependencies->Push(funcApp.m_args[0]); + memoryDependencies.Add(m_pComp, funcApp.m_args[0]); return funcApp.m_args[2]; } @@ -3191,7 +3264,7 @@ ValueNum ValueNumStore::VNForMapSelectWork(ValueNumKind vnk, bool allSame = true; ValueNum argRest = phiFuncApp.m_args[1]; ValueNum sameSelResult = VNForMapSelectWork(vnk, type, phiArgVN, index, pBudget, - pUsedRecursiveVN, memoryDependencies); + pUsedRecursiveVN, recMemoryDependencies); // It is possible that we just now exceeded our budget, if so we need to force an early exit // and stop calling VNForMapSelectWork @@ -3233,7 +3306,7 @@ ValueNum ValueNumStore::VNForMapSelectWork(ValueNumKind vnk, { bool usedRecursiveVN = false; ValueNum curResult = VNForMapSelectWork(vnk, type, phiArgVN, index, pBudget, - &usedRecursiveVN, memoryDependencies); + &usedRecursiveVN, recMemoryDependencies); *pUsedRecursiveVN |= usedRecursiveVN; if (sameSelResult == ValueNumStore::RecursiveVN) @@ -3261,11 +3334,14 @@ ValueNum ValueNumStore::VNForMapSelectWork(ValueNumKind vnk, if (!*pUsedRecursiveVN) { entry.Result = sameSelResult; - entry.SetMemoryDependencies(m_alloc, *memoryDependencies, firstMemoryDependency); + entry.SetMemoryDependencies(m_pComp, recMemoryDependencies); GetMapSelectWorkCache()->Set(fstruct, entry); } + recMemoryDependencies.ForEach( + [this, &memoryDependencies](ValueNum vn) { memoryDependencies.Add(m_pComp, vn); }); + return sameSelResult; } // Otherwise, fall through to creating the select(phi(m1, m2), x) function application. @@ -3294,11 +3370,13 @@ ValueNum ValueNumStore::VNForMapSelectWork(ValueNumKind vnk, fapp->m_args[1] = fstruct.m_args[1]; entry.Result = c->m_baseVN + offsetWithinChunk; - entry.SetMemoryDependencies(m_alloc, *memoryDependencies, firstMemoryDependency); + entry.SetMemoryDependencies(m_pComp, recMemoryDependencies); GetMapSelectWorkCache()->Set(fstruct, entry); } + recMemoryDependencies.ForEach([this, &memoryDependencies](ValueNum vn) { memoryDependencies.Add(m_pComp, vn); }); + return entry.Result; } @@ -7891,7 +7969,7 @@ ValueNum ValueNumStore::EvalHWIntrinsicFunBinary(var_types type, #endif { // Handle `x ^ x == 0` - return arg0VN; + return VNZeroForType(type); } default: diff --git a/src/coreclr/jit/valuenum.h b/src/coreclr/jit/valuenum.h index 8417579ea2ff7a..04fed7bfbc1f67 100644 --- a/src/coreclr/jit/valuenum.h +++ b/src/coreclr/jit/valuenum.h @@ -684,13 +684,13 @@ class ValueNumStore ValueNum VNForMapSelectInner(ValueNumKind vnk, var_types type, ValueNum map, ValueNum index); // A method that does the work for VNForMapSelect and may call itself recursively. - ValueNum VNForMapSelectWork(ValueNumKind vnk, - var_types type, - ValueNum map, - ValueNum index, - int* pBudget, - bool* pUsedRecursiveVN, - ArrayStack* loopMemoryDependencies); + ValueNum VNForMapSelectWork(ValueNumKind vnk, + var_types type, + ValueNum map, + ValueNum index, + int* pBudget, + bool* pUsedRecursiveVN, + class SmallValueNumSet& loopMemoryDependencies); // A specialized version of VNForFunc that is used for VNF_MapStore and provides some logging when verbose is set ValueNum VNForMapStore(ValueNum map, ValueNum index, ValueNum value); @@ -1821,8 +1821,8 @@ class ValueNumStore public: ValueNum Result; - void SetMemoryDependencies(CompAllocator alloc, ArrayStack& deps, unsigned startIndex); - void GetMemoryDependencies(ArrayStack& deps); + void SetMemoryDependencies(Compiler* comp, class SmallValueNumSet& deps); + void GetMemoryDependencies(Compiler* comp, class SmallValueNumSet& deps); }; typedef JitHashTable, VNDefFuncAppKeyFuncs<2>, MapSelectWorkCacheEntry> MapSelectWorkCache; diff --git a/src/coreclr/md/compiler/mdutil.cpp b/src/coreclr/md/compiler/mdutil.cpp index 8fb1551ef7ceea..05b56a25875bb0 100644 --- a/src/coreclr/md/compiler/mdutil.cpp +++ b/src/coreclr/md/compiler/mdutil.cpp @@ -265,11 +265,7 @@ HRESULT LOADEDMODULES::FindCachedReadOnlyEntry( { // If the name matches... LPCWSTR pszName = pRegMeta->GetNameOfDBFile(); - #ifdef FEATURE_CASE_SENSITIVE_FILESYSTEM - if (u16_strcmp(szName, pszName) == 0) - #else if (SString::_wcsicmp(szName, pszName) == 0) - #endif { ULONG cRefs; @@ -299,11 +295,7 @@ HRESULT LOADEDMODULES::FindCachedReadOnlyEntry( { // If the name matches... LPCWSTR pszName = pRegMeta->GetNameOfDBFile(); - #ifdef FEATURE_CASE_SENSITIVE_FILESYSTEM - if (u16_strcmp(szName, pszName) == 0) - #else if (SString::_wcsicmp(szName, pszName) == 0) - #endif { ULONG cRefs; diff --git a/src/coreclr/nativeaot/Bootstrap/main.cpp b/src/coreclr/nativeaot/Bootstrap/main.cpp index cc78cf8d6710a9..c2ff85b50e81fd 100644 --- a/src/coreclr/nativeaot/Bootstrap/main.cpp +++ b/src/coreclr/nativeaot/Bootstrap/main.cpp @@ -93,7 +93,7 @@ static char& __unbox_z = __stop___unbox; #endif // _MSC_VER -extern "C" bool RhInitialize(); +extern "C" bool RhInitialize(bool isDll); extern "C" void RhSetRuntimeInitializationCallback(int (*fPtr)()); extern "C" bool RhRegisterOSModule(void * pModule, @@ -164,7 +164,13 @@ extern "C" void __managed__Startup(); static int InitializeRuntime() { - if (!RhInitialize()) + if (!RhInitialize( +#ifdef NATIVEAOT_DLL + /* isDll */ true +#else + /* isDll */ false +#endif + )) return -1; void * osModule = PalGetModuleHandleFromPointer((void*)&NATIVEAOT_ENTRYPOINT); diff --git a/src/coreclr/nativeaot/BuildIntegration/Microsoft.NETCore.Native.Publish.targets b/src/coreclr/nativeaot/BuildIntegration/Microsoft.NETCore.Native.Publish.targets index 647aee4993d960..da6c90642f6f13 100644 --- a/src/coreclr/nativeaot/BuildIntegration/Microsoft.NETCore.Native.Publish.targets +++ b/src/coreclr/nativeaot/BuildIntegration/Microsoft.NETCore.Native.Publish.targets @@ -50,6 +50,7 @@ Text="RuntimeIdentifier is required for native compilation. Try running dotnet publish with the -r option value specified." /> + @@ -94,7 +95,10 @@ + + diff --git a/src/coreclr/nativeaot/BuildIntegration/Microsoft.NETCore.Native.Unix.targets b/src/coreclr/nativeaot/BuildIntegration/Microsoft.NETCore.Native.Unix.targets index 409fcb654e919d..1f5b2cd681095c 100644 --- a/src/coreclr/nativeaot/BuildIntegration/Microsoft.NETCore.Native.Unix.targets +++ b/src/coreclr/nativeaot/BuildIntegration/Microsoft.NETCore.Native.Unix.targets @@ -50,6 +50,8 @@ The .NET Foundation licenses this file to you under the MIT license. libeventpipe-disabled libeventpipe-enabled + + true @@ -121,7 +123,7 @@ The .NET Foundation licenses this file to you under the MIT license. - + diff --git a/src/coreclr/nativeaot/BuildIntegration/Microsoft.NETCore.Native.targets b/src/coreclr/nativeaot/BuildIntegration/Microsoft.NETCore.Native.targets index a4f34ef2225483..e9462399741c5e 100644 --- a/src/coreclr/nativeaot/BuildIntegration/Microsoft.NETCore.Native.targets +++ b/src/coreclr/nativeaot/BuildIntegration/Microsoft.NETCore.Native.targets @@ -263,7 +263,7 @@ The .NET Foundation licenses this file to you under the MIT license. - + diff --git a/src/coreclr/nativeaot/Directory.Build.props b/src/coreclr/nativeaot/Directory.Build.props index ebfa725e4efd2c..005d6ae997adab 100644 --- a/src/coreclr/nativeaot/Directory.Build.props +++ b/src/coreclr/nativeaot/Directory.Build.props @@ -25,6 +25,9 @@ false v4.0.30319 + + $(ProductVersion) + $(ProductVersion) $(NoWarn),0419,0649,CA2249,CA1830 diff --git a/src/coreclr/nativeaot/Runtime.Base/src/System/Runtime/InternalCalls.cs b/src/coreclr/nativeaot/Runtime.Base/src/System/Runtime/InternalCalls.cs index 3c9d6c86ffc323..5c11243bbad99b 100644 --- a/src/coreclr/nativeaot/Runtime.Base/src/System/Runtime/InternalCalls.cs +++ b/src/coreclr/nativeaot/Runtime.Base/src/System/Runtime/InternalCalls.cs @@ -62,12 +62,12 @@ internal static class InternalCalls [RuntimeExport("RhCollect")] internal static void RhCollect(int generation, InternalGCCollectionMode mode, bool lowMemoryP = false) { - RhpCollect(generation, mode, lowMemoryP); + RhpCollect(generation, mode, lowMemoryP ? Interop.BOOL.TRUE : Interop.BOOL.FALSE); } [DllImport(Redhawk.BaseName)] [UnmanagedCallConv(CallConvs = new Type[] { typeof(CallConvCdecl) })] - private static extern void RhpCollect(int generation, InternalGCCollectionMode mode, bool lowMemoryP); + private static extern void RhpCollect(int generation, InternalGCCollectionMode mode, Interop.BOOL lowMemoryP); [RuntimeExport("RhGetGcTotalMemory")] internal static long RhGetGcTotalMemory() diff --git a/src/coreclr/nativeaot/Runtime/Full/CMakeLists.txt b/src/coreclr/nativeaot/Runtime/Full/CMakeLists.txt index 3cbaa6e2f253a6..f3d48797c2184a 100644 --- a/src/coreclr/nativeaot/Runtime/Full/CMakeLists.txt +++ b/src/coreclr/nativeaot/Runtime/Full/CMakeLists.txt @@ -6,7 +6,7 @@ project(Runtime) # Include auto-generated files on include path set(CMAKE_INCLUDE_CURRENT_DIR ON) -if (CLR_CMAKE_TARGET_APPLE AND NOT CLR_CMAKE_TARGET_OSX) +if (CLR_CMAKE_TARGET_APPLE) list(APPEND RUNTIME_SOURCES_ARCH_ASM ${ARCH_SOURCES_DIR}/ThunkPoolThunks.${ASM_SUFFIX} ) diff --git a/src/coreclr/nativeaot/Runtime/GCMemoryHelpers.cpp b/src/coreclr/nativeaot/Runtime/GCMemoryHelpers.cpp index 27126acbdb839f..30f2c5c5fd3e9a 100644 --- a/src/coreclr/nativeaot/Runtime/GCMemoryHelpers.cpp +++ b/src/coreclr/nativeaot/Runtime/GCMemoryHelpers.cpp @@ -10,7 +10,6 @@ #include "PalRedhawkCommon.h" #include "CommonMacros.inl" -#include "GCMemoryHelpers.h" #include "GCMemoryHelpers.inl" // This function clears a piece of memory in a GC safe way. @@ -31,11 +30,26 @@ COOP_PINVOKE_CDECL_HELPER(void *, RhpGcSafeZeroMemory, (void * mem, size_t size) return mem; } +#if defined(TARGET_X86) || defined(TARGET_AMD64) + // + // Memory writes are already ordered + // + #define GCHeapMemoryBarrier() +#else + #define GCHeapMemoryBarrier() MemoryBarrier() +#endif + // Move memory, in a way that is compatible with a move onto the heap, but // does not require the destination pointer to be on the heap. COOP_PINVOKE_HELPER(void, RhBulkMoveWithWriteBarrier, (uint8_t* pDest, uint8_t* pSrc, size_t cbDest)) { + // It is possible that the bulk write is publishing object references accessible so far only + // by the current thread to shared memory. + // The memory model requires that writes performed by current thread are observable no later + // than the writes that will actually publish the references. + GCHeapMemoryBarrier(); + if (pDest <= pSrc || pSrc + cbDest <= pDest) InlineForwardGCSafeCopy(pDest, pSrc, cbDest); else @@ -43,8 +57,3 @@ COOP_PINVOKE_HELPER(void, RhBulkMoveWithWriteBarrier, (uint8_t* pDest, uint8_t* InlinedBulkWriteBarrier(pDest, cbDest); } - -void REDHAWK_CALLCONV RhpBulkWriteBarrier(void* pMemStart, uint32_t cbMemSize) -{ - InlinedBulkWriteBarrier(pMemStart, cbMemSize); -} diff --git a/src/coreclr/nativeaot/Runtime/MiscHelpers.cpp b/src/coreclr/nativeaot/Runtime/MiscHelpers.cpp index ec2fabcc540f1f..6df37cf23b9d36 100644 --- a/src/coreclr/nativeaot/Runtime/MiscHelpers.cpp +++ b/src/coreclr/nativeaot/Runtime/MiscHelpers.cpp @@ -35,7 +35,6 @@ #include "MethodTable.inl" #include "CommonMacros.inl" #include "volatile.h" -#include "GCMemoryHelpers.h" #include "GCMemoryHelpers.inl" #include "yieldprocessornormalized.h" #include "RhConfig.h" diff --git a/src/coreclr/nativeaot/Runtime/arm64/AllocFast.S b/src/coreclr/nativeaot/Runtime/arm64/AllocFast.S index 2bab323e65abca..79ffed2b05210d 100644 --- a/src/coreclr/nativeaot/Runtime/arm64/AllocFast.S +++ b/src/coreclr/nativeaot/Runtime/arm64/AllocFast.S @@ -46,7 +46,7 @@ OFFSETOF__Thread__m_alloc_context__alloc_limit = OFFSETOF__Thread__m_rgbAll add x2, x2, x12 ldr x13, [x1, #OFFSETOF__Thread__m_alloc_context__alloc_limit] cmp x2, x13 - bhi RhpNewFast_RarePath + bhi LOCAL_LABEL(RhpNewFast_RarePath) // Update the alloc pointer to account for the allocation. str x2, [x1, #OFFSETOF__Thread__m_alloc_context__alloc_ptr] @@ -57,7 +57,7 @@ OFFSETOF__Thread__m_alloc_context__alloc_limit = OFFSETOF__Thread__m_rgbAll mov x0, x12 ret -RhpNewFast_RarePath: +LOCAL_LABEL(RhpNewFast_RarePath): mov x1, #0 b C_FUNC(RhpNewObject) LEAF_END RhpNewFast, _TEXT @@ -88,12 +88,12 @@ RhpNewFast_RarePath: bl C_FUNC(RhpGcAlloc) // Set the new objects MethodTable pointer on success. - cbz x0, NewOutOfMemory + cbz x0, LOCAL_LABEL(NewOutOfMemory) POP_COOP_PINVOKE_FRAME EPILOG_RETURN -NewOutOfMemory: +LOCAL_LABEL(NewOutOfMemory): // This is the OOM failure path. We are going to tail-call to a managed helper that will throw // an out of memory exception that the caller of this allocator understands. @@ -113,7 +113,7 @@ NewOutOfMemory: movz x2, MAX_STRING_LENGTH & 0xFFFF movk x2, MAX_STRING_LENGTH >> 16, lsl 16 cmp x1, x2 - bhi StringSizeOverflow + bhi LOCAL_LABEL(StringSizeOverflow) // Compute overall allocation size (align(base size + (element size * elements), 8)). mov w2, #STRING_COMPONENT_SIZE @@ -139,7 +139,7 @@ NewOutOfMemory: add x2, x2, x12 ldr x12, [x3, #OFFSETOF__Thread__m_alloc_context__alloc_limit] cmp x2, x12 - bhi C_FUNC(RhpNewArrayRare) + bhi LOCAL_LABEL(RhNewString_Rare) // Reload new object address into r12. ldr x12, [x3, #OFFSETOF__Thread__m_alloc_context__alloc_ptr] @@ -156,7 +156,7 @@ NewOutOfMemory: ret -StringSizeOverflow: +LOCAL_LABEL(StringSizeOverflow): // We get here if the length of the final string object can not be represented as an unsigned // 32-bit value. We are going to tail-call to a managed helper that will throw // an OOM exception that the caller of this allocator understands. @@ -164,6 +164,9 @@ StringSizeOverflow: // x0 holds MethodTable pointer already mov x1, #1 // Indicate that we should throw OverflowException b C_FUNC(RhExceptionHandling_FailedAllocation) + +LOCAL_LABEL(RhNewString_Rare): + b C_FUNC(RhpNewArrayRare) LEAF_END RhNewString, _Text // Allocate one dimensional, zero based array (SZARRAY). @@ -177,7 +180,7 @@ StringSizeOverflow: // case (32 dimensional MdArray) is less than 0xffff, and thus the product fits in 64 bits. mov x2, #0x7FFFFFFF cmp x1, x2 - bhi ArraySizeOverflow + bhi LOCAL_LABEL(ArraySizeOverflow) ldrh w2, [x0, #OFFSETOF__MethodTable__m_usComponentSize] umull x2, w1, w2 @@ -204,7 +207,7 @@ StringSizeOverflow: add x2, x2, x12 ldr x12, [x3, #OFFSETOF__Thread__m_alloc_context__alloc_limit] cmp x2, x12 - bhi C_FUNC(RhpNewArrayRare) + bhi LOCAL_LABEL(RhpNewArray_Rare) // Reload new object address into x12. ldr x12, [x3, #OFFSETOF__Thread__m_alloc_context__alloc_ptr] @@ -221,7 +224,7 @@ StringSizeOverflow: ret -ArraySizeOverflow: +LOCAL_LABEL(ArraySizeOverflow): // We get here if the size of the final array object can not be represented as an unsigned // 32-bit value. We are going to tail-call to a managed helper that will throw // an overflow exception that the caller of this allocator understands. @@ -229,6 +232,9 @@ ArraySizeOverflow: // x0 holds MethodTable pointer already mov x1, #1 // Indicate that we should throw OverflowException b C_FUNC(RhExceptionHandling_FailedAllocation) + +LOCAL_LABEL(RhpNewArray_Rare): + b C_FUNC(RhpNewArrayRare) LEAF_END RhpNewArray, _TEXT // Allocate one dimensional, zero based array (SZARRAY) using the slow path that calls a runtime helper. @@ -254,12 +260,12 @@ ArraySizeOverflow: bl C_FUNC(RhpGcAlloc) // Set the new objects MethodTable pointer and length on success. - cbz x0, ArrayOutOfMemory + cbz x0, LOCAL_LABEL(ArrayOutOfMemory) POP_COOP_PINVOKE_FRAME EPILOG_RETURN -ArrayOutOfMemory: +LOCAL_LABEL(ArrayOutOfMemory): // This is the OOM failure path. We are going to tail-call to a managed helper that will throw // an out of memory exception that the caller of this allocator understands. diff --git a/src/coreclr/nativeaot/Runtime/arm64/ExceptionHandling.S b/src/coreclr/nativeaot/Runtime/arm64/ExceptionHandling.S index d0425171e1d191..7c04f15ad3b858 100644 --- a/src/coreclr/nativeaot/Runtime/arm64/ExceptionHandling.S +++ b/src/coreclr/nativeaot/Runtime/arm64/ExceptionHandling.S @@ -275,7 +275,7 @@ // where the tail-calling thread had saved LR, which may not match where we have saved LR. ldr x1, [x2, #OFFSETOF__Thread__m_pvHijackedReturnAddress] - cbz x1, NotHijacked + cbz x1, LOCAL_LABEL(NotHijacked) ldr x3, [x2, #OFFSETOF__Thread__m_ppvHijackedReturnAddressLocation] @@ -286,13 +286,13 @@ add x12, sp, #(STACKSIZEOF_ExInfo + SIZEOF__PAL_LIMITED_CONTEXT) // re-compute SP at callsite cmp x3, x12 // if (m_ppvHijackedReturnAddressLocation < SP at callsite) - blo TailCallWasHijacked + blo LOCAL_LABEL(TailCallWasHijacked) // normal case where a valid return address location is hijacked str x1, [x3] - b ClearThreadState + b LOCAL_LABEL(ClearThreadState) -TailCallWasHijacked: +LOCAL_LABEL(TailCallWasHijacked): // Abnormal case where the return address location is now invalid because we ended up here via a tail // call. In this case, our hijacked return address should be the correct caller of this method. @@ -302,13 +302,13 @@ TailCallWasHijacked: str lr, [sp, #(rsp_offsetof_Context + OFFSETOF__PAL_LIMITED_CONTEXT__LR)] str lr, [sp, #(rsp_offsetof_Context + OFFSETOF__PAL_LIMITED_CONTEXT__IP)] -ClearThreadState: +LOCAL_LABEL(ClearThreadState): // clear the Thread's hijack state str xzr, [x2, #OFFSETOF__Thread__m_ppvHijackedReturnAddressLocation] str xzr, [x2, #OFFSETOF__Thread__m_pvHijackedReturnAddress] -NotHijacked: +LOCAL_LABEL(NotHijacked): add x1, sp, #rsp_offsetof_ExInfo // x1 <- ExInfo* str xzr, [x1, #OFFSETOF__ExInfo__m_exception] // pExInfo->m_exception = null @@ -429,13 +429,13 @@ NotHijacked: add x12, x5, #OFFSETOF__Thread__m_ThreadStateFlags -ClearRetry_Catch: +LOCAL_LABEL(ClearRetry_Catch): ldxr w4, [x12] bic w4, w4, #TSF_DoNotTriggerGc stxr w6, w4, [x12] - cbz w6, ClearSuccess_Catch - b ClearRetry_Catch -ClearSuccess_Catch: + cbz w6, LOCAL_LABEL(ClearSuccess_Catch) + b LOCAL_LABEL(ClearRetry_Catch) +LOCAL_LABEL(ClearSuccess_Catch): // // set preserved regs to the values expected by the funclet @@ -487,21 +487,21 @@ ClearSuccess_Catch: ldr x3, [sp, #rsp_offset_x3] // x3 <- current ExInfo* ldr x2, [x2, #OFFSETOF__REGDISPLAY__SP] // x2 <- resume SP value -PopExInfoLoop: +LOCAL_LABEL(PopExInfoLoop): ldr x3, [x3, #OFFSETOF__ExInfo__m_pPrevExInfo] // x3 <- next ExInfo - cbz x3, DonePopping // if (pExInfo == null) { we're done } + cbz x3, LOCAL_LABEL(DonePopping) // if (pExInfo == null) { we're done } cmp x3, x2 - blt PopExInfoLoop // if (pExInfo < resume SP} { keep going } + blt LOCAL_LABEL(PopExInfoLoop) // if (pExInfo < resume SP} { keep going } -DonePopping: +LOCAL_LABEL(DonePopping): str x3, [x1, #OFFSETOF__Thread__m_pExInfoStackHead] // store the new head on the Thread PREPARE_EXTERNAL_VAR_INDIRECT_W RhpTrapThreads, 3 - tbz x3, #TrapThreadsFlags_AbortInProgress_Bit, NoAbort + tbz x3, #TrapThreadsFlags_AbortInProgress_Bit, LOCAL_LABEL(NoAbort) ldr x3, [sp, #rsp_offset_is_not_handling_thread_abort] - cbnz x3, NoAbort + cbnz x3, LOCAL_LABEL(NoAbort) // It was the ThreadAbortException, so rethrow it // reset SP @@ -510,7 +510,7 @@ DonePopping: mov sp, x2 b C_FUNC(RhpThrowHwEx) -NoAbort: +LOCAL_LABEL(NoAbort): // reset SP and jump to continuation address mov sp, x2 br x0 @@ -564,13 +564,13 @@ NoAbort: add x12, x2, #OFFSETOF__Thread__m_ThreadStateFlags -ClearRetry: +LOCAL_LABEL(ClearRetry): ldxr w4, [x12] bic w4, w4, #TSF_DoNotTriggerGc stxr w3, w4, [x12] - cbz w3, ClearSuccess - b ClearRetry -ClearSuccess: + cbz w3, LOCAL_LABEL(ClearSuccess) + b LOCAL_LABEL(ClearRetry) +LOCAL_LABEL(ClearSuccess): // // set preserved regs to the values expected by the funclet @@ -602,13 +602,13 @@ ClearSuccess: ldr x2, [sp, rsp_FinallyFunclet_offset_thread] add x12, x2, #OFFSETOF__Thread__m_ThreadStateFlags -SetRetry: +LOCAL_LABEL(SetRetry): ldxr w1, [x12] orr w1, w1, #TSF_DoNotTriggerGc stxr w3, w1, [x12] - cbz w3, SetSuccess - b SetRetry -SetSuccess: + cbz w3, LOCAL_LABEL(SetSuccess) + b LOCAL_LABEL(SetRetry) +LOCAL_LABEL(SetSuccess): ldp d8, d9, [sp, #0x00] ldp d10, d11, [sp, #0x10] @@ -707,13 +707,13 @@ SetSuccess: add x12, x5, #OFFSETOF__Thread__m_ThreadStateFlags -ClearRetry_Propagate: +LOCAL_LABEL(ClearRetry_Propagate): ldxr w4, [x12] bic w4, w4, #TSF_DoNotTriggerGc stxr w6, w4, [x12] - cbz w6, ClearSuccess_Propagate - b ClearRetry_Propagate -ClearSuccess_Propagate: + cbz w6, LOCAL_LABEL(ClearSuccess_Propagate) + b LOCAL_LABEL(ClearRetry_Propagate) +LOCAL_LABEL(ClearSuccess_Propagate): // // set preserved regs to the values expected by the funclet @@ -749,13 +749,13 @@ ClearSuccess_Propagate: ldr x3, [sp, #rsp_offset_x3] // x3 <- current ExInfo* ldr x2, [x2, #OFFSETOF__REGDISPLAY__SP] // x2 <- resume SP value -Propagate_PopExInfoLoop: +LOCAL_LABEL(Propagate_PopExInfoLoop): ldr x3, [x3, #OFFSETOF__ExInfo__m_pPrevExInfo] // x3 <- next ExInfo - cbz x3, Propagate_DonePopping // if (pExInfo == null) { we're done } + cbz x3, LOCAL_LABEL(Propagate_DonePopping) // if (pExInfo == null) { we're done } cmp x3, x2 - blt Propagate_PopExInfoLoop // if (pExInfo < resume SP} { keep going } + blt LOCAL_LABEL(Propagate_PopExInfoLoop) // if (pExInfo < resume SP} { keep going } -Propagate_DonePopping: +LOCAL_LABEL(Propagate_DonePopping): str x3, [x1, #OFFSETOF__Thread__m_pExInfoStackHead] // store the new head on the Thread // restore preemptive mode diff --git a/src/coreclr/nativeaot/Runtime/arm64/GcProbe.S b/src/coreclr/nativeaot/Runtime/arm64/GcProbe.S index e27834bae6fedd..abe7555b761134 100644 --- a/src/coreclr/nativeaot/Runtime/arm64/GcProbe.S +++ b/src/coreclr/nativeaot/Runtime/arm64/GcProbe.S @@ -127,10 +127,10 @@ NESTED_ENTRY RhpGcProbeHijack, _TEXT, NoHandler FixupHijackedCallstack PREPARE_EXTERNAL_VAR_INDIRECT_W RhpTrapThreads, 3 - tbnz x3, #TrapThreadsFlags_TrapThreads_Bit, WaitForGC + tbnz x3, #TrapThreadsFlags_TrapThreads_Bit, LOCAL_LABEL(WaitForGC) ret -WaitForGC: +LOCAL_LABEL(WaitForGC): orr x12, x12, DEFAULT_FRAME_SAVE_FLAGS + PTFF_SAVE_X0 + PTFF_SAVE_X1 b C_FUNC(RhpWaitForGC) NESTED_END RhpGcProbeHijack @@ -144,11 +144,11 @@ NESTED_ENTRY RhpWaitForGC, _TEXT, NoHandler bl C_FUNC(RhpWaitForGC2) ldr x2, [sp, #OFFSETOF__PInvokeTransitionFrame__m_Flags] - tbnz x2, #PTFF_THREAD_ABORT_BIT, ThrowThreadAbort + tbnz x2, #PTFF_THREAD_ABORT_BIT, LOCAL_LABEL(ThrowThreadAbort) POP_PROBE_FRAME EPILOG_RETURN -ThrowThreadAbort: +LOCAL_LABEL(ThrowThreadAbort): POP_PROBE_FRAME mov w0, #STATUS_REDHAWK_THREAD_ABORT mov x1, lr // return address as exception PC @@ -159,8 +159,10 @@ NESTED_END RhpWaitForGC LEAF_ENTRY RhpGcPoll PREPARE_EXTERNAL_VAR_INDIRECT_W RhpTrapThreads, 0 - cbnz w0, C_FUNC(RhpGcPollRare) // TrapThreadsFlags_None = 0 + cbnz w0, LOCAL_LABEL(RhpGcPoll_Rare) // TrapThreadsFlags_None = 0 ret +LOCAL_LABEL(RhpGcPoll_Rare): + b C_FUNC(RhpGcPollRare) LEAF_END RhpGcPoll NESTED_ENTRY RhpGcPollRare, _TEXT, NoHandler diff --git a/src/coreclr/nativeaot/Runtime/arm64/WriteBarriers.S b/src/coreclr/nativeaot/Runtime/arm64/WriteBarriers.S index d00ffb3a4a9978..835466c3b9e7e4 100644 --- a/src/coreclr/nativeaot/Runtime/arm64/WriteBarriers.S +++ b/src/coreclr/nativeaot/Runtime/arm64/WriteBarriers.S @@ -224,9 +224,11 @@ LEAF_END RhpByRefAssignRefArm64, _TEXT PREPARE_EXTERNAL_VAR_INDIRECT g_highest_address, x12 ccmp x14, x12, #0x2, hs - blo C_FUNC(RhpAssignRefArm64) + bhs LOCAL_LABEL(NotInHeap) -NotInHeap: + b C_FUNC(RhpAssignRefArm64) + +LOCAL_LABEL(NotInHeap): ALTERNATE_ENTRY RhpCheckedAssignRefAVLocation str x15, [x14], 8 ret @@ -293,44 +295,44 @@ LEAF_END RhpAssignRef, _TEXT #ifndef LSE_INSTRUCTIONS_ENABLED_BY_DEFAULT PREPARE_EXTERNAL_VAR_INDIRECT_W g_cpuFeatures, 16 - tbz w16, #ARM64_ATOMICS_FEATURE_FLAG_BIT, CmpXchgRetry + tbz w16, #ARM64_ATOMICS_FEATURE_FLAG_BIT, LOCAL_LABEL(CmpXchgRetry) #endif mov x10, x2 ALTERNATE_ENTRY RhpCheckedLockCmpXchgAVLocation casal x10, x1, [x0] // exchange cmp x2, x10 - bne CmpXchgNoUpdate + bne LOCAL_LABEL(CmpXchgNoUpdate) #ifndef LSE_INSTRUCTIONS_ENABLED_BY_DEFAULT - b DoCardsCmpXchg -CmpXchgRetry: + b LOCAL_LABEL(DoCardsCmpXchg) +LOCAL_LABEL(CmpXchgRetry): // Check location value is what we expect. ALTERNATE_ENTRY RhpCheckedLockCmpXchgAVLocation2 ldaxr x10, [x0] cmp x10, x2 - bne CmpXchgNoUpdate + bne LOCAL_LABEL(CmpXchgNoUpdate) // Current value matches comparand, attempt to update with the new value. stlxr w12, x1, [x0] - cbnz w12, CmpXchgRetry + cbnz w12, LOCAL_LABEL(CmpXchgRetry) #endif -DoCardsCmpXchg: +LOCAL_LABEL(DoCardsCmpXchg): // We have successfully updated the value of the objectref so now we need a GC write barrier. // The following barrier code takes the destination in x0 and the value in x1 so the arguments are // already correctly set up. INSERT_CHECKED_WRITE_BARRIER_CORE x0, x1 -CmpXchgNoUpdate: +LOCAL_LABEL(CmpXchgNoUpdate): // x10 still contains the original value. mov x0, x10 #ifndef LSE_INSTRUCTIONS_ENABLED_BY_DEFAULT - tbnz w16, #ARM64_ATOMICS_FEATURE_FLAG_BIT, NoBarrierCmpXchg + tbnz w16, #ARM64_ATOMICS_FEATURE_FLAG_BIT, LOCAL_LABEL(NoBarrierCmpXchg) InterlockedOperationBarrier -NoBarrierCmpXchg: +LOCAL_LABEL(NoBarrierCmpXchg): #endif ret lr @@ -357,25 +359,25 @@ NoBarrierCmpXchg: #ifndef LSE_INSTRUCTIONS_ENABLED_BY_DEFAULT PREPARE_EXTERNAL_VAR_INDIRECT_W g_cpuFeatures, 16 - tbz w16, #ARM64_ATOMICS_FEATURE_FLAG_BIT, ExchangeRetry + tbz w16, #ARM64_ATOMICS_FEATURE_FLAG_BIT, LOCAL_LABEL(ExchangeRetry) #endif ALTERNATE_ENTRY RhpCheckedXchgAVLocation swpal x1, x10, [x0] // exchange #ifndef LSE_INSTRUCTIONS_ENABLED_BY_DEFAULT - b DoCardsXchg -ExchangeRetry: + b LOCAL_LABEL(DoCardsXchg) +LOCAL_LABEL(ExchangeRetry): // Read the existing memory location. ALTERNATE_ENTRY RhpCheckedXchgAVLocation2 ldaxr x10, [x0] // Attempt to update with the new value. stlxr w12, x1, [x0] - cbnz w12, ExchangeRetry + cbnz w12, LOCAL_LABEL(ExchangeRetry) #endif -DoCardsXchg: +LOCAL_LABEL(DoCardsXchg): // We have successfully updated the value of the objectref so now we need a GC write barrier. // The following barrier code takes the destination in x0 and the value in x1 so the arguments are // already correctly set up. @@ -386,9 +388,9 @@ DoCardsXchg: mov x0, x10 #ifndef LSE_INSTRUCTIONS_ENABLED_BY_DEFAULT - tbnz w16, #ARM64_ATOMICS_FEATURE_FLAG_BIT, NoBarrierXchg + tbnz w16, #ARM64_ATOMICS_FEATURE_FLAG_BIT, LOCAL_LABEL(NoBarrierXchg) InterlockedOperationBarrier -NoBarrierXchg: +LOCAL_LABEL(NoBarrierXchg): #endif ret diff --git a/src/coreclr/nativeaot/Runtime/gcrhenv.cpp b/src/coreclr/nativeaot/Runtime/gcrhenv.cpp index 3d0990962b7c99..3ec488605c1b33 100644 --- a/src/coreclr/nativeaot/Runtime/gcrhenv.cpp +++ b/src/coreclr/nativeaot/Runtime/gcrhenv.cpp @@ -42,7 +42,6 @@ #include "daccess.h" -#include "GCMemoryHelpers.h" #include "interoplibinterface.h" #include "holder.h" diff --git a/src/coreclr/nativeaot/Runtime/portable.cpp b/src/coreclr/nativeaot/Runtime/portable.cpp index d45b3d062d00e3..8b425bfe2dff12 100644 --- a/src/coreclr/nativeaot/Runtime/portable.cpp +++ b/src/coreclr/nativeaot/Runtime/portable.cpp @@ -31,7 +31,6 @@ #include "MethodTable.inl" #include "ObjectLayout.h" -#include "GCMemoryHelpers.h" #include "GCMemoryHelpers.inl" #if defined(USE_PORTABLE_HELPERS) diff --git a/src/coreclr/nativeaot/Runtime/startup.cpp b/src/coreclr/nativeaot/Runtime/startup.cpp index 32cbab53cb8304..5db04aa27766dd 100644 --- a/src/coreclr/nativeaot/Runtime/startup.cpp +++ b/src/coreclr/nativeaot/Runtime/startup.cpp @@ -297,6 +297,10 @@ static void UninitDLL() Thread* g_threadPerformingShutdown = NULL; #endif +#if defined(_WIN32) && defined(FEATURE_PERFTRACING) +bool g_safeToShutdownTracing; +#endif + static void __cdecl OnProcessExit() { #ifdef _WIN32 @@ -309,8 +313,16 @@ static void __cdecl OnProcessExit() #endif #ifdef FEATURE_PERFTRACING - EventPipe_Shutdown(); - DiagnosticServer_Shutdown(); +#ifdef _WIN32 + // We forgo shutting down event pipe if it wouldn't be safe and could lead to a hang. + // If there was an active trace session, the trace will likely be corrupted without + // orderly shutdown. See https://github.com/dotnet/runtime/issues/89346. + if (g_safeToShutdownTracing) +#endif + { + EventPipe_Shutdown(); + DiagnosticServer_Shutdown(); + } #endif } @@ -348,7 +360,7 @@ void RuntimeThreadShutdown(void* thread) #endif } -extern "C" bool RhInitialize() +extern "C" bool RhInitialize(bool isDll) { if (!PalInit()) return false; @@ -357,6 +369,10 @@ extern "C" bool RhInitialize() atexit(&OnProcessExit); #endif +#if defined(_WIN32) && defined(FEATURE_PERFTRACING) + g_safeToShutdownTracing = !isDll; +#endif + if (!InitDLL(PalGetModuleHandleFromPointer((void*)&RhInitialize))) return false; diff --git a/src/coreclr/nativeaot/Runtime/threadstore.cpp b/src/coreclr/nativeaot/Runtime/threadstore.cpp index 67a6949fd7fb06..2e8369f9175fc5 100644 --- a/src/coreclr/nativeaot/Runtime/threadstore.cpp +++ b/src/coreclr/nativeaot/Runtime/threadstore.cpp @@ -24,7 +24,6 @@ #include "yieldprocessornormalized.h" #include "slist.inl" -#include "GCMemoryHelpers.h" EXTERN_C volatile uint32_t RhpTrapThreads; volatile uint32_t RhpTrapThreads = (uint32_t)TrapThreadsFlags::None; diff --git a/src/coreclr/nativeaot/Runtime/unix/unixasmmacros.inc b/src/coreclr/nativeaot/Runtime/unix/unixasmmacros.inc index ef6d393fd248b1..bde1d517b7e823 100644 --- a/src/coreclr/nativeaot/Runtime/unix/unixasmmacros.inc +++ b/src/coreclr/nativeaot/Runtime/unix/unixasmmacros.inc @@ -3,6 +3,11 @@ #define INVALIDGCVALUE 0xCCCCCCCD +// Enforce subsections via symbols to workaround bugs in Xcode 15 linker. +#if defined(__APPLE__) +.subsections_via_symbols +#endif + #if defined(__APPLE__) #define C_FUNC(name) _##name #define EXTERNAL_C_FUNC(name) C_FUNC(name) diff --git a/src/coreclr/nativeaot/System.Private.CoreLib/src/CompatibilitySuppressions.xml b/src/coreclr/nativeaot/System.Private.CoreLib/src/CompatibilitySuppressions.xml index d5fbde8e348dc3..229085a10afaa0 100644 --- a/src/coreclr/nativeaot/System.Private.CoreLib/src/CompatibilitySuppressions.xml +++ b/src/coreclr/nativeaot/System.Private.CoreLib/src/CompatibilitySuppressions.xml @@ -800,6 +800,10 @@ CP0001 T:System.Diagnostics.DebugAnnotations + + CP0001 + T:System.Diagnostics.DebuggerGuidedStepThroughAttribute + CP0001 T:System.MDArray @@ -864,6 +868,10 @@ CP0001 T:System.Reflection.RuntimeAssemblyName + + CP0001 + T:System.Runtime.CompilerServices.EagerStaticClassConstructionAttribute + CP0001 T:System.Runtime.CompilerServices.ForceDictionaryLookupsAttribute diff --git a/src/coreclr/nativeaot/System.Private.CoreLib/src/System/CrashInfo.cs b/src/coreclr/nativeaot/System.Private.CoreLib/src/System/CrashInfo.cs index 8a2f33c93f1f8c..c2421fc6b4ceb9 100644 --- a/src/coreclr/nativeaot/System.Private.CoreLib/src/System/CrashInfo.cs +++ b/src/coreclr/nativeaot/System.Private.CoreLib/src/System/CrashInfo.cs @@ -103,12 +103,17 @@ private bool WriteHeader(RhFailFastReason reason, ulong crashingThreadId, string if (!WriteValue("version"u8, "1.0.0"u8)) return false; - if (!WriteValue("runtime"u8, new ReadOnlySpan(RuntimeImports.RhGetRuntimeVersion(out int cbLength), cbLength))) + static void Dummy() { } + + if (!WriteHexValue("runtime_base"u8, (ulong)RuntimeImports.RhGetOSModuleFromPointer((nint)(void*)(delegate*)&Dummy))) return false; if (!WriteIntValue("runtime_type"u8, (int)RuntimeType.NativeAOT)) return false; + if (!WriteValue("runtime_version"u8, new ReadOnlySpan(RuntimeImports.RhGetRuntimeVersion(out int cbLength), cbLength))) + return false; + CrashReason crashReason = reason switch { RhFailFastReason.EnvironmentFailFast => CrashReason.EnvironmentFailFast, diff --git a/src/coreclr/nativeaot/System.Private.CoreLib/src/System/GC.NativeAot.cs b/src/coreclr/nativeaot/System.Private.CoreLib/src/System/GC.NativeAot.cs index 5ebcfc6c0771c1..3d263d5de6b63d 100644 --- a/src/coreclr/nativeaot/System.Private.CoreLib/src/System/GC.NativeAot.cs +++ b/src/coreclr/nativeaot/System.Private.CoreLib/src/System/GC.NativeAot.cs @@ -869,7 +869,6 @@ public static TimeSpan GetTotalPauseDuration() return new TimeSpan(RuntimeImports.RhGetTotalPauseDuration()); } - [System.Runtime.Versioning.RequiresPreviewFeaturesAttribute("RefreshMemoryLimit is in preview.")] public static void RefreshMemoryLimit() { ulong heapHardLimit = (AppContext.GetData("GCHeapHardLimit") as ulong?) ?? ulong.MaxValue; diff --git a/src/coreclr/nativeaot/System.Private.TypeLoader/src/Internal/Runtime/TypeLoader/TypeLoaderEnvironment.MetadataSignatureParsing.cs b/src/coreclr/nativeaot/System.Private.TypeLoader/src/Internal/Runtime/TypeLoader/TypeLoaderEnvironment.MetadataSignatureParsing.cs index 07674dded3541c..1a1cd76cc4ddfd 100644 --- a/src/coreclr/nativeaot/System.Private.TypeLoader/src/Internal/Runtime/TypeLoader/TypeLoaderEnvironment.MetadataSignatureParsing.cs +++ b/src/coreclr/nativeaot/System.Private.TypeLoader/src/Internal/Runtime/TypeLoader/TypeLoaderEnvironment.MetadataSignatureParsing.cs @@ -173,12 +173,23 @@ internal static NativeParser GetNativeParserForSignature(RuntimeSignature signat private bool CompareTypeSigWithType(ref NativeParser parser, TypeManagerHandle moduleHandle, Handle typeHandle) { - while (typeHandle.HandleType == HandleType.TypeSpecification) + while (typeHandle.HandleType == HandleType.TypeSpecification + || typeHandle.HandleType == HandleType.ModifiedType) { - typeHandle = typeHandle - .ToTypeSpecificationHandle(_metadataReader) - .GetTypeSpecification(_metadataReader) - .Signature; + if (typeHandle.HandleType == HandleType.TypeSpecification) + { + typeHandle = typeHandle + .ToTypeSpecificationHandle(_metadataReader) + .GetTypeSpecification(_metadataReader) + .Signature; + } + else + { + typeHandle = typeHandle + .ToModifiedTypeHandle(_metadataReader) + .GetModifiedType(_metadataReader) + .Type; + } } // startOffset lets us backtrack to the TypeSignatureKind for external types since the TypeLoader diff --git a/src/coreclr/pal/inc/unixasmmacros.inc b/src/coreclr/pal/inc/unixasmmacros.inc index 658a65bb4b35aa..120b26543e3faa 100644 --- a/src/coreclr/pal/inc/unixasmmacros.inc +++ b/src/coreclr/pal/inc/unixasmmacros.inc @@ -3,6 +3,11 @@ #define INVALIDGCVALUE 0xCCCCCCCD +// Enforce subsections via symbols to workaround bugs in Xcode 15 linker. +#if defined(__APPLE__) +.subsections_via_symbols +#endif + #if defined(__APPLE__) #define C_FUNC(name) _##name #define EXTERNAL_C_FUNC(name) C_FUNC(name) diff --git a/src/coreclr/tools/Common/CommandLineHelpers.cs b/src/coreclr/tools/Common/CommandLineHelpers.cs index 205592c1c91dca..3fb977a3047a6a 100644 --- a/src/coreclr/tools/Common/CommandLineHelpers.cs +++ b/src/coreclr/tools/Common/CommandLineHelpers.cs @@ -210,7 +210,7 @@ public static void MakeReproPackage(string makeReproPath, string outputFilePath, foreach (CliOption option in res.CommandResult.Command.Options) { OptionResult optionResult = res.GetResult(option); - if (optionResult is null || option.Name == "make-repro-path") + if (optionResult is null || option.Name == "--make-repro-path") { continue; } @@ -233,7 +233,7 @@ public static void MakeReproPackage(string makeReproPath, string outputFilePath, } foreach (string inputFile in dictionary.Values) { - rspFile.Add($"--{option.Name}:{ConvertFromOriginalPathToReproPackagePath(input: true, inputFile)}"); + rspFile.Add($"{option.Name}:{ConvertFromOriginalPathToReproPackagePath(input: true, inputFile)}"); } } else @@ -241,7 +241,7 @@ public static void MakeReproPackage(string makeReproPath, string outputFilePath, foreach (string optInList in values) { if (!string.IsNullOrEmpty(optInList)) - rspFile.Add($"--{option.Name}:{optInList}"); + rspFile.Add($"{option.Name}:{optInList}"); } } } @@ -254,11 +254,11 @@ public static void MakeReproPackage(string makeReproPath, string outputFilePath, // if output option is used, overwrite the path to the repro package stringVal = ConvertFromOriginalPathToReproPackagePath(input: false, stringVal); } - rspFile.Add($"--{option.Name}:{stringVal}"); + rspFile.Add($"{option.Name}:{stringVal}"); } else { - rspFile.Add($"--{option.Name}:{val}"); + rspFile.Add($"{option.Name}:{val}"); } } } diff --git a/src/coreclr/tools/Common/TypeSystem/Common/MetadataVirtualMethodAlgorithm.cs b/src/coreclr/tools/Common/TypeSystem/Common/MetadataVirtualMethodAlgorithm.cs index 8d2b8a4e3fd3fe..330296b18dbfa4 100644 --- a/src/coreclr/tools/Common/TypeSystem/Common/MetadataVirtualMethodAlgorithm.cs +++ b/src/coreclr/tools/Common/TypeSystem/Common/MetadataVirtualMethodAlgorithm.cs @@ -614,6 +614,8 @@ private static MethodDesc ResolveInterfaceMethodToVirtualMethodOnType(MethodDesc { Debug.Assert(!interfaceMethod.Signature.IsStatic); + // This would be a default interface method resolution. The algorithm below would sort of work, but doesn't handle + // things like diamond cases and it's better not to let it resolve as such. if (currentType.IsInterface) return null; @@ -781,7 +783,7 @@ private static DefaultInterfaceMethodResolution ResolveInterfaceMethodToDefaultI // If we're asking about an interface, include the interface in the list. consideredInterfaces = new DefType[currentType.RuntimeInterfaces.Length + 1]; Array.Copy(currentType.RuntimeInterfaces, consideredInterfaces, currentType.RuntimeInterfaces.Length); - consideredInterfaces[consideredInterfaces.Length - 1] = (DefType)currentType.InstantiateAsOpen(); + consideredInterfaces[consideredInterfaces.Length - 1] = currentType.IsGenericDefinition ? (DefType)currentType.InstantiateAsOpen() : currentType; } foreach (MetadataType runtimeInterface in consideredInterfaces) @@ -921,6 +923,11 @@ public static IEnumerable EnumAllVirtualSlots(MetadataType type) /// MethodDesc of the resolved virtual static method, null when not found (runtime lookup must be used) public static MethodDesc ResolveInterfaceMethodToStaticVirtualMethodOnType(MethodDesc interfaceMethod, MetadataType currentType) { + // This would be a default interface method resolution. The algorithm below would sort of work, but doesn't handle + // things like diamond cases and it's better not to let it resolve as such. + if (currentType.IsInterface) + return null; + // Search for match on a per-level in the type hierarchy for (MetadataType typeToCheck = currentType; typeToCheck != null; typeToCheck = typeToCheck.MetadataBaseType) { diff --git a/src/coreclr/tools/Common/TypeSystem/Ecma/SymbolReader/UnmanagedPdbSymbolReader.cs b/src/coreclr/tools/Common/TypeSystem/Ecma/SymbolReader/UnmanagedPdbSymbolReader.cs index c791a509f02a99..107170e743fbc7 100644 --- a/src/coreclr/tools/Common/TypeSystem/Ecma/SymbolReader/UnmanagedPdbSymbolReader.cs +++ b/src/coreclr/tools/Common/TypeSystem/Ecma/SymbolReader/UnmanagedPdbSymbolReader.cs @@ -68,7 +68,7 @@ protected override object CreateObject(IntPtr externalComObject, CreateObjectFla Debug.Assert(flags == CreateObjectFlags.UniqueInstance); var iid = ICLRMetaHost.IID; - if (Marshal.QueryInterface(externalComObject, ref iid, out IntPtr hostPtr) != 0) + if (Marshal.QueryInterface(externalComObject, in iid, out IntPtr hostPtr) != 0) { throw new ArgumentException("Expected ICLRMetaHost COM interface"); } @@ -284,7 +284,7 @@ private CoCreateWrapperCache() { } Debug.Assert(flags == CreateObjectFlags.UniqueInstance); var iid = new Guid("AA544D42-28CB-11d3-BD22-0000F80849BD"); - if (Marshal.QueryInterface(externalComObject, ref iid, out IntPtr ppv) != 0) + if (Marshal.QueryInterface(externalComObject, in iid, out IntPtr ppv) != 0) { return null; } diff --git a/src/coreclr/tools/aot/Directory.Build.props b/src/coreclr/tools/aot/Directory.Build.props deleted file mode 100644 index 5a5e0e9914b730..00000000000000 --- a/src/coreclr/tools/aot/Directory.Build.props +++ /dev/null @@ -1,6 +0,0 @@ - - - - true - - diff --git a/src/coreclr/tools/aot/Directory.Build.targets b/src/coreclr/tools/aot/Directory.Build.targets new file mode 100644 index 00000000000000..4f855d71288f72 --- /dev/null +++ b/src/coreclr/tools/aot/Directory.Build.targets @@ -0,0 +1,6 @@ + + + + true + + diff --git a/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/Dataflow/CompilerGeneratedState.cs b/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/Dataflow/CompilerGeneratedState.cs index 8150abeca61b1e..6a5e2ed77cdf08 100644 --- a/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/Dataflow/CompilerGeneratedState.cs +++ b/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/Dataflow/CompilerGeneratedState.cs @@ -158,8 +158,10 @@ referencedMethod.OwningType is MetadataType generatedType && break; case ILOpcode.stsfld: + case ILOpcode.ldsfld: { // Same as above, but stsfld instead of a call to the constructor + // Ldsfld may also trigger a cctor that creates a closure environment FieldDesc? field = methodBody.GetObject(reader.ReadILToken()) as FieldDesc; if (field == null) continue; @@ -417,6 +419,7 @@ void MapGeneratedTypeTypeParameters( break; case ILOpcode.stsfld: + case ILOpcode.ldsfld: { if (body.GetObject(reader.ReadILToken()) is FieldDesc { OwningType: MetadataType owningType } && compilerGeneratedType == owningType.GetTypeDefinition()) diff --git a/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/DependencyAnalysis/EETypeNode.cs b/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/DependencyAnalysis/EETypeNode.cs index 1e9f7679b44cc3..eb64be16023150 100644 --- a/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/DependencyAnalysis/EETypeNode.cs +++ b/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/DependencyAnalysis/EETypeNode.cs @@ -372,16 +372,8 @@ public sealed override IEnumerable GetConditionalSt DefType defType = _type.GetClosestDefType(); - // Interfaces don't have vtables and we don't need to track their slot use. - // The only exception are those interfaces that provide IDynamicInterfaceCastable implementations; - // those have slots and we dispatch on them. - bool needsDependenciesForVirtualMethodImpls = !defType.IsInterface - || ((MetadataType)defType).IsDynamicInterfaceCastableImplementation(); - // If we're producing a full vtable, none of the dependencies are conditional. - needsDependenciesForVirtualMethodImpls &= !factory.VTable(defType).HasFixedSlots; - - if (needsDependenciesForVirtualMethodImpls) + if (!factory.VTable(defType).HasFixedSlots) { bool isNonInterfaceAbstractType = !defType.IsInterface && ((MetadataType)defType).IsAbstract; @@ -436,13 +428,23 @@ public sealed override IEnumerable GetConditionalSt ((System.Collections.IStructuralEquatable)defType.RuntimeInterfaces).Equals(_type.RuntimeInterfaces, EqualityComparer.Default)); + // Interfaces don't have vtables and we don't need to track their instance method slot use. + // The only exception are those interfaces that provide IDynamicInterfaceCastable implementations; + // those have slots and we dispatch on them. + bool needsDependenciesForInstanceInterfaceMethodImpls = !defType.IsInterface + || ((MetadataType)defType).IsDynamicInterfaceCastableImplementation(); + // Add conditional dependencies for interface methods the type implements. For example, if the type T implements // interface IFoo which has a method M1, add a dependency on T.M1 dependent on IFoo.M1 being called, since it's // possible for any IFoo object to actually be an instance of T. + DefType defTypeDefinition = (DefType)defType.GetTypeDefinition(); DefType[] defTypeRuntimeInterfaces = defType.RuntimeInterfaces; + DefType[] defTypeDefinitionRuntimeInterfaces = defTypeDefinition.RuntimeInterfaces; + Debug.Assert(defTypeDefinitionRuntimeInterfaces.Length == defTypeRuntimeInterfaces.Length); for (int interfaceIndex = 0; interfaceIndex < defTypeRuntimeInterfaces.Length; interfaceIndex++) { DefType interfaceType = defTypeRuntimeInterfaces[interfaceIndex]; + DefType interfaceDefinitionType = defTypeDefinitionRuntimeInterfaces[interfaceIndex]; Debug.Assert(interfaceType.IsInterface); @@ -456,11 +458,25 @@ public sealed override IEnumerable GetConditionalSt bool isStaticInterfaceMethod = interfaceMethod.Signature.IsStatic; + if (!isStaticInterfaceMethod && !needsDependenciesForInstanceInterfaceMethodImpls) + continue; + + MethodDesc interfaceMethodDefinition = interfaceMethod; + if (interfaceType != interfaceDefinitionType) + interfaceMethodDefinition = factory.TypeSystemContext.GetMethodForInstantiatedType(interfaceMethodDefinition.GetTypicalMethodDefinition(), (InstantiatedType)interfaceDefinitionType); + MethodDesc implMethod = isStaticInterfaceMethod ? - defType.ResolveInterfaceMethodToStaticVirtualMethodOnType(interfaceMethod) : - defType.ResolveInterfaceMethodToVirtualMethodOnType(interfaceMethod); + defTypeDefinition.ResolveInterfaceMethodToStaticVirtualMethodOnType(interfaceMethodDefinition) : + defTypeDefinition.ResolveInterfaceMethodToVirtualMethodOnType(interfaceMethodDefinition); if (implMethod != null) { + TypeDesc implType = defType; + while (!implType.HasSameTypeDefinition(implMethod.OwningType)) + implType = implType.BaseType; + + if (!implType.IsTypeDefinition) + implMethod = factory.TypeSystemContext.GetMethodForInstantiatedType(implMethod.GetTypicalMethodDefinition(), (InstantiatedType)implType); + if (isStaticInterfaceMethod) { Debug.Assert(!implMethod.IsVirtual); @@ -499,12 +515,7 @@ public sealed override IEnumerable GetConditionalSt // Is the implementation provided by a default interface method? // If so, add a dependency on the entrypoint directly since nobody else is going to do that // (interface types have an empty vtable, modulo their generic dictionary). - TypeDesc interfaceOnDefinition = defType.GetTypeDefinition().RuntimeInterfaces[interfaceIndex]; - MethodDesc interfaceMethodDefinition = interfaceMethod; - if (!interfaceType.IsTypeDefinition) - interfaceMethodDefinition = factory.TypeSystemContext.GetMethodForInstantiatedType(interfaceMethod.GetTypicalMethodDefinition(), (InstantiatedType)interfaceOnDefinition); - - var resolution = defType.GetTypeDefinition().ResolveInterfaceMethodToDefaultImplementationOnType(interfaceMethodDefinition, out implMethod); + var resolution = defTypeDefinition.ResolveInterfaceMethodToDefaultImplementationOnType(interfaceMethodDefinition, out implMethod); if (resolution == DefaultInterfaceMethodResolution.DefaultImplementation) { DefType providingInterfaceDefinitionType = (DefType)implMethod.OwningType; diff --git a/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/DependencyAnalysis/InterfaceDispatchMapNode.cs b/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/DependencyAnalysis/InterfaceDispatchMapNode.cs index 38104d7ab015c9..c2ac4568c748c9 100644 --- a/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/DependencyAnalysis/InterfaceDispatchMapNode.cs +++ b/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/DependencyAnalysis/InterfaceDispatchMapNode.cs @@ -73,7 +73,7 @@ public static bool MightHaveInterfaceDispatchMap(TypeDesc type, NodeFactory fact if (!type.IsArray && !type.IsDefType) return false; - // Interfaces don't have a dispatch map because we dispatch them based on the + // Interfaces don't have a dispatch map for instance methods because we dispatch them based on the // dispatch map of the implementing class. // The only exception are IDynamicInterfaceCastable scenarios that dispatch // using the interface dispatch map. @@ -83,8 +83,9 @@ public static bool MightHaveInterfaceDispatchMap(TypeDesc type, NodeFactory fact // wasn't marked as [DynamicInterfaceCastableImplementation]" and "we couldn't find an // implementation". We don't want to use the custom attribute for that at runtime because // that's reflection and this should work without reflection. - if (type.IsInterface) - return ((MetadataType)type).IsDynamicInterfaceCastableImplementation(); + bool isInterface = type.IsInterface; + if (isInterface && ((MetadataType)type).IsDynamicInterfaceCastableImplementation()) + return true; DefType declType = type.GetClosestDefType(); @@ -112,6 +113,11 @@ public static bool MightHaveInterfaceDispatchMap(TypeDesc type, NodeFactory fact Debug.Assert(declMethod.IsVirtual); + // Only static methods get placed in dispatch maps of interface types (modulo + // IDynamicInterfaceCastable we already handled above). + if (isInterface && !declMethod.Signature.IsStatic) + continue; + if (interfaceOnDefinitionType != null) declMethod = factory.TypeSystemContext.GetMethodForInstantiatedType(declMethod.GetTypicalMethodDefinition(), interfaceOnDefinitionType); @@ -154,6 +160,10 @@ private void EmitDispatchMap(ref ObjectDataBuilder builder, NodeFactory factory) var staticImplementations = new List<(int InterfaceIndex, int InterfaceMethodSlot, int ImplMethodSlot, int Context)>(); var staticDefaultImplementations = new List<(int InterfaceIndex, int InterfaceMethodSlot, int ImplMethodSlot, int Context)>(); + bool isInterface = declType.IsInterface; + bool needsEntriesForInstanceInterfaceMethodImpls = !isInterface + || ((MetadataType)declType).IsDynamicInterfaceCastableImplementation(); + // Resolve all the interfaces, but only emit non-static and non-default implementations for (int interfaceIndex = 0; interfaceIndex < declTypeRuntimeInterfaces.Length; interfaceIndex++) { @@ -166,6 +176,10 @@ private void EmitDispatchMap(ref ObjectDataBuilder builder, NodeFactory factory) for (int interfaceMethodSlot = 0; interfaceMethodSlot < virtualSlots.Count; interfaceMethodSlot++) { MethodDesc declMethod = virtualSlots[interfaceMethodSlot]; + + if (!declMethod.Signature.IsStatic && !needsEntriesForInstanceInterfaceMethodImpls) + continue; + if(!interfaceType.IsTypeDefinition) declMethod = factory.TypeSystemContext.GetMethodForInstantiatedType(declMethod.GetTypicalMethodDefinition(), (InstantiatedType)interfaceDefinitionType); @@ -244,9 +258,17 @@ private void EmitDispatchMap(ref ObjectDataBuilder builder, NodeFactory factory) // For default interface methods, the generic context is acquired by indexing // into the interface list of the owning type. Debug.Assert(providingInterfaceDefinitionType != null); - int indexOfInterface = Array.IndexOf(declTypeDefinitionRuntimeInterfaces, providingInterfaceDefinitionType); - Debug.Assert(indexOfInterface >= 0); - genericContext = StaticVirtualMethodContextSource.ContextFromFirstInterface + indexOfInterface; + if (declTypeDefinition.HasSameTypeDefinition(providingInterfaceDefinitionType) && + providingInterfaceDefinitionType == declTypeDefinition.InstantiateAsOpen()) + { + genericContext = StaticVirtualMethodContextSource.ContextFromThisClass; + } + else + { + int indexOfInterface = Array.IndexOf(declTypeDefinitionRuntimeInterfaces, providingInterfaceDefinitionType); + Debug.Assert(indexOfInterface >= 0); + genericContext = StaticVirtualMethodContextSource.ContextFromFirstInterface + indexOfInterface; + } } staticDefaultImplementations.Add(( interfaceIndex, diff --git a/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/DependencyAnalysis/SealedVTableNode.cs b/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/DependencyAnalysis/SealedVTableNode.cs index bb67f884264dd3..a8460e80d0b413 100644 --- a/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/DependencyAnalysis/SealedVTableNode.cs +++ b/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/DependencyAnalysis/SealedVTableNode.cs @@ -108,17 +108,21 @@ public bool BuildSealedVTableSlots(NodeFactory factory, bool relocsOnly) _sealedVTableEntries = new List(); - // Interfaces don't have any virtual slots with the exception of interfaces that provide + // Interfaces don't have any instance virtual slots with the exception of interfaces that provide // IDynamicInterfaceCastable implementation. // Normal interface don't need one because the dispatch is done at the class level. // For IDynamicInterfaceCastable, we don't have an implementing class. - if (_type.IsInterface && !((MetadataType)_type).IsDynamicInterfaceCastableImplementation()) - return true; + bool isInterface = declType.IsInterface; + bool needsEntriesForInstanceInterfaceMethodImpls = !isInterface + || ((MetadataType)declType).IsDynamicInterfaceCastableImplementation(); IReadOnlyList virtualSlots = factory.VTable(declType).Slots; for (int i = 0; i < virtualSlots.Count; i++) { + if (!virtualSlots[i].Signature.IsStatic && !needsEntriesForInstanceInterfaceMethodImpls) + continue; + MethodDesc implMethod = declType.FindVirtualFunctionTargetMethodOnObjectType(virtualSlots[i]); if (implMethod.CanMethodBeInSealedVTable()) @@ -143,6 +147,10 @@ public bool BuildSealedVTableSlots(NodeFactory factory, bool relocsOnly) for (int interfaceMethodSlot = 0; interfaceMethodSlot < virtualSlots.Count; interfaceMethodSlot++) { MethodDesc declMethod = virtualSlots[interfaceMethodSlot]; + + if (!declMethod.Signature.IsStatic && !needsEntriesForInstanceInterfaceMethodImpls) + continue; + if (!interfaceType.IsTypeDefinition) declMethod = factory.TypeSystemContext.GetMethodForInstantiatedType(declMethod.GetTypicalMethodDefinition(), (InstantiatedType)interfaceDefinitionType); diff --git a/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/RootingHelpers.cs b/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/RootingHelpers.cs index 038c1ee1f38dc0..ffd286cce759f2 100644 --- a/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/RootingHelpers.cs +++ b/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/RootingHelpers.cs @@ -11,11 +11,11 @@ namespace ILCompiler { public class RootingHelpers { - public static bool TryRootType(IRootingServiceProvider rootProvider, TypeDesc type, string reason) + public static bool TryRootType(IRootingServiceProvider rootProvider, TypeDesc type, bool rootBaseTypes, string reason) { try { - RootType(rootProvider, type, reason); + RootType(rootProvider, type, rootBaseTypes, reason); return true; } catch (TypeSystemException) @@ -24,7 +24,7 @@ public static bool TryRootType(IRootingServiceProvider rootProvider, TypeDesc ty } } - public static void RootType(IRootingServiceProvider rootProvider, TypeDesc type, string reason) + public static void RootType(IRootingServiceProvider rootProvider, TypeDesc type, bool rootBaseTypes, string reason) { rootProvider.AddReflectionRoot(type, reason); @@ -40,13 +40,13 @@ public static void RootType(IRootingServiceProvider rootProvider, TypeDesc type, rootProvider.AddReflectionRoot(type, reason); } - // Also root base types. This is so that we make methods on the base types callable. - // This helps in cases like "class Foo : Bar { }" where we discover new - // generic instantiations. - TypeDesc baseType = type.BaseType; - if (baseType != null) + if (rootBaseTypes) { - RootType(rootProvider, baseType.NormalizeInstantiation(), reason); + TypeDesc baseType = type.BaseType; + if (baseType != null) + { + RootType(rootProvider, baseType.NormalizeInstantiation(), rootBaseTypes, reason); + } } if (type.IsDefType) diff --git a/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/UsageBasedMetadataManager.cs b/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/UsageBasedMetadataManager.cs index e680bf80f2dfa7..bd3e1069881361 100644 --- a/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/UsageBasedMetadataManager.cs +++ b/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/UsageBasedMetadataManager.cs @@ -351,7 +351,7 @@ protected override void GetMetadataDependenciesDueToReflectability(ref Dependenc var rootProvider = new RootingServiceProvider(factory, dependencies.Add); foreach (TypeDesc t in mdType.Module.GetAllTypes()) { - RootingHelpers.TryRootType(rootProvider, t, reason); + RootingHelpers.TryRootType(rootProvider, t, rootBaseTypes: false, reason); } } } @@ -1093,7 +1093,7 @@ private void ProcessAttribute(TypeDesc type, XPathNavigator nav) string internalValue = GetAttribute(nav, "internal"); if (!string.IsNullOrEmpty(internalValue)) { - if (!IsRemoveAttributeInstances(internalValue) || !nav.IsEmptyElement) + if (!IsRemoveAttributeInstances(internalValue)) { LogWarning(nav, DiagnosticId.UnrecognizedInternalAttribute, internalValue); } diff --git a/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/VirtualMethodCallHelper.cs b/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/VirtualMethodCallHelper.cs index a9de1fce5e3a52..ec398d37433966 100644 --- a/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/VirtualMethodCallHelper.cs +++ b/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/VirtualMethodCallHelper.cs @@ -93,9 +93,8 @@ private static int GetNumberOfSlotsInCurrentType(NodeFactory factory, TypeDesc i { if (implType.IsInterface) { - // We normally don't need to ask about vtable slots of interfaces. It's not wrong to ask - // that question, but we currently only ask it for IDynamicInterfaceCastable implementations. - Debug.Assert(((MetadataType)implType).IsDynamicInterfaceCastableImplementation()); + // Interface types don't have physically assigned virtual slots, so the number of slots + // is always 0. They may have sealed slots. return (implType.HasGenericDictionarySlot() && countDictionarySlots) ? 1 : 0; } diff --git a/src/coreclr/tools/aot/ILCompiler/RdXmlRootProvider.cs b/src/coreclr/tools/aot/ILCompiler/RdXmlRootProvider.cs index 2c6b01849db031..6d263eddc1eb5f 100644 --- a/src/coreclr/tools/aot/ILCompiler/RdXmlRootProvider.cs +++ b/src/coreclr/tools/aot/ILCompiler/RdXmlRootProvider.cs @@ -71,7 +71,7 @@ private void ProcessAssemblyDirective(IRootingServiceProvider rootProvider, XEle foreach (TypeDesc type in ((EcmaModule)assembly).GetAllTypes()) { - RootingHelpers.TryRootType(rootProvider, type, "RD.XML root"); + RootingHelpers.TryRootType(rootProvider, type, rootBaseTypes: true, "RD.XML root"); } } @@ -103,7 +103,7 @@ private static void ProcessTypeDirective(IRootingServiceProvider rootProvider, M if (dynamicDegreeAttribute.Value != "Required All") throw new NotSupportedException($"\"{dynamicDegreeAttribute.Value}\" is not a supported value for the \"Dynamic\" attribute of the \"Type\" Runtime Directive. Supported values are \"Required All\"."); - RootingHelpers.RootType(rootProvider, type, "RD.XML root"); + RootingHelpers.RootType(rootProvider, type, rootBaseTypes: true, "RD.XML root"); } var marshalStructureDegreeAttribute = typeElement.Attribute("MarshalStructure"); diff --git a/src/coreclr/tools/aot/Mono.Linker.Tests/TestCases/TestDatabase.cs b/src/coreclr/tools/aot/Mono.Linker.Tests/TestCases/TestDatabase.cs index 07147b1110323d..57c3de36dd7574 100644 --- a/src/coreclr/tools/aot/Mono.Linker.Tests/TestCases/TestDatabase.cs +++ b/src/coreclr/tools/aot/Mono.Linker.Tests/TestCases/TestDatabase.cs @@ -29,7 +29,17 @@ public static IEnumerable Generics () return TestNamesBySuiteName (); } - public static IEnumerable LinkXml() + public static IEnumerable InlineArrays () + { + return TestNamesBySuiteName(); + } + + public static IEnumerable Libraries() + { + return TestNamesBySuiteName(); + } + + public static IEnumerable LinkXml() { return TestNamesBySuiteName(); } diff --git a/src/coreclr/tools/aot/Mono.Linker.Tests/TestCases/TestSuites.cs b/src/coreclr/tools/aot/Mono.Linker.Tests/TestCases/TestSuites.cs index 883582cc5d9df2..f8d7fb5c69d02b 100644 --- a/src/coreclr/tools/aot/Mono.Linker.Tests/TestCases/TestSuites.cs +++ b/src/coreclr/tools/aot/Mono.Linker.Tests/TestCases/TestSuites.cs @@ -31,6 +31,20 @@ public void Generics (string t) } [Theory] + [MemberData(nameof(TestDatabase.InlineArrays), MemberType = typeof(TestDatabase))] + public void InlineArrays(string t) + { + Run(t); + } + + [Theory] + [MemberData(nameof(TestDatabase.Libraries), MemberType = typeof(TestDatabase))] + public void Libraries(string t) + { + Run(t); + } + + [Theory] [MemberData (nameof (TestDatabase.LinkXml), MemberType = typeof (TestDatabase))] public void LinkXml (string t) { diff --git a/src/coreclr/tools/aot/Mono.Linker.Tests/TestCasesRunner/AssemblyChecker.cs b/src/coreclr/tools/aot/Mono.Linker.Tests/TestCasesRunner/AssemblyChecker.cs index 7280796f5abc4e..393837de75f076 100644 --- a/src/coreclr/tools/aot/Mono.Linker.Tests/TestCasesRunner/AssemblyChecker.cs +++ b/src/coreclr/tools/aot/Mono.Linker.Tests/TestCasesRunner/AssemblyChecker.cs @@ -123,9 +123,12 @@ public void Verify () throw new NotImplementedException ($"Don't know how to check member of type {originalMember.GetType ()}"); } - // Filter out all members which are not from the main assembly - // The Kept attributes are "optional" for non-main assemblies - string mainModuleName = originalAssembly.Name.Name; + // Verify anything not in the main assembly + VerifyLinkingOfOtherAssemblies(this.originalAssembly); + + // Filter out all members which are not from the main assembly + // The Kept attributes are "optional" for non-main assemblies + string mainModuleName = originalAssembly.Name.Name; List externalMembers = linkedMembers.Where (m => GetModuleName (m.Value.Entity) != mainModuleName).Select (m => m.Key).ToList (); foreach (var externalMember in externalMembers) { linkedMembers.Remove (externalMember); @@ -136,7 +139,7 @@ public void Verify () false, "Linked output includes unexpected member:\n " + string.Join ("\n ", linkedMembers.Values.Select (e => e.Entity.GetDisplayName ()))); - } + } static bool IsCompilerGeneratedMemberName (string memberName) { @@ -304,12 +307,23 @@ static bool ShouldIncludeType (TypeDesc type) static bool ShouldIncludeMethod (MethodDesc method) => ShouldIncludeType (method.OwningType) && ShouldIncludeEntityByDisplayName (method); } + private static MetadataType? GetOwningType (TypeSystemEntity? entity) + { + return entity switch + { + MetadataType type => type.ContainingType as MetadataType, + MethodDesc method => method.OwningType as MetadataType, + PropertyPseudoDesc prop => prop.OwningType, + EventPseudoDesc e => e.OwningType, + _ => null + }; + } + private static string? GetModuleName (TypeSystemEntity entity) { return entity switch { MetadataType type => type.Module.ToString (), - MethodDesc { OwningType: MetadataType owningType } => owningType.Module.ToString (), - _ => null + _ => GetOwningType(entity)?.Module.ToString() }; } @@ -1338,38 +1352,38 @@ private static bool HasActiveKeptDerivedAttribute (ICustomAttributeProvider prov return GetActiveKeptDerivedAttributes (provider).Any (); } - private void VerifyLinkingOfOtherAssemblies (AssemblyDefinition original) + internal void VerifyLinkingOfOtherAssemblies (AssemblyDefinition original) { var checks = BuildOtherAssemblyCheckTable (original); - // TODO - // For now disable the code below by simply removing all checks - checks.Clear (); - try { foreach (var assemblyName in checks.Keys) { - var linkedAssembly = ResolveLinkedAssembly (assemblyName); + var linkedMembersInAssembly = ResolveLinkedMembersForAssembly (assemblyName); + var originalTargetAssembly = ResolveOriginalsAssembly(assemblyName); foreach (var checkAttrInAssembly in checks[assemblyName]) { var attributeTypeName = checkAttrInAssembly.AttributeType.Name; switch (attributeTypeName) { case nameof (KeptAllTypesAndMembersInAssemblyAttribute): - VerifyKeptAllTypesAndMembersInAssembly (linkedAssembly); + VerifyKeptAllTypesAndMembersInAssembly (assemblyName, linkedMembersInAssembly); continue; case nameof (KeptAttributeInAssemblyAttribute): - VerifyKeptAttributeInAssembly (checkAttrInAssembly, linkedAssembly); + // VerifyKeptAttributeInAssembly (checkAttrInAssembly, linkedAssembly); continue; case nameof (RemovedAttributeInAssembly): - VerifyRemovedAttributeInAssembly (checkAttrInAssembly, linkedAssembly); + // VerifyRemovedAttributeInAssembly (checkAttrInAssembly, linkedAssembly); continue; default: break; } - var expectedTypeName = checkAttrInAssembly.ConstructorArguments[1].Value.ToString ()!; - TypeDefinition? linkedType = linkedAssembly.MainModule.GetType (expectedTypeName); + var expectedTypeName = checkAttrInAssembly.ConstructorArguments[1].Value.ToString ()!; + var expectedType = originalTargetAssembly.MainModule.GetType(expectedTypeName); + linkedMembersInAssembly.TryGetValue(new AssemblyQualifiedToken(expectedType), out LinkedEntity? linkedTypeEntity); + MetadataType? linkedType = linkedTypeEntity?.Entity as MetadataType; - if (linkedType == null && linkedAssembly.MainModule.HasExportedTypes) { +#if false + if (linkedType == null && linkedAssembly.MainModule.HasExportedTypes) { ExportedType? exportedType = linkedAssembly.MainModule.ExportedTypes .FirstOrDefault (exported => exported.FullName == expectedTypeName); @@ -1381,6 +1395,7 @@ private void VerifyLinkingOfOtherAssemblies (AssemblyDefinition original) linkedType = exportedType?.Resolve (); } +#endif switch (attributeTypeName) { case nameof (RemovedTypeInAssemblyAttribute): @@ -1392,6 +1407,7 @@ private void VerifyLinkingOfOtherAssemblies (AssemblyDefinition original) if (linkedType == null) Assert.Fail ($"Type `{expectedTypeName}' should have been kept in assembly {assemblyName}"); break; +#if false case nameof (RemovedInterfaceOnTypeInAssemblyAttribute): if (linkedType == null) Assert.Fail ($"Type `{expectedTypeName}' should have been kept in assembly {assemblyName}"); @@ -1444,11 +1460,15 @@ private void VerifyLinkingOfOtherAssemblies (AssemblyDefinition original) Assert.Fail ($"Type `{expectedTypeName}` should have been kept in assembly {assemblyName}"); VerifyExpectedInstructionSequenceOnMemberInAssembly (checkAttrInAssembly, linkedType); break; - default: + default: UnhandledOtherAssemblyAssertion (expectedTypeName, checkAttrInAssembly, linkedType); break; - } - } +#else + default: + break; +#endif + } + } } } catch (AssemblyResolutionException e) { Assert.Fail ($"Failed to resolve linked assembly `{e.AssemblyReference.Name}`. It must not exist in the output."); @@ -1740,54 +1760,62 @@ protected virtual bool TryVerifyKeptMemberInAssemblyAsMethod (string memberName, private void VerifyKeptReferencesInAssembly (CustomAttribute inAssemblyAttribute) { - var assembly = ResolveLinkedAssembly (inAssemblyAttribute.ConstructorArguments[0].Value.ToString ()!); +#if false + var assembly = ResolveLinkedAssembly (inAssemblyAttribute.ConstructorArguments[0].Value.ToString ()!); var expectedReferenceNames = ((CustomAttributeArgument[]) inAssemblyAttribute.ConstructorArguments[1].Value).Select (attr => (string) attr.Value).ToList (); for (int i = 0; i < expectedReferenceNames.Count; i++) if (expectedReferenceNames[i].EndsWith (".dll")) expectedReferenceNames[i] = expectedReferenceNames[i].Substring (0, expectedReferenceNames[i].LastIndexOf (".")); Assert.Equal (assembly.MainModule.AssemblyReferences.Select (asm => asm.Name), expectedReferenceNames); +#endif } private void VerifyKeptResourceInAssembly (CustomAttribute inAssemblyAttribute) { - var assembly = ResolveLinkedAssembly (inAssemblyAttribute.ConstructorArguments[0].Value.ToString ()!); +#if false + var assembly = ResolveLinkedAssembly (inAssemblyAttribute.ConstructorArguments[0].Value.ToString ()!); var resourceName = inAssemblyAttribute.ConstructorArguments[1].Value.ToString (); Assert.Contains (resourceName, assembly.MainModule.Resources.Select (r => r.Name)); +#endif } private void VerifyRemovedResourceInAssembly (CustomAttribute inAssemblyAttribute) { - var assembly = ResolveLinkedAssembly (inAssemblyAttribute.ConstructorArguments[0].Value.ToString ()!); +#if false + var assembly = ResolveLinkedAssembly (inAssemblyAttribute.ConstructorArguments[0].Value.ToString ()!); var resourceName = inAssemblyAttribute.ConstructorArguments[1].Value.ToString (); Assert.DoesNotContain (resourceName, assembly.MainModule.Resources.Select (r => r.Name)); +#endif } - private void VerifyKeptAllTypesAndMembersInAssembly (AssemblyDefinition linked) + private void VerifyKeptAllTypesAndMembersInAssembly (string assemblyName, Dictionary linkedMembers) { - var original = ResolveOriginalsAssembly (linked.MainModule.Assembly.Name.Name); + var original = ResolveOriginalsAssembly (assemblyName); if (original == null) - Assert.Fail ($"Failed to resolve original assembly {linked.MainModule.Assembly.Name.Name}"); + Assert.Fail ($"Failed to resolve original assembly {assemblyName}"); - var originalTypes = original.AllDefinedTypes ().ToDictionary (t => t.FullName); - var linkedTypes = linked.AllDefinedTypes ().ToDictionary (t => t.FullName); + var originalTypes = original.AllDefinedTypes ().ToDictionary (t => new AssemblyQualifiedToken(t)); + var linkedTypes = linkedMembers.Where(t => t.Value.Entity is TypeDesc).ToDictionary(); var missingInLinked = originalTypes.Keys.Except (linkedTypes.Keys); - Assert.True (missingInLinked.Any (), $"Expected all types to exist in the linked assembly, but one or more were missing"); + Assert.False (missingInLinked.Any (), $"Expected all types to exist in the linked assembly {assemblyName}, but one or more were missing"); foreach (var originalKvp in originalTypes) { var linkedType = linkedTypes[originalKvp.Key]; + TypeDesc linkedTypeDesc = (TypeDesc)linkedType.Entity; - var originalMembers = originalKvp.Value.AllMembers ().Select (m => m.FullName); - var linkedMembers = linkedType.AllMembers ().Select (m => m.FullName); + // NativeAOT field trimming is very different (it basically doesn't trim fields, not in the same way trimmer does) + var originalMembers = originalKvp.Value.AllMembers ().Where(m => m is not FieldDefinition).Select (m => new AssemblyQualifiedToken(m)); + var linkedMembersOnType = linkedMembers.Where(t => GetOwningType(t.Value.Entity) == linkedTypeDesc).Select(t => t.Key); - var missingMembersInLinked = originalMembers.Except (linkedMembers); + var missingMembersInLinked = originalMembers.Except (linkedMembersOnType); - Assert.True (missingMembersInLinked.Any (), $"Expected all members of `{originalKvp.Key}`to exist in the linked assembly, but one or more were missing"); + Assert.False (missingMembersInLinked.Any (), $"Expected all members of `{linkedTypeDesc.GetDisplayName()}`to exist in the linked assembly, but one or more were missing"); } } @@ -1823,6 +1851,11 @@ private static Dictionary> BuildOtherAssemblyCheck foreach (var typeWithRemoveInAssembly in original.AllDefinedTypes ()) { foreach (var attr in typeWithRemoveInAssembly.CustomAttributes.Where (IsTypeInOtherAssemblyAssertion)) { var assemblyName = (string) attr.ConstructorArguments[0].Value; + + Tool? toolTarget = (Tool?)(int?)attr.GetPropertyValue("Tool"); + if (toolTarget is not null && !toolTarget.Value.HasFlag(Tool.NativeAot)) + continue; + if (!checks.TryGetValue (assemblyName, out List? checksForAssembly)) checks[assemblyName] = checksForAssembly = new List (); @@ -1833,14 +1866,13 @@ private static Dictionary> BuildOtherAssemblyCheck return checks; } - protected AssemblyDefinition ResolveLinkedAssembly (string assemblyName) + private Dictionary ResolveLinkedMembersForAssembly (string assemblyName) { - //var cleanAssemblyName = assemblyName; - //if (assemblyName.EndsWith (".exe") || assemblyName.EndsWith (".dll")) - //cleanAssemblyName = System.IO.Path.GetFileNameWithoutExtension (assemblyName); - //return _linkedResolver.Resolve (new AssemblyNameReference (cleanAssemblyName, null), _linkedReaderParameters); - // TODO - adapt to Native AOT - return ResolveOriginalsAssembly (assemblyName); + var cleanAssemblyName = assemblyName; + if (assemblyName.EndsWith(".exe") || assemblyName.EndsWith(".dll")) + cleanAssemblyName = System.IO.Path.GetFileNameWithoutExtension(assemblyName); + + return this.linkedMembers.Where(e => GetModuleName(e.Value.Entity) == cleanAssemblyName).ToDictionary(); } protected AssemblyDefinition ResolveOriginalsAssembly (string assemblyName) diff --git a/src/coreclr/tools/aot/Mono.Linker.Tests/TestCasesRunner/ILCompilerDriver.cs b/src/coreclr/tools/aot/Mono.Linker.Tests/TestCasesRunner/ILCompilerDriver.cs index 275d035c66843e..5fe3c9adf30d40 100644 --- a/src/coreclr/tools/aot/Mono.Linker.Tests/TestCasesRunner/ILCompilerDriver.cs +++ b/src/coreclr/tools/aot/Mono.Linker.Tests/TestCasesRunner/ILCompilerDriver.cs @@ -103,7 +103,7 @@ public ILScanResults Trim (ILCompilerOptions options, ILogWriter logWriter) new ManifestResourceBlockingPolicy (logger, options.FeatureSwitches, new Dictionary>()), logFile: null, new NoStackTraceEmissionPolicy (), - new NoDynamicInvokeThunkGenerationPolicy (), + new DefaultDynamicInvokeThunkGenerationPolicy (), new FlowAnnotations (logger, ilProvider, compilerGeneratedState), UsageBasedMetadataGenerationOptions.ReflectionILScanning, options: default, diff --git a/src/coreclr/tools/aot/Mono.Linker.Tests/TestCasesRunner/TestCaseCompilationMetadataProvider.cs b/src/coreclr/tools/aot/Mono.Linker.Tests/TestCasesRunner/TestCaseCompilationMetadataProvider.cs index e97a51a5c2c737..6441776f33ad01 100644 --- a/src/coreclr/tools/aot/Mono.Linker.Tests/TestCasesRunner/TestCaseCompilationMetadataProvider.cs +++ b/src/coreclr/tools/aot/Mono.Linker.Tests/TestCasesRunner/TestCaseCompilationMetadataProvider.cs @@ -129,9 +129,11 @@ public virtual IEnumerable GetCommonReferencedAssemblies (NPath workingD yield return Path.Combine (referenceDir, "mscorlib.dll"); yield return Path.Combine (referenceDir, "System.Collections.dll"); + yield return Path.Combine (referenceDir, "System.Collections.Immutable.dll"); yield return Path.Combine (referenceDir, "System.ComponentModel.TypeConverter.dll"); yield return Path.Combine (referenceDir, "System.Console.dll"); yield return Path.Combine (referenceDir, "System.Linq.Expressions.dll"); + yield return Path.Combine (referenceDir, "System.Memory.dll"); yield return Path.Combine (referenceDir, "System.ObjectModel.dll"); yield return Path.Combine (referenceDir, "System.Runtime.dll"); yield return Path.Combine (referenceDir, "System.Runtime.Extensions.dll"); diff --git a/src/coreclr/utilcode/stresslog.cpp b/src/coreclr/utilcode/stresslog.cpp index c55c5afe9249c8..90ad5900473ed7 100644 --- a/src/coreclr/utilcode/stresslog.cpp +++ b/src/coreclr/utilcode/stresslog.cpp @@ -12,9 +12,9 @@ #include "switches.h" #include "stresslog.h" #include "clrhost.h" +#include "ex.h" #define DONOT_DEFINE_ETW_CALLBACK #include "eventtracebase.h" -#include "ex.h" #if !defined(STRESS_LOG_READONLY) #ifdef HOST_WINDOWS diff --git a/src/coreclr/vm/amd64/jithelpers_fast.S b/src/coreclr/vm/amd64/jithelpers_fast.S index 32890b471b26c1..72f91a18061579 100644 --- a/src/coreclr/vm/amd64/jithelpers_fast.S +++ b/src/coreclr/vm/amd64/jithelpers_fast.S @@ -32,16 +32,16 @@ LEAF_ENTRY JIT_CheckedWriteBarrier, _TEXT // See if this is in GCHeap PREPARE_EXTERNAL_VAR g_lowest_address, rax cmp rdi, [rax] - // jb NotInHeap + // jb LOCAL_LABEL(NotInHeap) .byte 0x72, 0x12 PREPARE_EXTERNAL_VAR g_highest_address, rax cmp rdi, [rax] - // jnb NotInHeap + // jnb LOCAL_LABEL(NotInHeap) .byte 0x73, 0x06 jmp [rip + C_FUNC(JIT_WriteBarrier_Loc)] - NotInHeap: + LOCAL_LABEL(NotInHeap): // See comment above about possible AV mov [rdi], rsi ret @@ -85,16 +85,16 @@ LEAF_ENTRY JIT_WriteBarrier, _TEXT add rax, r10 cmp byte ptr [rax], 0x0 .byte 0x75, 0x06 - // jne CheckCardTable + // jne LOCAL_LABEL(CheckCardTable) mov byte ptr [rax], 0xFF NOP_3_BYTE // padding for alignment of constant // Check the lower and upper ephemeral region bounds - CheckCardTable: + LOCAL_LABEL(CheckCardTable): cmp rsi, r11 .byte 0x72,0x3D - // jb Exit + // jb LOCAL_LABEL(Exit) NOP_3_BYTE // padding for alignment of constant @@ -102,7 +102,7 @@ LEAF_ENTRY JIT_WriteBarrier, _TEXT cmp rsi, r10 .byte 0x73,0x2B - // jae Exit + // jae LOCAL_LABEL(Exit) nop // padding for alignment of constant @@ -112,10 +112,10 @@ LEAF_ENTRY JIT_WriteBarrier, _TEXT shr rdi, 0x0B cmp byte ptr [rdi + rax], 0xFF .byte 0x75, 0x02 - // jne UpdateCardTable + // jne LOCAL_LABEL(UpdateCardTable) REPRET - UpdateCardTable: + LOCAL_LABEL(UpdateCardTable): mov byte ptr [rdi + rax], 0xFF #ifdef FEATURE_MANUALLY_MANAGED_CARD_BUNDLES @@ -126,17 +126,17 @@ LEAF_ENTRY JIT_WriteBarrier, _TEXT cmp byte ptr [rdi + rax], 0xFF .byte 0x75, 0x02 - // jne UpdateCardBundle_WriteWatch_PostGrow64 + // jne LOCAL_LABEL(UpdateCardBundle_WriteWatch_PostGrow64) REPRET - UpdateCardBundle_WriteWatch_PostGrow64: + LOCAL_LABEL(UpdateCardBundle_WriteWatch_PostGrow64): mov byte ptr [rdi + rax], 0xFF #endif ret .balign 16 - Exit: + LOCAL_LABEL(Exit): REPRET NOP_3_BYTE @@ -184,7 +184,7 @@ LEAF_ENTRY JIT_WriteBarrier, _TEXT // Check the lower and upper ephemeral region bounds cmp rsi, rax - // jb Exit + // jb LOCAL_LABEL(Exit) .byte 0x72, 0x36 nop // padding for alignment of constant @@ -192,7 +192,7 @@ LEAF_ENTRY JIT_WriteBarrier, _TEXT movabs r8, 0xF0F0F0F0F0F0F0F0 cmp rsi, r8 - // jae Exit + // jae LOCAL_LABEL(Exit) .byte 0x73, 0x26 nop // padding for alignment of constant @@ -203,10 +203,10 @@ LEAF_ENTRY JIT_WriteBarrier, _TEXT shr rdi, 0Bh cmp byte ptr [rdi + rax], 0FFh .byte 0x75, 0x02 - // jne UpdateCardTable + // jne LOCAL_LABEL(UpdateCardTable) REPRET - UpdateCardTable: + LOCAL_LABEL(UpdateCardTable): mov byte ptr [rdi + rax], 0FFh #ifdef FEATURE_MANUALLY_MANAGED_CARD_BUNDLES @@ -220,17 +220,17 @@ LEAF_ENTRY JIT_WriteBarrier, _TEXT cmp byte ptr [rdi + rax], 0FFh .byte 0x75, 0x02 - // jne UpdateCardBundle + // jne LOCAL_LABEL(UpdateCardBundle) REPRET - UpdateCardBundle: + LOCAL_LABEL(UpdateCardBundle): mov byte ptr [rdi + rax], 0FFh #endif ret .balign 16 - Exit: + LOCAL_LABEL(Exit): REPRET #endif @@ -277,30 +277,30 @@ LEAF_ENTRY JIT_ByRefWriteBarrier, _TEXT // See if this is in GCHeap PREPARE_EXTERNAL_VAR g_lowest_address, rax cmp rdi, [rax] - jb NotInHeap_ByRefWriteBarrier + jb LOCAL_LABEL(NotInHeap_ByRefWriteBarrier) PREPARE_EXTERNAL_VAR g_highest_address, rax cmp rdi, [rax] - jnb NotInHeap_ByRefWriteBarrier + jnb LOCAL_LABEL(NotInHeap_ByRefWriteBarrier) #ifdef WRITE_BARRIER_CHECK // **ALSO update the shadow GC heap if that is enabled** // Do not perform the work if g_GCShadow is 0 PREPARE_EXTERNAL_VAR g_GCShadow, rax cmp qword ptr [rax], 0 - je NoShadow_ByRefWriteBarrier + je LOCAL_LABEL(NoShadow_ByRefWriteBarrier) // If we end up outside of the heap don't corrupt random memory mov r10, rdi PREPARE_EXTERNAL_VAR g_lowest_address, rax sub r10, [rax] - jb NoShadow_ByRefWriteBarrier + jb LOCAL_LABEL(NoShadow_ByRefWriteBarrier) // Check that our adjusted destination is somewhere in the shadow gc PREPARE_EXTERNAL_VAR g_GCShadow, rax add r10, [rax] PREPARE_EXTERNAL_VAR g_GCShadowEnd, rax cmp r10, [rax] - jnb NoShadow_ByRefWriteBarrier + jnb LOCAL_LABEL(NoShadow_ByRefWriteBarrier) // Write ref into real GC mov [rdi], rcx @@ -315,73 +315,73 @@ LEAF_ENTRY JIT_ByRefWriteBarrier, _TEXT mov r11, [rdi] mov rax, [r10] cmp rax, r11 - je DoneShadow_ByRefWriteBarrier + je LOCAL_LABEL(DoneShadow_ByRefWriteBarrier) movabs r11, INVALIDGCVALUE mov [r10], r11 - jmp DoneShadow_ByRefWriteBarrier + jmp LOCAL_LABEL(DoneShadow_ByRefWriteBarrier) // If we don't have a shadow GC we won't have done the write yet - NoShadow_ByRefWriteBarrier: + LOCAL_LABEL(NoShadow_ByRefWriteBarrier): mov [rdi], rcx // If we had a shadow GC then we already wrote to the real GC at the same time // as the shadow GC so we want to jump over the real write immediately above. // Additionally we know for sure that we are inside the heap and therefore don't // need to replicate the above checks. - DoneShadow_ByRefWriteBarrier: + LOCAL_LABEL(DoneShadow_ByRefWriteBarrier): #endif #ifdef FEATURE_USE_SOFTWARE_WRITE_WATCH_FOR_GC_HEAP // Update the write watch table if necessary PREPARE_EXTERNAL_VAR g_sw_ww_enabled_for_gc_heap, rax cmp byte ptr [rax], 0x0 - je CheckCardTable_ByRefWriteBarrier + je LOCAL_LABEL(CheckCardTable_ByRefWriteBarrier) mov rax, rdi shr rax, 0xC // SoftwareWriteWatch::AddressToTableByteIndexShift PREPARE_EXTERNAL_VAR g_sw_ww_table, r10 add rax, qword ptr [r10] cmp byte ptr [rax], 0x0 - jne CheckCardTable_ByRefWriteBarrier + jne LOCAL_LABEL(CheckCardTable_ByRefWriteBarrier) mov byte ptr [rax], 0xFF #endif - CheckCardTable_ByRefWriteBarrier: + LOCAL_LABEL(CheckCardTable_ByRefWriteBarrier): // See if we can just quick out PREPARE_EXTERNAL_VAR g_ephemeral_low, rax cmp rcx, [rax] - jb Exit_ByRefWriteBarrier + jb LOCAL_LABEL(Exit_ByRefWriteBarrier) PREPARE_EXTERNAL_VAR g_ephemeral_high, rax cmp rcx, [rax] - jnb Exit_ByRefWriteBarrier + jnb LOCAL_LABEL(Exit_ByRefWriteBarrier) mov rax, rcx PREPARE_EXTERNAL_VAR g_region_shr, rcx mov cl, [rcx] test cl, cl - je SkipCheck_ByRefWriteBarrier + je LOCAL_LABEL(SkipCheck_ByRefWriteBarrier) // check if the source is in gen 2 - then it's not an ephemeral pointer shr rax, cl PREPARE_EXTERNAL_VAR g_region_to_generation_table, r10 mov r10, [r10] cmp byte ptr [rax + r10], 0x82 - je Exit_ByRefWriteBarrier + je LOCAL_LABEL(Exit_ByRefWriteBarrier) // check if the destination happens to be in gen 0 mov rax, rdi shr rax, cl cmp byte ptr [rax + r10], 0 - je Exit_ByRefWriteBarrier - SkipCheck_ByRefWriteBarrier: + je LOCAL_LABEL(Exit_ByRefWriteBarrier) + LOCAL_LABEL(SkipCheck_ByRefWriteBarrier): PREPARE_EXTERNAL_VAR g_card_table, r10 mov r10, [r10] PREPARE_EXTERNAL_VAR g_region_use_bitwise_write_barrier, rax cmp byte ptr [rax], 0 - je CheckCardTableByte_ByRefWriteBarrier + je LOCAL_LABEL(CheckCardTableByte_ByRefWriteBarrier) // compute card table bit mov ecx, edi @@ -400,15 +400,15 @@ LEAF_ENTRY JIT_ByRefWriteBarrier, _TEXT shr rcx, 0xB // Check if this card table bit is already set test byte ptr [rcx + r10], al - je SetCardTableBit_ByRefWriteBarrier + je LOCAL_LABEL(SetCardTableBit_ByRefWriteBarrier) REPRET - SetCardTableBit_ByRefWriteBarrier: + LOCAL_LABEL(SetCardTableBit_ByRefWriteBarrier): lock or byte ptr [rcx + r10], al - jmp CheckCardBundle_ByRefWriteBarrier + jmp LOCAL_LABEL(CheckCardBundle_ByRefWriteBarrier) - CheckCardTableByte_ByRefWriteBarrier: + LOCAL_LABEL(CheckCardTableByte_ByRefWriteBarrier): // move current rdi value into rcx and then increment the pointers mov rcx, rdi add rsi, 0x8 @@ -416,12 +416,12 @@ LEAF_ENTRY JIT_ByRefWriteBarrier, _TEXT shr rcx, 0xB cmp byte ptr [rcx + r10], 0xFF - jne SetCardTableByte_ByRefWriteBarrier + jne LOCAL_LABEL(SetCardTableByte_ByRefWriteBarrier) REPRET - SetCardTableByte_ByRefWriteBarrier: + LOCAL_LABEL(SetCardTableByte_ByRefWriteBarrier): mov byte ptr [rcx + r10], 0xFF - CheckCardBundle_ByRefWriteBarrier: + LOCAL_LABEL(CheckCardBundle_ByRefWriteBarrier): #ifdef FEATURE_MANUALLY_MANAGED_CARD_BUNDLES // Shift rcx by 0x0A more to get the card bundle byte (we shifted by 0x0B already) @@ -433,17 +433,17 @@ LEAF_ENTRY JIT_ByRefWriteBarrier, _TEXT // Check if this bundle byte is dirty cmp byte ptr [rcx], 0xFF - jne UpdateCardBundle_ByRefWriteBarrier + jne LOCAL_LABEL(UpdateCardBundle_ByRefWriteBarrier) REPRET - UpdateCardBundle_ByRefWriteBarrier: + LOCAL_LABEL(UpdateCardBundle_ByRefWriteBarrier): mov byte ptr [rcx], 0xFF #endif ret .balign 16 - NotInHeap_ByRefWriteBarrier: + LOCAL_LABEL(NotInHeap_ByRefWriteBarrier): // If WRITE_BARRIER_CHECK then we won't have already done the mov and should do it here // If !WRITE_BARRIER_CHECK we want _NotInHeap and _Leave to be the same and have both // 16 byte aligned. @@ -451,7 +451,7 @@ LEAF_ENTRY JIT_ByRefWriteBarrier, _TEXT // rcx is [rsi] mov [rdi], rcx #endif - Exit_ByRefWriteBarrier: + LOCAL_LABEL(Exit_ByRefWriteBarrier): // Increment the pointers before leaving add rdi, 0x8 add rsi, 0x8 diff --git a/src/coreclr/vm/appdomain.hpp b/src/coreclr/vm/appdomain.hpp index 2103ba6c654593..7633fb9f2881bf 100644 --- a/src/coreclr/vm/appdomain.hpp +++ b/src/coreclr/vm/appdomain.hpp @@ -2451,12 +2451,24 @@ class SystemDomain : public BaseDomain } static FrozenObjectHeapManager* GetFrozenObjectHeapManager() { - WRAPPER_NO_CONTRACT; - if (m_FrozenObjectHeapManager == NULL) + CONTRACTL + { + THROWS; + MODE_COOPERATIVE; + } + CONTRACTL_END; + + if (VolatileLoad(&m_FrozenObjectHeapManager) == nullptr) { LazyInitFrozenObjectsHeap(); } - return m_FrozenObjectHeapManager; + return VolatileLoad(&m_FrozenObjectHeapManager); + } + static FrozenObjectHeapManager* GetFrozenObjectHeapManagerNoThrow() + { + LIMITED_METHOD_CONTRACT; + + return VolatileLoad(&m_FrozenObjectHeapManager); } #endif // DACCESS_COMPILE diff --git a/src/coreclr/vm/arm/cgencpu.h b/src/coreclr/vm/arm/cgencpu.h index d31700e3477a40..6538cea705a56a 100644 --- a/src/coreclr/vm/arm/cgencpu.h +++ b/src/coreclr/vm/arm/cgencpu.h @@ -996,7 +996,7 @@ inline BOOL ClrFlushInstructionCache(LPCVOID pCodeAddr, size_t sizeOfCode, bool // Precode to shuffle this and retbuf for closed delegates over static methods with return buffer struct ThisPtrRetBufPrecode { - static const int Type = 0x46; + static const int Type = 0x01; // mov r12, r0 // mov r0, r1 diff --git a/src/coreclr/vm/arm/stubs.cpp b/src/coreclr/vm/arm/stubs.cpp index 634eea810a31df..36eaeb51cdc5de 100644 --- a/src/coreclr/vm/arm/stubs.cpp +++ b/src/coreclr/vm/arm/stubs.cpp @@ -671,6 +671,16 @@ void HelperMethodFrame::UpdateRegDisplay(const PREGDISPLAY pRD) pRD->pCurrentContext->R10 = (DWORD)(pUnwoundState->captureR4_R11[6]); pRD->pCurrentContext->R11 = (DWORD)(pUnwoundState->captureR4_R11[7]); + pRD->pCurrentContextPointers->R4 = &pRD->pCurrentContext->R4; + pRD->pCurrentContextPointers->R5 = &pRD->pCurrentContext->R5; + pRD->pCurrentContextPointers->R6 = &pRD->pCurrentContext->R6; + pRD->pCurrentContextPointers->R7 = &pRD->pCurrentContext->R7; + pRD->pCurrentContextPointers->R8 = &pRD->pCurrentContext->R8; + pRD->pCurrentContextPointers->R9 = &pRD->pCurrentContext->R9; + pRD->pCurrentContextPointers->R10 = &pRD->pCurrentContext->R10; + pRD->pCurrentContextPointers->R11 = &pRD->pCurrentContext->R11; + pRD->pCurrentContextPointers->Lr = &pRD->pCurrentContext->Lr; + return; } #endif // DACCESS_COMPILE diff --git a/src/coreclr/vm/arm64/asmhelpers.S b/src/coreclr/vm/arm64/asmhelpers.S index cdbe24ec427a98..89dab80461c356 100644 --- a/src/coreclr/vm/arm64/asmhelpers.S +++ b/src/coreclr/vm/arm64/asmhelpers.S @@ -329,7 +329,9 @@ WRITE_BARRIER_ENTRY JIT_CheckedWriteBarrier // branch below is not taken. ccmp x14, x12, #0x2, hs - blo C_FUNC(JIT_WriteBarrier) + bhs LOCAL_LABEL(NotInHeap) + + b C_FUNC(JIT_WriteBarrier) LOCAL_LABEL(NotInHeap): str x15, [x14], 8 diff --git a/src/coreclr/vm/arm64/stubs.cpp b/src/coreclr/vm/arm64/stubs.cpp index bc3e2b9609caeb..4ae26363fd2d24 100644 --- a/src/coreclr/vm/arm64/stubs.cpp +++ b/src/coreclr/vm/arm64/stubs.cpp @@ -472,18 +472,18 @@ void HelperMethodFrame::UpdateRegDisplay(const PREGDISPLAY pRD) pRD->pCurrentContext->Fp = (DWORD64)(pUnwoundState->captureX19_X29[10]); pRD->pCurrentContext->Lr = NULL; // Unwind again to get Caller's PC - pRD->pCurrentContextPointers->X19 = pUnwoundState->ptrX19_X29[0]; - pRD->pCurrentContextPointers->X20 = pUnwoundState->ptrX19_X29[1]; - pRD->pCurrentContextPointers->X21 = pUnwoundState->ptrX19_X29[2]; - pRD->pCurrentContextPointers->X22 = pUnwoundState->ptrX19_X29[3]; - pRD->pCurrentContextPointers->X23 = pUnwoundState->ptrX19_X29[4]; - pRD->pCurrentContextPointers->X24 = pUnwoundState->ptrX19_X29[5]; - pRD->pCurrentContextPointers->X25 = pUnwoundState->ptrX19_X29[6]; - pRD->pCurrentContextPointers->X26 = pUnwoundState->ptrX19_X29[7]; - pRD->pCurrentContextPointers->X27 = pUnwoundState->ptrX19_X29[8]; - pRD->pCurrentContextPointers->X28 = pUnwoundState->ptrX19_X29[9]; - pRD->pCurrentContextPointers->Fp = pUnwoundState->ptrX19_X29[10]; - pRD->pCurrentContextPointers->Lr = NULL; + pRD->pCurrentContextPointers->X19 = &pRD->pCurrentContext->X19; + pRD->pCurrentContextPointers->X20 = &pRD->pCurrentContext->X20; + pRD->pCurrentContextPointers->X21 = &pRD->pCurrentContext->X21; + pRD->pCurrentContextPointers->X22 = &pRD->pCurrentContext->X22; + pRD->pCurrentContextPointers->X23 = &pRD->pCurrentContext->X23; + pRD->pCurrentContextPointers->X24 = &pRD->pCurrentContext->X24; + pRD->pCurrentContextPointers->X25 = &pRD->pCurrentContext->X25; + pRD->pCurrentContextPointers->X26 = &pRD->pCurrentContext->X26; + pRD->pCurrentContextPointers->X27 = &pRD->pCurrentContext->X27; + pRD->pCurrentContextPointers->X28 = &pRD->pCurrentContext->X28; + pRD->pCurrentContextPointers->Fp = &pRD->pCurrentContext->Fp; + pRD->pCurrentContextPointers->Lr = &pRD->pCurrentContext->Lr; return; } diff --git a/src/coreclr/vm/callsiteinspect.cpp b/src/coreclr/vm/callsiteinspect.cpp index dabbe89a497720..8209e41e6a7d44 100644 --- a/src/coreclr/vm/callsiteinspect.cpp +++ b/src/coreclr/vm/callsiteinspect.cpp @@ -433,7 +433,8 @@ void CallsiteInspect::PropagateOutParametersBackToCallsite( *(ARG_SLOT *)(frame->GetReturnValuePtr()) = retVal; } #ifdef ENREGISTERED_RETURNTYPE_MAXSIZE - else if (argit.HasNonStandardByvalReturn()) + else if (argit.HasNonStandardByvalReturn() + && !(flags & CallsiteDetails::HResultReturn)) { // In these cases, put the pointer to the return buffer into // the frame's return value slot. diff --git a/src/coreclr/vm/callsiteinspect.h b/src/coreclr/vm/callsiteinspect.h index 373b9347dfd9c9..4ca66eca9feba2 100644 --- a/src/coreclr/vm/callsiteinspect.h +++ b/src/coreclr/vm/callsiteinspect.h @@ -25,6 +25,7 @@ struct CallsiteDetails BeginInvoke = 0x01, EndInvoke = 0x02, Ctor = 0x04, + HResultReturn = 0x08, }; INT32 Flags; }; diff --git a/src/coreclr/vm/clrtocomcall.cpp b/src/coreclr/vm/clrtocomcall.cpp index 06d28f507249b4..c604a6c8a90116 100644 --- a/src/coreclr/vm/clrtocomcall.cpp +++ b/src/coreclr/vm/clrtocomcall.cpp @@ -364,7 +364,7 @@ UINT32 CLRToCOMEventCallWorker(ComPlusMethodFrame* pFrame, ComPlusCallMethodDesc return 0; } -CallsiteDetails CreateCallsiteDetails(_In_ FramedMethodFrame *pFrame) +static CallsiteDetails CreateCallsiteDetails(_In_ FramedMethodFrame *pFrame) { CONTRACTL { @@ -442,10 +442,20 @@ CallsiteDetails CreateCallsiteDetails(_In_ FramedMethodFrame *pFrame) SigTypeContext::InitTypeContext(pMD, actualType, &typeContext); } + // If the signature is marked preserve sig, then the return + // is required to be an HRESULT, per COM rules. We set a flag to + // indicate this state to avoid issues when a C# developer defines + // an HRESULT in C# as a ValueClass with a single int field. This + // is convenient but does violate the COM ABI. Setting the flag + // lets us permit this convention and allow either a 4 byte primitive + // or the commonly used C# type "struct HResult { int Value; }". + if (IsMiPreserveSig(pMD->GetImplAttrs())) + callsiteFlags |= CallsiteDetails::HResultReturn; + _ASSERTE(!signature.IsEmpty() && pModule != nullptr); // Create details - return CallsiteDetails{ { signature, pModule, &typeContext }, pFrame, pMD, fIsDelegate }; + return CallsiteDetails{ { signature, pModule, &typeContext }, pFrame, pMD, fIsDelegate, callsiteFlags }; } UINT32 CLRToCOMLateBoundWorker( diff --git a/src/coreclr/vm/dwbucketmanager.hpp b/src/coreclr/vm/dwbucketmanager.hpp index 167685f78a597c..82532aa5b42b4b 100644 --- a/src/coreclr/vm/dwbucketmanager.hpp +++ b/src/coreclr/vm/dwbucketmanager.hpp @@ -960,11 +960,9 @@ bool BaseBucketParamsManager::GetFileVersionInfoForModule(Module* pModule, USHOR // if we failed to get the version info from the native image then fall back to the IL image. if (!succeeded) { - LPCWSTR modulePath = pPEAssembly->GetPath().GetUnicode(); - if (modulePath != NULL && modulePath != SString::Empty() && SUCCEEDED(DwGetFileVersionInfo(modulePath, major, minor, build, revision))) - { - succeeded = true; - } + const SString& modulePath = pPEAssembly->GetPath(); + _ASSERTE(modulePath.IsNormalized()); + succeeded = !modulePath.IsEmpty() && SUCCEEDED(DwGetFileVersionInfo(modulePath.GetUnicode(), major, minor, build, revision)); } } diff --git a/src/coreclr/vm/eeconfig.cpp b/src/coreclr/vm/eeconfig.cpp index 417910eb183bc0..0f1de4d4fe788a 100644 --- a/src/coreclr/vm/eeconfig.cpp +++ b/src/coreclr/vm/eeconfig.cpp @@ -721,7 +721,8 @@ HRESULT EEConfig::sync() fTieredCompilation_CallCounting = CLRConfig::GetConfigValue(CLRConfig::INTERNAL_TC_CallCounting) != 0; DWORD tieredCompilation_ConfiguredCallCountThreshold = - CLRConfig::GetConfigValue(CLRConfig::INTERNAL_TC_CallCountThreshold); + Configuration::GetKnobDWORDValue(W("System.Runtime.TieredCompilation.CallCountThreshold"), CLRConfig::EXTERNAL_TC_CallCountThreshold); + if (tieredCompilation_ConfiguredCallCountThreshold == 0) { tieredCompilation_CallCountThreshold = 1; @@ -735,8 +736,9 @@ HRESULT EEConfig::sync() tieredCompilation_CallCountThreshold = (UINT16)tieredCompilation_ConfiguredCallCountThreshold; } - tieredCompilation_CallCountingDelayMs = CLRConfig::GetConfigValue(CLRConfig::INTERNAL_TC_CallCountingDelayMs); - + tieredCompilation_CallCountingDelayMs = + Configuration::GetKnobDWORDValue(W("System.Runtime.TieredCompilation.CallCountingDelayMs"), CLRConfig::EXTERNAL_TC_CallCountingDelayMs); + bool hasSingleProcessor = GetCurrentProcessCpuCount() == 1; if (hasSingleProcessor) { diff --git a/src/coreclr/vm/eventtrace.cpp b/src/coreclr/vm/eventtrace.cpp index b9a82cfb7c28e9..9498f00edf4329 100644 --- a/src/coreclr/vm/eventtrace.cpp +++ b/src/coreclr/vm/eventtrace.cpp @@ -3555,7 +3555,7 @@ VOID ETW::MethodLog::MethodJitted(MethodDesc *pMethodDesc, SString *namespaceOrC /*************************************************/ /* This is called by the runtime when method jitting started */ /*************************************************/ -VOID ETW::MethodLog::MethodJitting(MethodDesc *pMethodDesc, SString *namespaceOrClassName, SString *methodName, SString *methodSignature) +VOID ETW::MethodLog::MethodJitting(MethodDesc *pMethodDesc, COR_ILMETHOD_DECODER* methodDecoder, SString *namespaceOrClassName, SString *methodName, SString *methodSignature) { CONTRACTL { NOTHROW; @@ -3570,7 +3570,7 @@ VOID ETW::MethodLog::MethodJitting(MethodDesc *pMethodDesc, SString *namespaceOr CLR_JIT_KEYWORD)) { pMethodDesc->GetMethodInfo(*namespaceOrClassName, *methodName, *methodSignature); - ETW::MethodLog::SendMethodJitStartEvent(pMethodDesc, namespaceOrClassName, methodName, methodSignature); + ETW::MethodLog::SendMethodJitStartEvent(pMethodDesc, methodDecoder, namespaceOrClassName, methodName, methodSignature); } } EX_CATCH { } EX_END_CATCH(SwallowAllExceptions); } @@ -4528,7 +4528,12 @@ VOID ETW::MethodLog::SendNonDuplicateMethodDetailsEvent(MethodDesc* pMethodDesc, /*****************************************************************/ /* This routine is used to send an ETW event just before a method starts jitting*/ /*****************************************************************/ -VOID ETW::MethodLog::SendMethodJitStartEvent(MethodDesc *pMethodDesc, SString *namespaceOrClassName, SString *methodName, SString *methodSignature) +VOID ETW::MethodLog::SendMethodJitStartEvent( + MethodDesc *pMethodDesc, + COR_ILMETHOD_DECODER* methodDecoder, + SString *namespaceOrClassName, + SString *methodName, + SString *methodSignature) { CONTRACTL { THROWS; @@ -4566,15 +4571,13 @@ VOID ETW::MethodLog::SendMethodJitStartEvent(MethodDesc *pMethodDesc, SString *n ulMethodToken = (ULONG)0; } else - ulMethodToken = (ULONG)pMethodDesc->GetMemberDef(); - - if(pMethodDesc->IsIL()) { - COR_ILMETHOD_DECODER::DecoderStatus decoderstatus = COR_ILMETHOD_DECODER::FORMAT_ERROR; - COR_ILMETHOD_DECODER ILHeader(pMethodDesc->GetILHeader(), pMethodDesc->GetMDImport(), &decoderstatus); - ulMethodILSize = (ULONG)ILHeader.GetCodeSize(); + ulMethodToken = (ULONG)pMethodDesc->GetMemberDef(); } + if (methodDecoder != NULL) + ulMethodILSize = methodDecoder->GetCodeSize(); + SString tNamespace, tMethodName, tMethodSignature; if(!namespaceOrClassName|| !methodName|| !methodSignature || (methodName->IsEmpty() && namespaceOrClassName->IsEmpty() && methodSignature->IsEmpty())) { diff --git a/src/coreclr/vm/excep.cpp b/src/coreclr/vm/excep.cpp index 5be645ad496d7d..ce8b325a5f41b8 100644 --- a/src/coreclr/vm/excep.cpp +++ b/src/coreclr/vm/excep.cpp @@ -11616,7 +11616,8 @@ VOID GetAssemblyDetailInfo(SString &sType, SString sAlcName; pPEAssembly->GetAssemblyBinder()->GetNameForDiagnostics(sAlcName); - if (pPEAssembly->GetPath().IsEmpty()) + SString assemblyPath{ pPEAssembly->GetPath() }; + if (assemblyPath.IsEmpty()) { detailsUtf8.Printf("Type %s originates from '%s' in the context '%s' in a byte array", sType.GetUTF8(), @@ -11629,7 +11630,7 @@ VOID GetAssemblyDetailInfo(SString &sType, sType.GetUTF8(), sAssemblyDisplayName.GetUTF8(), sAlcName.GetUTF8(), - pPEAssembly->GetPath().GetUTF8()); + assemblyPath.GetUTF8()); } sAssemblyDetailInfo.Append(detailsUtf8.GetUnicode()); diff --git a/src/coreclr/vm/frozenobjectheap.cpp b/src/coreclr/vm/frozenobjectheap.cpp index 45492155d2089c..8f11f3c8c74d64 100644 --- a/src/coreclr/vm/frozenobjectheap.cpp +++ b/src/coreclr/vm/frozenobjectheap.cpp @@ -10,7 +10,8 @@ #define FOH_COMMIT_SIZE (64 * 1024) FrozenObjectHeapManager::FrozenObjectHeapManager(): - m_Crst(CrstFrozenObjectHeap, CRST_UNSAFE_COOPGC), + m_Crst(CrstFrozenObjectHeap, CRST_UNSAFE_ANYMODE), + m_SegmentRegistrationCrst(CrstFrozenObjectHeap), m_CurrentSegment(nullptr) { } @@ -18,7 +19,9 @@ FrozenObjectHeapManager::FrozenObjectHeapManager(): // Allocates an object of the give size (including header) on a frozen segment. // May return nullptr if object is too large (larger than FOH_COMMIT_SIZE) // in such cases caller is responsible to find a more appropriate heap to allocate it -Object* FrozenObjectHeapManager::TryAllocateObject(PTR_MethodTable type, size_t objectSize, bool publish) + +Object* FrozenObjectHeapManager::TryAllocateObject(PTR_MethodTable type, size_t objectSize, + void(*initFunc)(Object*, void*), void* pParam) { CONTRACTL { @@ -33,64 +36,83 @@ Object* FrozenObjectHeapManager::TryAllocateObject(PTR_MethodTable type, size_t #else // FEATURE_BASICFREEZE Object* obj = nullptr; - { - CrstHolder ch(&m_Crst); - - _ASSERT(type != nullptr); - _ASSERT(FOH_COMMIT_SIZE >= MIN_OBJECT_SIZE); - - // Currently we don't support frozen objects with special alignment requirements - // TODO: We should also give up on arrays of doubles on 32-bit platforms. - // (we currently never allocate them on frozen segments) - #ifdef FEATURE_64BIT_ALIGNMENT - if (type->RequiresAlign8()) - { - // Align8 objects are not supported yet - return nullptr; - } - #endif + FrozenObjectSegment* curSeg = nullptr; + uint8_t* curSegmentCurrent = nullptr; + size_t curSegSizeCommitted = 0; - // NOTE: objectSize is expected be the full size including header - _ASSERT(objectSize >= MIN_OBJECT_SIZE); - - if (objectSize > FOH_COMMIT_SIZE) + { + GCX_PREEMP(); { - // The current design doesn't allow objects larger than FOH_COMMIT_SIZE and - // since FrozenObjectHeap is just an optimization, let's not fill it with huge objects. - return nullptr; - } - - if (m_CurrentSegment == nullptr) + CrstHolder ch(&m_Crst); + + _ASSERT(type != nullptr); + _ASSERT(FOH_COMMIT_SIZE >= MIN_OBJECT_SIZE); + + // Currently we don't support frozen objects with special alignment requirements + // TODO: We should also give up on arrays of doubles on 32-bit platforms. + // (we currently never allocate them on frozen segments) +#ifdef FEATURE_64BIT_ALIGNMENT + if (type->RequiresAlign8()) + { + // Align8 objects are not supported yet + return nullptr; + } +#endif + + // NOTE: objectSize is expected be the full size including header + _ASSERT(objectSize >= MIN_OBJECT_SIZE); + + if (objectSize > FOH_COMMIT_SIZE) + { + // The current design doesn't allow objects larger than FOH_COMMIT_SIZE and + // since FrozenObjectHeap is just an optimization, let's not fill it with huge objects. + return nullptr; + } + + obj = m_CurrentSegment == nullptr ? nullptr : m_CurrentSegment->TryAllocateObject(type, objectSize); + // obj is nullptr if the current segment is full or hasn't been allocated yet + if (obj == nullptr) + { + size_t newSegmentSize = FOH_SEGMENT_DEFAULT_SIZE; + if (m_CurrentSegment != nullptr) + { + // Double the reserved size to reduce the number of frozen segments in apps with lots of frozen objects + // Use the same size in case if prevSegmentSize*2 operation overflows. + const size_t prevSegmentSize = m_CurrentSegment->m_Size; + newSegmentSize = max(prevSegmentSize, prevSegmentSize * 2); + } + + m_CurrentSegment = new FrozenObjectSegment(newSegmentSize); + m_FrozenSegments.Append(m_CurrentSegment); + + // Try again + obj = m_CurrentSegment->TryAllocateObject(type, objectSize); + + // This time it's not expected to be null + _ASSERT(obj != nullptr); + } + + if (initFunc != nullptr) + { + initFunc(obj, pParam); + } + + curSeg = m_CurrentSegment; + curSegSizeCommitted = curSeg->m_SizeCommitted; + curSegmentCurrent = curSeg->m_pCurrent; + } // end of m_Crst lock + + // Let GC know about the new segment or changes in it. + // We do it under a new lock because the main one (m_Crst) can be used by Profiler in a GC's thread + // and that might cause deadlocks since RegisterFrozenSegment may stuck on GC's lock. { - // Create the first segment on first allocation - m_CurrentSegment = new FrozenObjectSegment(FOH_SEGMENT_DEFAULT_SIZE); - m_FrozenSegments.Append(m_CurrentSegment); - _ASSERT(m_CurrentSegment != nullptr); + CrstHolder regLock(&m_SegmentRegistrationCrst); + curSeg->RegisterOrUpdate(curSegmentCurrent, curSegSizeCommitted); } - obj = m_CurrentSegment->TryAllocateObject(type, objectSize); - - // The only case where it can be null is when the current segment is full and we need - // to create a new one - if (obj == nullptr) - { - // Double the reserved size to reduce the number of frozen segments in apps with lots of frozen objects - // Use the same size in case if prevSegmentSize*2 operation overflows. - size_t prevSegmentSize = m_CurrentSegment->GetSize(); - m_CurrentSegment = new FrozenObjectSegment(max(prevSegmentSize, prevSegmentSize * 2)); - m_FrozenSegments.Append(m_CurrentSegment); - - // Try again - obj = m_CurrentSegment->TryAllocateObject(type, objectSize); + } // end of GCX_PREEMP - // This time it's not expected to be null - _ASSERT(obj != nullptr); - } - } - if (publish) - { - PublishFrozenObject(obj); - } + PublishFrozenObject(obj); return obj; #endif // !FEATURE_BASICFREEZE @@ -101,10 +123,10 @@ Object* FrozenObjectHeapManager::TryAllocateObject(PTR_MethodTable type, size_t FrozenObjectSegment::FrozenObjectSegment(size_t sizeHint) : m_pStart(nullptr), m_pCurrent(nullptr), + m_pCurrentRegistered(nullptr), m_SizeCommitted(0), m_Size(sizeHint), m_SegmentHandle(nullptr) - COMMA_INDEBUG(m_ObjectsCount(0)) { _ASSERT(m_Size > FOH_COMMIT_SIZE); _ASSERT(m_Size % FOH_COMMIT_SIZE == 0); @@ -135,34 +157,61 @@ FrozenObjectSegment::FrozenObjectSegment(size_t sizeHint) : ThrowOutOfMemory(); } + m_pStart = static_cast(committedAlloc); + m_pCurrent = m_pStart + sizeof(ObjHeader); + m_SizeCommitted = FOH_COMMIT_SIZE; + // ClrVirtualAlloc is expected to be PageSize-aligned so we can expect // DATA_ALIGNMENT alignment as well _ASSERT(IS_ALIGNED(committedAlloc, DATA_ALIGNMENT)); +} - segment_info si; - si.pvMem = committedAlloc; - si.ibFirstObject = sizeof(ObjHeader); - si.ibAllocated = si.ibFirstObject; - si.ibCommit = FOH_COMMIT_SIZE; - si.ibReserved = m_Size; - - m_SegmentHandle = GCHeapUtilities::GetGCHeap()->RegisterFrozenSegment(&si); - if (m_SegmentHandle == nullptr) +void FrozenObjectSegment::RegisterOrUpdate(uint8_t* current, size_t sizeCommited) +{ + CONTRACTL { - ClrVirtualFree(alloc, 0, MEM_RELEASE); - ThrowOutOfMemory(); + THROWS; + MODE_PREEMPTIVE; } + CONTRACTL_END - m_pStart = static_cast(committedAlloc); - m_pCurrent = m_pStart + sizeof(ObjHeader); - m_SizeCommitted = si.ibCommit; - INDEBUG(m_ObjectsCount = 0); - return; + if (m_pCurrentRegistered == nullptr) + { + segment_info si; + si.pvMem = m_pStart; + si.ibFirstObject = sizeof(ObjHeader); + si.ibAllocated = (size_t)current - (size_t)si.pvMem; + si.ibCommit = sizeCommited; + si.ibReserved = m_Size; + + assert((size_t)current >= (size_t)si.pvMem); + + // NOTE: RegisterFrozenSegment may take a GC lock inside. + m_SegmentHandle = GCHeapUtilities::GetGCHeap()->RegisterFrozenSegment(&si); + if (m_SegmentHandle == nullptr) + { + ThrowOutOfMemory(); + } + m_pCurrentRegistered = current; + } + else + { + if (current > m_pCurrentRegistered) + { + GCHeapUtilities::GetGCHeap()->UpdateFrozenSegment( + m_SegmentHandle, current, m_pStart + sizeCommited); + m_pCurrentRegistered = current; + } + else + { + // Some other thread already advanced it. + } + } } Object* FrozenObjectSegment::TryAllocateObject(PTR_MethodTable type, size_t objectSize) { - _ASSERT(m_pStart != nullptr && m_Size > 0 && m_SegmentHandle != nullptr); // Expected to be inited + _ASSERT((m_pStart != nullptr) && (m_Size > 0)); _ASSERT(IS_ALIGNED(m_pCurrent, DATA_ALIGNMENT)); _ASSERT(IS_ALIGNED(objectSize, DATA_ALIGNMENT)); _ASSERT(objectSize <= FOH_COMMIT_SIZE); @@ -194,16 +243,11 @@ Object* FrozenObjectSegment::TryAllocateObject(PTR_MethodTable type, size_t obje m_SizeCommitted += FOH_COMMIT_SIZE; } - INDEBUG(m_ObjectsCount++); - Object* object = reinterpret_cast(m_pCurrent); object->SetMethodTable(type); m_pCurrent += objectSize; - // Notify GC that we bumped the pointer and, probably, committed more memory in the reserved part - GCHeapUtilities::GetGCHeap()->UpdateFrozenSegment(m_SegmentHandle, m_pCurrent, m_pStart + m_SizeCommitted); - return object; } diff --git a/src/coreclr/vm/frozenobjectheap.h b/src/coreclr/vm/frozenobjectheap.h index d2c0bb62f134af..e191731d64dd5b 100644 --- a/src/coreclr/vm/frozenobjectheap.h +++ b/src/coreclr/vm/frozenobjectheap.h @@ -27,10 +27,12 @@ class FrozenObjectHeapManager { public: FrozenObjectHeapManager(); - Object* TryAllocateObject(PTR_MethodTable type, size_t objectSize, bool publish = true); + Object* TryAllocateObject(PTR_MethodTable type, size_t objectSize, + void(*initFunc)(Object*,void*) = nullptr, void* pParam = nullptr); private: Crst m_Crst; + Crst m_SegmentRegistrationCrst; SArray m_FrozenSegments; FrozenObjectSegment* m_CurrentSegment; @@ -43,10 +45,7 @@ class FrozenObjectSegment public: FrozenObjectSegment(size_t sizeHint); Object* TryAllocateObject(PTR_MethodTable type, size_t objectSize); - size_t GetSize() const - { - return m_Size; - } + void RegisterOrUpdate(uint8_t* current, size_t sizeCommited); private: Object* GetFirstObject() const; @@ -55,12 +54,20 @@ class FrozenObjectSegment // Start of the reserved memory, the first object starts at "m_pStart + sizeof(ObjHeader)" (its pMT) uint8_t* m_pStart; + // NOTE: To handle potential race conditions, only m_[x]Registered fields should be accessed + // externally as they guarantee that GC is aware of the current state of the segment. + // Pointer to the end of the current segment, ready to be used as a pMT for a new object // meaning that "m_pCurrent - sizeof(ObjHeader)" is the actual start of the new object (header). // // m_pCurrent <= m_SizeCommitted uint8_t* m_pCurrent; + // Last known value of m_pCurrent that GC is aware of. + // + // m_pCurrentRegistered <= m_pCurrent + uint8_t* m_pCurrentRegistered; + // Memory committed in the current segment // // m_SizeCommitted <= m_pStart + FOH_SIZE_RESERVED @@ -70,10 +77,10 @@ class FrozenObjectSegment size_t m_Size; segment_handle m_SegmentHandle; - INDEBUG(size_t m_ObjectsCount); friend class ProfilerObjectEnum; friend class ProfToEEInterfaceImpl; + friend class FrozenObjectHeapManager; }; #endif // _FROZENOBJECTHEAP_H diff --git a/src/coreclr/vm/gcenv.ee.cpp b/src/coreclr/vm/gcenv.ee.cpp index dc28e791de0cf7..8d9f676d967a8e 100644 --- a/src/coreclr/vm/gcenv.ee.cpp +++ b/src/coreclr/vm/gcenv.ee.cpp @@ -115,16 +115,28 @@ static void ScanStackRoots(Thread * pThread, promote_func* fn, ScanContext* sc) IsGCSpecialThread() || (GetThread() == ThreadSuspend::GetSuspensionThread() && ThreadStore::HoldingThreadStore())); +#if defined(FEATURE_CONSERVATIVE_GC) || defined(USE_FEF) Frame* pTopFrame = pThread->GetFrame(); Object ** topStack = (Object **)pTopFrame; - if ((pTopFrame != ((Frame*)-1)) - && (pTopFrame->GetVTablePtr() == InlinedCallFrame::GetMethodFrameVPtr())) { - // It is an InlinedCallFrame. Get SP from it. + if (InlinedCallFrame::FrameHasActiveCall(pTopFrame)) + { + // It is an InlinedCallFrame with active call. Get SP from it. InlinedCallFrame* pInlinedFrame = (InlinedCallFrame*)pTopFrame; topStack = (Object **)pInlinedFrame->GetCallSiteSP(); } +#endif // FEATURE_CONSERVATIVE_GC || USE_FEF +#ifdef USE_FEF + // We only set the stack_limit when FEF (FaultingExceptionFrame) is enabled, because without the + // FEF, the code above would have to check if hardware exception is being handled and get the limit + // from the exception frame. Since the stack_limit is strictly necessary only on Unix and FEF is + // not enabled on Window x86 only, it is sufficient to keep the stack_limit set to 0 in this case. + // See the comment on the stack_limit usage in the PromoteCarefully function for more details. sc->stack_limit = (uintptr_t)topStack; +#else // USE_FEF + // It should be set to 0 in the ScanContext constructor + _ASSERTE(sc->stack_limit == 0); +#endif // USE_FEF #ifdef FEATURE_CONSERVATIVE_GC if (g_pConfig->GetGCConservative()) diff --git a/src/coreclr/vm/gchelpers.cpp b/src/coreclr/vm/gchelpers.cpp index e3c882f623b244..8e8a464c4f8852 100644 --- a/src/coreclr/vm/gchelpers.cpp +++ b/src/coreclr/vm/gchelpers.cpp @@ -555,18 +555,18 @@ OBJECTREF TryAllocateFrozenSzArray(MethodTable* pArrayMT, INT32 cElements) #endif FrozenObjectHeapManager* foh = SystemDomain::GetFrozenObjectHeapManager(); - ArrayBase* orArray = static_cast(foh->TryAllocateObject(pArrayMT, PtrAlign(totalSize), /*publish*/ false)); + ArrayBase* orArray = static_cast( + foh->TryAllocateObject(pArrayMT, PtrAlign(totalSize), [](Object* obj, void* elemCntPtr){ + // Initialize newly allocated object before publish + static_cast(obj)->m_NumComponents = *static_cast(elemCntPtr); + }, &cElements)); + if (orArray == nullptr) { // We failed to allocate on a frozen segment, fallback to AllocateSzArray // E.g. if the array is too big to fit on a frozen segment return NULL; } - orArray->m_NumComponents = cElements; - - // Publish needs to be postponed in this case because we need to specify array length - PublishObjectAndNotify(orArray, GC_ALLOC_NO_FLAGS); - return ObjectToOBJECTREF(orArray); } @@ -968,12 +968,15 @@ STRINGREF AllocateString(DWORD cchStringLength, bool preferFrozenHeap, bool* pIs if (preferFrozenHeap) { FrozenObjectHeapManager* foh = SystemDomain::GetFrozenObjectHeapManager(); - orString = static_cast(foh->TryAllocateObject(g_pStringClass, totalSize, /* publish = */false)); + + orString = static_cast(foh->TryAllocateObject( + g_pStringClass, totalSize, [](Object* obj, void* pStrLen) { + // Initialize newly allocated object before publish + static_cast(obj)->SetStringLength(*static_cast(pStrLen)); + }, &cchStringLength)); + if (orString != nullptr) { - orString->SetStringLength(cchStringLength); - // Publish needs to be postponed in this case because we need to specify string length - PublishObjectAndNotify(orString, GC_ALLOC_NO_FLAGS); _ASSERTE(orString->GetBuffer()[cchStringLength] == W('\0')); orStringRef = ObjectToSTRINGREF(orString); *pIsFrozen = true; @@ -1139,7 +1142,7 @@ OBJECTREF TryAllocateFrozenObject(MethodTable* pObjMT) #endif // FEATURE_64BIT_ALIGNMENT FrozenObjectHeapManager* foh = SystemDomain::GetFrozenObjectHeapManager(); - Object* orObject = foh->TryAllocateObject(pObjMT, PtrAlign(pObjMT->GetBaseSize()), /*publish*/ true); + Object* orObject = foh->TryAllocateObject(pObjMT, PtrAlign(pObjMT->GetBaseSize())); return ObjectToOBJECTREF(orObject); } diff --git a/src/coreclr/vm/ilmarshalers.h b/src/coreclr/vm/ilmarshalers.h index f3c9f31628f156..61ff10ac2b2b86 100644 --- a/src/coreclr/vm/ilmarshalers.h +++ b/src/coreclr/vm/ilmarshalers.h @@ -3138,39 +3138,13 @@ class ILMngdMarshaler : public ILMarshaler void EmitClearNative(ILCodeStream* pslILEmit) override { WRAPPER_NO_CONTRACT; - ILCodeLabel* pNoManagedValueLabel = nullptr; - if (IsFieldMarshal(m_dwMarshalFlags)) - { - pNoManagedValueLabel = pslILEmit->NewCodeLabel(); - pslILEmit->EmitLDARG(StructMarshalStubs::MANAGED_STRUCT_ARGIDX); - pslILEmit->EmitBRFALSE(pNoManagedValueLabel); - } - EmitCallMngdMarshalerMethod(pslILEmit, GetClearNativeMethod()); - - if (IsFieldMarshal(m_dwMarshalFlags)) - { - pslILEmit->EmitLabel(pNoManagedValueLabel); - } } void EmitClearNativeContents(ILCodeStream* pslILEmit) override { WRAPPER_NO_CONTRACT; - ILCodeLabel* pNoManagedValueLabel = nullptr; - if (IsFieldMarshal(m_dwMarshalFlags)) - { - pNoManagedValueLabel = pslILEmit->NewCodeLabel(); - pslILEmit->EmitLDARG(StructMarshalStubs::MANAGED_STRUCT_ARGIDX); - pslILEmit->EmitBRFALSE(pNoManagedValueLabel); - } - EmitCallMngdMarshalerMethod(pslILEmit, GetClearNativeContentsMethod()); - - if (IsFieldMarshal(m_dwMarshalFlags)) - { - pslILEmit->EmitLabel(pNoManagedValueLabel); - } } bool NeedsClearCLR() override diff --git a/src/coreclr/vm/interoplibinterface_comwrappers.cpp b/src/coreclr/vm/interoplibinterface_comwrappers.cpp index 93fa6a448d4687..4b41b68b8e6151 100644 --- a/src/coreclr/vm/interoplibinterface_comwrappers.cpp +++ b/src/coreclr/vm/interoplibinterface_comwrappers.cpp @@ -5,6 +5,7 @@ // Runtime headers #include "common.h" +#include "simplerwlock.hpp" #include "rcwrefcache.h" #ifdef FEATURE_COMINTEROP_APARTMENT_SUPPORT #include "olecontexthelpers.h" @@ -257,32 +258,36 @@ namespace using Element = SHash::element_t; using Iterator = SHash::Iterator; - class LockHolder : public CrstHolder + class ReaderLock final { + SimpleReadLockHolder _lock; public: - LockHolder(_In_ ExtObjCxtCache *cache) - : CrstHolder(&cache->_lock) - { - // This cache must be locked in Cooperative mode - // since releases of wrappers can occur during a GC. - CONTRACTL - { - NOTHROW; - GC_NOTRIGGER; - MODE_COOPERATIVE; - } - CONTRACTL_END; - } + ReaderLock(_In_ ExtObjCxtCache* cache) + : _lock{ &cache->_lock } + { } + + ~ReaderLock() = default; + }; + + class WriterLock final + { + SimpleWriteLockHolder _lock; + public: + WriterLock(_In_ ExtObjCxtCache* cache) + : _lock{ &cache->_lock } + { } + + ~WriterLock() = default; }; private: friend struct InteropLibImports::RuntimeCallContext; SHash _hashMap; - Crst _lock; + SimpleRWLock _lock; ExtObjCxtRefCache* _refCache; ExtObjCxtCache() - : _lock(CrstExternalObjectContextCache, CRST_UNSAFE_COOPGC) + : _lock(COOPERATIVE, LOCK_TYPE_DEFAULT) , _refCache(GetAppDomain()->GetRCWRefCache()) { } ~ExtObjCxtCache() = default; @@ -292,7 +297,7 @@ namespace bool IsLockHeld() { WRAPPER_NO_CONTRACT; - return (_lock.OwnedByCurrentThread() != FALSE); + return (_lock.LockTaken() != FALSE); } #endif // _DEBUG @@ -343,7 +348,7 @@ namespace // Determine the count of objects to return. SIZE_T objCountMax = 0; { - LockHolder lock(this); + ReaderLock lock(this); Iterator end = _hashMap.End(); for (Iterator curr = _hashMap.Begin(); curr != end; ++curr) { @@ -365,7 +370,7 @@ namespace SIZE_T objCount = 0; if (0 < objCountMax) { - LockHolder lock(this); + ReaderLock lock(this); Iterator end = _hashMap.End(); for (Iterator curr = _hashMap.Begin(); curr != end; ++curr) { @@ -823,11 +828,22 @@ namespace if (!uniqueInstance) { bool objectFound = false; + bool tryRemove = false; { - // Query the external object cache - ExtObjCxtCache::LockHolder lock(cache); + // Perform a quick look up to determine if we know of the object and if + // we need to perform a more expensive cleanup operation below. + ExtObjCxtCache::ReaderLock lock(cache); extObjCxt = cache->Find(cacheKey); + objectFound = extObjCxt != NULL; + tryRemove = objectFound && extObjCxt->IsSet(ExternalObjectContext::Flags_Detached); + } + if (tryRemove) + { + // Perform the slower cleanup operation that may be appropriate + // if the object still exists and has been detached. + ExtObjCxtCache::WriterLock lock(cache); + extObjCxt = cache->Find(cacheKey); objectFound = extObjCxt != NULL; if (objectFound && extObjCxt->IsSet(ExternalObjectContext::Flags_Detached)) { @@ -958,7 +974,7 @@ namespace else { // Attempt to insert the new context into the cache. - ExtObjCxtCache::LockHolder lock(cache); + ExtObjCxtCache::WriterLock lock(cache); extObjCxt = cache->FindOrAdd(cacheKey, resultHolder.GetContext()); } @@ -980,7 +996,7 @@ namespace { // Failed to set the context; one must already exist. // Remove from the cache above as well. - ExtObjCxtCache::LockHolder lock(cache); + ExtObjCxtCache::WriterLock lock(cache); cache->Remove(resultHolder.GetContext()); COMPlusThrow(kNotSupportedException); diff --git a/src/coreclr/vm/invokeutil.cpp b/src/coreclr/vm/invokeutil.cpp index c4ba804a4c493f..e6ed2103393e37 100644 --- a/src/coreclr/vm/invokeutil.cpp +++ b/src/coreclr/vm/invokeutil.cpp @@ -906,9 +906,13 @@ void InvokeUtil::SetValidField(CorElementType fldType, { void* pFieldData; if (pField->IsStatic()) + { pFieldData = pField->GetCurrentStaticAddress(); + } else - pFieldData = (*((BYTE**)target)) + pField->GetOffset() + sizeof(Object); + { + pFieldData = pField->GetInstanceAddress(*target); + } if (*valueObj == NULL) InitValueClass(pFieldData, pMT); @@ -1049,9 +1053,12 @@ OBJECTREF InvokeUtil::GetFieldValue(FieldDesc* pField, TypeHandle fieldType, OBJ GCPROTECT_BEGIN(obj); // calculate the offset to the field... if (pField->IsStatic()) + { p = pField->GetCurrentStaticAddress(); - else { - p = (*((BYTE**)target)) + pField->GetOffset() + sizeof(Object); + } + else + { + p = pField->GetInstanceAddress(*target); } GCPROTECT_END(); diff --git a/src/coreclr/vm/jitinterface.cpp b/src/coreclr/vm/jitinterface.cpp index 1a3439bdd235ef..74ee2f7482e747 100644 --- a/src/coreclr/vm/jitinterface.cpp +++ b/src/coreclr/vm/jitinterface.cpp @@ -12286,7 +12286,7 @@ static CorJitResult CompileMethodWithEtwWrapper(EEJitManager *jitMgr, SString namespaceOrClassName, methodName, methodSignature; // Fire an ETW event to mark the beginning of JIT'ing - ETW::MethodLog::MethodJitting(reinterpret_cast(info->ftn), &namespaceOrClassName, &methodName, &methodSignature); + ETW::MethodLog::MethodJitting(reinterpret_cast(info->ftn), NULL, &namespaceOrClassName, &methodName, &methodSignature); CorJitResult ret = jitMgr->m_jit->compileMethod(comp, info, flags, nativeEntry, nativeSizeOfCode); diff --git a/src/coreclr/vm/method.hpp b/src/coreclr/vm/method.hpp index aaecfb20779f34..12b3c86e0f7414 100644 --- a/src/coreclr/vm/method.hpp +++ b/src/coreclr/vm/method.hpp @@ -1819,7 +1819,7 @@ class MethodDesc PCODE GetMulticoreJitCode(PrepareCodeConfig* pConfig, bool* pWasTier0); PCODE JitCompileCode(PrepareCodeConfig* pConfig); PCODE JitCompileCodeLockedEventWrapper(PrepareCodeConfig* pConfig, JitListLockEntry* pEntry); - PCODE JitCompileCodeLocked(PrepareCodeConfig* pConfig, JitListLockEntry* pLockEntry, ULONG* pSizeOfCode); + PCODE JitCompileCodeLocked(PrepareCodeConfig* pConfig, COR_ILMETHOD_DECODER* pilHeader, JitListLockEntry* pLockEntry, ULONG* pSizeOfCode); public: bool TryGenerateUnsafeAccessor(DynamicResolver** resolver, COR_ILMETHOD_DECODER** methodILDecoder); diff --git a/src/coreclr/vm/methodtable.cpp b/src/coreclr/vm/methodtable.cpp index 1016a80971b7a0..688bf29079e08d 100644 --- a/src/coreclr/vm/methodtable.cpp +++ b/src/coreclr/vm/methodtable.cpp @@ -4187,8 +4187,9 @@ void MethodTable::AllocateRegularStaticBox(FieldDesc* pField, Object** boxedStat THROWS; GC_TRIGGERS; MODE_COOPERATIVE; - CONTRACTL_END; } + CONTRACTL_END + _ASSERT(pField->IsStatic() && !pField->IsSpecialStatic() && pField->IsByValue()); // Static fields are not pinned in collectible types so we need to protect the address @@ -4222,8 +4223,8 @@ OBJECTREF MethodTable::AllocateStaticBox(MethodTable* pFieldMT, BOOL fPinned, OB THROWS; GC_TRIGGERS; MODE_COOPERATIVE; - CONTRACTL_END; } + CONTRACTL_END _ASSERTE(pFieldMT->IsValueType()); diff --git a/src/coreclr/vm/peimage.cpp b/src/coreclr/vm/peimage.cpp index 55fc458c83762c..656b42b8111c70 100644 --- a/src/coreclr/vm/peimage.cpp +++ b/src/coreclr/vm/peimage.cpp @@ -141,11 +141,11 @@ ULONG PEImage::Release() result=InterlockedDecrement(&m_refCount); if (result == 0 ) { - LOG((LF_LOADER, LL_INFO100, "PEImage: Closing Image %s\n", m_path.GetUTF8())); + LOG((LF_LOADER, LL_INFO100, "PEImage: Closing %p\n", this)); if(m_bInHashMap) { PEImageLocator locator(this); - PEImage* deleted = (PEImage *)s_Images->DeleteValue(GetPathHash(), &locator); + PEImage* deleted = (PEImage *)s_Images->DeleteValue(m_pathHash, &locator); _ASSERTE(deleted == this); } } @@ -249,12 +249,7 @@ BOOL PEImage::CompareImage(UPTR u1, UPTR u2) EX_TRY { SString path(SString::Literal, pLocator->m_pPath); - -#ifdef FEATURE_CASE_SENSITIVE_FILESYSTEM - if (pImage->GetPath().Equals(path)) -#else if (pImage->GetPath().EqualsCaseInsensitive(path)) -#endif { ret = TRUE; } @@ -623,6 +618,7 @@ void PEImage::EnumMemoryRegions(CLRDataEnumMemoryFlags flags) PEImage::PEImage(): m_path(), + m_pathHash(0), m_refCount(1), m_bInHashMap(FALSE), m_bundleFileLocation(), diff --git a/src/coreclr/vm/peimage.h b/src/coreclr/vm/peimage.h index d30e561890c79e..e7bfc11d319f12 100644 --- a/src/coreclr/vm/peimage.h +++ b/src/coreclr/vm/peimage.h @@ -132,8 +132,6 @@ class PEImage final PTR_PEImageLayout GetLoadedLayout(); PTR_PEImageLayout GetFlatLayout(); - BOOL HasPath(); - ULONG GetPathHash(); const SString& GetPath(); const SString& GetPathToLoad(); LPCWSTR GetPathForErrorMessages() { return GetPath(); } @@ -288,6 +286,7 @@ class PEImage final // ------------------------------------------------------------ SString m_path; + ULONG m_pathHash; LONG m_refCount; // means this is a unique (deduped) instance. diff --git a/src/coreclr/vm/peimage.inl b/src/coreclr/vm/peimage.inl index 6bb3c9320cb5f0..d17d5d9dd77387 100644 --- a/src/coreclr/vm/peimage.inl +++ b/src/coreclr/vm/peimage.inl @@ -288,6 +288,7 @@ inline void PEImage::Init(LPCWSTR pPath, BundleFileLocation bundleFileLocation) m_path = pPath; m_path.Normalize(); + m_pathHash = m_path.HashCaseInsensitive(); m_bundleFileLocation = bundleFileLocation; SetModuleFileNameHintForDAC(); } @@ -310,11 +311,7 @@ inline PTR_PEImage PEImage::FindByPath(LPCWSTR pPath, BOOL isInBundle /* = TRUE int CaseHashHelper(const WCHAR *buffer, COUNT_T count); PEImageLocator locator(pPath, isInBundle); -#ifdef FEATURE_CASE_SENSITIVE_FILESYSTEM - DWORD dwHash=path.Hash(); -#else DWORD dwHash = CaseHashHelper(pPath, (COUNT_T) u16_strlen(pPath)); -#endif return (PEImage *) s_Images->LookupValue(dwHash, &locator); } @@ -366,7 +363,7 @@ inline void PEImage::AddToHashMap() CONTRACTL_END; _ASSERTE(s_hashLock.OwnedByCurrentThread()); - s_Images->InsertValue(GetPathHash(),this); + s_Images->InsertValue(m_pathHash,this); m_bInHashMap=TRUE; } @@ -378,31 +375,6 @@ inline BOOL PEImage::Has32BitNTHeaders() return GetOrCreateLayout(PEImageLayout::LAYOUT_ANY)->Has32BitNTHeaders(); } -inline BOOL PEImage::HasPath() -{ - LIMITED_METHOD_CONTRACT; - - return !GetPath().IsEmpty(); -} - -inline ULONG PEImage::GetPathHash() -{ - CONTRACT(ULONG) - { - PRECONDITION(HasPath()); - MODE_ANY; - GC_NOTRIGGER; - THROWS; - } - CONTRACT_END; - -#ifdef FEATURE_CASE_SENSITIVE_FILESYSTEM - RETURN m_path.Hash(); -#else - RETURN m_path.HashCaseInsensitive(); -#endif -} - inline void PEImage::GetPEKindAndMachine(DWORD* pdwKind, DWORD* pdwMachine) { CONTRACTL diff --git a/src/coreclr/vm/peimagelayout.cpp b/src/coreclr/vm/peimagelayout.cpp index c756c451649b34..0f0dde27618196 100644 --- a/src/coreclr/vm/peimagelayout.cpp +++ b/src/coreclr/vm/peimagelayout.cpp @@ -534,7 +534,10 @@ LoadedImageLayout::LoadedImageLayout(PEImage* pOwner, HRESULT* loadFailure) IfFailThrow(Init(m_Module, true)); - LOG((LF_LOADER, LL_INFO1000, "PEImage: Opened HMODULE %s\n", pOwner->GetPath().GetUTF8())); +#ifdef LOGGING + SString ownerPath{ pOwner->GetPath() }; + LOG((LF_LOADER, LL_INFO1000, "PEImage: Opened HMODULE %s\n", ownerPath.GetUTF8())); +#endif // LOGGING #else HANDLE hFile = pOwner->GetFileHandle(); @@ -548,8 +551,11 @@ LoadedImageLayout::LoadedImageLayout(PEImage* pOwner, HRESULT* loadFailure) return; } +#ifdef LOGGING + SString ownerPath{ pOwner->GetPath() }; LOG((LF_LOADER, LL_INFO1000, "PEImage: image %s (hFile %p) mapped @ %p\n", - pOwner->GetPath().GetUTF8(), hFile, (void*)m_LoadedFile)); + ownerPath.GetUTF8(), hFile, (void*)m_LoadedFile)); +#endif // LOGGING IfFailThrow(Init((void*)m_LoadedFile)); @@ -616,7 +622,10 @@ FlatImageLayout::FlatImageLayout(PEImage* pOwner) INT64 offset = pOwner->GetOffset(); INT64 size = pOwner->GetSize(); - LOG((LF_LOADER, LL_INFO100, "PEImage: Opening flat %s\n", pOwner->GetPath().GetUTF8())); +#ifdef LOGGING + SString ownerPath{ pOwner->GetPath() }; + LOG((LF_LOADER, LL_INFO100, "PEImage: Opening flat %s\n", ownerPath.GetUTF8())); +#endif // LOGGING // If a size is not specified, load the whole file if (size == 0) diff --git a/src/coreclr/vm/perfinfo.cpp b/src/coreclr/vm/perfinfo.cpp index 0be2e519936fbe..98fc667661a504 100644 --- a/src/coreclr/vm/perfinfo.cpp +++ b/src/coreclr/vm/perfinfo.cpp @@ -32,8 +32,8 @@ void PerfInfo::LogImage(PEAssembly* pPEAssembly, CHAR* guid) PRECONDITION(guid != nullptr); } CONTRACTL_END; - SString value; - const SString& path = pPEAssembly->GetPath(); + // Nothing to log if the assembly path isn't present. + SString path{ pPEAssembly->GetPath() }; if (path.IsEmpty()) { return; @@ -49,12 +49,11 @@ void PerfInfo::LogImage(PEAssembly* pPEAssembly, CHAR* guid) } } + SString value; value.Printf("%s%c%s%c%p", path.GetUTF8(), sDelimiter, guid, sDelimiter, baseAddr); - SString command; - command.Printf("%s", "ImageLoad"); + SString command{ SString::Literal, "ImageLoad" }; WriteLine(command, value); - } // Writes a command line, with "type" being the type of command, with "value" as the command's corresponding instructions/values. This is to be used to log specific information, e.g. LogImage diff --git a/src/coreclr/vm/precode.h b/src/coreclr/vm/precode.h index d7f6b4cac1a74d..3f6a2f532c4e67 100644 --- a/src/coreclr/vm/precode.h +++ b/src/coreclr/vm/precode.h @@ -11,7 +11,7 @@ #define PRECODE_ALIGNMENT sizeof(void*) -#if defined(HOST_AMD64) +#if defined(TARGET_AMD64) #define OFFSETOF_PRECODE_TYPE 0 #define OFFSETOF_PRECODE_TYPE_CALL_OR_JMP 5 @@ -19,7 +19,7 @@ #define SIZEOF_PRECODE_BASE 16 -#elif defined(HOST_X86) +#elif defined(TARGET_X86) EXTERN_C VOID STDCALL PrecodeRemotingThunk(); @@ -29,27 +29,27 @@ EXTERN_C VOID STDCALL PrecodeRemotingThunk(); #define SIZEOF_PRECODE_BASE 8 -#elif defined(HOST_ARM64) +#elif defined(TARGET_ARM64) #define SIZEOF_PRECODE_BASE CODE_SIZE_ALIGN #define OFFSETOF_PRECODE_TYPE 0 -#elif defined(HOST_ARM) +#elif defined(TARGET_ARM) -#define SIZEOF_PRECODE_BASE CODE_SIZE_ALIGN -#define OFFSETOF_PRECODE_TYPE 3 +#define SIZEOF_PRECODE_BASE CODE_SIZE_ALIGN * 2 +#define OFFSETOF_PRECODE_TYPE 7 -#elif defined(HOST_LOONGARCH64) +#elif defined(TARGET_LOONGARCH64) #define SIZEOF_PRECODE_BASE CODE_SIZE_ALIGN #define OFFSETOF_PRECODE_TYPE 0 -#elif defined(HOST_RISCV64) +#elif defined(TARGET_RISCV64) #define SIZEOF_PRECODE_BASE CODE_SIZE_ALIGN #define OFFSETOF_PRECODE_TYPE 0 -#endif // HOST_AMD64 +#endif // TARGET_AMD64 #ifndef DACCESS_COMPILE // Given an address in a slot, figure out if the prestub will be called @@ -61,14 +61,14 @@ BOOL DoesSlotCallPrestub(PCODE pCode); // Invalid precode type struct InvalidPrecode { -#if defined(HOST_AMD64) || defined(HOST_X86) +#if defined(TARGET_AMD64) || defined(TARGET_X86) // int3 static const int Type = 0xCC; -#elif defined(HOST_ARM64) || defined(HOST_ARM) +#elif defined(TARGET_ARM64) || defined(TARGET_ARM) static const int Type = 0; -#elif defined(HOST_LOONGARCH64) +#elif defined(TARGET_LOONGARCH64) static const int Type = 0xff; -#elif defined(HOST_RISCV64) +#elif defined(TARGET_RISCV64) static const int Type = 0xff; #endif }; @@ -90,25 +90,25 @@ extern "C" void StubPrecodeCode_End(); // Regular precode struct StubPrecode { -#if defined(HOST_AMD64) +#if defined(TARGET_AMD64) static const BYTE Type = 0x4C; static const SIZE_T CodeSize = 24; -#elif defined(HOST_X86) +#elif defined(TARGET_X86) static const BYTE Type = 0xA1; static const SIZE_T CodeSize = 24; -#elif defined(HOST_ARM64) +#elif defined(TARGET_ARM64) static const int Type = 0x4A; static const SIZE_T CodeSize = 24; -#elif defined(HOST_ARM) - static const int Type = 0xCF; +#elif defined(TARGET_ARM) + static const int Type = 0xFF; static const SIZE_T CodeSize = 12; -#elif defined(HOST_LOONGARCH64) +#elif defined(TARGET_LOONGARCH64) static const int Type = 0x4; static const SIZE_T CodeSize = 24; -#elif defined(HOST_RISCV64) +#elif defined(TARGET_RISCV64) static const int Type = 0x17; static const SIZE_T CodeSize = 24; -#endif // HOST_AMD64 +#endif // TARGET_AMD64 BYTE m_code[CodeSize]; @@ -189,7 +189,7 @@ typedef DPTR(StubPrecode) PTR_StubPrecode; // (This is fake precode. VTable slot does not point to it.) struct NDirectImportPrecode : StubPrecode { - static const int Type = 0x01; + static const int Type = 0x05; void Init(NDirectImportPrecode* pPrecodeRX, MethodDesc* pMD, LoaderAllocator *pLoaderAllocator); @@ -224,31 +224,31 @@ extern "C" void FixupPrecodeCode_End(); // The fixup precode is simple jump once patched. It does not have the two instruction overhead of regular precode. struct FixupPrecode { -#if defined(HOST_AMD64) +#if defined(TARGET_AMD64) static const int Type = 0xFF; static const SIZE_T CodeSize = 24; static const int FixupCodeOffset = 6; -#elif defined(HOST_X86) +#elif defined(TARGET_X86) static const int Type = 0xFF; static const SIZE_T CodeSize = 24; static const int FixupCodeOffset = 6; -#elif defined(HOST_ARM64) +#elif defined(TARGET_ARM64) static const int Type = 0x0B; static const SIZE_T CodeSize = 24; static const int FixupCodeOffset = 8; -#elif defined(HOST_ARM) - static const int Type = 0xFF; +#elif defined(TARGET_ARM) + static const int Type = 0xCF; static const SIZE_T CodeSize = 12; static const int FixupCodeOffset = 4 + THUMB_CODE; -#elif defined(HOST_LOONGARCH64) +#elif defined(TARGET_LOONGARCH64) static const int Type = 0x3; static const SIZE_T CodeSize = 32; static const int FixupCodeOffset = 12; -#elif defined(HOST_RISCV64) +#elif defined(TARGET_RISCV64) static const int Type = 0x97; static const SIZE_T CodeSize = 32; static const int FixupCodeOffset = 10; -#endif // HOST_AMD64 +#endif // TARGET_AMD64 BYTE m_code[CodeSize]; @@ -614,4 +614,8 @@ static_assert_no_msg(FixupPrecode::Type != NDirectImportPrecode::Type); static_assert_no_msg(FixupPrecode::Type != ThisPtrRetBufPrecode::Type); static_assert_no_msg(NDirectImportPrecode::Type != ThisPtrRetBufPrecode::Type); +// Verify that the base type for each precode fits into each specific precode type +static_assert_no_msg(sizeof(Precode) <= sizeof(NDirectImportPrecode)); +static_assert_no_msg(sizeof(Precode) <= sizeof(FixupPrecode)); +static_assert_no_msg(sizeof(Precode) <= sizeof(ThisPtrRetBufPrecode)); #endif // __PRECODE_H__ diff --git a/src/coreclr/vm/prestub.cpp b/src/coreclr/vm/prestub.cpp index 174e48565f31b8..32944952301648 100644 --- a/src/coreclr/vm/prestub.cpp +++ b/src/coreclr/vm/prestub.cpp @@ -709,6 +709,53 @@ PCODE MethodDesc::JitCompileCode(PrepareCodeConfig* pConfig) } } +namespace +{ + COR_ILMETHOD_DECODER* GetAndVerifyMetadataILHeader(MethodDesc* pMD, PrepareCodeConfig* pConfig, COR_ILMETHOD_DECODER* pDecoderMemory) + { + STANDARD_VM_CONTRACT; + _ASSERTE(pMD != NULL); + _ASSERTE(!pMD->IsNoMetadata()); + _ASSERTE(pConfig != NULL); + _ASSERTE(pDecoderMemory != NULL); + + COR_ILMETHOD_DECODER* pHeader = NULL; + COR_ILMETHOD* ilHeader = pConfig->GetILHeader(); + if (ilHeader == NULL) + return NULL; + + COR_ILMETHOD_DECODER::DecoderStatus status = COR_ILMETHOD_DECODER::FORMAT_ERROR; + { + // Decoder ctor can AV on a malformed method header + AVInRuntimeImplOkayHolder AVOkay; + pHeader = new (pDecoderMemory) COR_ILMETHOD_DECODER(ilHeader, pMD->GetMDImport(), &status); + } + + if (status == COR_ILMETHOD_DECODER::FORMAT_ERROR) + COMPlusThrowHR(COR_E_BADIMAGEFORMAT, BFA_BAD_IL); + + return pHeader; + } + + COR_ILMETHOD_DECODER* GetAndVerifyILHeader(MethodDesc* pMD, PrepareCodeConfig* pConfig, COR_ILMETHOD_DECODER* pIlDecoderMemory) + { + STANDARD_VM_CONTRACT; + _ASSERTE(pMD != NULL); + if (pMD->IsIL()) + { + return GetAndVerifyMetadataILHeader(pMD, pConfig, pIlDecoderMemory); + } + else if (pMD->IsILStub()) + { + ILStubResolver* pResolver = pMD->AsDynamicMethodDesc()->GetILStubResolver(); + return pResolver->GetILHeader(); + } + + _ASSERTE(pMD->IsNoMetadata()); + return NULL; + } +} + PCODE MethodDesc::JitCompileCodeLockedEventWrapper(PrepareCodeConfig* pConfig, JitListLockEntry* pEntry) { STANDARD_VM_CONTRACT; @@ -759,11 +806,18 @@ PCODE MethodDesc::JitCompileCodeLockedEventWrapper(PrepareCodeConfig* pConfig, J } #endif // PROFILING_SUPPORTED + // The profiler may have changed the code on the callback. Need to + // pick up the new code. + // + // (don't want this for OSR, need to see how it works) + COR_ILMETHOD_DECODER ilDecoderTemp; + COR_ILMETHOD_DECODER* pilHeader = GetAndVerifyILHeader(this, pConfig, &ilDecoderTemp); + if (!ETW_TRACING_CATEGORY_ENABLED(MICROSOFT_WINDOWS_DOTNETRUNTIME_PROVIDER_DOTNET_Context, TRACE_LEVEL_VERBOSE, CLR_JIT_KEYWORD)) { - pCode = JitCompileCodeLocked(pConfig, pEntry, &sizeOfCode); + pCode = JitCompileCodeLocked(pConfig, pilHeader, pEntry, &sizeOfCode); } else { @@ -778,12 +832,13 @@ PCODE MethodDesc::JitCompileCodeLockedEventWrapper(PrepareCodeConfig* pConfig, J // a small stub of native code but no native-IL mapping. #ifndef FEATURE_INTERPRETER ETW::MethodLog::MethodJitting(this, + pilHeader, &namespaceOrClassName, &methodName, &methodSignature); #endif - pCode = JitCompileCodeLocked(pConfig, pEntry, &sizeOfCode); + pCode = JitCompileCodeLocked(pConfig, pilHeader, pEntry, &sizeOfCode); // Interpretted methods skip this notification #ifdef FEATURE_INTERPRETER @@ -869,66 +924,11 @@ PCODE MethodDesc::JitCompileCodeLockedEventWrapper(PrepareCodeConfig* pConfig, J return pCode; } -namespace -{ - COR_ILMETHOD_DECODER* GetAndVerifyMetadataILHeader(MethodDesc* pMD, PrepareCodeConfig* pConfig, COR_ILMETHOD_DECODER* pDecoderMemory) - { - STANDARD_VM_CONTRACT; - _ASSERTE(pMD != NULL); - _ASSERTE(!pMD->IsNoMetadata()); - _ASSERTE(pConfig != NULL); - _ASSERTE(pDecoderMemory != NULL); - - COR_ILMETHOD_DECODER* pHeader = NULL; - COR_ILMETHOD* ilHeader = pConfig->GetILHeader(); - if (ilHeader == NULL) - return NULL; - - COR_ILMETHOD_DECODER::DecoderStatus status = COR_ILMETHOD_DECODER::FORMAT_ERROR; - { - // Decoder ctor can AV on a malformed method header - AVInRuntimeImplOkayHolder AVOkay; - pHeader = new (pDecoderMemory) COR_ILMETHOD_DECODER(ilHeader, pMD->GetMDImport(), &status); - } - - if (status == COR_ILMETHOD_DECODER::FORMAT_ERROR) - COMPlusThrowHR(COR_E_BADIMAGEFORMAT, BFA_BAD_IL); - - return pHeader; - } - - COR_ILMETHOD_DECODER* GetAndVerifyILHeader(MethodDesc* pMD, PrepareCodeConfig* pConfig, COR_ILMETHOD_DECODER* pIlDecoderMemory) - { - STANDARD_VM_CONTRACT; - _ASSERTE(pMD != NULL); - if (pMD->IsIL()) - { - return GetAndVerifyMetadataILHeader(pMD, pConfig, pIlDecoderMemory); - } - else if (pMD->IsILStub()) - { - ILStubResolver* pResolver = pMD->AsDynamicMethodDesc()->GetILStubResolver(); - return pResolver->GetILHeader(); - } - - _ASSERTE(pMD->IsNoMetadata()); - return NULL; - } -} - -PCODE MethodDesc::JitCompileCodeLocked(PrepareCodeConfig* pConfig, JitListLockEntry* pEntry, ULONG* pSizeOfCode) +PCODE MethodDesc::JitCompileCodeLocked(PrepareCodeConfig* pConfig, COR_ILMETHOD_DECODER* pilHeader, JitListLockEntry* pEntry, ULONG* pSizeOfCode) { STANDARD_VM_CONTRACT; PCODE pCode = NULL; - - // The profiler may have changed the code on the callback. Need to - // pick up the new code. - // - // (don't want this for OSR, need to see how it works) - COR_ILMETHOD_DECODER ilDecoderTemp; - COR_ILMETHOD_DECODER* pilHeader = GetAndVerifyILHeader(this, pConfig, &ilDecoderTemp); - CORJIT_FLAGS jitFlags; PCODE pOtherCode = NULL; diff --git a/src/coreclr/vm/profilingenumerators.cpp b/src/coreclr/vm/profilingenumerators.cpp index 1d19a87324be92..c55d40ffbeed9e 100644 --- a/src/coreclr/vm/profilingenumerators.cpp +++ b/src/coreclr/vm/profilingenumerators.cpp @@ -103,7 +103,12 @@ BOOL ProfilerObjectEnum::Init() } CONTRACTL_END; - FrozenObjectHeapManager* foh = SystemDomain::GetFrozenObjectHeapManager(); + FrozenObjectHeapManager* foh = SystemDomain::GetFrozenObjectHeapManagerNoThrow(); + if (foh == nullptr) + { + return TRUE; + } + CrstHolder ch(&foh->m_Crst); const unsigned segmentsCount = foh->m_FrozenSegments.GetCount(); @@ -113,7 +118,6 @@ BOOL ProfilerObjectEnum::Init() for (unsigned segmentIdx = 0; segmentIdx < segmentsCount; segmentIdx++) { const FrozenObjectSegment* segment = segments[segmentIdx]; - Object* currentObj = segment->GetFirstObject(); while (currentObj != nullptr) { diff --git a/src/coreclr/vm/proftoeeinterfaceimpl.cpp b/src/coreclr/vm/proftoeeinterfaceimpl.cpp index 1f2e4f26001859..11853acdcf86ee 100644 --- a/src/coreclr/vm/proftoeeinterfaceimpl.cpp +++ b/src/coreclr/vm/proftoeeinterfaceimpl.cpp @@ -7672,7 +7672,12 @@ HRESULT ProfToEEInterfaceImpl::GetNonGCHeapBounds(ULONG cObjectRanges, return E_INVALIDARG; } - FrozenObjectHeapManager* foh = SystemDomain::GetFrozenObjectHeapManager(); + FrozenObjectHeapManager* foh = SystemDomain::GetFrozenObjectHeapManagerNoThrow(); + if (foh == nullptr) + { + *pcObjectRanges = 0; + return S_OK; + } CrstHolder ch(&foh->m_Crst); const unsigned segmentsCount = foh->m_FrozenSegments.GetCount(); diff --git a/src/coreclr/vm/readytoruninfo.cpp b/src/coreclr/vm/readytoruninfo.cpp index 2c186cf7adeaff..e75373db8855aa 100644 --- a/src/coreclr/vm/readytoruninfo.cpp +++ b/src/coreclr/vm/readytoruninfo.cpp @@ -429,7 +429,8 @@ static void LogR2r(const char *msg, PEAssembly *pPEAssembly) if (r2rLogFile == NULL) return; - fprintf(r2rLogFile, "%s: \"%s\".\n", msg, pPEAssembly->GetPath().GetUTF8()); + SString assemblyPath{ pPEAssembly->GetPath() }; + fprintf(r2rLogFile, "%s: \"%s\".\n", msg, assemblyPath.GetUTF8()); fflush(r2rLogFile); } @@ -1904,7 +1905,7 @@ uint32_t ReadyToRun_TypeGenericInfoMap::GetGenericArgumentCount(mdTypeDef input, uint32_t count = ((uint8_t)typeGenericInfo & (uint8_t)ReadyToRunTypeGenericInfo::GenericCountMask); if (count > 2) foundResult = false; - + if (!foundResult) { HENUMInternalHolder hEnumTyPars(pImport); @@ -1922,7 +1923,7 @@ HRESULT ReadyToRun_TypeGenericInfoMap::GetGenericArgumentCountNoThrow(mdTypeDef uint32_t count = ((uint8_t)typeGenericInfo & (uint8_t)ReadyToRunTypeGenericInfo::GenericCountMask); if (count > 2) foundResult = false; - + if (!foundResult) { HENUMInternalHolder hEnumTyPars(pImport); diff --git a/src/coreclr/vm/siginfo.cpp b/src/coreclr/vm/siginfo.cpp index b9807b51ba4af8..eb35f3a372538f 100644 --- a/src/coreclr/vm/siginfo.cpp +++ b/src/coreclr/vm/siginfo.cpp @@ -4962,11 +4962,6 @@ void PromoteCarefully(promote_func fn, #if !defined(DACCESS_COMPILE) - // - // Sanity check the stack scan limit - // - assert(sc->stack_limit != 0); - // Note that the base is at a higher address than the limit, since the stack // grows downwards. // To check whether the object is in the stack or not, we also need to check the sc->stack_limit. diff --git a/src/coreclr/vm/typedesc.cpp b/src/coreclr/vm/typedesc.cpp index 6c3226882503c6..95e86ccc961691 100644 --- a/src/coreclr/vm/typedesc.cpp +++ b/src/coreclr/vm/typedesc.cpp @@ -104,6 +104,31 @@ BOOL TypeDesc::IsSharedByGenericInstantiations() return FALSE; } +BOOL TypeDesc::ContainsGenericVariables(BOOL methodOnly) +{ + if (IsGenericVariable()) + { + if (!methodOnly) + return TRUE; + + PTR_TypeVarTypeDesc pTyVar = dac_cast(this); + return TypeFromToken(pTyVar->GetTypeOrMethodDef()) == mdtMethodDef; + } + + if (HasTypeParam()) + { + return GetRootTypeParam().ContainsGenericVariables(methodOnly); + } + + if (IsFnPtr()) + { + return dac_cast(this)->ContainsGenericVariables(methodOnly); + } + + return FALSE; +} + + PTR_BaseDomain TypeDesc::GetDomain() { CONTRACTL @@ -1670,6 +1695,21 @@ FnPtrTypeDesc::IsSharedByGenericInstantiations() return FALSE; } // FnPtrTypeDesc::IsSharedByGenericInstantiations +BOOL +FnPtrTypeDesc::ContainsGenericVariables(BOOL methodOnly) +{ + LIMITED_METHOD_DAC_CONTRACT; + + for (DWORD i = 0; i <= m_NumArgs; i++) + { + if (m_RetAndArgTypes[i].ContainsGenericVariables(methodOnly)) + { + return TRUE; + } + } + return FALSE; +} // FnPtrTypeDesc::ContainsGenericVariables + #ifndef DACCESS_COMPILE // Returns TRUE if all return and argument types are externally visible. diff --git a/src/coreclr/vm/typedesc.h b/src/coreclr/vm/typedesc.h index 51614c3b110778..b86845c81e21c0 100644 --- a/src/coreclr/vm/typedesc.h +++ b/src/coreclr/vm/typedesc.h @@ -182,6 +182,8 @@ class TypeDesc BOOL IsSharedByGenericInstantiations(); + BOOL ContainsGenericVariables(BOOL methodOnly); + protected: // See methodtable.h for details of the flags with the same name there enum @@ -527,6 +529,8 @@ class FnPtrTypeDesc : public TypeDesc BOOL IsSharedByGenericInstantiations(); + BOOL ContainsGenericVariables(BOOL methodOnly); + #ifndef DACCESS_COMPILE // Returns TRUE if all return and argument types are externally visible. BOOL IsExternallyVisible() const; diff --git a/src/coreclr/vm/typehandle.cpp b/src/coreclr/vm/typehandle.cpp index 59eed8b5030d86..053cc759217c76 100644 --- a/src/coreclr/vm/typehandle.cpp +++ b/src/coreclr/vm/typehandle.cpp @@ -138,26 +138,10 @@ BOOL TypeHandle::ContainsGenericVariables(BOOL methodOnly /*=FALSE*/) const STATIC_CONTRACT_NOTHROW; SUPPORTS_DAC; - if (HasTypeParam()) - { - return GetTypeParam().ContainsGenericVariables(methodOnly); - } - - if (IsGenericVariable()) - { - if (!methodOnly) - return TRUE; - - PTR_TypeVarTypeDesc pTyVar = dac_cast(AsTypeDesc()); - return TypeFromToken(pTyVar->GetTypeOrMethodDef()) == mdtMethodDef; - } - else if (HasInstantiation()) - { - if (GetMethodTable()->ContainsGenericVariables(methodOnly)) - return TRUE; - } - - return FALSE; + if (IsTypeDesc()) + return AsTypeDesc()->ContainsGenericVariables(methodOnly); + else + return AsMethodTable()->ContainsGenericVariables(methodOnly); } //@GENERICS: diff --git a/src/coreclr/vm/versionresilienthashcode.cpp b/src/coreclr/vm/versionresilienthashcode.cpp index b3ba764baac595..85bd146d8dc463 100644 --- a/src/coreclr/vm/versionresilienthashcode.cpp +++ b/src/coreclr/vm/versionresilienthashcode.cpp @@ -286,7 +286,7 @@ bool AddVersionResilientHashCodeForInstruction(ILInstructionParser *parser, xxHa hash->Add(varValue); break; } - + case InlineVar: // 2 byte value which is token change resilient { uint16_t varValue; @@ -388,6 +388,12 @@ bool GetVersionResilientILCodeHashCode(MethodDesc *pMD, int* hashCode, unsigned* initLocals = (options & CORINFO_OPT_INIT_LOCALS) == CORINFO_OPT_INIT_LOCALS; } + else if (!pMD->HasILHeader()) + { + // Dynamically generated IL methods like UnsafeAccessors may not have + // an IL header. + return false; + } else { COR_ILMETHOD_DECODER header(pMD->GetILHeader(TRUE), pMD->GetMDImport(), NULL); diff --git a/src/installer/managed/Microsoft.NET.HostModel/AppHost/PEUtils.cs b/src/installer/managed/Microsoft.NET.HostModel/AppHost/PEUtils.cs index 0d0b33ed55569e..1bfc80fcfca492 100644 --- a/src/installer/managed/Microsoft.NET.HostModel/AppHost/PEUtils.cs +++ b/src/installer/managed/Microsoft.NET.HostModel/AppHost/PEUtils.cs @@ -1,8 +1,11 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System; +using System.Buffers.Binary; using System.IO; using System.IO.MemoryMappedFiles; +using System.Reflection.PortableExecutable; namespace Microsoft.NET.HostModel.AppHost { @@ -15,29 +18,13 @@ public static class PEUtils /// true if the accessor represents a PE image, false otherwise. internal static unsafe bool IsPEImage(MemoryMappedViewAccessor accessor) { - byte* pointer = null; + if (accessor.Capacity < PEOffsets.DosStub.PESignatureOffset + sizeof(uint)) + return false; - try - { - accessor.SafeMemoryMappedViewHandle.AcquirePointer(ref pointer); - byte* bytes = pointer + accessor.PointerOffset; - - // https://en.wikipedia.org/wiki/Portable_Executable - // Validate that we're looking at Windows PE file - if (((ushort*)bytes)[0] != PEOffsets.DosImageSignature - || accessor.Capacity < PEOffsets.DosStub.PESignatureOffset + sizeof(uint)) - { - return false; - } - return true; - } - finally - { - if (pointer != null) - { - accessor.SafeMemoryMappedViewHandle.ReleasePointer(); - } - } + // https://en.wikipedia.org/wiki/Portable_Executable + // Validate that we're looking at Windows PE file + ushort signature = AsLittleEndian(accessor.ReadUInt16(0)); + return signature == PEOffsets.DosImageSignature; } public static bool IsPEImage(string filePath) @@ -60,40 +47,15 @@ public static bool IsPEImage(string filePath) /// The memory accessor which has the apphost file opened. internal static unsafe void SetWindowsGraphicalUserInterfaceBit(MemoryMappedViewAccessor accessor) { - byte* pointer = null; - - try - { - accessor.SafeMemoryMappedViewHandle.AcquirePointer(ref pointer); - byte* bytes = pointer + accessor.PointerOffset; - - // https://en.wikipedia.org/wiki/Portable_Executable - uint peHeaderOffset = ((uint*)(bytes + PEOffsets.DosStub.PESignatureOffset))[0]; - - if (accessor.Capacity < peHeaderOffset + PEOffsets.PEHeader.Subsystem + sizeof(ushort)) - { - throw new AppHostNotPEFileException("Subsystem offset out of file range."); - } - - ushort* subsystem = ((ushort*)(bytes + peHeaderOffset + PEOffsets.PEHeader.Subsystem)); - - // https://docs.microsoft.com/en-us/windows/desktop/Debug/pe-format#windows-subsystem - // The subsystem of the prebuilt apphost should be set to CUI - if (subsystem[0] != (ushort)PEOffsets.Subsystem.WindowsCui) - { - throw new AppHostNotCUIException(subsystem[0]); - } - - // Set the subsystem to GUI - subsystem[0] = (ushort)PEOffsets.Subsystem.WindowsGui; - } - finally - { - if (pointer != null) - { - accessor.SafeMemoryMappedViewHandle.ReleasePointer(); - } - } + // https://learn.microsoft.com/windows/win32/debug/pe-format#windows-subsystem + // The subsystem of the prebuilt apphost should be set to CUI + uint peHeaderOffset; + ushort subsystem = GetWindowsSubsystem(accessor, out peHeaderOffset); + if (subsystem != (ushort)Subsystem.WindowsCui) + throw new AppHostNotCUIException(subsystem); + + // Set the subsystem to GUI + accessor.Write(peHeaderOffset + PEOffsets.PEHeader.Subsystem, AsLittleEndian((ushort)Subsystem.WindowsGui)); } public static unsafe void SetWindowsGraphicalUserInterfaceBit(string filePath) @@ -113,32 +75,7 @@ public static unsafe void SetWindowsGraphicalUserInterfaceBit(string filePath) /// The memory accessor which has the apphost file opened. internal static unsafe ushort GetWindowsGraphicalUserInterfaceBit(MemoryMappedViewAccessor accessor) { - byte* pointer = null; - - try - { - accessor.SafeMemoryMappedViewHandle.AcquirePointer(ref pointer); - byte* bytes = pointer + accessor.PointerOffset; - - // https://en.wikipedia.org/wiki/Portable_Executable - uint peHeaderOffset = ((uint*)(bytes + PEOffsets.DosStub.PESignatureOffset))[0]; - - if (accessor.Capacity < peHeaderOffset + PEOffsets.PEHeader.Subsystem + sizeof(ushort)) - { - throw new AppHostNotPEFileException("Subsystem offset out of file range."); - } - - ushort* subsystem = ((ushort*)(bytes + peHeaderOffset + PEOffsets.PEHeader.Subsystem)); - - return subsystem[0]; - } - finally - { - if (pointer != null) - { - accessor.SafeMemoryMappedViewHandle.ReleasePointer(); - } - } + return GetWindowsSubsystem(accessor, out _); } public static unsafe ushort GetWindowsGraphicalUserInterfaceBit(string filePath) @@ -151,5 +88,25 @@ public static unsafe ushort GetWindowsGraphicalUserInterfaceBit(string filePath) } } } + + private static ushort GetWindowsSubsystem(MemoryMappedViewAccessor accessor, out uint peHeaderOffset) + { + // https://en.wikipedia.org/wiki/Portable_Executable + if (accessor.Capacity < PEOffsets.DosStub.PESignatureOffset + sizeof(uint)) + throw new AppHostNotPEFileException("PESignature offset out of file range."); + + peHeaderOffset = AsLittleEndian(accessor.ReadUInt32(PEOffsets.DosStub.PESignatureOffset)); + if (accessor.Capacity < peHeaderOffset + PEOffsets.PEHeader.Subsystem + sizeof(ushort)) + throw new AppHostNotPEFileException("Subsystem offset out of file range."); + + // https://learn.microsoft.com/windows/win32/debug/pe-format#windows-subsystem + return AsLittleEndian(accessor.ReadUInt16(peHeaderOffset + PEOffsets.PEHeader.Subsystem)); + } + + private static ushort AsLittleEndian(ushort value) + => BitConverter.IsLittleEndian ? value : BinaryPrimitives.ReverseEndianness(value); + + private static uint AsLittleEndian(uint value) + => BitConverter.IsLittleEndian ? value : BinaryPrimitives.ReverseEndianness(value); } } diff --git a/src/installer/pkg/THIRD-PARTY-NOTICES.TXT b/src/installer/pkg/THIRD-PARTY-NOTICES.TXT index fcb519506d511e..237af34f098f15 100644 --- a/src/installer/pkg/THIRD-PARTY-NOTICES.TXT +++ b/src/installer/pkg/THIRD-PARTY-NOTICES.TXT @@ -449,8 +449,8 @@ Foundation, Inc., Hewlett-Packard Company, Microsoft, nor Digital Equipment Corporation makes any representations about the suitability of this software for any purpose." -License notice for The LLVM Compiler Infrastructure ---------------------------------------------------- +License notice for The LLVM Compiler Infrastructure (Legacy License) +-------------------------------------------------------------------- Developed by: @@ -1481,3 +1481,342 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +--- + +Copyright (c) 2015, Google Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +License for C# Implementation of Fast CRC Computation +----------------------------------------------------- + +https://github.com/SixLabors/ImageSharp/blob/f4f689ce67ecbcc35cebddba5aacb603e6d1068a/src/ImageSharp/Formats/Png/Zlib/Crc32.cs + +Copyright (c) Six Labors. +Licensed under the Apache License, Version 2.0. + +Available at +https://github.com/SixLabors/ImageSharp/blob/f4f689ce67ecbcc35cebddba5aacb603e6d1068a/LICENSE + +License for Fast CRC Computation +-------------------------------------- + +https://github.com/intel/isa-l/blob/33a2d9484595c2d6516c920ce39a694c144ddf69/crc/crc32_ieee_by4.asm +https://github.com/intel/isa-l/blob/33a2d9484595c2d6516c920ce39a694c144ddf69/crc/crc64_ecma_norm_by8.asm + +Copyright(c) 2011-2015 Intel Corporation All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in + the documentation and/or other materials provided with the + distribution. + * Neither the name of Intel Corporation nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +License for fastmod (https://github.com/lemire/fastmod), ibm-fpgen (https://github.com/nigeltao/parse-number-fxx-test-data) and fastrange (https://github.com/lemire/fastrange) +-------------------------------------- + + Copyright 2018 Daniel Lemire + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +License for Jb Evain +--------------------- + +Copyright (c) 2006 Jb Evain (jbevain@gmail.com) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included +in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. + +--- Optional exception to the license --- + +As an exception, if, as a result of your compiling your source code, portions +of this Software are embedded into a machine-executable object form of such +source code, you may redistribute such embedded portions in such object form +without including the above copyright and permission notices. + +License for MurmurHash3 +-------------------------------------- + +https://github.com/aappleby/smhasher/blob/master/src/MurmurHash3.cpp + +MurmurHash3 was written by Austin Appleby, and is placed in the public +domain. The author hereby disclaims copyright to this source + +License notice for The LLVM Project +----------------------------------- + +Copyright 2019 LLVM Project + +Licensed under the Apache License, Version 2.0 (the "License") with LLVM Exceptions; +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +https://llvm.org/LICENSE.txt + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +License for sse4-strstr (https://github.com/WojciechMula/sse4-strstr) +-------------------------------------- + + Copyright (c) 2008-2016, Wojciech Mula + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS + IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED + TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A + PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +License for the Teddy multi-substring searching implementation +-------------------------------------- + +https://github.com/BurntSushi/aho-corasick + +The MIT License (MIT) + +Copyright (c) 2015 Andrew Gallant + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. + +License notice for amd/aocl-libm-ose +------------------------------- + +Copyright (C) 2008-2020 Advanced Micro Devices, Inc. All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: +1. Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. +3. Neither the name of the copyright holder nor the names of its contributors + may be used to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, +INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. + +License notice for code from The Practice of Programming +------------------------------- + +Copyright (C) 1999 Lucent Technologies + +Excerpted from 'The Practice of Programming +by Brian W. Kernighan and Rob Pike + +You may use this code for any purpose, as long as you leave the copyright notice and book citation attached. + +License notice for fmtlib/fmt +------------------------------- + +Formatting library for C++ + +Copyright (c) 2012 - present, Victor Zverovich + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +License notice for gRPC +=================================== + +Copyright 2019 The gRPC Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +License notice for m-ou-se/floatconv +------------------------------- + +Copyright (c) 2020 Mara Bos +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +License notice for MsQuic +-------------------------------------- + +Copyright (c) Microsoft Corporation. +Licensed under the MIT License. + +Available at +https://github.com/microsoft/msquic/blob/main/LICENSE + +License notice for vectorized hex parsing +-------------------------------------------------------- + +Copyright (c) 2022, Geoff Langdale +Copyright (c) 2022, Wojciech Mula +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + +- Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + +- Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS +IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED +TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED +TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +Notice for Euclidean Affine Functions and Applications to Calendar +Algorithms +------------------------------- + +Aspects of Date/Time processing based on algorithm described in "Euclidean Affine Functions and Applications to Calendar +Algorithms", Cassio Neri and Lorenz Schneider. https://arxiv.org/pdf/2102.06959.pdf diff --git a/src/installer/pkg/sfx/Microsoft.NETCore.App/Directory.Build.props b/src/installer/pkg/sfx/Microsoft.NETCore.App/Directory.Build.props index 13cf1fc895c8ac..c44ffae5a7caff 100644 --- a/src/installer/pkg/sfx/Microsoft.NETCore.App/Directory.Build.props +++ b/src/installer/pkg/sfx/Microsoft.NETCore.App/Directory.Build.props @@ -155,9 +155,6 @@ - - - @@ -197,6 +194,8 @@ + + diff --git a/src/installer/pkg/sfx/Microsoft.NETCore.App/Microsoft.NETCore.App.Runtime.sfxproj b/src/installer/pkg/sfx/Microsoft.NETCore.App/Microsoft.NETCore.App.Runtime.sfxproj index ba597425bcfc1d..d0e9b6e16ff1a1 100644 --- a/src/installer/pkg/sfx/Microsoft.NETCore.App/Microsoft.NETCore.App.Runtime.sfxproj +++ b/src/installer/pkg/sfx/Microsoft.NETCore.App/Microsoft.NETCore.App.Runtime.sfxproj @@ -7,11 +7,14 @@ dotnet-runtime-internal dotnet-runtime dotnet-runtime-internal - $(SharedFrameworkName).PGO true + false dotnet-runtime-symbols NetCore.SharedFramework true + + true diff --git a/src/installer/tests/HostActivation.Tests/DotnetArgValidation.cs b/src/installer/tests/HostActivation.Tests/DotnetArgValidation.cs index 20e7e97a628e96..12501889471756 100644 --- a/src/installer/tests/HostActivation.Tests/DotnetArgValidation.cs +++ b/src/installer/tests/HostActivation.Tests/DotnetArgValidation.cs @@ -82,6 +82,35 @@ public void InvalidFileOrCommand_NoSDK_ListsPossibleIssues() .And.FindAnySdk(false); } + [Fact] + public void DotNetInfo_NoSDK() + { + sharedTestState.BuiltDotNet.Exec("--info") + .CaptureStdOut() + .CaptureStdErr() + .Execute() + .Should().Pass() + .And.HaveStdOutMatching($@"Architecture:\s*{RepoDirectoriesProvider.Default.BuildArchitecture}") + .And.HaveStdOutMatching($@"RID:\s*{RepoDirectoriesProvider.Default.BuildRID}"); + } + + [Fact] + public void DotNetInfo_WithSDK() + { + DotNetCli dotnet = new DotNetBuilder(sharedTestState.BaseDirectory.Location, RepoDirectoriesProvider.Default.BuiltDotnet, "withSdk") + .AddMicrosoftNETCoreAppFrameworkMockHostPolicy("1.0.0") + .AddMockSDK("1.0.0", "1.0.0") + .Build(); + + dotnet.Exec("--info") + .WorkingDirectory(sharedTestState.BaseDirectory.Location) + .CaptureStdOut() + .CaptureStdErr() + .Execute() + .Should().Pass() + .And.NotHaveStdOutMatching($@"RID:\s*{RepoDirectoriesProvider.Default.BuildRID}"); + } + // Return a non-existent path that contains a mix of / and \ private string GetNonexistentAndUnnormalizedPath() { diff --git a/src/installer/tests/Microsoft.NET.HostModel.Tests/Microsoft.NET.HostModel.AppHost.Tests/AppHostUpdateTests.cs b/src/installer/tests/Microsoft.NET.HostModel.Tests/Microsoft.NET.HostModel.AppHost.Tests/AppHostUpdateTests.cs index 232410be8f2596..b4d038b99b5ce3 100644 --- a/src/installer/tests/Microsoft.NET.HostModel.Tests/Microsoft.NET.HostModel.AppHost.Tests/AppHostUpdateTests.cs +++ b/src/installer/tests/Microsoft.NET.HostModel.Tests/Microsoft.NET.HostModel.AppHost.Tests/AppHostUpdateTests.cs @@ -11,6 +11,7 @@ using Microsoft.NET.HostModel.AppHost; using Microsoft.DotNet.CoreSetup.Test; using System.Diagnostics; +using System.Reflection.PortableExecutable; namespace Microsoft.NET.HostModel.Tests { @@ -111,7 +112,9 @@ public void ItCanSetWindowsGUISubsystem() BitConverter .ToUInt16(File.ReadAllBytes(destinationFilePath), SubsystemOffset) .Should() - .Be(2); + .Be((ushort)Subsystem.WindowsGui); + + Assert.Equal((ushort)Subsystem.WindowsGui, PEUtils.GetWindowsGraphicalUserInterfaceBit(destinationFilePath)); } } @@ -153,6 +156,7 @@ public void ItFailsToSetGUISubsystemWithWrongDefault() string destinationFilePath = Path.Combine(testDirectory.Path, "DestinationAppHost.exe.mock"); string appBinaryFilePath = "Test/App/Binary/Path.dll"; + Assert.Equal(42, PEUtils.GetWindowsGraphicalUserInterfaceBit(sourceAppHostMock)); Assert.Throws(() => HostWriter.CreateAppHost( sourceAppHostMock, diff --git a/src/installer/tests/TestUtils/Assertions/CommandResultAssertions.cs b/src/installer/tests/TestUtils/Assertions/CommandResultAssertions.cs index f939ba1388336b..f312913aff8a05 100644 --- a/src/installer/tests/TestUtils/Assertions/CommandResultAssertions.cs +++ b/src/installer/tests/TestUtils/Assertions/CommandResultAssertions.cs @@ -74,6 +74,13 @@ public AndConstraint HaveStdOutMatching(string pattern, return new AndConstraint(this); } + public AndConstraint NotHaveStdOutMatching(string pattern, RegexOptions options = RegexOptions.None) + { + Execute.Assertion.ForCondition(!Regex.IsMatch(Result.StdOut, pattern, options)) + .FailWith($"The command output matched a pattern is should not have matched. Pattern: '{pattern}'{GetDiagnosticsInfo()}"); + return new AndConstraint(this); + } + public AndConstraint HaveStdErr() { Execute.Assertion.ForCondition(!string.IsNullOrEmpty(Result.StdErr)) diff --git a/src/libraries/Common/src/Interop/Interop.Locale.iOS.cs b/src/libraries/Common/src/Interop/Interop.Locale.iOS.cs index 0f7b1763d58509..6725436b4f0945 100644 --- a/src/libraries/Common/src/Interop/Interop.Locale.iOS.cs +++ b/src/libraries/Common/src/Interop/Interop.Locale.iOS.cs @@ -24,5 +24,8 @@ internal static partial class Globalization [LibraryImport(Libraries.GlobalizationNative, EntryPoint = "GlobalizationNative_GetLocaleTimeFormatNative", StringMarshalling = StringMarshalling.Utf8)] internal static partial string GetLocaleTimeFormatNative(string localeName, [MarshalAs(UnmanagedType.Bool)] bool shortFormat); + + [LibraryImport(Libraries.GlobalizationNative, EntryPoint = "GlobalizationNative_GetLocalesNative", StringMarshalling = StringMarshalling.Utf16)] + internal static partial int GetLocalesNative([Out] char[]? value, int valueLength); } } diff --git a/src/libraries/Common/src/Roslyn/DiagnosticDescriptorHelper.cs b/src/libraries/Common/src/Roslyn/DiagnosticDescriptorHelper.cs index d11fdd8b5aaa03..8c5783ff9bd0b4 100644 --- a/src/libraries/Common/src/Roslyn/DiagnosticDescriptorHelper.cs +++ b/src/libraries/Common/src/Roslyn/DiagnosticDescriptorHelper.cs @@ -17,7 +17,7 @@ public static DiagnosticDescriptor Create( LocalizableString? description = null, params string[] customTags) { - string helpLink = $"https://learn.microsoft.com/dotnet/fundamentals/syslib-diagnostics/{id.ToLowerInvariant()}.md"; + string helpLink = $"https://learn.microsoft.com/dotnet/fundamentals/syslib-diagnostics/{id.ToLowerInvariant()}"; return new DiagnosticDescriptor(id, title, messageFormat, category, defaultSeverity, isEnabledByDefault, description, helpLink, customTags); } diff --git a/src/libraries/Common/src/SourceGenerators/DiagnosticInfo.cs b/src/libraries/Common/src/SourceGenerators/DiagnosticInfo.cs new file mode 100644 index 00000000000000..74f44f99c62baa --- /dev/null +++ b/src/libraries/Common/src/SourceGenerators/DiagnosticInfo.cs @@ -0,0 +1,60 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Linq; +using System.Numerics.Hashing; +using Microsoft.CodeAnalysis; + +namespace SourceGenerators; + +/// +/// Descriptor for diagnostic instances using structural equality comparison. +/// Provides a work-around for https://github.com/dotnet/roslyn/issues/68291. +/// +internal readonly struct DiagnosticInfo : IEquatable +{ + public DiagnosticDescriptor Descriptor { get; private init; } + public object?[] MessageArgs { get; private init; } + public Location? Location { get; private init; } + + public static DiagnosticInfo Create(DiagnosticDescriptor descriptor, Location? location, object?[]? messageArgs) + { + Location? trimmedLocation = location is null ? null : GetTrimmedLocation(location); + + return new DiagnosticInfo + { + Descriptor = descriptor, + Location = trimmedLocation, + MessageArgs = messageArgs ?? Array.Empty() + }; + + // Creates a copy of the Location instance that does not capture a reference to Compilation. + static Location GetTrimmedLocation(Location location) + => Location.Create(location.SourceTree?.FilePath ?? "", location.SourceSpan, location.GetLineSpan().Span); + } + + public Diagnostic CreateDiagnostic() + => Diagnostic.Create(Descriptor, Location, MessageArgs); + + public override readonly bool Equals(object? obj) => obj is DiagnosticInfo info && Equals(info); + + public readonly bool Equals(DiagnosticInfo other) + { + return Descriptor.Equals(other.Descriptor) && + MessageArgs.SequenceEqual(other.MessageArgs) && + Location == other.Location; + } + + public override readonly int GetHashCode() + { + int hashCode = Descriptor.GetHashCode(); + foreach (object? messageArg in MessageArgs) + { + hashCode = HashHelpers.Combine(hashCode, messageArg?.GetHashCode() ?? 0); + } + + hashCode = HashHelpers.Combine(hashCode, Location?.GetHashCode() ?? 0); + return hashCode; + } +} diff --git a/src/libraries/System.Text.Json/gen/Helpers/ImmutableEquatableArray.cs b/src/libraries/Common/src/SourceGenerators/ImmutableEquatableArray.cs similarity index 85% rename from src/libraries/System.Text.Json/gen/Helpers/ImmutableEquatableArray.cs rename to src/libraries/Common/src/SourceGenerators/ImmutableEquatableArray.cs index ac3aa804fdd9dc..47fdde1751882a 100644 --- a/src/libraries/System.Text.Json/gen/Helpers/ImmutableEquatableArray.cs +++ b/src/libraries/Common/src/SourceGenerators/ImmutableEquatableArray.cs @@ -1,12 +1,13 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System; using System.Collections; using System.Collections.Generic; using System.Linq; using System.Numerics.Hashing; -namespace System.Text.Json.SourceGeneration +namespace SourceGenerators { /// /// Provides an immutable list implementation which implements sequence equality. @@ -72,15 +73,9 @@ public bool MoveNext() } } - public static class ImmutableEquatableArray + internal static class ImmutableEquatableArray { - public static ImmutableEquatableArray Empty() where T : IEquatable - => ImmutableEquatableArray.Empty; - public static ImmutableEquatableArray ToImmutableEquatableArray(this IEnumerable values) where T : IEquatable => new(values); - - public static ImmutableEquatableArray Create(params T[] values) where T : IEquatable - => values is { Length: > 0 } ? new(values) : ImmutableEquatableArray.Empty; } } diff --git a/src/libraries/Common/src/SourceGenerators/TypeModelHelper.cs b/src/libraries/Common/src/SourceGenerators/TypeModelHelper.cs new file mode 100644 index 00000000000000..7a3a3e98fd7fde --- /dev/null +++ b/src/libraries/Common/src/SourceGenerators/TypeModelHelper.cs @@ -0,0 +1,40 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.CodeAnalysis; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Linq; + +namespace SourceGenerators +{ + internal static class TypeModelHelper + { + public static List? GetAllTypeArgumentsInScope(this INamedTypeSymbol type) + { + if (!type.IsGenericType) + { + return null; + } + + List? args = null; + TraverseContainingTypes(type); + return args; + + void TraverseContainingTypes(INamedTypeSymbol current) + { + if (current.ContainingType is INamedTypeSymbol parent) + { + TraverseContainingTypes(parent); + } + + if (!current.TypeArguments.IsEmpty) + { + (args ??= new()).AddRange(current.TypeArguments); + } + } + } + + public static string GetFullyQualifiedName(this ITypeSymbol type) => type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + } +} diff --git a/src/libraries/System.Text.Json/gen/Model/TypeRef.cs b/src/libraries/Common/src/SourceGenerators/TypeRef.cs similarity index 96% rename from src/libraries/System.Text.Json/gen/Model/TypeRef.cs rename to src/libraries/Common/src/SourceGenerators/TypeRef.cs index 050aba0cda658c..cfbf33ed741366 100644 --- a/src/libraries/System.Text.Json/gen/Model/TypeRef.cs +++ b/src/libraries/Common/src/SourceGenerators/TypeRef.cs @@ -1,10 +1,11 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System; using System.Diagnostics; using Microsoft.CodeAnalysis; -namespace System.Text.Json.SourceGeneration +namespace SourceGenerators { /// /// An equatable value representing type identity. diff --git a/src/libraries/Common/src/System/Net/Http/X509ResourceClient.cs b/src/libraries/Common/src/System/Net/Http/X509ResourceClient.cs index 57538487800e1e..74e84bb750386a 100644 --- a/src/libraries/Common/src/System/Net/Http/X509ResourceClient.cs +++ b/src/libraries/Common/src/System/Net/Http/X509ResourceClient.cs @@ -12,6 +12,9 @@ namespace System.Net.Http { internal static partial class X509ResourceClient { + private const long DefaultAiaDownloadLimit = 100 * 1024 * 1024; + private static long AiaDownloadLimit { get; } = GetValue("System.Security.Cryptography.AiaDownloadLimit", DefaultAiaDownloadLimit); + private static readonly Func>? s_downloadBytes = CreateDownloadBytesFunc(); static partial void ReportNoClient(); @@ -111,6 +114,7 @@ internal static partial class X509ResourceClient ConstructorInfo? httpRequestMessageCtor = httpRequestMessageType.GetConstructor(Type.EmptyTypes); MethodInfo? sendMethod = httpClientType.GetMethod("Send", new Type[] { httpRequestMessageType, typeof(CancellationToken) }); MethodInfo? sendAsyncMethod = httpClientType.GetMethod("SendAsync", new Type[] { httpRequestMessageType, typeof(CancellationToken) }); + PropertyInfo? maxResponseContentBufferSizeProp = httpClientType.GetProperty("MaxResponseContentBufferSize"); PropertyInfo? responseContentProp = httpResponseMessageType.GetProperty("Content"); PropertyInfo? responseStatusCodeProp = httpResponseMessageType.GetProperty("StatusCode"); PropertyInfo? responseHeadersProp = httpResponseMessageType.GetProperty("Headers"); @@ -121,7 +125,7 @@ internal static partial class X509ResourceClient if (socketsHttpHandlerCtor == null || pooledConnectionIdleTimeoutProp == null || allowAutoRedirectProp == null || httpClientCtor == null || requestUriProp == null || httpRequestMessageCtor == null || - sendMethod == null || sendAsyncMethod == null || + sendMethod == null || sendAsyncMethod == null || maxResponseContentBufferSizeProp == null || responseContentProp == null || responseStatusCodeProp == null || responseHeadersProp == null || responseHeadersLocationProp == null || readAsStreamMethod == null || taskOfHttpResponseMessageResultProp == null) @@ -145,6 +149,7 @@ internal static partial class X509ResourceClient pooledConnectionIdleTimeoutProp.SetValue(socketsHttpHandler, TimeSpan.FromSeconds(PooledConnectionIdleTimeoutSeconds)); allowAutoRedirectProp.SetValue(socketsHttpHandler, false); object? httpClient = httpClientCtor.Invoke(new object?[] { socketsHttpHandler }); + maxResponseContentBufferSizeProp.SetValue(httpClient, AiaDownloadLimit); return async (string uriString, CancellationToken cancellationToken, bool async) => { @@ -302,5 +307,24 @@ private static bool IsAllowedScheme(string scheme) { return string.Equals(scheme, "http", StringComparison.OrdinalIgnoreCase); } + + private static long GetValue(string name, long defaultValue) + { + object? data = AppContext.GetData(name); + + if (data is null) + { + return defaultValue; + } + + try + { + return Convert.ToInt64(data); + } + catch + { + return defaultValue; + } + } } } diff --git a/src/libraries/Common/src/System/Net/TlsStream.cs b/src/libraries/Common/src/System/Net/TlsStream.cs index 047253159127ea..503253099aac17 100644 --- a/src/libraries/Common/src/System/Net/TlsStream.cs +++ b/src/libraries/Common/src/System/Net/TlsStream.cs @@ -5,6 +5,8 @@ using System.Net.Sockets; using System.Security.Authentication; using System.Security.Cryptography.X509Certificates; +using System.Threading; +using System.Threading.Tasks; namespace System.Net { @@ -46,6 +48,11 @@ public void EndAuthenticateAsClient(IAsyncResult asyncResult) _sslStream.EndAuthenticateAsClient(asyncResult); } + public override void Write(byte[] buffer, int offset, int size) + { + _sslStream.Write(buffer, offset, size); + } + public override IAsyncResult BeginWrite(byte[] buffer, int offset, int size, AsyncCallback? callback, object? state) { return _sslStream.BeginWrite(buffer, offset, size, callback, state); @@ -56,9 +63,9 @@ public override void EndWrite(IAsyncResult result) _sslStream.EndWrite(result); } - public override void Write(byte[] buffer, int offset, int size) + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { - _sslStream.Write(buffer, offset, size); + return _sslStream.WriteAsync(buffer, offset, count, cancellationToken); } public override int Read(byte[] buffer, int offset, int size) @@ -66,6 +73,11 @@ public override int Read(byte[] buffer, int offset, int size) return _sslStream.Read(buffer, offset, size); } + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + return _sslStream.ReadAsync(buffer, offset, count, cancellationToken); + } + public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback? callback, object? state) { return _sslStream.BeginRead(buffer, offset, count, callback, state); diff --git a/src/libraries/Common/tests/SourceGenerators/GeneratorTestHelpers.cs b/src/libraries/Common/tests/SourceGenerators/GeneratorTestHelpers.cs new file mode 100644 index 00000000000000..d62a3c788e73dc --- /dev/null +++ b/src/libraries/Common/tests/SourceGenerators/GeneratorTestHelpers.cs @@ -0,0 +1,84 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using Xunit; + +namespace SourceGenerators.Tests +{ + internal static class GeneratorTestHelpers + { + /// + /// Asserts for structural equality, returning a path to the mismatching data when not equal. + /// + public static void AssertStructurallyEqual(T expected, T actual) + { + CheckAreEqualCore(expected, actual, new()); + static void CheckAreEqualCore(object expected, object actual, Stack path) + { + if (expected is null || actual is null) + { + if (expected is not null || actual is not null) + { + FailNotEqual(); + } + + return; + } + + Type type = expected.GetType(); + if (type != actual.GetType()) + { + FailNotEqual(); + return; + } + + if (expected is IEnumerable leftCollection) + { + if (actual is not IEnumerable rightCollection) + { + FailNotEqual(); + return; + } + + object?[] expectedValues = leftCollection.Cast().ToArray(); + object?[] actualValues = rightCollection.Cast().ToArray(); + + for (int i = 0; i < Math.Max(expectedValues.Length, actualValues.Length); i++) + { + object? expectedElement = i < expectedValues.Length ? expectedValues[i] : ""; + object? actualElement = i < actualValues.Length ? actualValues[i] : ""; + + path.Push($"[{i}]"); + CheckAreEqualCore(expectedElement, actualElement, path); + path.Pop(); + } + } + + if (type.GetProperty("EqualityContract", BindingFlags.Instance | BindingFlags.NonPublic, null, returnType: typeof(Type), types: Array.Empty(), null) != null) + { + // Type is a C# record, run pointwise equality comparison. + foreach (PropertyInfo property in type.GetProperties(BindingFlags.Public | BindingFlags.Instance)) + { + path.Push("." + property.Name); + CheckAreEqualCore(property.GetValue(expected), property.GetValue(actual), path); + path.Pop(); + } + + return; + } + + if (!expected.Equals(actual)) + { + FailNotEqual(); + } + + void FailNotEqual() => Assert.Fail($"Value not equal in ${string.Join("", path.Reverse())}: expected {expected}, but was {actual}."); + } + } + } +} diff --git a/src/libraries/Common/tests/SourceGenerators/RoslynTestUtils.cs b/src/libraries/Common/tests/SourceGenerators/RoslynTestUtils.cs index 2b4568a79eaba0..e67289721d8af9 100644 --- a/src/libraries/Common/tests/SourceGenerators/RoslynTestUtils.cs +++ b/src/libraries/Common/tests/SourceGenerators/RoslynTestUtils.cs @@ -91,7 +91,7 @@ public static async Task AssertNoDiagnostic(this Project proj, params string[] i } } - private static Project WithDocuments(this Project project, IEnumerable sources, IEnumerable? sourceNames = null) + public static Project WithDocuments(this Project project, IEnumerable sources, IEnumerable? sourceNames = null) { int count = 0; Project result = project; diff --git a/src/libraries/Common/tests/System/FunctionPointerTests.cs b/src/libraries/Common/tests/System/FunctionPointerTests.cs index 925c8683f8f7f4..8a344e01e78955 100644 --- a/src/libraries/Common/tests/System/FunctionPointerTests.cs +++ b/src/libraries/Common/tests/System/FunctionPointerTests.cs @@ -173,6 +173,22 @@ public static unsafe void RequiredModifiers() Assert.Equal(typeof(Runtime.InteropServices.OutAttribute).Project(), parameters[1].GetRequiredCustomModifiers()[0]); } + [Fact] + public static unsafe void GenericFunctionPointer() + { + Type t = typeof(FunctionPointerHolder).Project(); + + MethodInfo m1 = t.GetMethod(nameof(FunctionPointerHolder.GenericReturnValue), Bindings); + Type fcnPtr1 = m1.ReturnType; + Assert.True(fcnPtr1.IsFunctionPointer); + Assert.True(fcnPtr1.ContainsGenericParameters); + + MethodInfo m2 = t.GetMethod(nameof(FunctionPointerHolder.GenericArgument), Bindings); + Type fcnPtr2 = m2.GetParameters()[1].ParameterType; + Assert.True(fcnPtr2.IsFunctionPointer); + Assert.True(fcnPtr2.ContainsGenericParameters); + } + [Theory] [InlineData(nameof(FunctionPointerHolder.MethodReturnValue1), "MethodReturnValue1()", @@ -278,6 +294,9 @@ private unsafe class FunctionPointerHolder public delegate* unmanaged[Stdcall, MemberFunction] SeveralArguments() => default; public delegate* RequiredModifiers() => default; + public delegate* GenericReturnValue() => default; + public bool GenericArgument(int x, delegate* fptr) => default; + public class MyClass { } public struct MyStruct { } } diff --git a/src/libraries/Common/tests/System/Net/Http/Http3LoopbackConnection.cs b/src/libraries/Common/tests/System/Net/Http/Http3LoopbackConnection.cs index ec5c7a022c65a6..9d6fef5fb3a720 100644 --- a/src/libraries/Common/tests/System/Net/Http/Http3LoopbackConnection.cs +++ b/src/libraries/Common/tests/System/Net/Http/Http3LoopbackConnection.cs @@ -108,11 +108,6 @@ public static int GetRequestId(QuicStream stream) return checked((int)stream.Id + 1); } - public Http3LoopbackStream GetOpenRequest(int requestId = 0) - { - return requestId == 0 ? _currentStream : _openStreams[requestId - 1]; - } - public override Task InitializeConnectionAsync() { throw new NotImplementedException(); @@ -195,6 +190,17 @@ public async Task EstablishControlStreamAsync(SettingsEntry[] settingsEntries) await _outboundControlStream.SendSettingsFrameAsync(settingsEntries); } + public async Task DisposeCurrentStream() + { + Assert.NotNull(_currentStream); + Assert.True(_currentStreamId >= 0); + + await _currentStream.DisposeAsync().ConfigureAwait(false); + _openStreams.Remove((int)_currentStreamId); + _currentStream = null; + _currentStreamId = -4; + } + public override async Task ReadRequestBodyAsync() { return await _currentStream.ReadRequestBodyAsync().ConfigureAwait(false); @@ -206,24 +212,32 @@ public override async Task ReadRequestDataAsync(bool readBody = return await stream.ReadRequestDataAsync(readBody).ConfigureAwait(false); } - public override Task SendResponseAsync(HttpStatusCode statusCode = HttpStatusCode.OK, IList headers = null, string content = "", bool isFinal = true) + public override async Task SendResponseAsync(HttpStatusCode statusCode = HttpStatusCode.OK, IList headers = null, string content = "", bool isFinal = true) { - return GetOpenRequest().SendResponseAsync(statusCode, headers, content, isFinal); + await _currentStream.SendResponseAsync(statusCode, headers, content, isFinal); + if (isFinal) + { + await DisposeCurrentStream().ConfigureAwait(false); + } } - public override Task SendResponseBodyAsync(byte[] content, bool isFinal = true) + public override async Task SendResponseBodyAsync(byte[] content, bool isFinal = true) { - return GetOpenRequest().SendResponseBodyAsync(content, isFinal); + await _currentStream.SendResponseBodyAsync(content, isFinal); + if (isFinal) + { + await DisposeCurrentStream().ConfigureAwait(false); + } } public override Task SendResponseHeadersAsync(HttpStatusCode statusCode = HttpStatusCode.OK, IList headers = null) { - return GetOpenRequest().SendResponseHeadersAsync(statusCode, headers); + return _currentStream.SendResponseHeadersAsync(statusCode, headers); } public override Task SendPartialResponseHeadersAsync(HttpStatusCode statusCode = HttpStatusCode.OK, IList headers = null) { - return GetOpenRequest().SendPartialResponseHeadersAsync(statusCode, headers); + return _currentStream.SendPartialResponseHeadersAsync(statusCode, headers); } public override async Task HandleRequestAsync(HttpStatusCode statusCode = HttpStatusCode.OK, IList headers = null, string content = "") @@ -310,7 +324,7 @@ public async Task WaitForClientDisconnectAsync(bool refuseNewRequests = true) public override async Task WaitForCancellationAsync(bool ignoreIncomingData = true) { - await GetOpenRequest().WaitForCancellationAsync(ignoreIncomingData).ConfigureAwait(false); + await _currentStream.WaitForCancellationAsync(ignoreIncomingData).ConfigureAwait(false); } public override Task WaitForCloseAsync(CancellationToken cancellationToken) diff --git a/src/libraries/Common/tests/TestUtilities/System/AssertExtensions.cs b/src/libraries/Common/tests/TestUtilities/System/AssertExtensions.cs index bf676b7b6f26be..230e4cb4276a42 100644 --- a/src/libraries/Common/tests/TestUtilities/System/AssertExtensions.cs +++ b/src/libraries/Common/tests/TestUtilities/System/AssertExtensions.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Linq; +using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Threading; using System.Threading.Tasks; @@ -15,6 +16,26 @@ public static class AssertExtensions { private static bool IsNetFramework => RuntimeInformation.FrameworkDescription.StartsWith(".NET Framework"); + + /// + /// Helper for AOT tests that verifies that the compile succeeds, or throws PlatformNotSupported + /// when AOT is enabled. + /// + public static void ThrowsOnAot(Action action) + where T : Exception + { +#if NETCOREAPP // Dynamic code is always supported on .NET Framework + if (!RuntimeFeature.IsDynamicCodeSupported) + { + Assert.Throws(action); + } + else +#endif + { + action(); + } + } + public static void Throws(Action action, string expectedMessage) where T : Exception { diff --git a/src/libraries/Common/tests/TestUtilities/System/Buffers/BoundedMemory.Unix.cs b/src/libraries/Common/tests/TestUtilities/System/Buffers/BoundedMemory.Unix.cs index b127a2a72491e9..c8197b0055092f 100644 --- a/src/libraries/Common/tests/TestUtilities/System/Buffers/BoundedMemory.Unix.cs +++ b/src/libraries/Common/tests/TestUtilities/System/Buffers/BoundedMemory.Unix.cs @@ -30,6 +30,8 @@ public UnixImplementation(int elementCount) public override bool IsReadonly => false; + public override int Length => _elementCount; + public override Memory Memory => _memoryManager.Memory; public override Span Span @@ -83,10 +85,7 @@ protected override void Dispose(bool disposing) // no-op; the handle will be disposed separately } - public override Span GetSpan() - { - throw new NotImplementedException(); - } + public override Span GetSpan() => _impl.Span; public override MemoryHandle Pin(int elementIndex) { diff --git a/src/libraries/Common/tests/TestUtilities/System/Buffers/BoundedMemory.Windows.cs b/src/libraries/Common/tests/TestUtilities/System/Buffers/BoundedMemory.Windows.cs index 7b1cbbafc72dbc..96f40d61492e52 100644 --- a/src/libraries/Common/tests/TestUtilities/System/Buffers/BoundedMemory.Windows.cs +++ b/src/libraries/Common/tests/TestUtilities/System/Buffers/BoundedMemory.Windows.cs @@ -81,6 +81,8 @@ internal WindowsImplementation(VirtualAllocHandle handle, int byteOffsetIntoHand public override bool IsReadonly => (Protection != VirtualAllocProtection.PAGE_READWRITE); + public override int Length => _elementCount; + internal VirtualAllocProtection Protection { get @@ -189,10 +191,7 @@ protected override void Dispose(bool disposing) // no-op; the handle will be disposed separately } - public override Span GetSpan() - { - throw new NotImplementedException(); - } + public override Span GetSpan() => _impl.Span; public override MemoryHandle Pin(int elementIndex) { diff --git a/src/libraries/Common/tests/TestUtilities/System/Buffers/BoundedMemory.cs b/src/libraries/Common/tests/TestUtilities/System/Buffers/BoundedMemory.cs index d70b30721fdbd9..81d359cc80e926 100644 --- a/src/libraries/Common/tests/TestUtilities/System/Buffers/BoundedMemory.cs +++ b/src/libraries/Common/tests/TestUtilities/System/Buffers/BoundedMemory.cs @@ -14,6 +14,9 @@ public abstract class BoundedMemory : IDisposable where T : unmanaged /// public abstract bool IsReadonly { get; } + /// Gets the length of the instance. + public abstract int Length { get; } + /// /// Gets the which represents this native memory. /// This instance must be kept alive while working with the . @@ -44,5 +47,23 @@ public abstract class BoundedMemory : IDisposable where T : unmanaged /// OS does not support marking the memory block as read+write. /// public abstract void MakeWriteable(); + + /// + /// Gets the which represents this native memory. + /// This instance must be kept alive while working with the . + /// + public static implicit operator Span(BoundedMemory boundedMemory) => boundedMemory.Span; + + /// + /// Gets the which represents this native memory. + /// This instance must be kept alive while working with the . + /// + public static implicit operator ReadOnlySpan(BoundedMemory boundedMemory) => boundedMemory.Span; + + /// + /// Gets a reference to the element at the specified index. + /// This instance must be kept alive while working with the reference. + /// + public ref T this[int index] => ref Span[index]; } } diff --git a/src/libraries/Microsoft.Bcl.AsyncInterfaces/src/PACKAGE.md b/src/libraries/Microsoft.Bcl.AsyncInterfaces/src/PACKAGE.md new file mode 100644 index 00000000000000..e0c6e8ae9adaae --- /dev/null +++ b/src/libraries/Microsoft.Bcl.AsyncInterfaces/src/PACKAGE.md @@ -0,0 +1,64 @@ +## About + +As of C# 8, the C# language has support for producing and consuming asynchronous iterators. The library types in support of those features are available in .NET Core 3.0 and newer as well as in .NET Standard 2.1. This library provides the necessary definitions of those types to support these language features on .NET Framework and on .NET Standard 2.0. This library is not necessary nor recommended when targeting versions of .NET that include the relevant support. + +## Key Features + + + +* Enables the use of C# async iterators on older .NET platforms + +## How to Use + + + +```C# +using System; +using System.Collections.Generic; +using System.Threading.Tasks; + +internal static class Program +{ + private static async Task Main() + { + Console.WriteLine("Starting..."); + await foreach (var value in GetValuesAsync()) + { + Console.WriteLine(value); + } + Console.WriteLine("Finished!"); + + static async IAsyncEnumerable GetValuesAsync() + { + for (int i = 0; i < 10; i++) + { + await Task.Delay(TimeSpan.FromSeconds(1)); + yield return i; + } + } + } +} +``` + +## Main Types + + + +The main types provided by this library are: + +* `IAsyncEnumerable` +* `IAsyncEnumerator` +* `IAsyncDisposable` + +## Additional Documentation + + + +* [C# Feature Specification](https://learn.microsoft.com/dotnet/csharp/language-reference/proposals/csharp-8.0/async-streams) +* [Walkthrough article](https://learn.microsoft.com/archive/msdn-magazine/2019/november/csharp-iterating-with-async-enumerables-in-csharp-8) + +## Feedback & Contributing + + + +Microsoft.Bcl.AsyncInterfaces is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). \ No newline at end of file diff --git a/src/libraries/Microsoft.Bcl.Cryptography/src/PACKAGE.md b/src/libraries/Microsoft.Bcl.Cryptography/src/PACKAGE.md new file mode 100644 index 00000000000000..215d29e162c4b6 --- /dev/null +++ b/src/libraries/Microsoft.Bcl.Cryptography/src/PACKAGE.md @@ -0,0 +1,39 @@ +## About + +This library provides some cryptographic types and functionality for .NET Standard and .NET Framework. This library is not necessary nor recommended when targeting versions of .NET that include the relevant support. + +## Key Features + +* Enables the use of some cryptographic functionality on older .NET platforms. + +## How to Use + +This package should only be used by platforms where the desired functionality is not built-in. + +```C# +using System.Security.Cryptography; + +internal static class Program +{ + private static void Main() + { + byte[] key = LoadKey(); + SP800108HmacCounterKdf kbkdf = new(key, HashAlgorithmName.SHA256); + byte[] derivedKey = kbkdf.DeriveKey("label"u8, "context"u8, derivedKeyLengthInBytes: 32); + } +} +``` + +## Main Types + +The main types provided by this library are: + +* `System.Security.Cryptography.SP800108HmacCounterKdf` + +## Additional Documentation + +* [API documentation](https://learn.microsoft.com/dotnet/api/System.Security.Cryptography) + +## Feedback & Contributing + +Microsoft.Bcl.Cryptography is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). diff --git a/src/libraries/Microsoft.Bcl.Numerics/Microsoft.Bcl.Numerics.sln b/src/libraries/Microsoft.Bcl.Numerics/Microsoft.Bcl.Numerics.sln new file mode 100644 index 00000000000000..1192785b14a9fc --- /dev/null +++ b/src/libraries/Microsoft.Bcl.Numerics/Microsoft.Bcl.Numerics.sln @@ -0,0 +1,74 @@ +Microsoft Visual Studio Solution File, Format Version 12.00 +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TestUtilities", "..\Common\tests\TestUtilities\TestUtilities.csproj", "{CAEE0409-CCC3-4EA6-AB54-177FD305D42D}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.Bcl.Numerics", "ref\Microsoft.Bcl.Numerics.csproj", "{73E7C25C-AEBC-4F4F-B8D1-0CC49D5B92DE}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.Bcl.Numerics", "src\Microsoft.Bcl.Numerics.csproj", "{4D4BED71-8904-4A74-88CD-63D002CCACD0}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.Bcl.Numerics.Tests", "tests\Microsoft.Bcl.Numerics.Tests.csproj", "{51D9518A-464D-4257-9567-3BDCFF24F3EE}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "ComInterfaceGenerator", "..\System.Runtime.InteropServices\gen\ComInterfaceGenerator\ComInterfaceGenerator.csproj", "{E30F71EB-6C3B-4052-84F7-36EAA178A45E}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LibraryImportGenerator", "..\System.Runtime.InteropServices\gen\LibraryImportGenerator\LibraryImportGenerator.csproj", "{0AE44453-273B-4F0E-9901-A87891A73C1B}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.Interop.SourceGeneration", "..\System.Runtime.InteropServices\gen\Microsoft.Interop.SourceGeneration\Microsoft.Interop.SourceGeneration.csproj", "{D0F1936C-CF7C-4448-9F90-B9DEABE89EBB}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "tests", "tests", "{6614EF7F-23FC-4809-AFF5-1ADBF1B6422C}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "ref", "ref", "{111B1B5B-A004-4C05-9A8C-E0931DADA5FB}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{85204CF5-0C88-4BBB-9E70-D8CCED82ED3D}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "gen", "gen", "{D6A9108E-553B-445E-A037-FA4F3140A279}" +EndProject +Global + GlobalSection(SolutionConfigurationPlatforms) = preSolution + Debug|Any CPU = Debug|Any CPU + Release|Any CPU = Release|Any CPU + EndGlobalSection + GlobalSection(ProjectConfigurationPlatforms) = postSolution + {CAEE0409-CCC3-4EA6-AB54-177FD305D42D}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {CAEE0409-CCC3-4EA6-AB54-177FD305D42D}.Debug|Any CPU.Build.0 = Debug|Any CPU + {CAEE0409-CCC3-4EA6-AB54-177FD305D42D}.Release|Any CPU.ActiveCfg = Release|Any CPU + {CAEE0409-CCC3-4EA6-AB54-177FD305D42D}.Release|Any CPU.Build.0 = Release|Any CPU + {73E7C25C-AEBC-4F4F-B8D1-0CC49D5B92DE}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {73E7C25C-AEBC-4F4F-B8D1-0CC49D5B92DE}.Debug|Any CPU.Build.0 = Debug|Any CPU + {73E7C25C-AEBC-4F4F-B8D1-0CC49D5B92DE}.Release|Any CPU.ActiveCfg = Release|Any CPU + {73E7C25C-AEBC-4F4F-B8D1-0CC49D5B92DE}.Release|Any CPU.Build.0 = Release|Any CPU + {4D4BED71-8904-4A74-88CD-63D002CCACD0}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {4D4BED71-8904-4A74-88CD-63D002CCACD0}.Debug|Any CPU.Build.0 = Debug|Any CPU + {4D4BED71-8904-4A74-88CD-63D002CCACD0}.Release|Any CPU.ActiveCfg = Release|Any CPU + {4D4BED71-8904-4A74-88CD-63D002CCACD0}.Release|Any CPU.Build.0 = Release|Any CPU + {51D9518A-464D-4257-9567-3BDCFF24F3EE}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {51D9518A-464D-4257-9567-3BDCFF24F3EE}.Debug|Any CPU.Build.0 = Debug|Any CPU + {51D9518A-464D-4257-9567-3BDCFF24F3EE}.Release|Any CPU.ActiveCfg = Release|Any CPU + {51D9518A-464D-4257-9567-3BDCFF24F3EE}.Release|Any CPU.Build.0 = Release|Any CPU + {E30F71EB-6C3B-4052-84F7-36EAA178A45E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {E30F71EB-6C3B-4052-84F7-36EAA178A45E}.Debug|Any CPU.Build.0 = Debug|Any CPU + {E30F71EB-6C3B-4052-84F7-36EAA178A45E}.Release|Any CPU.ActiveCfg = Release|Any CPU + {E30F71EB-6C3B-4052-84F7-36EAA178A45E}.Release|Any CPU.Build.0 = Release|Any CPU + {0AE44453-273B-4F0E-9901-A87891A73C1B}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {0AE44453-273B-4F0E-9901-A87891A73C1B}.Debug|Any CPU.Build.0 = Debug|Any CPU + {0AE44453-273B-4F0E-9901-A87891A73C1B}.Release|Any CPU.ActiveCfg = Release|Any CPU + {0AE44453-273B-4F0E-9901-A87891A73C1B}.Release|Any CPU.Build.0 = Release|Any CPU + {D0F1936C-CF7C-4448-9F90-B9DEABE89EBB}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {D0F1936C-CF7C-4448-9F90-B9DEABE89EBB}.Debug|Any CPU.Build.0 = Debug|Any CPU + {D0F1936C-CF7C-4448-9F90-B9DEABE89EBB}.Release|Any CPU.ActiveCfg = Release|Any CPU + {D0F1936C-CF7C-4448-9F90-B9DEABE89EBB}.Release|Any CPU.Build.0 = Release|Any CPU + EndGlobalSection + GlobalSection(SolutionProperties) = preSolution + HideSolutionNode = FALSE + EndGlobalSection + GlobalSection(NestedProjects) = preSolution + {CAEE0409-CCC3-4EA6-AB54-177FD305D42D} = {6614EF7F-23FC-4809-AFF5-1ADBF1B6422C} + {51D9518A-464D-4257-9567-3BDCFF24F3EE} = {6614EF7F-23FC-4809-AFF5-1ADBF1B6422C} + {73E7C25C-AEBC-4F4F-B8D1-0CC49D5B92DE} = {111B1B5B-A004-4C05-9A8C-E0931DADA5FB} + {4D4BED71-8904-4A74-88CD-63D002CCACD0} = {85204CF5-0C88-4BBB-9E70-D8CCED82ED3D} + {E30F71EB-6C3B-4052-84F7-36EAA178A45E} = {D6A9108E-553B-445E-A037-FA4F3140A279} + {0AE44453-273B-4F0E-9901-A87891A73C1B} = {D6A9108E-553B-445E-A037-FA4F3140A279} + {D0F1936C-CF7C-4448-9F90-B9DEABE89EBB} = {D6A9108E-553B-445E-A037-FA4F3140A279} + EndGlobalSection + GlobalSection(ExtensibilityGlobals) = postSolution + SolutionGuid = {A835CEDB-E9E2-49EE-8499-BD7FDD984E53} + EndGlobalSection +EndGlobal diff --git a/src/coreclr/nativeaot/Runtime/GCMemoryHelpers.h b/src/libraries/Microsoft.Bcl.Numerics/ref/Microsoft.Bcl.Numerics.Forwards.cs similarity index 52% rename from src/coreclr/nativeaot/Runtime/GCMemoryHelpers.h rename to src/libraries/Microsoft.Bcl.Numerics/ref/Microsoft.Bcl.Numerics.Forwards.cs index 127b4d772040ab..641ce5525675f4 100644 --- a/src/coreclr/nativeaot/Runtime/GCMemoryHelpers.h +++ b/src/libraries/Microsoft.Bcl.Numerics/ref/Microsoft.Bcl.Numerics.Forwards.cs @@ -1,8 +1,4 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -// -// Unmanaged GC memory helpers -// - -EXTERN_C void REDHAWK_CALLCONV RhpBulkWriteBarrier(void* pMemStart, uint32_t cbMemSize); +[assembly: System.Runtime.CompilerServices.TypeForwardedTo(typeof(System.MathF))] diff --git a/src/libraries/Microsoft.Bcl.Numerics/ref/Microsoft.Bcl.Numerics.cs b/src/libraries/Microsoft.Bcl.Numerics/ref/Microsoft.Bcl.Numerics.cs new file mode 100644 index 00000000000000..d952814464f676 --- /dev/null +++ b/src/libraries/Microsoft.Bcl.Numerics/ref/Microsoft.Bcl.Numerics.cs @@ -0,0 +1,42 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// ------------------------------------------------------------------------------ +// Changes to this file must follow the https://aka.ms/api-review process. +// ------------------------------------------------------------------------------ + +namespace System +{ + public static partial class MathF + { + public const float E = 2.7182817f; + public const float PI = 3.1415927f; + public static float Abs(float x) { throw null; } + public static float Acos(float x) { throw null; } + public static float Asin(float x) { throw null; } + public static float Atan(float x) { throw null; } + public static float Atan2(float y, float x) { throw null; } + public static float Ceiling(float x) { throw null; } + public static float Cos(float x) { throw null; } + public static float Cosh(float x) { throw null; } + public static float Exp(float x) { throw null; } + public static float Floor(float x) { throw null; } + public static float IEEERemainder(float x, float y) { throw null; } + public static float Log(float x) { throw null; } + public static float Log(float x, float y) { throw null; } + public static float Log10(float x) { throw null; } + public static float Max(float x, float y) { throw null; } + public static float Min(float x, float y) { throw null; } + public static float Pow(float x, float y) { throw null; } + public static float Round(float x) { throw null; } + public static float Round(float x, int digits) { throw null; } + public static float Round(float x, int digits, System.MidpointRounding mode) { throw null; } + public static float Round(float x, System.MidpointRounding mode) { throw null; } + public static int Sign(float x) { throw null; } + public static float Sin(float x) { throw null; } + public static float Sinh(float x) { throw null; } + public static float Sqrt(float x) { throw null; } + public static float Tan(float x) { throw null; } + public static float Tanh(float x) { throw null; } + public static float Truncate(float x) { throw null; } + } +} diff --git a/src/libraries/Microsoft.Bcl.Numerics/ref/Microsoft.Bcl.Numerics.csproj b/src/libraries/Microsoft.Bcl.Numerics/ref/Microsoft.Bcl.Numerics.csproj new file mode 100644 index 00000000000000..36d3ac9605c195 --- /dev/null +++ b/src/libraries/Microsoft.Bcl.Numerics/ref/Microsoft.Bcl.Numerics.csproj @@ -0,0 +1,12 @@ + + + + netstandard2.0;$(NetFrameworkMinimum);netstandard2.1 + + + + + + + + diff --git a/src/libraries/Microsoft.Bcl.Numerics/src/Microsoft.Bcl.Numerics.csproj b/src/libraries/Microsoft.Bcl.Numerics/src/Microsoft.Bcl.Numerics.csproj new file mode 100644 index 00000000000000..5d6b96dc6a29a8 --- /dev/null +++ b/src/libraries/Microsoft.Bcl.Numerics/src/Microsoft.Bcl.Numerics.csproj @@ -0,0 +1,24 @@ + + + + netstandard2.0;$(NetFrameworkMinimum);netstandard2.1 + true + true + + + Provides the System.MathF for .NET Standard 2.0 + + true + + + + + + + + + + + diff --git a/src/libraries/Microsoft.Bcl.Numerics/src/PACKAGE.md b/src/libraries/Microsoft.Bcl.Numerics/src/PACKAGE.md new file mode 100644 index 00000000000000..b19975c09f41f4 --- /dev/null +++ b/src/libraries/Microsoft.Bcl.Numerics/src/PACKAGE.md @@ -0,0 +1,43 @@ +## About + +As of .NET Core 2.0 and .NET Standard 2.1, the C# language has support for math (System.MathF) functions with floats. This library provides the necessary definitions of those types to support these language features on .NET Framework and on .NET Standard 2.0. This library is not necessary nor recommended when targeting versions of .NET that include the relevant support. + +## Key Features + + + +* Enables the use of MathF on older .NET platforms + +## How to Use + +```C# +using System; + +internal static class Program +{ + private static async Task Main() + { + Console.WriteLine(MathF.Max(1f, 5f)); // returns 5f + } +} +``` + +## Main Types + + + +The main types provided by this library are: + +* `System.MathF` + +## Additional Documentation + + + +* [API documentation](https://learn.microsoft.com/dotnet/api/system.mathf) + +## Feedback & Contributing + + + +Microsoft.Bcl.Numerics is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). \ No newline at end of file diff --git a/src/libraries/Microsoft.Bcl.Numerics/src/System/MathF.cs b/src/libraries/Microsoft.Bcl.Numerics/src/System/MathF.cs new file mode 100644 index 00000000000000..128a3174ad0c4e --- /dev/null +++ b/src/libraries/Microsoft.Bcl.Numerics/src/System/MathF.cs @@ -0,0 +1,332 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +/*============================================================ +** +** Purpose: Some single-precision floating-point math operations +** +===========================================================*/ + +//This class contains only static members and doesn't require serialization. + +//For most of this implementation for .NET Framework we just defer to System.Math and do a cast internally from single to double. +//We do this because it safer and less likely to break people since that is what they are alrady doing. Also, adding in the +//extra pinvokes needed to not do this route would probably incur an extra overhead that would be undersired. + +//For any version of .NET Core this just forwards directly to the MathF implementation inside the runtime. + +//There are a few cases where .NET Framework handles things differently than .NET Core does. For example, it returns -0 and +0 +//when using things like Min/Max, and they count as different values from each other. This is fixed in .NET Core, but since its +//inherent in .NET Framework we decided to leave that behavior as is for this BCL. + +using System.Diagnostics.Contracts; + +namespace System +{ + /// + /// Provides constants and static methods for trigonometric, logarithmic, and other common mathematical functions. + /// + public static class MathF + { + /// + /// Represents the ratio of the circumference of a circle to its diameter, specified by the constant, p. + /// + public const float PI = 3.14159265f; + + /// + /// Represents the natural logarithmic base, specified by the constant, e. + /// + public const float E = 2.71828183f; + + private static float NegativeZero = Int32BitsToSingle(unchecked((int)0x80000000)); + + private static unsafe float Int32BitsToSingle(int value) + { + return *((float*)&value); + } + + [Pure] + private static unsafe bool IsNegative(float f) + { + return (*(uint*)(&f) & 0x80000000) == 0x80000000; + } + + /// + /// Returns the absolute value of a single-precision floating-point number. + /// + /// The number to take the absolute value of. + /// The absolute value of + public static float Abs(float x) => Math.Abs(x); + + /// + /// Returns the angle whose cosine is the specified number. + /// + /// The number to take the acos of. + /// The acos of + public static float Acos(float x) => (float)Math.Acos(x); + + /// + /// Returns the angle whose sine is the specified number. + /// + /// The number to take the asin of. + /// The asin of + public static float Asin(float x) => (float)Math.Asin(x); + + /// + /// Returns the angle whose tangent is the specified number. + /// + /// The number to take the atan of. + /// The atan of + public static float Atan(float x) => (float)Math.Atan(x); + + /// + /// Returns the angle whose tangent is the quotient of two specified numbers. + /// + /// The first number. + /// The second number. + /// The angle whose tangent is the quotient of and + public static float Atan2(float y, float x) => (float)Math.Atan2(y, x); + + /// + /// Returns the smallest integral value that is greater than or equal to the specified single-precision floating-point number. + /// + /// The number to take the ceiling of. + /// The ceiling of + public static float Ceiling(float x) => (float)Math.Ceiling(x); + + /// + /// Returns the cosine of the specified angle. + /// + /// The angle to take the cosine of. + /// The cosine of + public static float Cos(float x) => (float)Math.Cos(x); + + /// + /// Returns the hyperbolic cosine of the specified angle. + /// + /// The angle to take the hyperbolic cosine of. + /// The hyperbolic cosine of + public static float Cosh(float x) => (float)Math.Cosh(x); + + /// + /// Returns e raised to the specified power. + /// + /// The number to raise e to. + /// e raised to the power of + public static float Exp(float x) => (float)Math.Exp(x); + + /// + /// Returns the largest integral value less than or equal to the specified single-precision floating-point number. + /// + /// The number to take the floor of. + /// The floor of + public static float Floor(float x) => (float)Math.Floor(x); + + /// + /// Returns the remainder resulting from the division of a specified number by another specified number. + /// + /// The numerator + /// The denominator + /// The result of dividing by + public static float IEEERemainder(float x, float y) + { + if (float.IsNaN(x)) + { + return x; // IEEE 754-2008: NaN payload must be preserved + } + + if (float.IsNaN(y)) + { + return y; // IEEE 754-2008: NaN payload must be preserved + } + + var regularMod = x % y; + + if (float.IsNaN(regularMod)) + { + return float.NaN; + } + + if ((regularMod == 0) && IsNegative(x)) + { + return NegativeZero; + } + + var alternativeResult = (regularMod - (Abs(y) * Sign(x))); + + if (Abs(alternativeResult) == Abs(regularMod)) + { + var divisionResult = x / y; + var roundedResult = Round(divisionResult); + + if (Abs(roundedResult) > Abs(divisionResult)) + { + return alternativeResult; + } + else + { + return regularMod; + } + } + + if (Abs(alternativeResult) < Abs(regularMod)) + { + return alternativeResult; + } + else + { + return regularMod; + } + } + + /// + /// Returns the natural (base e) logarithm of a specified number. + /// + /// The number to take the natural log of. + /// The natural log of + public static float Log(float x) => (float)Math.Log(x); + + /// + /// Returns the logarithm of a specified number in a specified base. + /// + /// The number to take the log of. + /// The base of the log + /// The log of with base + public static float Log(float x, float y) + { + if (float.IsNaN(x)) + { + return x; // IEEE 754-2008: NaN payload must be preserved + } + + if (float.IsNaN(y)) + { + return y; // IEEE 754-2008: NaN payload must be preserved + } + + if (y == 1) + { + return float.NaN; + } + + if ((x != 1) && ((y == 0) || float.IsPositiveInfinity(y))) + { + return float.NaN; + } + + return Log(x) / Log(y); + } + + /// + /// Returns the base 10 logarithm of a specified number. + /// + /// The number to take the base 10 log of. + /// The base 10 log of + public static float Log10(float x) => (float)Math.Log10(x); + + /// + /// Returns the larger of two single-precision floating-point numbers. + /// + /// The first number to compare. + /// The second number to compare. + /// The larger of and + public static float Max(float x, float y) => Math.Max(x, y); + + /// + /// Returns the smaller of two single-precision floating-point numbers. + /// + /// The first number to compare. + /// The second number to compare. + /// The smaller of and + public static float Min(float x, float y) => Math.Min(x, y); + + /// + /// Returns a specified number raised to the specified power. + /// + /// The base number. + /// The specified power. + /// raised to the power of + public static float Pow(float x, float y) => (float)Math.Pow(x, y); + + /// + /// Rounds a single-precision floating-point value to the nearest integral value, and rounds midpoint values to the nearest even number. + /// + /// The number to round. + /// The rounded representation of + public static float Round(float x) => (float)Math.Round(x); + + /// + /// Rounds a single-precision floating-point value to a specified number of fractional digits, and rounds midpoint values to the nearest even number. + /// + /// The number to round. + /// How many fractional digits to keep. + /// The rounded representation of with fractional digits + public static float Round(float x, int digits) => (float)Math.Round(x, digits); + + /// + /// Rounds a single-precision floating-point value to a specified number of fractional digits using the specified rounding convention. + /// + /// The number to round. + /// How many fractional digits to keep. + /// The rounding convention to use. + /// The rounded representation of with fractional digits using rounding convention + public static float Round(float x, int digits, MidpointRounding mode) => (float)Math.Round(x, digits, mode); + + /// + /// Rounds a single-precision floating-point value to an integer using the specified rounding convention. + /// + /// The number to round. + /// The rounding convention to use. + /// The rounded representation of using rounding convention + public static float Round(float x, MidpointRounding mode) => (float)Math.Round(x, mode); + + /// + /// Returns an integer that indicates the sign of a single-precision floating-point number. + /// + /// The number check the sign of. + /// The sign of + public static int Sign(float x) => Math.Sign(x); + + /// + /// Returns the sine of the specified angle. + /// + /// The angle to take the sine of. + /// The sine of + public static float Sin(float x) => (float)Math.Sin(x); + + /// + /// Returns the hyperbolic sine of the specified angle. + /// + /// The angle to take the hyperbolic sine of. + /// The hyperbolic sine of + public static float Sinh(float x) => (float)Math.Sinh(x); + + /// + /// Returns the square root of a specified number. + /// + /// The number to take the square root of. + /// The square root of + public static float Sqrt(float x) => (float)Math.Sqrt(x); + + /// + /// Returns the tangent of the specified angle. + /// + /// The angle to take the tangent of. + /// The tangent of + public static float Tan(float x) => (float)Math.Tan(x); + + /// + /// Returns the hyperbolic tangent of the specified angle. + /// + /// The angle to take the hyperbolic tangent of. + /// The hyperbolic tangent of + public static float Tanh(float x) => (float)Math.Tanh(x); + + /// + /// Calculates the integral part of a specified single-precision floating-point number. + /// + /// The number to truncate. + /// The truncated representation of + public static float Truncate(float x) => (float)Math.Truncate(x); + } +} diff --git a/src/libraries/Microsoft.Bcl.Numerics/src/System/Microsoft.Bcl.Numerics.Forwards.cs b/src/libraries/Microsoft.Bcl.Numerics/src/System/Microsoft.Bcl.Numerics.Forwards.cs new file mode 100644 index 00000000000000..641ce5525675f4 --- /dev/null +++ b/src/libraries/Microsoft.Bcl.Numerics/src/System/Microsoft.Bcl.Numerics.Forwards.cs @@ -0,0 +1,4 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +[assembly: System.Runtime.CompilerServices.TypeForwardedTo(typeof(System.MathF))] diff --git a/src/libraries/Microsoft.Bcl.Numerics/tests/MathF.cs b/src/libraries/Microsoft.Bcl.Numerics/tests/MathF.cs new file mode 100644 index 00000000000000..3aa148a2161564 --- /dev/null +++ b/src/libraries/Microsoft.Bcl.Numerics/tests/MathF.cs @@ -0,0 +1,1169 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Xunit; +using System.Collections.Generic; +using System.Runtime.CompilerServices; + +#pragma warning disable xUnit1025 // reporting duplicate test cases due to not distinguishing 0.0 from -0.0 + +namespace System.Tests +{ + public static class MathFTests + { + // binary32 (float) has a machine epsilon of 2^-23 (approx. 1.19e-07). However, this + // is slightly too accurate when writing tests meant to run against libm implementations + // for various platforms. 2^-21 (approx. 4.76e-07) seems to be as accurate as we can get. + // + // The tests themselves will take CrossPlatformMachineEpsilon and adjust it according to the expected result + // so that the delta used for comparison will compare the most significant digits and ignore + // any digits that are outside the single precision range (6-9 digits). + + // For example, a test with an expect result in the format of 0.xxxxxxxxx will use + // CrossPlatformMachineEpsilon for the variance, while an expected result in the format of 0.0xxxxxxxxx + // will use CrossPlatformMachineEpsilon / 10 and expected result in the format of x.xxxxxx will + // use CrossPlatformMachineEpsilon * 10. + private const float CrossPlatformMachineEpsilon = 4.76837158e-07f; + + // The existing estimate functions either have an error of no more than 1.5 * 2^-12 (approx. 3.66e-04) + // or perform one Newton-Raphson iteration which, for the currently tested values, gives an error of + // no more than approx. 1.5 * 2^-7 (approx 1.17e-02). + private const double CrossPlatformMachineEpsilonForEstimates = 1.171875e-02f; + + [Fact] + public static void E() + { + Assert.Equal(2.71828183f, MathF.E); + } + + [Fact] + public static void Pi() + { + Assert.Equal(3.14159265f, MathF.PI); + } + + [Theory] + [InlineData(float.NegativeInfinity, float.PositiveInfinity, 0.0f)] + [InlineData(-3.14159265f, 3.14159265f, CrossPlatformMachineEpsilon * 10)] // value: -(pi) expected: (pi) + [InlineData(-2.71828183f, 2.71828183f, CrossPlatformMachineEpsilon * 10)] // value: -(e) expected: (e) + [InlineData(-2.30258509f, 2.30258509f, CrossPlatformMachineEpsilon * 10)] // value: -(ln(10)) expected: (ln(10)) + [InlineData(-1.57079633f, 1.57079633f, CrossPlatformMachineEpsilon * 10)] // value: -(pi / 2) expected: (pi / 2) + [InlineData(-1.44269504f, 1.44269504f, CrossPlatformMachineEpsilon * 10)] // value: -(log2(e)) expected: (log2(e)) + [InlineData(-1.41421356f, 1.41421356f, CrossPlatformMachineEpsilon * 10)] // value: -(sqrt(2)) expected: (sqrt(2)) + [InlineData(-1.12837917f, 1.12837917f, CrossPlatformMachineEpsilon * 10)] // value: -(2 / sqrt(pi)) expected: (2 / sqrt(pi)) + [InlineData(-1.0f, 1.0f, CrossPlatformMachineEpsilon * 10)] + [InlineData(-0.785398163f, 0.785398163f, CrossPlatformMachineEpsilon)] // value: -(pi / 4) expected: (pi / 4) + [InlineData(-0.707106781f, 0.707106781f, CrossPlatformMachineEpsilon)] // value: -(1 / sqrt(2)) expected: (1 / sqrt(2)) + [InlineData(-0.693147181f, 0.693147181f, CrossPlatformMachineEpsilon)] // value: -(ln(2)) expected: (ln(2)) + [InlineData(-0.636619772f, 0.636619772f, CrossPlatformMachineEpsilon)] // value: -(2 / pi) expected: (2 / pi) + [InlineData(-0.434294482f, 0.434294482f, CrossPlatformMachineEpsilon)] // value: -(log10(e)) expected: (log10(e)) + [InlineData(-0.318309886f, 0.318309886f, CrossPlatformMachineEpsilon)] // value: -(1 / pi) expected: (1 / pi) + [InlineData(-0.0f, 0.0f, 0.0f)] + [InlineData(float.NaN, float.NaN, 0.0f)] + [InlineData(0.0f, 0.0f, 0.0f)] + [InlineData(0.318309886f, 0.318309886f, CrossPlatformMachineEpsilon)] // value: (1 / pi) expected: (1 / pi) + [InlineData(0.434294482f, 0.434294482f, CrossPlatformMachineEpsilon)] // value: (log10(e)) expected: (log10(e)) + [InlineData(0.636619772f, 0.636619772f, CrossPlatformMachineEpsilon)] // value: (2 / pi) expected: (2 / pi) + [InlineData(0.693147181f, 0.693147181f, CrossPlatformMachineEpsilon)] // value: (ln(2)) expected: (ln(2)) + [InlineData(0.707106781f, 0.707106781f, CrossPlatformMachineEpsilon)] // value: (1 / sqrt(2)) expected: (1 / sqrt(2)) + [InlineData(0.785398163f, 0.785398163f, CrossPlatformMachineEpsilon)] // value: (pi / 4) expected: (pi / 4) + [InlineData(1.0f, 1.0f, CrossPlatformMachineEpsilon * 10)] + [InlineData(1.12837917f, 1.12837917f, CrossPlatformMachineEpsilon * 10)] // value: (2 / sqrt(pi)) expected: (2 / sqrt(pi)) + [InlineData(1.41421356f, 1.41421356f, CrossPlatformMachineEpsilon * 10)] // value: (sqrt(2)) expected: (sqrt(2)) + [InlineData(1.44269504f, 1.44269504f, CrossPlatformMachineEpsilon * 10)] // value: (log2(e)) expected: (log2(e)) + [InlineData(1.57079633f, 1.57079633f, CrossPlatformMachineEpsilon * 10)] // value: (pi / 2) expected: (pi / 2) + [InlineData(2.30258509f, 2.30258509f, CrossPlatformMachineEpsilon * 10)] // value: (ln(10)) expected: (ln(10)) + [InlineData(2.71828183f, 2.71828183f, CrossPlatformMachineEpsilon * 10)] // value: (e) expected: (e) + [InlineData(3.14159265f, 3.14159265f, CrossPlatformMachineEpsilon * 10)] // value: (pi) expected: (pi) + [InlineData(float.PositiveInfinity, float.PositiveInfinity, 0.0f)] + public static void Abs(float value, float expectedResult, float allowedVariance) + { + AssertExtensions.Equal(expectedResult, MathF.Abs(value), allowedVariance); + } + + [Theory] + [InlineData(float.NegativeInfinity, float.NaN, 0.0f)] + [InlineData(-3.14159265f, float.NaN, 0.0f)] // value: -(pi) + [InlineData(-2.71828183f, float.NaN, 0.0f)] // value: -(e) + [InlineData(-1.41421356f, float.NaN, 0.0f)] // value: -(sqrt(2)) + [InlineData(-1.0f, 3.14159265f, CrossPlatformMachineEpsilon * 10)] // expected: (pi) + [InlineData(-0.911733915f, 2.71828183f, CrossPlatformMachineEpsilon * 10)] // expected: (e) + [InlineData(-0.668201510f, 2.30258509f, CrossPlatformMachineEpsilon * 10)] // expected: (ln(10)) + [InlineData(-0.0f, 1.57079633f, CrossPlatformMachineEpsilon * 10)] // expected: (pi / 2) + [InlineData(float.NaN, float.NaN, 0.0f)] + [InlineData(0.0f, 1.57079633f, CrossPlatformMachineEpsilon * 10)] // expected: (pi / 2) + [InlineData(0.127751218f, 1.44269504f, CrossPlatformMachineEpsilon * 10)] // expected: (log2(e)) + [InlineData(0.155943695f, 1.41421356f, CrossPlatformMachineEpsilon * 10)] // expected: (sqrt(2)) + [InlineData(0.428125148f, 1.12837917f, CrossPlatformMachineEpsilon * 10)] // expected: (2 / sqrt(pi)) + [InlineData(0.540302306f, 1.0f, CrossPlatformMachineEpsilon * 10)] + [InlineData(0.707106781f, 0.785398163f, CrossPlatformMachineEpsilon)] // expected: (pi / 4), value: (1 / sqrt(2)) + [InlineData(0.760244597f, 0.707106781f, CrossPlatformMachineEpsilon)] // expected: (1 / sqrt(2)) + [InlineData(0.769238901f, 0.693147181f, CrossPlatformMachineEpsilon)] // expected: (ln(2)) + [InlineData(0.804109828f, 0.636619772f, CrossPlatformMachineEpsilon)] // expected: (2 / pi) + [InlineData(0.907167129f, 0.434294482f, CrossPlatformMachineEpsilon)] // expected: (log10(e)) + [InlineData(0.949765715f, 0.318309886f, CrossPlatformMachineEpsilon)] // expected: (1 / pi) + [InlineData(1.0f, 0.0f, 0.0f)] + [InlineData(1.41421356f, float.NaN, 0.0f)] // value: (sqrt(2)) + [InlineData(2.71828183f, float.NaN, 0.0f)] // value: (e) + [InlineData(3.14159265f, float.NaN, 0.0f)] // value: (pi) + [InlineData(float.PositiveInfinity, float.NaN, 0.0f)] + public static void Acos(float value, float expectedResult, float allowedVariance) + { + AssertExtensions.Equal(expectedResult, MathF.Acos(value), allowedVariance); + } + + [Theory] + [InlineData(float.NegativeInfinity, float.NaN, 0.0f)] + [InlineData(-3.14159265f, float.NaN, 0.0f)] // value: -(pi) + [InlineData(-2.71828183f, float.NaN, 0.0f)] // value: -(e) + [InlineData(-1.41421356f, float.NaN, 0.0f)] // value: -(sqrt(2)) + [InlineData(-1.0f, -1.57079633f, CrossPlatformMachineEpsilon * 10)] // expected: -(pi / 2) + [InlineData(-0.991806244f, -1.44269504f, CrossPlatformMachineEpsilon * 10)] // expected: -(log2(e)) + [InlineData(-0.987765946f, -1.41421356f, CrossPlatformMachineEpsilon * 10)] // expected: -(sqrt(2)) + [InlineData(-0.903719457f, -1.12837917f, CrossPlatformMachineEpsilon * 10)] // expected: -(2 / sqrt(pi)) + [InlineData(-0.841470985f, -1.0f, CrossPlatformMachineEpsilon * 10)] + [InlineData(-0.743980337f, -0.839007561f, CrossPlatformMachineEpsilon)] // expected: -(pi - ln(10)) + [InlineData(-0.707106781f, -0.785398163f, CrossPlatformMachineEpsilon)] // expected: -(pi / 4), value: (1 / sqrt(2)) + [InlineData(-0.649636939f, -0.707106781f, CrossPlatformMachineEpsilon)] // expected: -(1 / sqrt(2)) + [InlineData(-0.638961276f, -0.693147181f, CrossPlatformMachineEpsilon)] // expected: -(ln(2)) + [InlineData(-0.594480769f, -0.636619772f, CrossPlatformMachineEpsilon)] // expected: -(2 / pi) + [InlineData(-0.420770483f, -0.434294482f, CrossPlatformMachineEpsilon)] // expected: -(log10(e)) + [InlineData(-0.410781291f, -0.423310825f, CrossPlatformMachineEpsilon)] // expected: -(pi - e) + [InlineData(-0.312961796f, -0.318309886f, CrossPlatformMachineEpsilon)] // expected: -(1 / pi) + [InlineData(-0.0f, -0.0f, 0.0f)] + [InlineData(float.NaN, float.NaN, 0.0f)] + [InlineData(0.0f, 0.0f, 0.0f)] + [InlineData(0.312961796f, 0.318309886f, CrossPlatformMachineEpsilon)] // expected: (1 / pi) + [InlineData(0.410781291f, 0.423310825f, CrossPlatformMachineEpsilon)] // expected: (pi - e) + [InlineData(0.420770483f, 0.434294482f, CrossPlatformMachineEpsilon)] // expected: (log10(e)) + [InlineData(0.594480769f, 0.636619772f, CrossPlatformMachineEpsilon)] // expected: (2 / pi) + [InlineData(0.638961276f, 0.693147181f, CrossPlatformMachineEpsilon)] // expected: (ln(2)) + [InlineData(0.649636939f, 0.707106781f, CrossPlatformMachineEpsilon)] // expected: (1 / sqrt(2)) + [InlineData(0.707106781f, 0.785398163f, CrossPlatformMachineEpsilon)] // expected: (pi / 4), value: (1 / sqrt(2)) + [InlineData(0.743980337f, 0.839007561f, CrossPlatformMachineEpsilon)] // expected: (pi - ln(10)) + [InlineData(0.841470985f, 1.0f, CrossPlatformMachineEpsilon * 10)] + [InlineData(0.903719457f, 1.12837917f, CrossPlatformMachineEpsilon * 10)] // expected: (2 / sqrt(pi)) + [InlineData(0.987765946f, 1.41421356f, CrossPlatformMachineEpsilon * 10)] // expected: (sqrt(2)) + [InlineData(0.991806244f, 1.44269504f, CrossPlatformMachineEpsilon * 10)] // expected: (log2(e)) + [InlineData(1.0f, 1.57079633f, CrossPlatformMachineEpsilon * 10)] // expected: (pi / 2) + [InlineData(1.41421356f, float.NaN, 0.0f)] // value: (sqrt(2)) + [InlineData(2.71828183f, float.NaN, 0.0f)] // value: (e) + [InlineData(3.14159265f, float.NaN, 0.0f)] // value: (pi) + [InlineData(float.PositiveInfinity, float.NaN, 0.0f)] + public static void Asin(float value, float expectedResult, float allowedVariance) + { + AssertExtensions.Equal(expectedResult, MathF.Asin(value), allowedVariance); + } + + [Theory] + [InlineData(float.NegativeInfinity, -1.57079633f, CrossPlatformMachineEpsilon * 10)] // expected: -(pi / 2) + [InlineData(-7.76357567f, -1.44269504f, CrossPlatformMachineEpsilon * 10)] // expected: -(log2(e)) + [InlineData(-6.33411917f, -1.41421356f, CrossPlatformMachineEpsilon * 10)] // expected: -(sqrt(2)) + [InlineData(-2.11087684f, -1.12837917f, CrossPlatformMachineEpsilon * 10)] // expected: -(2 / sqrt(pi)) + [InlineData(-1.55740772f, -1.0f, CrossPlatformMachineEpsilon * 10)] + [InlineData(-1.11340715f, -0.839007561f, CrossPlatformMachineEpsilon)] // expected: -(pi - ln(10)) + [InlineData(-1.0f, -0.785398163f, CrossPlatformMachineEpsilon)] // expected: -(pi / 4) + [InlineData(-0.854510432f, -0.707106781f, CrossPlatformMachineEpsilon)] // expected: -(1 / sqrt(2)) + [InlineData(-0.830640878f, -0.693147181f, CrossPlatformMachineEpsilon)] // expected: -(ln(2)) + [InlineData(-0.739302950f, -0.636619772f, CrossPlatformMachineEpsilon)] // expected: -(2 / pi) + [InlineData(-0.463829067f, -0.434294482f, CrossPlatformMachineEpsilon)] // expected: -(log10(e)) + [InlineData(-0.450549534f, -0.423310825f, CrossPlatformMachineEpsilon)] // expected: -(pi - e) + [InlineData(-0.329514733f, -0.318309886f, CrossPlatformMachineEpsilon)] // expected: -(1 / pi) + [InlineData(-0.0f, -0.0f, 0.0f)] + [InlineData(float.NaN, float.NaN, 0.0f)] + [InlineData(0.0f, 0.0f, 0.0f)] + [InlineData(0.329514733f, 0.318309886f, CrossPlatformMachineEpsilon)] // expected: (1 / pi) + [InlineData(0.450549534f, 0.423310825f, CrossPlatformMachineEpsilon)] // expected: (pi - e) + [InlineData(0.463829067f, 0.434294482f, CrossPlatformMachineEpsilon)] // expected: (log10(e)) + [InlineData(0.739302950f, 0.636619772f, CrossPlatformMachineEpsilon)] // expected: (2 / pi) + [InlineData(0.830640878f, 0.693147181f, CrossPlatformMachineEpsilon)] // expected: (ln(2)) + [InlineData(0.854510432f, 0.707106781f, CrossPlatformMachineEpsilon)] // expected: (1 / sqrt(2)) + [InlineData(1.0f, 0.785398163f, CrossPlatformMachineEpsilon)] // expected: (pi / 4) + [InlineData(1.11340715f, 0.839007561f, CrossPlatformMachineEpsilon)] // expected: (pi - ln(10)) + [InlineData(1.55740772f, 1.0f, CrossPlatformMachineEpsilon * 10)] + [InlineData(2.11087684f, 1.12837917f, CrossPlatformMachineEpsilon * 10)] // expected: (2 / sqrt(pi)) + [InlineData(6.33411917f, 1.41421356f, CrossPlatformMachineEpsilon * 10)] // expected: (sqrt(2)) + [InlineData(7.76357567f, 1.44269504f, CrossPlatformMachineEpsilon * 10)] // expected: (log2(e)) + [InlineData(float.PositiveInfinity, 1.57079633f, CrossPlatformMachineEpsilon * 10)] // expected: (pi / 2) + public static void Atan(float value, float expectedResult, float allowedVariance) + { + AssertExtensions.Equal(expectedResult, MathF.Atan(value), allowedVariance); + } + + [Theory] + [InlineData(float.NegativeInfinity, -1.0f, -1.57079633f, CrossPlatformMachineEpsilon * 10)] // expected: -(pi / 2) + [InlineData(float.NegativeInfinity, -0.0f, -1.57079633f, CrossPlatformMachineEpsilon * 10)] // expected: -(pi / 2) + [InlineData(float.NegativeInfinity, float.NaN, float.NaN, 0.0f)] + [InlineData(float.NegativeInfinity, 0.0f, -1.57079633f, CrossPlatformMachineEpsilon * 10)] // expected: -(pi / 2) + [InlineData(float.NegativeInfinity, 1.0f, -1.57079633f, CrossPlatformMachineEpsilon * 10)] // expected: -(pi / 2) + [InlineData(-1.0f, -1.0f, -2.35619449f, CrossPlatformMachineEpsilon * 10)] // expected: -(3 * pi / 4) + [InlineData(-1.0f, -0.0f, -1.57079633f, CrossPlatformMachineEpsilon * 10)] // expected: -(pi / 2) + [InlineData(-1.0f, float.NaN, float.NaN, 0.0f)] + [InlineData(-1.0f, 0.0f, -1.57079633f, CrossPlatformMachineEpsilon * 10)] // expected: -(pi / 2) + [InlineData(-1.0f, 1.0f, -0.785398163f, CrossPlatformMachineEpsilon)] // expected: -(pi / 4) + [InlineData(-1.0f, float.PositiveInfinity, -0.0f, 0.0f)] + [InlineData(-0.991806244f, -0.127751218f, -1.69889761f, CrossPlatformMachineEpsilon * 10)] // expected: -(pi - log2(e)) + [InlineData(-0.991806244f, 0.127751218f, -1.44269504f, CrossPlatformMachineEpsilon * 10)] // expected: -(log2(e)) + [InlineData(-0.987765946f, -0.155943695f, -1.72737909f, CrossPlatformMachineEpsilon * 10)] // expected: -(pi - sqrt(2)) + [InlineData(-0.987765946f, 0.155943695f, -1.41421356f, CrossPlatformMachineEpsilon * 10)] // expected: -(sqrt(2)) + [InlineData(-0.903719457f, -0.428125148f, -2.01321349f, CrossPlatformMachineEpsilon * 10)] // expected: -(pi - (2 / sqrt(pi)) + [InlineData(-0.903719457f, 0.428125148f, -1.12837917f, CrossPlatformMachineEpsilon * 10)] // expected: -(2 / sqrt(pi) + [InlineData(-0.841470985f, -0.540302306f, -2.14159265f, CrossPlatformMachineEpsilon * 10)] // expected: -(pi - 1) + [InlineData(-0.841470985f, 0.540302306f, -1.0f, CrossPlatformMachineEpsilon * 10)] + [InlineData(-0.743980337f, -0.668201510f, -2.30258509f, CrossPlatformMachineEpsilon * 10)] // expected: -(ln(10)) + [InlineData(-0.743980337f, 0.668201510f, -0.839007561f, CrossPlatformMachineEpsilon)] // expected: -(pi - ln(10)) + [InlineData(-0.707106781f, -0.707106781f, -2.35619449f, CrossPlatformMachineEpsilon * 10)] // expected: -(3 * pi / 4), y: -(1 / sqrt(2)) x: -(1 / sqrt(2)) + [InlineData(-0.707106781f, 0.707106781f, -0.785398163f, CrossPlatformMachineEpsilon)] // expected: -(pi / 4), y: -(1 / sqrt(2)) x: (1 / sqrt(2)) + [InlineData(-0.649636939f, -0.760244597f, -2.43448587f, CrossPlatformMachineEpsilon * 10)] // expected: -(pi - (1 / sqrt(2)) + [InlineData(-0.649636939f, 0.760244597f, -0.707106781f, CrossPlatformMachineEpsilon)] // expected: -(1 / sqrt(2)) + [InlineData(-0.638961276f, -0.769238901f, -2.44844547f, CrossPlatformMachineEpsilon * 10)] // expected: -(pi - ln(2)) + [InlineData(-0.638961276f, 0.769238901f, -0.693147181f, CrossPlatformMachineEpsilon)] // expected: -(ln(2)) + [InlineData(-0.594480769f, -0.804109828f, -2.50497288f, CrossPlatformMachineEpsilon * 10)] // expected: -(pi - (2 / pi)) + [InlineData(-0.594480769f, 0.804109828f, -0.636619772f, CrossPlatformMachineEpsilon)] // expected: -(2 / pi) + [InlineData(-0.420770483f, -0.907167129f, -2.70729817f, CrossPlatformMachineEpsilon * 10)] // expected: -(pi - log10(e)) + [InlineData(-0.420770483f, 0.907167129f, -0.434294482f, CrossPlatformMachineEpsilon)] // expected: -(log10(e)) + [InlineData(-0.410781291f, -0.911733915f, -2.71828183f, CrossPlatformMachineEpsilon * 10)] // expected: -(e) + [InlineData(-0.410781291f, 0.911733915f, -0.423310825f, CrossPlatformMachineEpsilon)] // expected: -(pi - e) + [InlineData(-0.312961796f, -0.949765715f, -2.82328277f, CrossPlatformMachineEpsilon * 10)] // expected: -(pi - (1 / pi)) + [InlineData(-0.312961796f, 0.949765715f, -0.318309886f, CrossPlatformMachineEpsilon)] // expected: -(1 / pi) + [InlineData(-0.0f, float.NegativeInfinity, -3.14159265f, CrossPlatformMachineEpsilon * 10)] // expected: -(pi) + [InlineData(-0.0f, -1.0f, -3.14159265f, CrossPlatformMachineEpsilon * 10)] // expected: -(pi) + [InlineData(-0.0f, -0.0f, -3.14159265f, CrossPlatformMachineEpsilon * 10)] // expected: -(pi) + [InlineData(-0.0f, float.NaN, float.NaN, 0.0f)] + [InlineData(-0.0f, 0.0f, -0.0f, 0.0f)] + [InlineData(-0.0f, 1.0f, -0.0f, 0.0f)] + [InlineData(-0.0f, float.PositiveInfinity, -0.0f, 0.0f)] + [InlineData(float.NaN, float.NegativeInfinity, float.NaN, 0.0f)] + [InlineData(float.NaN, -1.0f, float.NaN, 0.0f)] + [InlineData(float.NaN, -0.0f, float.NaN, 0.0f)] + [InlineData(float.NaN, float.NaN, float.NaN, 0.0f)] + [InlineData(float.NaN, 0.0f, float.NaN, 0.0f)] + [InlineData(float.NaN, 1.0f, float.NaN, 0.0f)] + [InlineData(float.NaN, float.PositiveInfinity, float.NaN, 0.0f)] + [InlineData(0.0f, float.NegativeInfinity, 3.14159265f, CrossPlatformMachineEpsilon * 10)] // expected: (pi) + [InlineData(0.0f, -1.0f, 3.14159265f, CrossPlatformMachineEpsilon * 10)] // expected: (pi) + [InlineData(0.0f, -0.0f, 3.14159265f, CrossPlatformMachineEpsilon * 10)] // expected: (pi) + [InlineData(0.0f, float.NaN, float.NaN, 0.0f)] + [InlineData(0.0f, 0.0f, 0.0f, 0.0f)] + [InlineData(0.0f, 1.0f, 0.0f, 0.0f)] + [InlineData(0.0f, float.PositiveInfinity, 0.0f, 0.0f)] + [InlineData(0.312961796f, -0.949765715f, 2.82328277f, CrossPlatformMachineEpsilon * 10)] // expected: (pi - (1 / pi)) + [InlineData(0.312961796f, 0.949765715f, 0.318309886f, CrossPlatformMachineEpsilon)] // expected: (1 / pi) + [InlineData(0.410781291f, -0.911733915f, 2.71828183f, CrossPlatformMachineEpsilon * 10)] // expected: (e) + [InlineData(0.410781291f, 0.911733915f, 0.423310825f, CrossPlatformMachineEpsilon)] // expected: (pi - e) + [InlineData(0.420770483f, -0.907167129f, 2.70729817f, CrossPlatformMachineEpsilon * 10)] // expected: (pi - log10(e)) + [InlineData(0.420770483f, 0.907167129f, 0.434294482f, CrossPlatformMachineEpsilon)] // expected: (log10(e)) + [InlineData(0.594480769f, -0.804109828f, 2.50497288f, CrossPlatformMachineEpsilon * 10)] // expected: (pi - (2 / pi)) + [InlineData(0.594480769f, 0.804109828f, 0.636619772f, CrossPlatformMachineEpsilon)] // expected: (2 / pi) + [InlineData(0.638961276f, -0.769238901f, 2.44844547f, CrossPlatformMachineEpsilon * 10)] // expected: (pi - ln(2)) + [InlineData(0.638961276f, 0.769238901f, 0.693147181f, CrossPlatformMachineEpsilon)] // expected: (ln(2)) + [InlineData(0.649636939f, -0.760244597f, 2.43448587f, CrossPlatformMachineEpsilon * 10)] // expected: (pi - (1 / sqrt(2)) + [InlineData(0.649636939f, 0.760244597f, 0.707106781f, CrossPlatformMachineEpsilon)] // expected: (1 / sqrt(2)) + [InlineData(0.707106781f, -0.707106781f, 2.35619449f, CrossPlatformMachineEpsilon * 10)] // expected: (3 * pi / 4), y: (1 / sqrt(2)) x: -(1 / sqrt(2)) + [InlineData(0.707106781f, 0.707106781f, 0.785398163f, CrossPlatformMachineEpsilon)] // expected: (pi / 4), y: (1 / sqrt(2)) x: (1 / sqrt(2)) + [InlineData(0.743980337f, -0.668201510f, 2.30258509f, CrossPlatformMachineEpsilon * 10)] // expected: (ln(10)) + [InlineData(0.743980337f, 0.668201510f, 0.839007561f, CrossPlatformMachineEpsilon)] // expected: (pi - ln(10)) + [InlineData(0.841470985f, -0.540302306f, 2.14159265f, CrossPlatformMachineEpsilon * 10)] // expected: (pi - 1) + [InlineData(0.841470985f, 0.540302306f, 1.0f, CrossPlatformMachineEpsilon * 10)] + [InlineData(0.903719457f, -0.428125148f, 2.01321349f, CrossPlatformMachineEpsilon * 10)] // expected: (pi - (2 / sqrt(pi)) + [InlineData(0.903719457f, 0.428125148f, 1.12837917f, CrossPlatformMachineEpsilon * 10)] // expected: (2 / sqrt(pi)) + [InlineData(0.987765946f, -0.155943695f, 1.72737909f, CrossPlatformMachineEpsilon * 10)] // expected: (pi - sqrt(2)) + [InlineData(0.987765946f, 0.155943695f, 1.41421356f, CrossPlatformMachineEpsilon * 10)] // expected: (sqrt(2)) + [InlineData(0.991806244f, -0.127751218f, 1.69889761f, CrossPlatformMachineEpsilon * 10)] // expected: (pi - log2(e)) + [InlineData(0.991806244f, 0.127751218f, 1.44269504f, CrossPlatformMachineEpsilon * 10)] // expected: (log2(e)) + [InlineData(1.0f, -1.0f, 2.35619449f, CrossPlatformMachineEpsilon * 10)] // expected: (3 * pi / 4) + [InlineData(1.0f, -0.0f, 1.57079633f, CrossPlatformMachineEpsilon * 10)] // expected: (pi / 2) + [InlineData(1.0f, float.NaN, float.NaN, 0.0f)] + [InlineData(1.0f, 0.0f, 1.57079633f, CrossPlatformMachineEpsilon * 10)] // expected: (pi / 2) + [InlineData(1.0f, 1.0f, 0.785398163f, CrossPlatformMachineEpsilon)] // expected: (pi / 4) + [InlineData(1.0f, float.PositiveInfinity, 0.0f, 0.0f)] + [InlineData(float.PositiveInfinity, -1.0f, 1.57079633f, CrossPlatformMachineEpsilon * 10)] // expected: (pi / 2) + [InlineData(float.PositiveInfinity, -0.0f, 1.57079633f, CrossPlatformMachineEpsilon * 10)] // expected: (pi / 2) + [InlineData(float.PositiveInfinity, float.NaN, float.NaN, 0.0f)] + [InlineData(float.PositiveInfinity, 0.0f, 1.57079633f, CrossPlatformMachineEpsilon * 10)] // expected: (pi / 2) + [InlineData(float.PositiveInfinity, 1.0f, 1.57079633f, CrossPlatformMachineEpsilon * 10)] // expected: (pi / 2) + public static void Atan2(float y, float x, float expectedResult, float allowedVariance) + { + AssertExtensions.Equal(expectedResult, MathF.Atan2(y, x), allowedVariance); + } + + [Theory] + [InlineData(float.NegativeInfinity, float.NegativeInfinity, 0.0f)] + [InlineData(-3.14159265f, -3.0f, 0.0f)] // value: -(pi) + [InlineData(-2.71828183f, -2.0f, 0.0f)] // value: -(e) + [InlineData(-2.30258509f, -2.0f, 0.0f)] // value: -(ln(10)) + [InlineData(-1.57079633f, -1.0f, 0.0f)] // value: -(pi / 2) + [InlineData(-1.44269504f, -1.0f, 0.0f)] // value: -(log2(e)) + [InlineData(-1.41421356f, -1.0f, 0.0f)] // value: -(sqrt(2)) + [InlineData(-1.12837917f, -1.0f, 0.0f)] // value: -(2 / sqrt(pi)) + [InlineData(-1.0f, -1.0f, 0.0f)] +#if NETFRAMEWORK + [InlineData(-0.785398163f, 0.0f, 0.0f)] // value: (pi / 4) + [InlineData(-0.707106781f, 0.0f, 0.0f)] // value: (1 / sqrt(2)) + [InlineData(-0.693147181f, 0.0f, 0.0f)] // value: (ln(2)) + [InlineData(-0.636619772f, 0.0f, 0.0f)] // value: (2 / pi) + [InlineData(-0.434294482f, 0.0f, 0.0f)] // value: (log10(e)) + [InlineData(-0.318309886f, 0.0f, 0.0f)] // value: (1 / pi) + [InlineData(-0.0f, -0.0f, 0.0f)] +#else + [InlineData(-0.785398163f, -0.0f, 0.0f)] // value: (pi / 4) + [InlineData(-0.707106781f, -0.0f, 0.0f)] // value: (1 / sqrt(2)) + [InlineData(-0.693147181f, -0.0f, 0.0f)] // value: (ln(2)) + [InlineData(-0.636619772f, -0.0f, 0.0f)] // value: (2 / pi) + [InlineData(-0.434294482f, -0.0f, 0.0f)] // value: (log10(e)) + [InlineData(-0.318309886f, -0.0f, 0.0f)] // value: (1 / pi) + [InlineData(-0.0f, -0.0f, 0.0f)] +#endif + [InlineData(float.NaN, float.NaN, 0.0f)] + [InlineData(0.0f, 0.0f, 0.0f)] + [InlineData(0.318309886f, 1.0f, 0.0f)] // value: (1 / pi) + [InlineData(0.434294482f, 1.0f, 0.0f)] // value: (log10(e)) + [InlineData(0.636619772f, 1.0f, 0.0f)] // value: (2 / pi) + [InlineData(0.693147181f, 1.0f, 0.0f)] // value: (ln(2)) + [InlineData(0.707106781f, 1.0f, 0.0f)] // value: (1 / sqrt(2)) + [InlineData(0.785398163f, 1.0f, 0.0f)] // value: (pi / 4) + [InlineData(1.0f, 1.0f, 0.0f)] + [InlineData(1.12837917f, 2.0f, 0.0f)] // value: (2 / sqrt(pi)) + [InlineData(1.41421356f, 2.0f, 0.0f)] // value: (sqrt(2)) + [InlineData(1.44269504f, 2.0f, 0.0f)] // value: (log2(e)) + [InlineData(1.57079633f, 2.0f, 0.0f)] // value: (pi / 2) + [InlineData(2.30258509f, 3.0f, 0.0f)] // value: (ln(10)) + [InlineData(2.71828183f, 3.0f, 0.0f)] // value: (e) + [InlineData(3.14159265f, 4.0f, 0.0f)] // value: (pi) + [InlineData(float.PositiveInfinity, float.PositiveInfinity, 0.0f)] + public static void Ceiling(float value, float expectedResult, float allowedVariance) + { + AssertExtensions.Equal(expectedResult, MathF.Ceiling(value), allowedVariance); + } + + [Theory] + [InlineData(float.NegativeInfinity, float.NaN, 0.0f)] + [InlineData(-3.14159265f, -1.0f, CrossPlatformMachineEpsilon * 10)] // value: -(pi) + [InlineData(-2.71828183f, -0.911733918f, CrossPlatformMachineEpsilon)] // value: -(e) + [InlineData(-2.30258509f, -0.668201510f, CrossPlatformMachineEpsilon)] // value: -(ln(10)) + [InlineData(-1.57079633f, 0.0f, CrossPlatformMachineEpsilon)] // value: -(pi / 2) + [InlineData(-1.44269504f, 0.127751218f, CrossPlatformMachineEpsilon)] // value: -(log2(e)) + [InlineData(-1.41421356f, 0.155943695f, CrossPlatformMachineEpsilon)] // value: -(sqrt(2)) + [InlineData(-1.12837917f, 0.428125148f, CrossPlatformMachineEpsilon)] // value: -(2 / sqrt(pi)) + [InlineData(-1.0f, 0.540302306f, CrossPlatformMachineEpsilon)] + [InlineData(-0.785398163f, 0.707106781f, CrossPlatformMachineEpsilon)] // value: -(pi / 4), expected: (1 / sqrt(2)) + [InlineData(-0.707106781f, 0.760244597f, CrossPlatformMachineEpsilon)] // value: -(1 / sqrt(2)) + [InlineData(-0.693147181f, 0.769238901f, CrossPlatformMachineEpsilon)] // value: -(ln(2)) + [InlineData(-0.636619772f, 0.804109828f, CrossPlatformMachineEpsilon)] // value: -(2 / pi) + [InlineData(-0.434294482f, 0.907167129f, CrossPlatformMachineEpsilon)] // value: -(log10(e)) + [InlineData(-0.318309886f, 0.949765715f, CrossPlatformMachineEpsilon)] // value: -(1 / pi) + [InlineData(-0.0f, 1.0f, CrossPlatformMachineEpsilon * 10)] + [InlineData(float.NaN, float.NaN, 0.0f)] + [InlineData(0.0f, 1.0f, CrossPlatformMachineEpsilon * 10)] + [InlineData(0.318309886f, 0.949765715f, CrossPlatformMachineEpsilon)] // value: (1 / pi) + [InlineData(0.434294482f, 0.907167129f, CrossPlatformMachineEpsilon)] // value: (log10(e)) + [InlineData(0.636619772f, 0.804109828f, CrossPlatformMachineEpsilon)] // value: (2 / pi) + [InlineData(0.693147181f, 0.769238901f, CrossPlatformMachineEpsilon)] // value: (ln(2)) + [InlineData(0.707106781f, 0.760244597f, CrossPlatformMachineEpsilon)] // value: (1 / sqrt(2)) + [InlineData(0.785398163f, 0.707106781f, CrossPlatformMachineEpsilon)] // value: (pi / 4), expected: (1 / sqrt(2)) + [InlineData(1.0f, 0.540302306f, CrossPlatformMachineEpsilon)] + [InlineData(1.12837917f, 0.428125148f, CrossPlatformMachineEpsilon)] // value: (2 / sqrt(pi)) + [InlineData(1.41421356f, 0.155943695f, CrossPlatformMachineEpsilon)] // value: (sqrt(2)) + [InlineData(1.44269504f, 0.127751218f, CrossPlatformMachineEpsilon)] // value: (log2(e)) + [InlineData(1.57079633f, 0.0f, CrossPlatformMachineEpsilon)] // value: (pi / 2) + [InlineData(2.30258509f, -0.668201510f, CrossPlatformMachineEpsilon)] // value: (ln(10)) + [InlineData(2.71828183f, -0.911733918f, CrossPlatformMachineEpsilon)] // value: (e) + [InlineData(3.14159265f, -1.0f, CrossPlatformMachineEpsilon * 10)] // value: (pi) + [InlineData(float.PositiveInfinity, float.NaN, 0.0f)] + public static void Cos(float value, float expectedResult, float allowedVariance) + { + AssertExtensions.Equal(expectedResult, MathF.Cos(value), allowedVariance); + } + + [Theory] + [InlineData(float.NegativeInfinity, float.PositiveInfinity, 0.0f)] + [InlineData(-3.14159265f, 11.5919533f, CrossPlatformMachineEpsilon * 100)] // value: (pi) + [InlineData(-2.71828183f, 7.61012514f, CrossPlatformMachineEpsilon * 10)] // value: (e) + [InlineData(-2.30258509f, 5.05f, CrossPlatformMachineEpsilon * 10)] // value: (ln(10)) + [InlineData(-1.57079633f, 2.50917848f, CrossPlatformMachineEpsilon * 10)] // value: (pi / 2) + [InlineData(-1.44269504f, 2.23418810f, CrossPlatformMachineEpsilon * 10)] // value: (log2(e)) + [InlineData(-1.41421356f, 2.17818356f, CrossPlatformMachineEpsilon * 10)] // value: (sqrt(2)) + [InlineData(-1.12837917f, 1.70710014f, CrossPlatformMachineEpsilon * 10)] // value: (2 / sqrt(pi)) + [InlineData(-1.0f, 1.54308063f, CrossPlatformMachineEpsilon * 10)] + [InlineData(-0.785398163f, 1.32460909f, CrossPlatformMachineEpsilon * 10)] // value: (pi / 4) + [InlineData(-0.707106781f, 1.26059184f, CrossPlatformMachineEpsilon * 10)] // value: (1 / sqrt(2)) + [InlineData(-0.693147181f, 1.25f, CrossPlatformMachineEpsilon * 10)] // value: (ln(2)) + [InlineData(-0.636619772f, 1.20957949f, CrossPlatformMachineEpsilon * 10)] // value: (2 / pi) + [InlineData(-0.434294482f, 1.09579746f, CrossPlatformMachineEpsilon * 10)] // value: (log10(e)) + [InlineData(-0.318309886f, 1.05108979f, CrossPlatformMachineEpsilon * 10)] // value: (1 / pi) + [InlineData(-0.0f, 1.0f, CrossPlatformMachineEpsilon * 10)] + [InlineData(float.NaN, float.NaN, 0.0f)] + [InlineData(0.0f, 1.0f, CrossPlatformMachineEpsilon * 10)] + [InlineData(0.318309886f, 1.05108979f, CrossPlatformMachineEpsilon * 10)] // value: (1 / pi) + [InlineData(0.434294482f, 1.09579746f, CrossPlatformMachineEpsilon * 10)] // value: (log10(e)) + [InlineData(0.636619772f, 1.20957949f, CrossPlatformMachineEpsilon * 10)] // value: (2 / pi) + [InlineData(0.693147181f, 1.25f, CrossPlatformMachineEpsilon * 10)] // value: (ln(2)) + [InlineData(0.707106781f, 1.26059184f, CrossPlatformMachineEpsilon * 10)] // value: (1 / sqrt(2)) + [InlineData(0.785398163f, 1.32460909f, CrossPlatformMachineEpsilon * 10)] // value: (pi / 4) + [InlineData(1.0f, 1.54308063f, CrossPlatformMachineEpsilon * 10)] + [InlineData(1.12837917f, 1.70710014f, CrossPlatformMachineEpsilon * 10)] // value: (2 / sqrt(pi)) + [InlineData(1.41421356f, 2.17818356f, CrossPlatformMachineEpsilon * 10)] // value: (sqrt(2)) + [InlineData(1.44269504f, 2.23418810f, CrossPlatformMachineEpsilon * 10)] // value: (log2(e)) + [InlineData(1.57079633f, 2.50917848f, CrossPlatformMachineEpsilon * 10)] // value: (pi / 2) + [InlineData(2.30258509f, 5.05f, CrossPlatformMachineEpsilon * 10)] // value: (ln(10)) + [InlineData(2.71828183f, 7.61012514f, CrossPlatformMachineEpsilon * 10)] // value: (e) + [InlineData(3.14159265f, 11.5919533f, CrossPlatformMachineEpsilon * 100)] // value: (pi) + [InlineData(float.PositiveInfinity, float.PositiveInfinity, 0.0f)] + public static void Cosh(float value, float expectedResult, float allowedVariance) + { + AssertExtensions.Equal(expectedResult, MathF.Cosh(value), allowedVariance); + } + + [Theory] + [InlineData(float.NegativeInfinity, 0.0f, CrossPlatformMachineEpsilon)] + [InlineData(-3.14159265f, 0.0432139183f, CrossPlatformMachineEpsilon / 10)] // value: -(pi) + [InlineData(-2.71828183f, 0.0659880358f, CrossPlatformMachineEpsilon / 10)] // value: -(e) + [InlineData(-2.30258509f, 0.1f, CrossPlatformMachineEpsilon)] // value: -(ln(10)) + [InlineData(-1.57079633f, 0.207879576f, CrossPlatformMachineEpsilon)] // value: -(pi / 2) + [InlineData(-1.44269504f, 0.236290088f, CrossPlatformMachineEpsilon)] // value: -(log2(e)) + [InlineData(-1.41421356f, 0.243116734f, CrossPlatformMachineEpsilon)] // value: -(sqrt(2)) + [InlineData(-1.12837917f, 0.323557264f, CrossPlatformMachineEpsilon)] // value: -(2 / sqrt(pi)) + [InlineData(-1.0f, 0.367879441f, CrossPlatformMachineEpsilon)] + [InlineData(-0.785398163f, 0.455938128f, CrossPlatformMachineEpsilon)] // value: -(pi / 4) + [InlineData(-0.707106781f, 0.493068691f, CrossPlatformMachineEpsilon)] // value: -(1 / sqrt(2)) + [InlineData(-0.693147181f, 0.5f, CrossPlatformMachineEpsilon)] // value: -(ln(2)) + [InlineData(-0.636619772f, 0.529077808f, CrossPlatformMachineEpsilon)] // value: -(2 / pi) + [InlineData(-0.434294482f, 0.647721485f, CrossPlatformMachineEpsilon)] // value: -(log10(e)) + [InlineData(-0.318309886f, 0.727377349f, CrossPlatformMachineEpsilon)] // value: -(1 / pi) + [InlineData(-0.0f, 1.0f, CrossPlatformMachineEpsilon * 10)] + [InlineData(float.NaN, float.NaN, 0.0f)] + [InlineData(0.0f, 1.0f, CrossPlatformMachineEpsilon * 10)] + [InlineData(0.318309886f, 1.37480223f, CrossPlatformMachineEpsilon * 10)] // value: (1 / pi) + [InlineData(0.434294482f, 1.54387344f, CrossPlatformMachineEpsilon * 10)] // value: (log10(e)) + [InlineData(0.636619772f, 1.89008116f, CrossPlatformMachineEpsilon * 10)] // value: (2 / pi) + [InlineData(0.693147181f, 2.0f, CrossPlatformMachineEpsilon * 10)] // value: (ln(2)) + [InlineData(0.707106781f, 2.02811498f, CrossPlatformMachineEpsilon * 10)] // value: (1 / sqrt(2)) + [InlineData(0.785398163f, 2.19328005f, CrossPlatformMachineEpsilon * 10)] // value: (pi / 4) + [InlineData(1.0f, 2.71828183f, CrossPlatformMachineEpsilon * 10)] // expected: (e) + [InlineData(1.12837917f, 3.09064302f, CrossPlatformMachineEpsilon * 10)] // value: (2 / sqrt(pi)) + [InlineData(1.41421356f, 4.11325038f, CrossPlatformMachineEpsilon * 10)] // value: (sqrt(2)) + [InlineData(1.44269504f, 4.23208611f, CrossPlatformMachineEpsilon * 10)] // value: (log2(e)) + [InlineData(1.57079633f, 4.81047738f, CrossPlatformMachineEpsilon * 10)] // value: (pi / 2) + [InlineData(2.30258509f, 10.0f, CrossPlatformMachineEpsilon * 100)] // value: (ln(10)) + [InlineData(2.71828183f, 15.1542622f, CrossPlatformMachineEpsilon * 100)] // value: (e) + [InlineData(3.14159265f, 23.1406926f, CrossPlatformMachineEpsilon * 100)] // value: (pi) + [InlineData(float.PositiveInfinity, float.PositiveInfinity, 0.0f)] + public static void Exp(float value, float expectedResult, float allowedVariance) + { + AssertExtensions.Equal(expectedResult, MathF.Exp(value), allowedVariance); + } + + [Theory] + [InlineData(float.NegativeInfinity, float.NegativeInfinity, 0.0f)] + [InlineData(-3.14159265f, -4.0f, 0.0f)] // value: -(pi) + [InlineData(-2.71828183f, -3.0f, 0.0f)] // value: -(e) + [InlineData(-2.30258509f, -3.0f, 0.0f)] // value: -(ln(10)) + [InlineData(-1.57079633f, -2.0f, 0.0f)] // value: -(pi / 2) + [InlineData(-1.44269504f, -2.0f, 0.0f)] // value: -(log2(e)) + [InlineData(-1.41421356f, -2.0f, 0.0f)] // value: -(sqrt(2)) + [InlineData(-1.12837917f, -2.0f, 0.0f)] // value: -(2 / sqrt(pi)) + [InlineData(-1.0f, -1.0f, 0.0f)] + [InlineData(-0.785398163f, -1.0f, 0.0f)] // value: -(pi / 4) + [InlineData(-0.707106781f, -1.0f, 0.0f)] // value: -(1 / sqrt(2)) + [InlineData(-0.693147181f, -1.0f, 0.0f)] // value: -(ln(2)) + [InlineData(-0.636619772f, -1.0f, 0.0f)] // value: -(2 / pi) + [InlineData(-0.434294482f, -1.0f, 0.0f)] // value: -(log10(e)) + [InlineData(-0.318309886f, -1.0f, 0.0f)] // value: -(1 / pi) + [InlineData(-0.0f, -0.0f, 0.0f)] + [InlineData(float.NaN, float.NaN, 0.0f)] + [InlineData(0.0f, 0.0f, 0.0f)] + [InlineData(0.318309886f, 0.0f, 0.0f)] // value: (1 / pi) + [InlineData(0.434294482f, 0.0f, 0.0f)] // value: (log10(e)) + [InlineData(0.636619772f, 0.0f, 0.0f)] // value: (2 / pi) + [InlineData(0.693147181f, 0.0f, 0.0f)] // value: (ln(2)) + [InlineData(0.707106781f, 0.0f, 0.0f)] // value: (1 / sqrt(2)) + [InlineData(0.785398163f, 0.0f, 0.0f)] // value: (pi / 4) + [InlineData(1.0f, 1.0f, 0.0f)] + [InlineData(1.12837917f, 1.0f, 0.0f)] // value: (2 / sqrt(pi)) + [InlineData(1.41421356f, 1.0f, 0.0f)] // value: (sqrt(2)) + [InlineData(1.44269504f, 1.0f, 0.0f)] // value: (log2(e)) + [InlineData(1.57079633f, 1.0f, 0.0f)] // value: (pi / 2) + [InlineData(2.30258509f, 2.0f, 0.0f)] // value: (ln(10)) + [InlineData(2.71828183f, 2.0f, 0.0f)] // value: (e) + [InlineData(3.14159265f, 3.0f, 0.0f)] // value: (pi) + [InlineData(float.PositiveInfinity, float.PositiveInfinity, 0.0f)] + public static void Floor(float value, float expectedResult, float allowedVariance) + { + AssertExtensions.Equal(expectedResult, MathF.Floor(value), allowedVariance); + } + + [Fact] + public static void IEEERemainder() + { + Assert.Equal(-1.0f, MathF.IEEERemainder(3.0f, 2.0f)); + Assert.Equal(0.0f, MathF.IEEERemainder(4.0f, 2.0f)); + Assert.Equal(1.0f, MathF.IEEERemainder(10.0f, 3.0f)); + Assert.Equal(-1.0f, MathF.IEEERemainder(11.0f, 3.0f)); + Assert.Equal(-2.0f, MathF.IEEERemainder(28.0f, 5.0f)); + AssertExtensions.Equal(1.8f, MathF.IEEERemainder(17.8f, 4.0f), CrossPlatformMachineEpsilon * 10); + AssertExtensions.Equal(1.4f, MathF.IEEERemainder(17.8f, 4.1f), CrossPlatformMachineEpsilon * 10); + AssertExtensions.Equal(0.1000004f, MathF.IEEERemainder(-16.3f, 4.1f), CrossPlatformMachineEpsilon / 10); + AssertExtensions.Equal(1.4f, MathF.IEEERemainder(17.8f, -4.1f), CrossPlatformMachineEpsilon * 10); + AssertExtensions.Equal(-1.4f, MathF.IEEERemainder(-17.8f, -4.1f), CrossPlatformMachineEpsilon * 10); + } + + [Theory] + [InlineData(float.NegativeInfinity, float.NaN, 0.0f)] + [InlineData(-3.14159265f, float.NaN, 0.0f)] // value: -(pi) + [InlineData(-2.71828183f, float.NaN, 0.0f)] // value: -(e) + [InlineData(-1.41421356f, float.NaN, 0.0f)] // value: -(sqrt(2)) + [InlineData(-1.0f, float.NaN, 0.0f)] + [InlineData(-0.693147181f, float.NaN, 0.0f)] // value: -(ln(2)) + [InlineData(-0.434294482f, float.NaN, 0.0f)] // value: -(log10(e)) + [InlineData(-0.0f, float.NegativeInfinity, 0.0f)] + [InlineData(float.NaN, float.NaN, 0.0f)] + [InlineData(0.0f, float.NegativeInfinity, 0.0f)] + [InlineData(0.0432139183f, -3.14159265f, CrossPlatformMachineEpsilon * 10)] // expected: -(pi) + [InlineData(0.0659880358f, -2.71828183f, CrossPlatformMachineEpsilon * 10)] // expected: -(e) + [InlineData(0.1f, -2.30258509f, CrossPlatformMachineEpsilon * 10)] // expected: -(ln(10)) + [InlineData(0.207879576f, -1.57079633f, CrossPlatformMachineEpsilon * 10)] // expected: -(pi / 2) + [InlineData(0.236290088f, -1.44269504f, CrossPlatformMachineEpsilon * 10)] // expected: -(log2(e)) + [InlineData(0.243116734f, -1.41421356f, CrossPlatformMachineEpsilon * 10)] // expected: -(sqrt(2)) + [InlineData(0.323557264f, -1.12837917f, CrossPlatformMachineEpsilon * 10)] // expected: -(2 / sqrt(pi)) + [InlineData(0.367879441f, -1.0f, CrossPlatformMachineEpsilon * 10)] + [InlineData(0.455938128f, -0.785398163f, CrossPlatformMachineEpsilon)] // expected: -(pi / 4) + [InlineData(0.493068691f, -0.707106781f, CrossPlatformMachineEpsilon)] // expected: -(1 / sqrt(2)) + [InlineData(0.5f, -0.693147181f, CrossPlatformMachineEpsilon)] // expected: -(ln(2)) + [InlineData(0.529077808f, -0.636619772f, CrossPlatformMachineEpsilon)] // expected: -(2 / pi) + [InlineData(0.647721485f, -0.434294482f, CrossPlatformMachineEpsilon)] // expected: -(log10(e)) + [InlineData(0.727377349f, -0.318309886f, CrossPlatformMachineEpsilon)] // expected: -(1 / pi) + [InlineData(1.0f, 0.0f, 0.0f)] + [InlineData(1.37480223f, 0.318309886f, CrossPlatformMachineEpsilon)] // expected: (1 / pi) + [InlineData(1.54387344f, 0.434294482f, CrossPlatformMachineEpsilon)] // expected: (log10(e)) + [InlineData(1.89008116f, 0.636619772f, CrossPlatformMachineEpsilon)] // expected: (2 / pi) + [InlineData(2.0f, 0.693147181f, CrossPlatformMachineEpsilon)] // expected: (ln(2)) + [InlineData(2.02811498f, 0.707106781f, CrossPlatformMachineEpsilon)] // expected: (1 / sqrt(2)) + [InlineData(2.19328005f, 0.785398163f, CrossPlatformMachineEpsilon)] // expected: (pi / 4) + [InlineData(2.71828183f, 1.0f, CrossPlatformMachineEpsilon * 10)] // value: (e) + [InlineData(3.09064302f, 1.12837917f, CrossPlatformMachineEpsilon * 10)] // expected: (2 / sqrt(pi)) + [InlineData(4.11325038f, 1.41421356f, CrossPlatformMachineEpsilon * 10)] // expected: (sqrt(2)) + [InlineData(4.23208611f, 1.44269504f, CrossPlatformMachineEpsilon * 10)] // expected: (log2(e)) + [InlineData(4.81047738f, 1.57079633f, CrossPlatformMachineEpsilon * 10)] // expected: (pi / 2) + [InlineData(10.0f, 2.30258509f, CrossPlatformMachineEpsilon * 10)] // expected: (ln(10)) + [InlineData(15.1542622f, 2.71828183f, CrossPlatformMachineEpsilon * 10)] // expected: (e) + [InlineData(23.1406926f, 3.14159265f, CrossPlatformMachineEpsilon * 10)] // expected: (pi) + [InlineData(float.PositiveInfinity, float.PositiveInfinity, 0.0f)] + public static void Log(float value, float expectedResult, float allowedVariance) + { + AssertExtensions.Equal(expectedResult, MathF.Log(value), allowedVariance); + } + + [Fact] + public static void LogWithBase() + { + Assert.Equal(1.0f, MathF.Log(3.0f, 3.0f)); + AssertExtensions.Equal(2.40217350f, MathF.Log(14.0f, 3.0f), CrossPlatformMachineEpsilon * 10); + Assert.Equal(float.NegativeInfinity, MathF.Log(0.0f, 3.0f)); + Assert.Equal(float.NaN, MathF.Log(-3.0f, 3.0f)); + Assert.Equal(float.NaN, MathF.Log(float.NaN, 3.0f)); + Assert.Equal(float.PositiveInfinity, MathF.Log(float.PositiveInfinity, 3.0f)); + Assert.Equal(float.NaN, MathF.Log(float.NegativeInfinity, 3.0f)); + } + + [Theory] + [InlineData(float.NegativeInfinity, float.NaN, 0.0f)] + [InlineData(-3.14159265f, float.NaN, 0.0f)] // value: -(pi) + [InlineData(-2.71828183f, float.NaN, 0.0f)] // value: -(e) + [InlineData(-1.41421356f, float.NaN, 0.0f)] // value: -(sqrt(2)) + [InlineData(-1.0f, float.NaN, 0.0f)] + [InlineData(-0.693147181f, float.NaN, 0.0f)] // value: -(ln(2)) + [InlineData(-0.434294482f, float.NaN, 0.0f)] // value: -(log10(e)) + [InlineData(-0.0f, float.NegativeInfinity, 0.0f)] + [InlineData(float.NaN, float.NaN, 0.0f)] + [InlineData(0.0f, float.NegativeInfinity, 0.0f)] + [InlineData(0.000721784159f, -3.14159265f, CrossPlatformMachineEpsilon * 10)] // expected: -(pi) + [InlineData(0.00191301410f, -2.71828183f, CrossPlatformMachineEpsilon * 10)] // expected: -(e) + [InlineData(0.00498212830f, -2.30258509f, CrossPlatformMachineEpsilon * 10)] // expected: -(ln(10)) + [InlineData(0.0268660410f, -1.57079633f, CrossPlatformMachineEpsilon * 10)] // expected: -(pi / 2) + [InlineData(0.0360831928f, -1.44269504f, CrossPlatformMachineEpsilon * 10)] // expected: -(log2(e)) + [InlineData(0.0385288847f, -1.41421356f, CrossPlatformMachineEpsilon * 10)] // expected: -(sqrt(2)) + [InlineData(0.0744082059f, -1.12837917f, CrossPlatformMachineEpsilon * 10)] // expected: -(2 / sqrt(pi)) + [InlineData(0.1f, -1.0f, CrossPlatformMachineEpsilon * 10)] + [InlineData(0.163908636f, -0.785398163f, CrossPlatformMachineEpsilon)] // expected: -(pi / 4) + [InlineData(0.196287760f, -0.707106781f, CrossPlatformMachineEpsilon)] // expected: -(1 / sqrt(2)) + [InlineData(0.202699566f, -0.693147181f, CrossPlatformMachineEpsilon)] // expected: -(ln(2)) + [InlineData(0.230876765f, -0.636619772f, CrossPlatformMachineEpsilon)] // expected: -(2 / pi) + [InlineData(0.367879441f, -0.434294482f, CrossPlatformMachineEpsilon)] // expected: -(log10(e)) + [InlineData(0.480496373f, -0.318309886f, CrossPlatformMachineEpsilon)] // expected: -(1 / pi) + [InlineData(1.0f, 0.0f, 0.0f)] + [InlineData(2.08118116f, 0.318309886f, CrossPlatformMachineEpsilon)] // expected: (1 / pi) + [InlineData(2.71828183f, 0.434294482f, CrossPlatformMachineEpsilon)] // expected: (log10(e)) value: (e) + [InlineData(4.33131503f, 0.636619772f, CrossPlatformMachineEpsilon)] // expected: (2 / pi) + [InlineData(4.93340967f, 0.693147181f, CrossPlatformMachineEpsilon)] // expected: (ln(2)) + [InlineData(5.09456117f, 0.707106781f, CrossPlatformMachineEpsilon)] // expected: (1 / sqrt(2)) + [InlineData(6.10095980f, 0.785398163f, CrossPlatformMachineEpsilon)] // expected: (pi / 4) + [InlineData(10.0f, 1.0f, CrossPlatformMachineEpsilon * 10)] + [InlineData(13.4393779f, 1.12837917f, CrossPlatformMachineEpsilon * 10)] // expected: (2 / sqrt(pi)) + [InlineData(25.9545535f, 1.41421356f, CrossPlatformMachineEpsilon * 10)] // expected: (sqrt(2)) + [InlineData(27.7137338f, 1.44269504f, CrossPlatformMachineEpsilon * 10)] // expected: (log2(e)) + [InlineData(37.2217105f, 1.57079633f, CrossPlatformMachineEpsilon * 10)] // expected: (pi / 2) + [InlineData(200.717432f, 2.30258509f, CrossPlatformMachineEpsilon * 10)] // expected: (ln(10)) + [InlineData(522.735300f, 2.71828183f, CrossPlatformMachineEpsilon * 10)] // expected: (e) + [InlineData(1385.45573f, 3.14159265f, CrossPlatformMachineEpsilon * 10)] // expected: (pi) + [InlineData(float.PositiveInfinity, float.PositiveInfinity, 0.0f)] + public static void Log10(float value, float expectedResult, float allowedVariance) + { + AssertExtensions.Equal(expectedResult, MathF.Log10(value), allowedVariance); + } + + [Theory] + [InlineData(float.NegativeInfinity, float.PositiveInfinity, float.PositiveInfinity)] + [InlineData(float.PositiveInfinity, float.NegativeInfinity, float.PositiveInfinity)] + [InlineData(float.MinValue, float.MaxValue, float.MaxValue)] + [InlineData(float.MaxValue, float.MinValue, float.MaxValue)] + [InlineData(float.NaN, float.NaN, float.NaN)] + [InlineData(float.NaN, 1.0f, float.NaN)] + [InlineData(1.0f, float.NaN, float.NaN)] + [InlineData(float.PositiveInfinity, float.NaN, float.NaN)] + [InlineData(float.NegativeInfinity, float.NaN, float.NaN)] + [InlineData(float.NaN, float.PositiveInfinity, float.NaN)] + [InlineData(float.NaN, float.NegativeInfinity, float.NaN)] + [InlineData(-0.0f, 0.0f, 0.0f)] +#if NETFRAMEWORK + [InlineData(0.0f, -0.0f, -0.0f)] +#else + [InlineData(0.0f, -0.0f, 0.0f)] +#endif + [InlineData(2.0f, -3.0f, 2.0f)] + [InlineData(-3.0f, 2.0f, 2.0f)] + [InlineData(3.0f, -2.0f, 3.0f)] + [InlineData(-2.0f, 3.0f, 3.0f)] + public static void Max(float x, float y, float expectedResult) + { + AssertExtensions.Equal(expectedResult, MathF.Max(x, y), 0.0f); + } + + [Theory] + [InlineData(float.NegativeInfinity, float.PositiveInfinity, float.NegativeInfinity)] + [InlineData(float.PositiveInfinity, float.NegativeInfinity, float.NegativeInfinity)] + [InlineData(float.MinValue, float.MaxValue, float.MinValue)] + [InlineData(float.MaxValue, float.MinValue, float.MinValue)] + [InlineData(float.NaN, float.NaN, float.NaN)] + [InlineData(float.NaN, 1.0f, float.NaN)] + [InlineData(1.0f, float.NaN, float.NaN)] + [InlineData(float.PositiveInfinity, float.NaN, float.NaN)] + [InlineData(float.NegativeInfinity, float.NaN, float.NaN)] + [InlineData(float.NaN, float.PositiveInfinity, float.NaN)] + [InlineData(float.NaN, float.NegativeInfinity, float.NaN)] +#if NETFRAMEWORK + [InlineData(-0.0f, 0.0f, 0.0f)] +#else + [InlineData(-0.0f, 0.0f, -0.0f)] +#endif + [InlineData(0.0f, -0.0f, -0.0f)] + [InlineData(2.0f, -3.0f, -3.0f)] + [InlineData(-3.0f, 2.0f, -3.0f)] + [InlineData(3.0f, -2.0f, -2.0f)] + [InlineData(-2.0f, 3.0f, -2.0f)] + public static void Min(float x, float y, float expectedResult) + { + AssertExtensions.Equal(expectedResult, MathF.Min(x, y), 0.0f); + } + + + [Theory] + [InlineData(float.NegativeInfinity, float.NegativeInfinity, 0.0f, 0.0f)] + [InlineData(float.NegativeInfinity, -1.0f, -0.0f, 0.0f)] + [InlineData(float.NegativeInfinity, -0.0f, 1.0f, CrossPlatformMachineEpsilon * 10)] + [InlineData(float.NegativeInfinity, float.NaN, float.NaN, 0.0f)] + [InlineData(float.NegativeInfinity, 0.0f, 1.0f, CrossPlatformMachineEpsilon * 10)] + [InlineData(float.NegativeInfinity, 1.0f, float.NegativeInfinity, 0.0f)] + [InlineData(float.NegativeInfinity, float.PositiveInfinity, float.PositiveInfinity, 0.0f)] + [InlineData(-10.0f, float.NegativeInfinity, 0.0f, 0.0f)] + [InlineData(-10.0f, -1.57079633f, float.NaN, 0.0f)] // y: -(pi / 2) + [InlineData(-10.0f, -1.0f, -0.1f, CrossPlatformMachineEpsilon)] + [InlineData(-10.0f, -0.785398163f, float.NaN, 0.0f)] // y: -(pi / 4) + [InlineData(-10.0f, -0.0f, 1.0f, CrossPlatformMachineEpsilon * 10)] + [InlineData(-10.0f, float.NaN, float.NaN, 0.0f)] + [InlineData(-10.0f, 0.0f, 1.0f, CrossPlatformMachineEpsilon * 10)] + [InlineData(-10.0f, 0.785398163f, float.NaN, 0.0f)] // y: (pi / 4) + [InlineData(-10.0f, 1.0f, -10.0f, CrossPlatformMachineEpsilon * 100)] + [InlineData(-10.0f, 1.57079633f, float.NaN, 0.0f)] // y: (pi / 2) + [InlineData(-10.0f, float.PositiveInfinity, float.PositiveInfinity, 0.0f)] + [InlineData(-2.71828183f, float.NegativeInfinity, 0.0f, 0.0f)] // x: -(e) + [InlineData(-2.71828183f, -1.57079633f, float.NaN, 0.0f)] // x: -(e) y: -(pi / 2) + [InlineData(-2.71828183f, -1.0f, -0.367879441f, CrossPlatformMachineEpsilon)] // x: -(e) + [InlineData(-2.71828183f, -0.785398163f, float.NaN, 0.0f)] // x: -(e) y: -(pi / 4) + [InlineData(-2.71828183f, -0.0f, 1.0f, CrossPlatformMachineEpsilon * 10)] // x: -(e) + [InlineData(-2.71828183f, float.NaN, float.NaN, 0.0f)] + [InlineData(-2.71828183f, 0.0f, 1.0f, CrossPlatformMachineEpsilon * 10)] // x: -(e) + [InlineData(-2.71828183f, 0.785398163f, float.NaN, 0.0f)] // x: -(e) y: (pi / 4) + [InlineData(-2.71828183f, 1.0f, -2.71828183f, CrossPlatformMachineEpsilon * 10)] // x: -(e) expected: (e) + [InlineData(-2.71828183f, 1.57079633f, float.NaN, 0.0f)] // x: -(e) y: (pi / 2) + [InlineData(-2.71828183f, float.PositiveInfinity, float.PositiveInfinity, 0.0f)] + [InlineData(-1.0f, -1.0f, -1.0f, CrossPlatformMachineEpsilon * 10)] + [InlineData(-1.0f, -0.0f, 1.0f, CrossPlatformMachineEpsilon * 10)] + [InlineData(-1.0f, float.NaN, float.NaN, 0.0f)] + [InlineData(-1.0f, 0.0f, 1.0f, CrossPlatformMachineEpsilon * 10)] + [InlineData(-1.0f, 1.0f, -1.0f, CrossPlatformMachineEpsilon * 10)] + [InlineData(-0.0f, float.NegativeInfinity, float.PositiveInfinity, 0.0f)] + [InlineData(-0.0f, -3.0f, float.NegativeInfinity, 0.0f)] + [InlineData(-0.0f, -2.0f, float.PositiveInfinity, 0.0f)] + [InlineData(-0.0f, -1.57079633f, float.PositiveInfinity, 0.0f)] // y: -(pi / 2) + [InlineData(-0.0f, -1.0f, float.NegativeInfinity, 0.0f)] + [InlineData(-0.0f, -0.0f, 1.0f, CrossPlatformMachineEpsilon * 10)] + [InlineData(-0.0f, float.NaN, float.NaN, 0.0f)] + [InlineData(-0.0f, 0.0f, 1.0f, CrossPlatformMachineEpsilon * 10)] + [InlineData(-0.0f, 1.0f, -0.0f, 0.0f)] + [InlineData(-0.0f, 1.57079633f, 0.0f, 0.0f)] // y: -(pi / 2) + [InlineData(-0.0f, 2.0f, 0.0f, 0.0f)] + [InlineData(-0.0f, 3.0f, -0.0f, 0.0f)] + [InlineData(-0.0f, float.PositiveInfinity, 0.0f, 0.0f)] + [InlineData(float.NaN, float.NegativeInfinity, float.NaN, 0.0f)] + [InlineData(float.NaN, -1.0f, float.NaN, 0.0f)] + [InlineData(float.NaN, float.NaN, float.NaN, 0.0f)] + [InlineData(float.NaN, 1.0f, float.NaN, 0.0f)] + [InlineData(float.NaN, float.PositiveInfinity, float.NaN, 0.0f)] + [InlineData(0.0f, float.NegativeInfinity, float.PositiveInfinity, 0.0f)] + [InlineData(0.0f, -3.0f, float.PositiveInfinity, 0.0f)] + [InlineData(0.0f, -2.0f, float.PositiveInfinity, 0.0f)] + [InlineData(0.0f, -1.57079633f, float.PositiveInfinity, 0.0f)] // y: -(pi / 2) + [InlineData(0.0f, -1.0f, float.PositiveInfinity, 0.0f)] + [InlineData(0.0f, -0.0f, 1.0f, CrossPlatformMachineEpsilon * 10)] + [InlineData(0.0f, float.NaN, float.NaN, 0.0f)] + [InlineData(0.0f, 0.0f, 1.0f, CrossPlatformMachineEpsilon * 10)] + [InlineData(0.0f, 1.0f, 0.0f, 0.0f)] + [InlineData(0.0f, 1.57079633f, 0.0f, 0.0f)] // y: -(pi / 2) + [InlineData(0.0f, 2.0f, 0.0f, 0.0f)] + [InlineData(0.0f, 3.0f, 0.0f, 0.0f)] + [InlineData(0.0f, float.PositiveInfinity, 0.0f, 0.0f)] + [InlineData(1.0f, float.NegativeInfinity, 1.0f, CrossPlatformMachineEpsilon * 10)] + [InlineData(1.0f, -1.0f, 1.0f, CrossPlatformMachineEpsilon * 10)] + [InlineData(1.0f, -0.0f, 1.0f, CrossPlatformMachineEpsilon * 10)] + [InlineData(1.0f, 0.0f, 1.0f, CrossPlatformMachineEpsilon * 10)] + [InlineData(1.0f, 1.0f, 1.0f, CrossPlatformMachineEpsilon * 10)] + [InlineData(1.0f, float.PositiveInfinity, 1.0f, CrossPlatformMachineEpsilon * 10)] + [InlineData(2.71828183f, float.NegativeInfinity, 0.0f, 0.0f)] + [InlineData(2.71828183f, -3.14159265f, 0.0432139183f, CrossPlatformMachineEpsilon / 10)] // x: (e) y: -(pi) + [InlineData(2.71828183f, -2.71828183f, 0.0659880358f, CrossPlatformMachineEpsilon / 10)] // x: (e) y: -(e) + [InlineData(2.71828183f, -2.30258509f, 0.1f, CrossPlatformMachineEpsilon)] // x: (e) y: -(ln(10)) + [InlineData(2.71828183f, -1.57079633f, 0.207879576f, CrossPlatformMachineEpsilon)] // x: (e) y: -(pi / 2) + [InlineData(2.71828183f, -1.44269504f, 0.236290088f, CrossPlatformMachineEpsilon)] // x: (e) y: -(log2(e)) + [InlineData(2.71828183f, -1.41421356f, 0.243116734f, CrossPlatformMachineEpsilon)] // x: (e) y: -(sqrt(2)) + [InlineData(2.71828183f, -1.12837917f, 0.323557264f, CrossPlatformMachineEpsilon)] // x: (e) y: -(2 / sqrt(pi)) + [InlineData(2.71828183f, -1.0f, 0.367879441f, CrossPlatformMachineEpsilon)] // x: (e) + [InlineData(2.71828183f, -0.785398163f, 0.455938128f, CrossPlatformMachineEpsilon)] // x: (e) y: -(pi / 4) + [InlineData(2.71828183f, -0.707106781f, 0.493068691f, CrossPlatformMachineEpsilon)] // x: (e) y: -(1 / sqrt(2)) + [InlineData(2.71828183f, -0.693147181f, 0.5f, CrossPlatformMachineEpsilon)] // x: (e) y: -(ln(2)) + [InlineData(2.71828183f, -0.636619772f, 0.529077808f, CrossPlatformMachineEpsilon)] // x: (e) y: -(2 / pi) + [InlineData(2.71828183f, -0.434294482f, 0.647721485f, CrossPlatformMachineEpsilon)] // x: (e) y: -(log10(e)) + [InlineData(2.71828183f, -0.318309886f, 0.727377349f, CrossPlatformMachineEpsilon)] // x: (e) y: -(1 / pi) + [InlineData(2.71828183f, -0.0f, 1.0f, CrossPlatformMachineEpsilon * 10)] // x: (e) + [InlineData(2.71828183f, float.NaN, float.NaN, 0.0f)] + [InlineData(2.71828183f, 0.0f, 1.0f, CrossPlatformMachineEpsilon * 10)] // x: (e) + [InlineData(2.71828183f, 0.318309886f, 1.37480223f, CrossPlatformMachineEpsilon * 10)] // x: (e) y: (1 / pi) + [InlineData(2.71828183f, 0.434294482f, 1.54387344f, CrossPlatformMachineEpsilon * 10)] // x: (e) y: (log10(e)) + [InlineData(2.71828183f, 0.636619772f, 1.89008116f, CrossPlatformMachineEpsilon * 10)] // x: (e) y: (2 / pi) + [InlineData(2.71828183f, 0.693147181f, 2.0f, CrossPlatformMachineEpsilon * 10)] // x: (e) y: (ln(2)) + [InlineData(2.71828183f, 0.707106781f, 2.02811498f, CrossPlatformMachineEpsilon * 10)] // x: (e) y: (1 / sqrt(2)) + [InlineData(2.71828183f, 0.785398163f, 2.19328005f, CrossPlatformMachineEpsilon * 10)] // x: (e) y: (pi / 4) + [InlineData(2.71828183f, 1.0f, 2.71828183f, CrossPlatformMachineEpsilon * 10)] // x: (e) expected: (e) + [InlineData(2.71828183f, 1.12837917f, 3.09064302f, CrossPlatformMachineEpsilon * 10)] // x: (e) y: (2 / sqrt(pi)) + [InlineData(2.71828183f, 1.41421356f, 4.11325038f, CrossPlatformMachineEpsilon * 10)] // x: (e) y: (sqrt(2)) + [InlineData(2.71828183f, 1.44269504f, 4.23208611f, CrossPlatformMachineEpsilon * 10)] // x: (e) y: (log2(e)) + [InlineData(2.71828183f, 1.57079633f, 4.81047738f, CrossPlatformMachineEpsilon * 10)] // x: (e) y: (pi / 2) + [InlineData(2.71828183f, 2.30258509f, 10.0f, CrossPlatformMachineEpsilon * 100)] // x: (e) y: (ln(10)) + [InlineData(2.71828183f, 2.71828183f, 15.1542622f, CrossPlatformMachineEpsilon * 100)] // x: (e) y: (e) + [InlineData(2.71828183f, 3.14159265f, 23.1406926f, CrossPlatformMachineEpsilon * 100)] // x: (e) y: (pi) + [InlineData(2.71828183f, float.PositiveInfinity, float.PositiveInfinity, 0.0f)] // x: (e) + [InlineData(10.0f, float.NegativeInfinity, 0.0f, 0.0f)] + [InlineData(10.0f, -3.14159265f, 0.000721784159f, CrossPlatformMachineEpsilon / 1000)] // y: -(pi) + [InlineData(10.0f, -2.71828183f, 0.00191301410f, CrossPlatformMachineEpsilon / 100)] // y: -(e) + [InlineData(10.0f, -2.30258509f, 0.00498212830f, CrossPlatformMachineEpsilon / 100)] // y: -(ln(10)) + [InlineData(10.0f, -1.57079633f, 0.0268660410f, CrossPlatformMachineEpsilon / 10)] // y: -(pi / 2) + [InlineData(10.0f, -1.44269504f, 0.0360831928f, CrossPlatformMachineEpsilon / 10)] // y: -(log2(e)) + [InlineData(10.0f, -1.41421356f, 0.0385288847f, CrossPlatformMachineEpsilon / 10)] // y: -(sqrt(2)) + [InlineData(10.0f, -1.12837917f, 0.0744082059f, CrossPlatformMachineEpsilon / 10)] // y: -(2 / sqrt(pi)) + [InlineData(10.0f, -1.0f, 0.1f, CrossPlatformMachineEpsilon)] + [InlineData(10.0f, -0.785398163f, 0.163908636f, CrossPlatformMachineEpsilon)] // y: -(pi / 4) + [InlineData(10.0f, -0.707106781f, 0.196287760f, CrossPlatformMachineEpsilon)] // y: -(1 / sqrt(2)) + [InlineData(10.0f, -0.693147181f, 0.202699566f, CrossPlatformMachineEpsilon)] // y: -(ln(2)) + [InlineData(10.0f, -0.636619772f, 0.230876765f, CrossPlatformMachineEpsilon)] // y: -(2 / pi) + [InlineData(10.0f, -0.434294482f, 0.367879441f, CrossPlatformMachineEpsilon)] // y: -(log10(e)) + [InlineData(10.0f, -0.318309886f, 0.480496373f, CrossPlatformMachineEpsilon)] // y: -(1 / pi) + [InlineData(10.0f, -0.0f, 1.0f, CrossPlatformMachineEpsilon * 10)] + [InlineData(10.0f, float.NaN, float.NaN, 0.0f)] + [InlineData(10.0f, 0.0f, 1.0f, CrossPlatformMachineEpsilon * 10)] + [InlineData(10.0f, 0.318309886f, 2.08118116f, CrossPlatformMachineEpsilon * 10)] // y: (1 / pi) + [InlineData(10.0f, 0.434294482f, 2.71828183f, CrossPlatformMachineEpsilon * 10)] // y: (log10(e)) expected: (e) + [InlineData(10.0f, 0.636619772f, 4.33131503f, CrossPlatformMachineEpsilon * 10)] // y: (2 / pi) + [InlineData(10.0f, 0.693147181f, 4.93340967f, CrossPlatformMachineEpsilon * 10)] // y: (ln(2)) + [InlineData(10.0f, 0.707106781f, 5.09456117f, CrossPlatformMachineEpsilon * 10)] // y: (1 / sqrt(2)) + [InlineData(10.0f, 0.785398163f, 6.10095980f, CrossPlatformMachineEpsilon * 10)] // y: (pi / 4) + [InlineData(10.0f, 1.0f, 10.0f, CrossPlatformMachineEpsilon * 100)] + [InlineData(10.0f, 1.12837917f, 13.4393779f, CrossPlatformMachineEpsilon * 100)] // y: (2 / sqrt(pi)) + [InlineData(10.0f, 1.41421356f, 25.9545535f, CrossPlatformMachineEpsilon * 100)] // y: (sqrt(2)) + [InlineData(10.0f, 1.44269504f, 27.7137338f, CrossPlatformMachineEpsilon * 100)] // y: (log2(e)) + [InlineData(10.0f, 1.57079633f, 37.2217105f, CrossPlatformMachineEpsilon * 100)] // y: (pi / 2) + [InlineData(10.0f, 2.30258509f, 200.717432f, CrossPlatformMachineEpsilon * 1000)] // y: (ln(10)) + [InlineData(10.0f, 2.71828183f, 522.735300f, CrossPlatformMachineEpsilon * 1000)] // y: (e) + [InlineData(10.0f, 3.14159265f, 1385.45573f, CrossPlatformMachineEpsilon * 10000)] // y: (pi) + [InlineData(10.0f, float.PositiveInfinity, float.PositiveInfinity, 0.0f)] + [InlineData(float.PositiveInfinity, float.NegativeInfinity, 0.0f, 0.0f)] + [InlineData(float.PositiveInfinity, -1.0f, 0.0f, 0.0f)] + [InlineData(float.PositiveInfinity, -0.0f, 1.0f, 0.0f)] + [InlineData(float.PositiveInfinity, float.NaN, float.NaN, 0.0f)] + [InlineData(float.PositiveInfinity, 0.0f, 1.0f, 0.0f)] + [InlineData(float.PositiveInfinity, 1.0f, float.PositiveInfinity, 0.0f)] + [InlineData(float.PositiveInfinity, float.PositiveInfinity, float.PositiveInfinity, 0.0f)] + public static void Pow(float x, float y, float expectedResult, float allowedVariance) + { + AssertExtensions.Equal(expectedResult, MathF.Pow(x, y), allowedVariance); + } + + public static IEnumerable Round_Digits_TestData + { + get + { + yield return new object[] { float.NaN, float.NaN, 3, MidpointRounding.ToEven }; + yield return new object[] { float.PositiveInfinity, float.PositiveInfinity, 3, MidpointRounding.ToEven }; + yield return new object[] { float.NegativeInfinity, float.NegativeInfinity, 3, MidpointRounding.ToEven }; + yield return new object[] { 0, 0, 3, MidpointRounding.ToEven }; + yield return new object[] { 3.42156f, 3.422f, 3, MidpointRounding.ToEven }; + yield return new object[] { -3.42156f, -3.422f, 3, MidpointRounding.ToEven }; + + yield return new object[] { float.NaN, float.NaN, 3, MidpointRounding.AwayFromZero }; + yield return new object[] { float.PositiveInfinity, float.PositiveInfinity, 3, MidpointRounding.AwayFromZero }; + yield return new object[] { float.NegativeInfinity, float.NegativeInfinity, 3, MidpointRounding.AwayFromZero }; + yield return new object[] { 0, 0, 3, MidpointRounding.AwayFromZero }; + yield return new object[] { 3.42156f, 3.422f, 3, MidpointRounding.AwayFromZero }; + yield return new object[] { -3.42156f, -3.422f, 3, MidpointRounding.AwayFromZero }; + } + } + + [Fact] + public static void Round() + { + Assert.Equal(0.0f, MathF.Round(0.0f)); + Assert.Equal(1.0f, MathF.Round(1.4f)); + Assert.Equal(2.0f, MathF.Round(1.5f)); + Assert.Equal(2e7f, MathF.Round(2e7f)); + Assert.Equal(0.0f, MathF.Round(-0.0f)); + Assert.Equal(-1.0f, MathF.Round(-1.4f)); + Assert.Equal(-2.0f, MathF.Round(-1.5f)); + Assert.Equal(-2e7f, MathF.Round(-2e7f)); + } + + [Theory] + [InlineData(MidpointRounding.ToEven)] + [InlineData(MidpointRounding.AwayFromZero)] + public static void Round_Digits_ByMidpointRounding(MidpointRounding mode) + { + Assert.Equal(float.PositiveInfinity, MathF.Round(float.PositiveInfinity, 3, mode)); + Assert.Equal(float.NegativeInfinity, MathF.Round(float.NegativeInfinity, 3, mode)); + } + + [Theory] + [MemberData(nameof(Round_Digits_TestData))] + public static void Round_Digits(float x, float expected, int digits, MidpointRounding mode) + { + AssertExtensions.Equal(expected, MathF.Round(x, digits, mode), CrossPlatformMachineEpsilon * 10); + } + + [Fact] + public static void Sign() + { + Assert.Equal(0, MathF.Sign(0.0f)); + Assert.Equal(0, MathF.Sign(-0.0f)); + Assert.Equal(-1, MathF.Sign(-3.14f)); + Assert.Equal(1, MathF.Sign(3.14f)); + Assert.Equal(-1, MathF.Sign(float.NegativeInfinity)); + Assert.Equal(1, MathF.Sign(float.PositiveInfinity)); + Assert.Throws(() => MathF.Sign(float.NaN)); + } + + [Theory] + [InlineData(float.NegativeInfinity, float.NaN, 0.0f)] + [InlineData(-3.14159265f, -0.0f, CrossPlatformMachineEpsilon)] // value: -(pi) + [InlineData(-2.71828183f, -0.410781291f, CrossPlatformMachineEpsilon)] // value: -(e) + [InlineData(-2.30258509f, -0.743980337f, CrossPlatformMachineEpsilon)] // value: -(ln(10)) + [InlineData(-1.57079633f, -1.0f, CrossPlatformMachineEpsilon * 10)] // value: -(pi / 2) + [InlineData(-1.44269504f, -0.991806244f, CrossPlatformMachineEpsilon)] // value: -(log2(e)) + [InlineData(-1.41421356f, -0.987765946f, CrossPlatformMachineEpsilon)] // value: -(sqrt(2)) + [InlineData(-1.12837917f, -0.903719457f, CrossPlatformMachineEpsilon)] // value: -(2 / sqrt(pi)) + [InlineData(-1.0f, -0.841470985f, CrossPlatformMachineEpsilon)] + [InlineData(-0.785398163f, -0.707106781f, CrossPlatformMachineEpsilon)] // value: -(pi / 4), expected: -(1 / sqrt(2)) + [InlineData(-0.707106781f, -0.649636939f, CrossPlatformMachineEpsilon)] // value: -(1 / sqrt(2)) + [InlineData(-0.693147181f, -0.638961276f, CrossPlatformMachineEpsilon)] // value: -(ln(2)) + [InlineData(-0.636619772f, -0.594480769f, CrossPlatformMachineEpsilon)] // value: -(2 / pi) + [InlineData(-0.434294482f, -0.420770483f, CrossPlatformMachineEpsilon)] // value: -(log10(e)) + [InlineData(-0.318309886f, -0.312961796f, CrossPlatformMachineEpsilon)] // value: -(1 / pi) + [InlineData(-0.0f, -0.0f, 0.0f)] + [InlineData(float.NaN, float.NaN, 0.0f)] + [InlineData(0.0f, 0.0f, 0.0f)] + [InlineData(0.318309886f, 0.312961796f, CrossPlatformMachineEpsilon)] // value: (1 / pi) + [InlineData(0.434294482f, 0.420770483f, CrossPlatformMachineEpsilon)] // value: (log10(e)) + [InlineData(0.636619772f, 0.594480769f, CrossPlatformMachineEpsilon)] // value: (2 / pi) + [InlineData(0.693147181f, 0.638961276f, CrossPlatformMachineEpsilon)] // value: (ln(2)) + [InlineData(0.707106781f, 0.649636939f, CrossPlatformMachineEpsilon)] // value: (1 / sqrt(2)) + [InlineData(0.785398163f, 0.707106781f, CrossPlatformMachineEpsilon)] // value: (pi / 4), expected: (1 / sqrt(2)) + [InlineData(1.0f, 0.841470985f, CrossPlatformMachineEpsilon)] + [InlineData(1.12837917f, 0.903719457f, CrossPlatformMachineEpsilon)] // value: (2 / sqrt(pi)) + [InlineData(1.41421356f, 0.987765946f, CrossPlatformMachineEpsilon)] // value: (sqrt(2)) + [InlineData(1.44269504f, 0.991806244f, CrossPlatformMachineEpsilon)] // value: (log2(e)) + [InlineData(1.57079633f, 1.0f, CrossPlatformMachineEpsilon * 10)] // value: (pi / 2) + [InlineData(2.30258509f, 0.743980337f, CrossPlatformMachineEpsilon)] // value: (ln(10)) + [InlineData(2.71828183f, 0.410781291f, CrossPlatformMachineEpsilon)] // value: (e) + [InlineData(3.14159265f, 0.0f, CrossPlatformMachineEpsilon)] // value: (pi) + [InlineData(float.PositiveInfinity, float.NaN, 0.0f)] + public static void Sin(float value, float expectedResult, float allowedVariance) + { + AssertExtensions.Equal(expectedResult, MathF.Sin(value), allowedVariance); + } + + [Theory] + [InlineData(float.NegativeInfinity, float.NegativeInfinity, 0.0f)] + [InlineData(-3.14159265f, -11.5487394f, CrossPlatformMachineEpsilon * 100)] // value: -(pi) + [InlineData(-2.71828183f, -7.54413710f, CrossPlatformMachineEpsilon * 10)] // value: -(e) + [InlineData(-2.30258509f, -4.95f, CrossPlatformMachineEpsilon * 10)] // value: -(ln(10)) + [InlineData(-1.57079633f, -2.30129890f, CrossPlatformMachineEpsilon * 10)] // value: -(pi / 2) + [InlineData(-1.44269504f, -1.99789801f, CrossPlatformMachineEpsilon * 10)] // value: -(log2(e)) + [InlineData(-1.41421356f, -1.93506682f, CrossPlatformMachineEpsilon * 10)] // value: -(sqrt(2)) + [InlineData(-1.12837917f, -1.38354288f, CrossPlatformMachineEpsilon * 10)] // value: -(2 / sqrt(pi)) + [InlineData(-1.0f, -1.17520119f, CrossPlatformMachineEpsilon * 10)] + [InlineData(-0.785398163f, -0.868670961f, CrossPlatformMachineEpsilon)] // value: -(pi / 4) + [InlineData(-0.707106781f, -0.767523145f, CrossPlatformMachineEpsilon)] // value: -(1 / sqrt(2)) + [InlineData(-0.693147181f, -0.75f, CrossPlatformMachineEpsilon)] // value: -(ln(2)) + [InlineData(-0.636619772f, -0.680501678f, CrossPlatformMachineEpsilon)] // value: -(2 / pi) + [InlineData(-0.434294482f, -0.448075979f, CrossPlatformMachineEpsilon)] // value: -(log10(e)) + [InlineData(-0.318309886f, -0.323712439f, CrossPlatformMachineEpsilon)] // value: -(1 / pi) + [InlineData(-0.0f, -0.0f, 0.0f)] + [InlineData(float.NaN, float.NaN, 0.0f)] + [InlineData(0.0f, 0.0f, 0.0f)] + [InlineData(0.318309886f, 0.323712439f, CrossPlatformMachineEpsilon)] // value: (1 / pi) + [InlineData(0.434294482f, 0.448075979f, CrossPlatformMachineEpsilon)] // value: (log10(e)) + [InlineData(0.636619772f, 0.680501678f, CrossPlatformMachineEpsilon)] // value: (2 / pi) + [InlineData(0.693147181f, 0.75f, CrossPlatformMachineEpsilon)] // value: (ln(2)) + [InlineData(0.707106781f, 0.767523145f, CrossPlatformMachineEpsilon)] // value: (1 / sqrt(2)) + [InlineData(0.785398163f, 0.868670961f, CrossPlatformMachineEpsilon)] // value: (pi / 4) + [InlineData(1.0f, 1.17520119f, CrossPlatformMachineEpsilon * 10)] + [InlineData(1.12837917f, 1.38354288f, CrossPlatformMachineEpsilon * 10)] // value: (2 / sqrt(pi)) + [InlineData(1.41421356f, 1.93506682f, CrossPlatformMachineEpsilon * 10)] // value: (sqrt(2)) + [InlineData(1.44269504f, 1.99789801f, CrossPlatformMachineEpsilon * 10)] // value: (log2(e)) + [InlineData(1.57079633f, 2.30129890f, CrossPlatformMachineEpsilon * 10)] // value: (pi / 2) + [InlineData(2.30258509f, 4.95f, CrossPlatformMachineEpsilon * 10)] // value: (ln(10)) + [InlineData(2.71828183f, 7.54413710f, CrossPlatformMachineEpsilon * 10)] // value: (e) + [InlineData(3.14159265f, 11.5487394f, CrossPlatformMachineEpsilon * 100)] // value: (pi) + [InlineData(float.PositiveInfinity, float.PositiveInfinity, 0.0f)] + public static void Sinh(float value, float expectedResult, float allowedVariance) + { + AssertExtensions.Equal(expectedResult, MathF.Sinh(value), allowedVariance); + } + + [Theory] + [InlineData(float.NegativeInfinity, float.NaN, 0.0f)] + [InlineData(-3.14159265f, float.NaN, 0.0f)] // value: (pi) + [InlineData(-2.71828183f, float.NaN, 0.0f)] // value: (e) + [InlineData(-2.30258509f, float.NaN, 0.0f)] // value: (ln(10)) + [InlineData(-1.57079633f, float.NaN, 0.0f)] // value: (pi / 2) + [InlineData(-1.44269504f, float.NaN, 0.0f)] // value: (log2(e)) + [InlineData(-1.41421356f, float.NaN, 0.0f)] // value: (sqrt(2)) + [InlineData(-1.12837917f, float.NaN, 0.0f)] // value: (2 / sqrt(pi)) + [InlineData(-1.0f, float.NaN, 0.0f)] + [InlineData(-0.785398163f, float.NaN, 0.0f)] // value: (pi / 4) + [InlineData(-0.707106781f, float.NaN, 0.0f)] // value: (1 / sqrt(2)) + [InlineData(-0.693147181f, float.NaN, 0.0f)] // value: (ln(2)) + [InlineData(-0.636619772f, float.NaN, 0.0f)] // value: (2 / pi) + [InlineData(-0.434294482f, float.NaN, 0.0f)] // value: (log10(e)) + [InlineData(-0.318309886f, float.NaN, 0.0f)] // value: (1 / pi) + [InlineData(-0.0f, -0.0f, 0.0f)] + [InlineData(float.NaN, float.NaN, 0.0f)] + [InlineData(0.0f, 0.0f, 0.0f)] + [InlineData(0.318309886f, 0.564189584f, CrossPlatformMachineEpsilon)] // value: (1 / pi) + [InlineData(0.434294482f, 0.659010229f, CrossPlatformMachineEpsilon)] // value: (log10(e)) + [InlineData(0.636619772f, 0.797884561f, CrossPlatformMachineEpsilon)] // value: (2 / pi) + [InlineData(0.693147181f, 0.832554611f, CrossPlatformMachineEpsilon)] // value: (ln(2)) + [InlineData(0.707106781f, 0.840896415f, CrossPlatformMachineEpsilon)] // value: (1 / sqrt(2)) + [InlineData(0.785398163f, 0.886226925f, CrossPlatformMachineEpsilon)] // value: (pi / 4) + [InlineData(1.0f, 1.0f, CrossPlatformMachineEpsilon * 10)] + [InlineData(1.12837917f, 1.06225193f, CrossPlatformMachineEpsilon * 10)] // value: (2 / sqrt(pi)) + [InlineData(1.41421356f, 1.18920712f, CrossPlatformMachineEpsilon * 10)] // value: (sqrt(2)) + [InlineData(1.44269504f, 1.20112241f, CrossPlatformMachineEpsilon * 10)] // value: (log2(e)) + [InlineData(1.57079633f, 1.25331414f, CrossPlatformMachineEpsilon * 10)] // value: (pi / 2) + [InlineData(2.30258509f, 1.51742713f, CrossPlatformMachineEpsilon * 10)] // value: (ln(10)) + [InlineData(2.71828183f, 1.64872127f, CrossPlatformMachineEpsilon * 10)] // value: (e) + [InlineData(3.14159265f, 1.77245385F, CrossPlatformMachineEpsilon * 10)] // value: (pi) + [InlineData(float.PositiveInfinity, float.PositiveInfinity, 0.0f)] + public static void Sqrt(float value, float expectedResult, float allowedVariance) + { + AssertExtensions.Equal(expectedResult, MathF.Sqrt(value), allowedVariance); + } + + [Theory] + [InlineData(float.NegativeInfinity, float.NaN, 0.0f)] + [InlineData(-3.14159265f, -0.0f, CrossPlatformMachineEpsilon)] // value: -(pi) + [InlineData(-2.71828183f, 0.450549534f, CrossPlatformMachineEpsilon)] // value: -(e) + [InlineData(-2.30258509f, 1.11340715f, CrossPlatformMachineEpsilon * 10)] // value: -(ln(10)) + [InlineData(-1.57079633f, 22877332.0f, 10.0f)] // value: -(pi / 2) + [InlineData(-1.44269504f, -7.76357567f, CrossPlatformMachineEpsilon * 10)] // value: -(log2(e)) + [InlineData(-1.41421356f, -6.33411917f, CrossPlatformMachineEpsilon * 10)] // value: -(sqrt(2)) + [InlineData(-1.12837917f, -2.11087684f, CrossPlatformMachineEpsilon * 10)] // value: -(2 / sqrt(pi)) + [InlineData(-1.0f, -1.55740772f, CrossPlatformMachineEpsilon * 10)] + [InlineData(-0.785398163f, -1.0f, CrossPlatformMachineEpsilon * 10)] // value: -(pi / 4) + [InlineData(-0.707106781f, -0.854510432f, CrossPlatformMachineEpsilon)] // value: -(1 / sqrt(2)) + [InlineData(-0.693147181f, -0.830640878f, CrossPlatformMachineEpsilon)] // value: -(ln(2)) + [InlineData(-0.636619772f, -0.739302950f, CrossPlatformMachineEpsilon)] // value: -(2 / pi) + [InlineData(-0.434294482f, -0.463829067f, CrossPlatformMachineEpsilon)] // value: -(log10(e)) + [InlineData(-0.318309886f, -0.329514733f, CrossPlatformMachineEpsilon)] // value: -(1 / pi) + [InlineData(-0.0f, -0.0f, 0.0f)] + [InlineData(float.NaN, float.NaN, 0.0f)] + [InlineData(0.0f, 0.0f, 0.0f)] + [InlineData(0.318309886f, 0.329514733f, CrossPlatformMachineEpsilon)] // value: (1 / pi) + [InlineData(0.434294482f, 0.463829067f, CrossPlatformMachineEpsilon)] // value: (log10(e)) + [InlineData(0.636619772f, 0.739302950f, CrossPlatformMachineEpsilon)] // value: (2 / pi) + [InlineData(0.693147181f, 0.830640878f, CrossPlatformMachineEpsilon)] // value: (ln(2)) + [InlineData(0.707106781f, 0.854510432f, CrossPlatformMachineEpsilon)] // value: (1 / sqrt(2)) + [InlineData(0.785398163f, 1.0f, CrossPlatformMachineEpsilon * 10)] // value: (pi / 4) + [InlineData(1.0f, 1.55740772f, CrossPlatformMachineEpsilon * 10)] + [InlineData(1.12837917f, 2.11087684f, CrossPlatformMachineEpsilon * 10)] // value: (2 / sqrt(pi)) + [InlineData(1.41421356f, 6.33411917f, CrossPlatformMachineEpsilon * 10)] // value: (sqrt(2)) + [InlineData(1.44269504f, 7.76357567f, CrossPlatformMachineEpsilon * 10)] // value: (log2(e)) + [InlineData(1.57079633f, -22877332.0f, 10.0f)] // value: (pi / 2) + [InlineData(2.30258509f, -1.11340715f, CrossPlatformMachineEpsilon * 10)] // value: (ln(10)) + [InlineData(2.71828183f, -0.450549534f, CrossPlatformMachineEpsilon)] // value: (e) + [InlineData(3.14159265f, 0.0f, CrossPlatformMachineEpsilon)] // value: (pi) + [InlineData(float.PositiveInfinity, float.NaN, 0.0f)] + public static void Tan(float value, float expectedResult, float allowedVariance) + { + AssertExtensions.Equal(expectedResult, MathF.Tan(value), allowedVariance); + } + + [Theory] + [InlineData(float.NegativeInfinity, -1.0f, CrossPlatformMachineEpsilon * 10)] + [InlineData(-3.14159265f, -0.996272076f, CrossPlatformMachineEpsilon)] // value: -(pi) + [InlineData(-2.71828183f, -0.991328916f, CrossPlatformMachineEpsilon)] // value: -(e) + [InlineData(-2.30258509f, -0.980198020f, CrossPlatformMachineEpsilon)] // value: -(ln(10)) + [InlineData(-1.57079633f, -0.917152336f, CrossPlatformMachineEpsilon)] // value: -(pi / 2) + [InlineData(-1.44269504f, -0.894238946f, CrossPlatformMachineEpsilon)] // value: -(log2(e)) + [InlineData(-1.41421356f, -0.888385562f, CrossPlatformMachineEpsilon)] // value: -(sqrt(2)) + [InlineData(-1.12837917f, -0.810463806f, CrossPlatformMachineEpsilon)] // value: -(2 / sqrt(pi)) + [InlineData(-1.0f, -0.761594156f, CrossPlatformMachineEpsilon)] + [InlineData(-0.785398163f, -0.655794203f, CrossPlatformMachineEpsilon)] // value: -(pi / 4) + [InlineData(-0.707106781f, -0.608859365f, CrossPlatformMachineEpsilon)] // value: -(1 / sqrt(2)) + [InlineData(-0.693147181f, -0.6f, CrossPlatformMachineEpsilon)] // value: -(ln(2)) + [InlineData(-0.636619772f, -0.562593600f, CrossPlatformMachineEpsilon)] // value: -(2 / pi) + [InlineData(-0.434294482f, -0.408904012f, CrossPlatformMachineEpsilon)] // value: -(log10(e)) + [InlineData(-0.318309886f, -0.307977913f, CrossPlatformMachineEpsilon)] // value: -(1 / pi) + [InlineData(-0.0f, -0.0f, 0.0f)] + [InlineData(float.NaN, float.NaN, 0.0f)] + [InlineData(0.0f, 0.0f, 0.0f)] + [InlineData(0.318309886f, 0.307977913f, CrossPlatformMachineEpsilon)] // value: (1 / pi) + [InlineData(0.434294482f, 0.408904012f, CrossPlatformMachineEpsilon)] // value: (log10(e)) + [InlineData(0.636619772f, 0.562593600f, CrossPlatformMachineEpsilon)] // value: (2 / pi) + [InlineData(0.693147181f, 0.6f, CrossPlatformMachineEpsilon)] // value: (ln(2)) + [InlineData(0.707106781f, 0.608859365f, CrossPlatformMachineEpsilon)] // value: (1 / sqrt(2)) + [InlineData(0.785398163f, 0.655794203f, CrossPlatformMachineEpsilon)] // value: (pi / 4) + [InlineData(1.0f, 0.761594156f, CrossPlatformMachineEpsilon)] + [InlineData(1.12837917f, 0.810463806f, CrossPlatformMachineEpsilon)] // value: (2 / sqrt(pi)) + [InlineData(1.41421356f, 0.888385562f, CrossPlatformMachineEpsilon)] // value: (sqrt(2)) + [InlineData(1.44269504f, 0.894238946f, CrossPlatformMachineEpsilon)] // value: (log2(e)) + [InlineData(1.57079633f, 0.917152336f, CrossPlatformMachineEpsilon)] // value: (pi / 2) + [InlineData(2.30258509f, 0.980198020f, CrossPlatformMachineEpsilon)] // value: (ln(10)) + [InlineData(2.71828183f, 0.991328916f, CrossPlatformMachineEpsilon)] // value: (e) + [InlineData(3.14159265f, 0.996272076f, CrossPlatformMachineEpsilon)] // value: (pi) + [InlineData(float.PositiveInfinity, 1.0f, CrossPlatformMachineEpsilon * 10)] + public static void Tanh(float value, float expectedResult, float allowedVariance) + { + AssertExtensions.Equal(expectedResult, MathF.Tanh(value), allowedVariance); + } + + [Fact] + public static void Truncate() + { + Assert.Equal(0.0f, MathF.Truncate(0.12345f)); + Assert.Equal(3.0f, MathF.Truncate(3.14159f)); + Assert.Equal(-3.0f, MathF.Truncate(-3.14159f)); + } + + public static IEnumerable Round_ToEven_TestData() + { + yield return new object[] { 1f, 1f }; + yield return new object[] { 0.5f, 0f }; + yield return new object[] { 1.5f, 2f }; + yield return new object[] { 2.5f, 2f }; + yield return new object[] { 3.5f, 4f }; + yield return new object[] { 0.49999997f, 0f }; + yield return new object[] { 1.5f, 2f }; + yield return new object[] { 2.5f, 2f }; + yield return new object[] { 3.5f, 4f }; + yield return new object[] { 4.5f, 4f }; + yield return new object[] { 3.1415927f, 3f }; + yield return new object[] { 2.7182817f, 3f }; + yield return new object[] { 1385.4557f, 1385f }; + yield return new object[] { 3423.4343f, 3423f }; + yield return new object[] { 535345.5f, 535346f }; + yield return new object[] { 535345.5f, 535346f }; + yield return new object[] { 535345.5f, 535346f }; + yield return new object[] { 535345.4f, 535345f }; + yield return new object[] { 535345.6f, 535346f }; + yield return new object[] { -2.7182817f, -3f }; + yield return new object[] { 10f, 10f }; + yield return new object[] { -10f, -10f }; + yield return new object[] { -0f, -0f }; + yield return new object[] { 0f, 0f }; + yield return new object[] { float.NaN, float.NaN }; + yield return new object[] { float.PositiveInfinity, float.PositiveInfinity }; + yield return new object[] { float.NegativeInfinity, float.NegativeInfinity }; + yield return new object[] { 3.4028235E+38f, 3.4028235E+38f }; + yield return new object[] { -3.4028235E+38f, -3.4028235E+38f }; + } + + [Theory] + [MemberData(nameof(Round_ToEven_TestData))] + public static void Round_ToEven_0(float value, float expected) + { + // Math.Round has special fast paths when MidpointRounding is a const + // Don't replace it with a variable + Assert.Equal(expected, MathF.Round(value, MidpointRounding.ToEven)); + Assert.Equal(expected, MathF.Round(value, 0, MidpointRounding.ToEven)); + } + + public static IEnumerable Round_AwayFromZero_TestData() + { + yield return new object[] { 1f, 1f }; + yield return new object[] { 0.5f, 1f }; + yield return new object[] { 1.5f, 2f }; + yield return new object[] { 2.5f, 3f }; + yield return new object[] { 3.5f, 4f }; + yield return new object[] { 0.49999997f, 0f }; + yield return new object[] { 1.5f, 2f }; + yield return new object[] { 2.5f, 3f }; + yield return new object[] { 3.5f, 4f }; + yield return new object[] { 4.5f, 5f }; + yield return new object[] { 3.1415927f, 3f }; + yield return new object[] { 2.7182817f, 3f }; + yield return new object[] { 1385.4557f, 1385f }; + yield return new object[] { 3423.4343f, 3423f }; + yield return new object[] { 535345.5f, 535346f }; + yield return new object[] { 535345.5f, 535346f }; + yield return new object[] { 535345.5f, 535346f }; + yield return new object[] { 535345.4f, 535345f }; + yield return new object[] { 535345.6f, 535346f }; + yield return new object[] { -2.7182817f, -3f }; + yield return new object[] { 10f, 10f }; + yield return new object[] { -10f, -10f }; + yield return new object[] { -0f, -0f }; + yield return new object[] { 0f, 0f }; + yield return new object[] { float.NaN, float.NaN }; + yield return new object[] { float.PositiveInfinity, float.PositiveInfinity }; + yield return new object[] { float.NegativeInfinity, float.NegativeInfinity }; + yield return new object[] { 3.4028235E+38f, 3.4028235E+38f }; + yield return new object[] { -3.4028235E+38f, -3.4028235E+38f }; + } + + [Theory] + [MemberData(nameof(Round_AwayFromZero_TestData))] + public static void Round_AwayFromZero_0(float value, float expected) + { + // Math.Round has special fast paths when MidpointRounding is a const + // Don't replace it with a variable + Assert.Equal(expected, MathF.Round(value, MidpointRounding.AwayFromZero)); + Assert.Equal(expected, MathF.Round(value, 0, MidpointRounding.AwayFromZero)); + } + } +} diff --git a/src/libraries/Microsoft.Bcl.Numerics/tests/Microsoft.Bcl.Numerics.Tests.csproj b/src/libraries/Microsoft.Bcl.Numerics/tests/Microsoft.Bcl.Numerics.Tests.csproj new file mode 100644 index 00000000000000..8f2abec46041a7 --- /dev/null +++ b/src/libraries/Microsoft.Bcl.Numerics/tests/Microsoft.Bcl.Numerics.Tests.csproj @@ -0,0 +1,15 @@ + + + + $(NetFrameworkMinimum);$(NetCoreAppCurrent) + + + + + + + + + + + diff --git a/src/libraries/Microsoft.Bcl.TimeProvider/src/PACKAGE.md b/src/libraries/Microsoft.Bcl.TimeProvider/src/PACKAGE.md new file mode 100644 index 00000000000000..f3c9c372cf2caf --- /dev/null +++ b/src/libraries/Microsoft.Bcl.TimeProvider/src/PACKAGE.md @@ -0,0 +1,57 @@ +## About + +Microsoft.Bcl.TimeProvider provides time abstraction support for apps targeting .NET 7 and earlier, as well as those intended for the .NET Framework. For apps targeting .NET 8 and newer versions, referencing this package is unnecessary, as the types it contains are already included in the .NET 8 and higher platform versions. + +## Key Features + +* Provides a common abstraction for time-related operations. + +## How to Use + +```csharp +using System; + +// A class that uses TimeProvider to get the current time in Utc coordinates +public class UtcClock +{ + private readonly TimeProvider _timeProvider; + + // Constructor that takes a TimeProvider as a dependency + public Clock(TimeProvider timeProvider) + { + _timeProvider = timeProvider; + } + + // A method that returns the current time as a string + public string GetTime() + { + return _timeProvider.GetLocalNow().ToString("HH:mm:ss"); + } +} + +// A class that inherits from TimeProvider and overrides the GetLocalNow method +public class UtcTimeProvider : TimeProvider +{ + // Override the GetLocalNow method to always return UTC time + public override DateTimeOffset GetLocalNow() + { + return TimeProvider.System.GetUtcNow(); + } +} + +``` + +## Main Types + +The main types provided by this library are: + +* `TimeProvider` +* `TimeProviderTaskExtensions` + +## Additional Documentation + +* [API documentation](https://learn.microsoft.com/dotnet/api/system.timeprovider) + +## Feedback & Contributing + +Microsoft.Bcl.TimeProvider is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). \ No newline at end of file diff --git a/src/libraries/Microsoft.Extensions.Caching.Abstractions/src/PACKAGE.md b/src/libraries/Microsoft.Extensions.Caching.Abstractions/src/PACKAGE.md new file mode 100644 index 00000000000000..eb8a9beacbc44b --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Caching.Abstractions/src/PACKAGE.md @@ -0,0 +1,53 @@ +## About + + + +Provides the abstractions to create and use in-memory and distributed caching in your applications. + +This library defines how in-memory and distributed caches should be implemented; it doesn’t contain any cache implementation. +With the abstractions provided in this library, various types of caches can be built and used interchangeably, whether the data is kept in memory, in files, or even across a network. + +## Key Features + + + +* Interfaces for building and using in-memory and distributed caches. + +## How to Use + + + +This package is typically used with an implementation of the caching abstractions, such as `Microsoft.Extensions.Caching.Memory` or `Microsoft.Extensions.Caching.SqlServer`. + +## Main Types + + + +The main types provided by this library are: + +* `Microsoft.Extensions.Caching.Abstractions.ICacheEntry` +* `Microsoft.Extensions.Caching.Abstractions.IMemoryCache` +* `Microsoft.Extensions.Caching.Abstractions.IDistributedCache` + +## Additional Documentation + + + +* [Conceptual documentation](https://learn.microsoft.com/dotnet/core/extensions/caching) +* API documentation + * [Microsoft.Extensions.Caching.Memory](https://learn.microsoft.com/dotnet/api/microsoft.extensions.caching.memory) + * [Microsoft.Extensions.Caching.Distributed](https://learn.microsoft.com/dotnet/api/microsoft.extensions.caching.distributed) + +## Related Packages + + + +* In-memory caching: [Microsoft.Extensions.Caching.Memory](https://www.nuget.org/packages/Microsoft.Extensions.Caching.Memory/) +* SQL Server caching: [Microsoft.Extensions.Caching.SqlServer](https://www.nuget.org/packages/Microsoft.Extensions.Caching.SqlServer/) +* Redis caching: [Microsoft.Extensions.Caching.StackExchangeRedis](https://www.nuget.org/packages/Microsoft.Extensions.Caching.StackExchangeRedis/) + +## Feedback & Contributing + + + +Microsoft.Extensions.Caching.Abstractions is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). diff --git a/src/libraries/Microsoft.Extensions.Caching.Memory/src/PACKAGE.md b/src/libraries/Microsoft.Extensions.Caching.Memory/src/PACKAGE.md new file mode 100644 index 00000000000000..a9e1c0e00615d6 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Caching.Memory/src/PACKAGE.md @@ -0,0 +1,89 @@ +## About + + + +Provides implementations for local and distributed in-memory cache. It stores and retrieves data in a fast and efficient way. + +## Key Features + + + +* A concrete implementation of the IMemoryCache interface, which represents a local in-memory cache that stores and retrieves data in a fast and efficient way +* A distributed cache that supports higher scale-out than local cache +* Expiration and eviction policies for its entries +* Entry prioritization for when the cache size limit is exceeded and needs to be compacted by entry eviction +* Track of cache statictics + +## How to Use + + + +Use Microsoft.Extensions.Caching.Memory over System.Runtime.Caching when working with ASP.NET Core as it provides better integration support. For example, IMemoryCache works natively with ASP.NET Core dependency injection. + +Local in-memory serialization: +```csharp +using Microsoft.Extensions.Caching.Memory; + +using MemoryCache cache = new(new MemoryCacheOptions()); + +object valueToCache = new(); +string key = "key"; + +using (ICacheEntry entry = cache.CreateEntry(key)) +{ + // Entries are committed after they are disposed therefore it does not exist yet. + Console.WriteLine($"Exists: {cache.TryGetValue(key, out _)}\n"); + + entry.Value = valueToCache; + entry.SlidingExpiration = TimeSpan.FromSeconds(2); +} + +bool exists = cache.TryGetValue(key, out object? cachedValue); +Console.WriteLine($"Exists: {exists}" ); +Console.WriteLine($"cachedValue is valueToCache? {object.ReferenceEquals(cachedValue, valueToCache)}\n"); + +Console.WriteLine("Wait for the sliding expiration..."); +Thread.Sleep(TimeSpan.FromSeconds(2)); + +Console.WriteLine("Exists: " + cache.TryGetValue(key, out _)); + +// You can also use the acceleration extensions to set and get entries +string key2 = "key2"; +object value2 = new(); + +cache.Set("key2", value2); + +object? cachedValue2 = cache.Get(key2); +Console.WriteLine($"cachedValue2 is value2? {object.ReferenceEquals(cachedValue2, value2)}"); +``` + +## Main Types + + + +The main types provided by this library are: + +* `Microsoft.Extensions.Caching.Memory.MemoryCache` +* `Microsoft.Extensions.Caching.Memory.MemoryCacheOptions` +* `Microsoft.Extensions.Caching.Distributed.MemoryDistributedCache` +* `Microsoft.Extensions.Caching.Memory.MemoryDistributedCacheOptions` + +## Additional Documentation + + + +* [Conceptual documentation](https://learn.microsoft.com/dotnet/core/extensions/caching) +* [Cache in-memory in ASP.NET Core](https://learn.microsoft.com/aspnet/core/performance/caching/memory) +* [API documentation](https://learn.microsoft.com/dotnet/api/microsoft.extensions.caching.memory) + +## Related Packages + + + +[Microsoft.Extensions.Caching.Abstractions](https://www.nuget.org/packages/Microsoft.Extensions.Caching.Abstractions) + +## Feedback & Contributing + + + +Microsoft.Extensions.Caching.Memory is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). diff --git a/src/libraries/Microsoft.Extensions.Configuration.Abstractions/src/PACKAGE.md b/src/libraries/Microsoft.Extensions.Configuration.Abstractions/src/PACKAGE.md index 9a93ad1f76eb74..e744e1b9e73df1 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Abstractions/src/PACKAGE.md +++ b/src/libraries/Microsoft.Extensions.Configuration.Abstractions/src/PACKAGE.md @@ -1,18 +1,21 @@ ## About + + Provides abstractions of key-value pair based configuration. Interfaces defined in this package are implemented by classes in [Microsoft.Extensions.Configuration](https://www.nuget.org/packages/Microsoft.Extensions.Configuration/) and other configuration packages. -Commonly used types: +## Key Features + + -- [Microsoft.Extensions.Configuration.IConfiguration](https://learn.microsoft.com/dotnet/api/microsoft.extensions.configuration.iconfiguration) -- [Microsoft.Extensions.Configuration.IConfigurationBuilder](https://learn.microsoft.com/dotnet/api/microsoft.extensions.configuration.iconfigurationbuilder) -- [Microsoft.Extensions.Configuration.IConfigurationProvider](https://learn.microsoft.com/dotnet/api/microsoft.extensions.configuration.iconfigurationprovider) -- [Microsoft.Extensions.Configuration.IConfigurationRoot](https://learn.microsoft.com/dotnet/api/microsoft.extensions.configuration.iconfigurationroot) -- [Microsoft.Extensions.Configuration.IConfigurationSection](https://learn.microsoft.com/dotnet/api/microsoft.extensions.configuration.iconfigurationsection) +* Abstractions for string key-value pair configuration sources and sections +* Path conventions of keys establishing a heirachy of values +* Support for multiple configuration sources, aggregating and defining precdence for values +* Support for reload on change -For more information, see the documentation: [Configuration in .NET](https://learn.microsoft.com/dotnet/core/extensions/configuration). +## How to Use -## Example + The example below shows a small code sample using this library and trying out the `ConfigurationKeyName` attribute available since .NET 6: @@ -39,3 +42,41 @@ var config = new ConfigurationBuilder() var options = config.Get(); Console.WriteLine(options.NamedProperty); // returns "value for named property" ``` + +## Main Types + + + +The main types provided by this library are: + +* `Microsoft.Extensions.Configuration.IConfiguration` +* `Microsoft.Extensions.Configuration.IConfigurationBuilder` +* `Microsoft.Extensions.Configuration.IConfigurationProvider` +* `Microsoft.Extensions.Configuration.IConfigurationRoot` +* `Microsoft.Extensions.Configuration.IConfigurationSection` + +## Additional Documentation + + + +* [Configuration in .NET](https://learn.microsoft.com/dotnet/core/extensions/configuration) +* [API documentation](https://learn.microsoft.com/dotnet/api/microsoft.extensions.configuration) + +## Related Packages + + +* [Microsoft.Extensions.Configuration](https://www.nuget.org/packages/Microsoft.Extensions.Configuration) +* [Microsoft.Extensions.Configuration.Binder](https://www.nuget.org/packages/Microsoft.Extensions.Configuration.Binder) +* [Microsoft.Extensions.Configuration.CommandLine](https://www.nuget.org/packages/Microsoft.Extensions.Configuration.CommandLine) +* [Microsoft.Extensions.Configuration.EnvironmentVariables](https://www.nuget.org/packages/Microsoft.Extensions.Configuration.EnvironmentVariables) +* [Microsoft.Extensions.Configuration.FileExtensions](https://www.nuget.org/packages/Microsoft.Extensions.Configuration.FileExtensions) +* [Microsoft.Extensions.Configuration.Ini](https://www.nuget.org/packages/Microsoft.Extensions.Configuration.Ini) +* [Microsoft.Extensions.Configuration.Json](https://www.nuget.org/packages/Microsoft.Extensions.Configuration.Json) +* [Microsoft.Extensions.Configuration.UserSecrets](https://www.nuget.org/packages/Microsoft.Extensions.Configuration.UserSecrets) +* [Microsoft.Extensions.Configuration.Xml](https://www.nuget.org/packages/Microsoft.Extensions.Configuration.Xml) + +## Feedback & Contributing + + + +Microsoft.Extensions.Caching.Abstractions is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/ConfigurationBindingGenerator.Emitter.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/ConfigurationBindingGenerator.Emitter.cs index a40cf2976b31fc..1721a124dead95 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/ConfigurationBindingGenerator.Emitter.cs +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/ConfigurationBindingGenerator.Emitter.cs @@ -1,9 +1,6 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System; -using System.Diagnostics; -using System.Text.RegularExpressions; using Microsoft.CodeAnalysis; using SourceGenerators; @@ -13,26 +10,22 @@ public sealed partial class ConfigurationBindingGenerator : IIncrementalGenerato { private sealed partial class Emitter { - private readonly SourceProductionContext _context; - private readonly SourceGenerationSpec _sourceGenSpec; - - private bool _emitBlankLineBeforeNextStatement; - private bool _useFullyQualifiedNames; - private int _valueSuffixIndex; - - private static readonly Regex s_arrayBracketsRegex = new(Regex.Escape("[]")); + private readonly InterceptorInfo _interceptorInfo; + private readonly BindingHelperInfo _bindingHelperInfo; + private readonly TypeIndex _typeIndex; private readonly SourceWriter _writer = new(); - public Emitter(SourceProductionContext context, SourceGenerationSpec sourceGenSpec) + public Emitter(SourceGenerationSpec sourceGenSpec) { - _context = context; - _sourceGenSpec = sourceGenSpec; + _interceptorInfo = sourceGenSpec.InterceptorInfo; + _bindingHelperInfo = sourceGenSpec.BindingHelperInfo; + _typeIndex = new TypeIndex(sourceGenSpec.ConfigTypes); } - public void Emit() + public void Emit(SourceProductionContext context) { - if (!ShouldEmitBinders()) + if (!ShouldEmitMethods(MethodsToGen.Any)) { return; } @@ -42,214 +35,56 @@ public void Emit() #nullable enable #pragma warning disable CS0612, CS0618 // Suppress warnings about [Obsolete] member usage in generated code. """); - _writer.WriteLine(); - - _useFullyQualifiedNames = true; - EmitBinder_Extensions_IConfiguration(); - EmitBinder_Extensions_OptionsBuilder(); - EmitBinder_Extensions_IServiceCollection(); - - _useFullyQualifiedNames = false; - Emit_CoreBindingHelper(); - - _context.AddSource($"{Identifier.GeneratedConfigurationBinder}.g.cs", _writer.ToSourceText()); - } - - private void EmitBindCoreCall( - TypeSpec type, - string memberAccessExpr, - string configArgExpr, - InitializationKind initKind, - Action? writeOnSuccess = null) - { - Debug.Assert(type.CanInitialize); - - if (!type.NeedsMemberBinding) - { - EmitObjectInit(memberAccessExpr, initKind); - return; - } - - string tempIdentifier = GetIncrementalIdentifier(Identifier.temp); - if (initKind is InitializationKind.AssignmentWithNullCheck) - { - Debug.Assert(!type.IsValueType); - _writer.WriteLine($"{type.MinimalDisplayString}? {tempIdentifier} = {memberAccessExpr};"); - EmitBindCoreCall(tempIdentifier, InitializationKind.AssignmentWithNullCheck); - } - else if (initKind is InitializationKind.None && type.IsValueType) - { - EmitBindCoreCall(tempIdentifier, InitializationKind.Declaration); - _writer.WriteLine($"{memberAccessExpr} = {tempIdentifier};"); - } - else - { - EmitBindCoreCall(memberAccessExpr, initKind); - } - - void EmitBindCoreCall(string objExpression, InitializationKind initKind) - { - string methodDisplayString = GetHelperMethodDisplayString(nameof(MethodsToGen_CoreBindingHelper.BindCore)); - string bindCoreCall = $@"{methodDisplayString}({configArgExpr}, ref {objExpression}, {Identifier.binderOptions});"; - EmitObjectInit(objExpression, initKind); - _writer.WriteLine(bindCoreCall); - writeOnSuccess?.Invoke(objExpression); - } + EmitInterceptsLocationAttrDecl(); - void EmitObjectInit(string objExpression, InitializationKind initKind) - { - if (initKind is not InitializationKind.None) - { - this.EmitObjectInit(type, objExpression, initKind, configArgExpr); - } - } - } + EmitStartBlock($"namespace {ProjectName}"); + EmitUsingStatements(); - private void EmitBindLogicFromString( - ParsableFromStringSpec type, - string sectionValueExpr, - string sectionPathExpr, - Action? writeOnSuccess, - bool checkForNullSectionValue, - bool useIncrementalStringValueIdentifier) - { - StringParsableTypeKind typeKind = type.StringParsableTypeKind; - Debug.Assert(typeKind is not StringParsableTypeKind.None); - - string nonNull_StringValue_Identifier = useIncrementalStringValueIdentifier ? GetIncrementalIdentifier(Identifier.value) : Identifier.value; - string stringValueToParse_Expr = checkForNullSectionValue ? nonNull_StringValue_Identifier : sectionValueExpr; - - string parsedValueExpr; - if (typeKind is StringParsableTypeKind.AssignFromSectionValue) - { - parsedValueExpr = stringValueToParse_Expr; - } - else if (typeKind is StringParsableTypeKind.Enum) - { - parsedValueExpr = $"ParseEnum<{type.MinimalDisplayString}>({stringValueToParse_Expr}, () => {sectionPathExpr})"; - } - else - { - string helperMethodDisplayString = GetHelperMethodDisplayString(type.ParseMethodName); - parsedValueExpr = $"{helperMethodDisplayString}({stringValueToParse_Expr}, () => {sectionPathExpr})"; - } + _writer.WriteLine(); + EmitStartBlock($$""" + {{Expression.GeneratedCodeAnnotation}} + file static class {{Identifier.BindingExtensions}} + """); + EmitBindingExtensions_IConfiguration(); + EmitBindingExtensions_OptionsBuilder(); + EmitBindingExtensions_IServiceCollection(); + EmitCoreBindingHelpers(); + EmitEndBlock(); // BindingExtensions class - if (!checkForNullSectionValue) - { - InvokeWriteOnSuccess(); - } - else - { - EmitStartBlock($"if ({sectionValueExpr} is string {nonNull_StringValue_Identifier})"); - InvokeWriteOnSuccess(); - EmitEndBlock(); - } + EmitEndBlock(); // Binding namespace. - void InvokeWriteOnSuccess() => writeOnSuccess?.Invoke(parsedValueExpr); + context.AddSource($"{Identifier.BindingExtensions}.g.cs", _writer.ToSourceText()); } - private bool EmitObjectInit(TypeSpec type, string memberAccessExpr, InitializationKind initKind, string configArgExpr) + private void EmitInterceptsLocationAttrDecl() { - Debug.Assert(type.CanInitialize && initKind is not InitializationKind.None); - - string initExpr; - CollectionSpec? collectionType = type as CollectionSpec; - - string effectiveDisplayString = GetTypeDisplayString(type); - if (collectionType is not null) - { - if (collectionType is EnumerableSpec { InitializationStrategy: InitializationStrategy.Array }) - { - initExpr = $"new {s_arrayBracketsRegex.Replace(effectiveDisplayString, "[0]", 1)}"; - } - else + _writer.WriteLine(); + _writer.WriteLine($$""" + namespace System.Runtime.CompilerServices { - effectiveDisplayString = GetTypeDisplayString(collectionType.ConcreteType ?? collectionType); - initExpr = $"new {effectiveDisplayString}()"; - } - } - else if (type.InitializationStrategy is InitializationStrategy.ParameterlessConstructor) - { - initExpr = $"new {effectiveDisplayString}()"; - } - else - { - Debug.Assert(type.InitializationStrategy is InitializationStrategy.ParameterizedConstructor); - string initMethodIdentifier = GetInitalizeMethodDisplayString(((ObjectSpec)type)); - initExpr = $"{initMethodIdentifier}({configArgExpr}, {Identifier.binderOptions})"; - } + using System; + using System.CodeDom.Compiler; - if (initKind == InitializationKind.Declaration) - { - Debug.Assert(!memberAccessExpr.Contains(".")); - _writer.WriteLine($"var {memberAccessExpr} = {initExpr};"); - } - else if (initKind == InitializationKind.AssignmentWithNullCheck) - { - if (collectionType is CollectionSpec - { - InitializationStrategy: InitializationStrategy.ParameterizedConstructor or InitializationStrategy.ToEnumerableMethod - }) - { - if (collectionType.InitializationStrategy is InitializationStrategy.ParameterizedConstructor) + {{Expression.GeneratedCodeAnnotation}} + [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] + file sealed class InterceptsLocationAttribute : Attribute { - _writer.WriteLine($"{memberAccessExpr} = {memberAccessExpr} is null ? {initExpr} : new {effectiveDisplayString}({memberAccessExpr});"); - } - else - { - _writer.WriteLine($"{memberAccessExpr} = {memberAccessExpr} is null ? {initExpr} : {memberAccessExpr}.{collectionType.ToEnumerableMethodCall!};"); + public InterceptsLocationAttribute(string filePath, int line, int column) + { + } } } - else - { - _writer.WriteLine($"{memberAccessExpr} ??= {initExpr};"); - } - } - else - { - Debug.Assert(initKind is InitializationKind.SimpleAssignment); - _writer.WriteLine($"{memberAccessExpr} = {initExpr};"); - } - - return true; + """); + _writer.WriteLine(); } - private void EmitCastToIConfigurationSection() + private void EmitUsingStatements() { - string sectionTypeDisplayString; - string exceptionTypeDisplayString; - if (_useFullyQualifiedNames) + foreach (string @namespace in _bindingHelperInfo.Namespaces) { - sectionTypeDisplayString = "global::Microsoft.Extensions.Configuration.IConfigurationSection"; - exceptionTypeDisplayString = FullyQualifiedDisplayString.InvalidOperationException; + _writer.WriteLine($"using {@namespace};"); } - else - { - sectionTypeDisplayString = Identifier.IConfigurationSection; - exceptionTypeDisplayString = nameof(InvalidOperationException); - } - - _writer.WriteLine($$""" - if ({{Identifier.configuration}} is not {{sectionTypeDisplayString}} {{Identifier.section}}) - { - throw new {{exceptionTypeDisplayString}}(); - } - """); - } - - private void EmitIConfigurationHasValueOrChildrenCheck(bool voidReturn) - { - string returnPostfix = voidReturn ? string.Empty : " null"; - string methodDisplayString = GetHelperMethodDisplayString(Identifier.HasValueOrChildren); - - _writer.WriteLine($$""" - if (!{{methodDisplayString}}({{Identifier.configuration}})) - { - return{{returnPostfix}}; - } - """); - _writer.WriteLine(); } } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/ConfigurationBindingGenerator.Parser.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/ConfigurationBindingGenerator.Parser.cs index 758311958c4515..d01c5dbae13f3c 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/ConfigurationBindingGenerator.Parser.cs +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/ConfigurationBindingGenerator.Parser.cs @@ -1,202 +1,212 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System; using System.Collections.Generic; using System.Collections.Immutable; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Linq; +using System.Threading; using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.Operations; +using SourceGenerators; namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration { public sealed partial class ConfigurationBindingGenerator : IIncrementalGenerator { - private sealed partial class Parser + internal sealed partial class Parser(CompilationData compilationData) { - private record struct InvocationDiagnosticInfo(DiagnosticDescriptor Descriptor, object[]? MessageArgs); + private readonly KnownTypeSymbols _typeSymbols = compilationData.TypeSymbols!; + private readonly bool _langVersionIsSupported = compilationData.LanguageVersionIsSupported; - private readonly SourceProductionContext _context; - private readonly SourceGenerationSpec _sourceGenSpec = new(); - private readonly KnownTypeSymbols _typeSymbols; - private readonly ImmutableArray _invocations; + private readonly List _invocationTypeParseInfo = new(); + private readonly Queue _typesToParse = new(); + private readonly Dictionary _createdTypeSpecs = new(SymbolEqualityComparer.Default); - private readonly Dictionary _createdSpecs = new(SymbolEqualityComparer.Default); - private readonly HashSet _unsupportedTypes = new(SymbolEqualityComparer.Default); + private readonly InterceptorInfo.Builder _interceptorInfoBuilder = new(); + private BindingHelperInfo.Builder? _helperInfoBuilder; // Init'ed with type index when registering interceptors, after creating type specs. - private readonly List _invocationTargetTypeDiags = new(); - private readonly Dictionary> _typeDiagnostics = new(SymbolEqualityComparer.Default); + public List? Diagnostics { get; private set; } - public Parser(SourceProductionContext context, KnownTypeSymbols typeSymbols, ImmutableArray invocations) + public SourceGenerationSpec? GetSourceGenerationSpec(ImmutableArray invocations, CancellationToken cancellationToken) { - _context = context; - _typeSymbols = typeSymbols; - _invocations = invocations; - } + if (!_langVersionIsSupported) + { + RecordDiagnostic(DiagnosticDescriptors.LanguageVersionNotSupported, trimmedLocation: Location.None); + return null; + } - public SourceGenerationSpec? GetSourceGenerationSpec() - { if (_typeSymbols is not { IConfiguration: { }, ConfigurationBinder: { } }) { return null; } - foreach (BinderInvocation invocation in _invocations) + ParseInvocations(invocations); + CreateTypeSpecs(cancellationToken); + RegisterInterceptors(); + + return new SourceGenerationSpec { - IInvocationOperation invocationOperation = invocation.Operation!; - if (!invocationOperation.TargetMethod.IsExtensionMethod) - { - continue; - } + InterceptorInfo = _interceptorInfoBuilder.ToIncrementalValue(), + BindingHelperInfo = _helperInfoBuilder!.ToIncrementalValue(), + ConfigTypes = _createdTypeSpecs.Values.OrderBy(s => s.TypeRef.FullyQualifiedName).ToImmutableEquatableArray(), + }; + } + + private bool IsValidRootConfigType([NotNullWhen(true)] ITypeSymbol? type) + { + if (type is null || + type.SpecialType is SpecialType.System_Object or SpecialType.System_Void || + !_typeSymbols.Compilation.IsSymbolAccessibleWithin(type, _typeSymbols.Compilation.Assembly) || + type.TypeKind is TypeKind.TypeParameter or TypeKind.Pointer or TypeKind.Error || + type.IsRefLikeType || + ContainsGenericParameters(type)) + { + return false; + } + + return true; + } + + private void ParseInvocations(ImmutableArray invocations) + { + foreach (BinderInvocation? invocation in invocations) + { + Debug.Assert(invocation is not null); + IMethodSymbol targetMethod = invocation.Operation.TargetMethod; + INamedTypeSymbol? candidateBinderType = targetMethod.ContainingType; + Debug.Assert(targetMethod.IsExtensionMethod); - INamedTypeSymbol? candidateBinderType = invocationOperation.TargetMethod.ContainingType; if (SymbolEqualityComparer.Default.Equals(candidateBinderType, _typeSymbols.ConfigurationBinder)) { - RegisterMethodInvocation_ConfigurationBinder(invocation); + ParseInvocation_ConfigurationBinder(invocation); } else if (SymbolEqualityComparer.Default.Equals(candidateBinderType, _typeSymbols.OptionsBuilderConfigurationExtensions)) { - RegisterMethodInvocation_OptionsBuilderExt(invocation); + ParseInvocation_OptionsBuilderExt(invocation); } else if (SymbolEqualityComparer.Default.Equals(candidateBinderType, _typeSymbols.OptionsConfigurationServiceCollectionExtensions)) { - RegisterMethodInvocation_ServiceCollectionExt(invocation); + ParseInvocation_ServiceCollectionExt(invocation); } } - - return _sourceGenSpec; } - private static bool IsValidRootConfigType(ITypeSymbol? type) + private void CreateTypeSpecs(CancellationToken cancellationToken) { - if (type is null || - type.SpecialType is SpecialType.System_Object or SpecialType.System_Void || - type.TypeKind is TypeKind.TypeParameter or TypeKind.Pointer or TypeKind.Error || - type.IsRefLikeType || - ContainsGenericParameters(type)) + while (_typesToParse.Count > 0) { - return false; - } + cancellationToken.ThrowIfCancellationRequested(); - return true; + TypeParseInfo typeParseInfo = _typesToParse.Dequeue(); + ITypeSymbol typeSymbol = typeParseInfo.TypeSymbol; + + if (!_createdTypeSpecs.ContainsKey(typeSymbol)) + { + _createdTypeSpecs.Add(typeSymbol, CreateTypeSpec(typeParseInfo)); + } + } } - private TypeSpec? GetTargetTypeForRootInvocation(ITypeSymbol? type, Location? invocationLocation) + private void RegisterInterceptors() { - if (!IsValidRootConfigType(type)) + TypeIndex typeIndex = new(_createdTypeSpecs.Values); + _helperInfoBuilder = new(typeIndex); + + foreach (TypeParseInfo typeParseInfo in _invocationTypeParseInfo) { - _context.ReportDiagnostic(Diagnostic.Create(Diagnostics.CouldNotDetermineTypeInfo, invocationLocation)); - return null; + TypeSpec typeSpec = _createdTypeSpecs[typeParseInfo.TypeSymbol]; + MethodsToGen overload = typeParseInfo.BindingOverload; + + if ((MethodsToGen.ConfigBinder_Any & overload) is not 0) + { + RegisterInterceptor_ConfigurationBinder(typeParseInfo, typeSpec); + } + else if ((MethodsToGen.OptionsBuilderExt_Any & overload) is not 0) + { + RegisterInterceptor_OptionsBuilderExt(typeParseInfo, typeSpec); + } + else + { + Debug.Assert((MethodsToGen.ServiceCollectionExt_Any & overload) is not 0); + RegisterInterceptor_ServiceCollectionExt(typeParseInfo, typeSpec); + } } + } - return GetTargetTypeForRootInvocationCore(type, invocationLocation); + private void EnqueueTargetTypeForRootInvocation(ITypeSymbol? typeSymbol, MethodsToGen overload, BinderInvocation invocation) + { + if (!IsValidRootConfigType(typeSymbol)) + { + RecordDiagnostic(DiagnosticDescriptors.CouldNotDetermineTypeInfo, invocation.Location); + } + else + { + TypeParseInfo typeParseInfo = TypeParseInfo.Create(typeSymbol, overload, invocation, containingTypeDiagInfo: null); + _typesToParse.Enqueue(typeParseInfo); + _invocationTypeParseInfo.Add(typeParseInfo); + } } - public TypeSpec? GetTargetTypeForRootInvocationCore(ITypeSymbol type, Location? invocationLocation) + private TypeRef EnqueueTransitiveType(TypeParseInfo containingTypeParseInfo, ITypeSymbol memberTypeSymbol, DiagnosticDescriptor diagDescriptor, string? memberName = null) { - TypeSpec? spec = GetOrCreateTypeSpec(type); + TypeParseInfo memberTypeParseInfo = containingTypeParseInfo.ToTransitiveTypeParseInfo(memberTypeSymbol, diagDescriptor, memberName); - foreach (InvocationDiagnosticInfo diag in _invocationTargetTypeDiags) + if (_createdTypeSpecs.TryGetValue(memberTypeSymbol, out TypeSpec? memberTypeSpec)) { - _context.ReportDiagnostic(Diagnostic.Create(diag.Descriptor, invocationLocation, diag.MessageArgs)); + RecordTypeDiagnosticIfRequired(memberTypeParseInfo, memberTypeSpec); + return memberTypeSpec.TypeRef; } - _invocationTargetTypeDiags.Clear(); - return spec; + _typesToParse.Enqueue(memberTypeParseInfo); + return new TypeRef(memberTypeSymbol); } - private TypeSpec? GetOrCreateTypeSpec(ITypeSymbol type) + private TypeSpec CreateTypeSpec(TypeParseInfo typeParseInfo) { - if (_createdSpecs.TryGetValue(type, out TypeSpec? spec)) - { - if (_typeDiagnostics.TryGetValue(type, out HashSet? typeDiags)) - { - _invocationTargetTypeDiags.AddRange(typeDiags); - } - - return spec; - } + ITypeSymbol type = typeParseInfo.TypeSymbol; + TypeSpec spec; if (IsNullable(type, out ITypeSymbol? underlyingType)) { - spec = TryGetTypeSpec(underlyingType, Diagnostics.NullableUnderlyingTypeNotSupported, out TypeSpec? underlyingTypeSpec) - ? new NullableSpec(type, underlyingTypeSpec) - : null; + TypeRef underlyingTypeRef = EnqueueTransitiveType( + typeParseInfo, + underlyingType, + DiagnosticDescriptors.NullableUnderlyingTypeNotSupported); + + spec = new NullableSpec(type, underlyingTypeRef); } else if (IsParsableFromString(type, out StringParsableTypeKind specialTypeKind)) { ParsableFromStringSpec stringParsableSpec = new(type) { StringParsableTypeKind = specialTypeKind }; - - if (stringParsableSpec.StringParsableTypeKind is not StringParsableTypeKind.AssignFromSectionValue) - { - _sourceGenSpec.PrimitivesForHelperGen.Add(stringParsableSpec); - } - spec = stringParsableSpec; } - else if (IsSupportedArrayType(type)) + else if (type.TypeKind is TypeKind.Array) { - spec = CreateArraySpec((type as IArrayTypeSymbol)); + spec = CreateArraySpec(typeParseInfo); + Debug.Assert(spec is ArraySpec or UnsupportedTypeSpec); } else if (IsCollection(type)) { - spec = CreateCollectionSpec((INamedTypeSymbol)type); + spec = CreateCollectionSpec(typeParseInfo); } else if (SymbolEqualityComparer.Default.Equals(type, _typeSymbols.IConfigurationSection)) { spec = new ConfigurationSectionSpec(type); } - else if (type is INamedTypeSymbol namedType) + else if (type is INamedTypeSymbol) { - // List is used in generated code as a temp holder for formatting - // an error for config properties that don't map to object properties. - _sourceGenSpec.TypeNamespaces.Add("System.Collections.Generic"); - - spec = CreateObjectSpec(namedType); + spec = CreateObjectSpec(typeParseInfo); } else { - RegisterUnsupportedType(type, Diagnostics.TypeNotSupported); + spec = CreateUnsupportedTypeSpec(typeParseInfo, NotSupportedReason.UnknownType); } - foreach (InvocationDiagnosticInfo diag in _invocationTargetTypeDiags) - { - RegisterTypeDiagnostic(type, diag); - } - - if (spec is null) - { - return null; - } - - string @namespace = spec.Namespace; - if (@namespace is not null and not "") - { - _sourceGenSpec.TypeNamespaces.Add(@namespace); - } - - return _createdSpecs[type] = spec; - } + RecordTypeDiagnosticIfRequired(typeParseInfo, spec); - private void RegisterTypeForMethodGen(MethodsToGen_CoreBindingHelper method, TypeSpec type) - { - if (!_sourceGenSpec.TypesForGen_CoreBindingHelper_Methods.TryGetValue(method, out HashSet? types)) - { - _sourceGenSpec.TypesForGen_CoreBindingHelper_Methods[method] = types = new HashSet(); - } - - types.Add(type); - _sourceGenSpec.MethodsToGen_CoreBindingHelper |= method; - } - - private void RegisterTypeForBindCoreUntypedGen(TypeSpec typeSpec) - { - if (typeSpec.NeedsMemberBinding) - { - RegisterTypeForMethodGen(MethodsToGen_CoreBindingHelper.BindCore, typeSpec); - RegisterTypeForMethodGen(MethodsToGen_CoreBindingHelper.BindCoreUntyped, typeSpec); - } + return spec; } private static bool IsNullable(ITypeSymbol type, [NotNullWhen(true)] out ITypeSymbol? underlyingType) @@ -317,265 +327,197 @@ private bool IsParsableFromString(ITypeSymbol type, out StringParsableTypeKind t } } - private bool TryGetTypeSpec(ITypeSymbol type, DiagnosticDescriptor descriptor, out TypeSpec? spec) + private TypeSpec CreateArraySpec(TypeParseInfo typeParseInfo) { - spec = GetOrCreateTypeSpec(type); + IArrayTypeSymbol typeSymbol = (IArrayTypeSymbol)typeParseInfo.TypeSymbol; - if (spec is null) + if (typeSymbol.Rank > 1) { - RegisterUnsupportedType(type, descriptor); - return false; + return CreateUnsupportedTypeSpec(typeParseInfo, NotSupportedReason.MultiDimArraysNotSupported); } - return true; - } - - private EnumerableSpec? CreateArraySpec(IArrayTypeSymbol arrayType) - { - if (!TryGetTypeSpec(arrayType.ElementType, Diagnostics.ElementTypeNotSupported, out TypeSpec elementSpec)) - { - return null; - } - - // We want a BindCore method for List as a temp holder for the array values. We know the element type is supported. - EnumerableSpec listSpec = (GetOrCreateTypeSpec(_typeSymbols.List.Construct(arrayType.ElementType)) as EnumerableSpec)!; - RegisterTypeForMethodGen(MethodsToGen_CoreBindingHelper.BindCore, listSpec); + TypeRef elementTypeRef = EnqueueTransitiveType( + typeParseInfo, + typeSymbol.ElementType, + DiagnosticDescriptors.ElementTypeNotSupported); - EnumerableSpec spec = new EnumerableSpec(arrayType) + return new ArraySpec(typeSymbol) { - ElementType = elementSpec, - ConcreteType = listSpec, - InitializationStrategy = InitializationStrategy.Array, - PopulationStrategy = CollectionPopulationStrategy.Cast_Then_Add, // Using the concrete list type as a temp holder. - ToEnumerableMethodCall = null, + ElementTypeRef = elementTypeRef, }; - - Debug.Assert(spec.CanInitialize); - RegisterTypeForMethodGen(MethodsToGen_CoreBindingHelper.BindCore, spec); - - return spec; } - private bool IsSupportedArrayType(ITypeSymbol type) + private TypeSpec CreateCollectionSpec(TypeParseInfo typeParseInfo) { - if (type is not IArrayTypeSymbol arrayType) - { - return false; - } + INamedTypeSymbol type = (INamedTypeSymbol)typeParseInfo.TypeSymbol; - if (arrayType.Rank > 1) + TypeSpec spec; + if (IsCandidateDictionary(type, out ITypeSymbol? keyType, out ITypeSymbol? elementType)) { - RegisterUnsupportedType(arrayType, Diagnostics.MultiDimArraysNotSupported); - return false; - } - - return true; - } - - private CollectionSpec? CreateCollectionSpec(INamedTypeSymbol type) - { - CollectionSpec? spec; - if (IsCandidateDictionary(type, out ITypeSymbol keyType, out ITypeSymbol elementType)) - { - spec = CreateDictionarySpec(type, keyType, elementType); - Debug.Assert(spec is null or DictionarySpec { KeyType: null or ParsableFromStringSpec }); + spec = CreateDictionarySpec(typeParseInfo, keyType, elementType); + Debug.Assert(spec is DictionarySpec or UnsupportedTypeSpec); } else { - spec = CreateEnumerableSpec(type); - } - - if (spec is not null) - { - RegisterTypeForMethodGen(MethodsToGen_CoreBindingHelper.BindCore, spec); - spec.InitExceptionMessage ??= spec.ElementType.InitExceptionMessage; + spec = CreateEnumerableSpec(typeParseInfo); + Debug.Assert(spec is EnumerableSpec or UnsupportedTypeSpec); } return spec; } - private DictionarySpec CreateDictionarySpec(INamedTypeSymbol type, ITypeSymbol keyType, ITypeSymbol elementType) + private TypeSpec CreateDictionarySpec(TypeParseInfo typeParseInfo, ITypeSymbol keyTypeSymbol, ITypeSymbol elementTypeSymbol) { - if (!TryGetTypeSpec(keyType, Diagnostics.DictionaryKeyNotSupported, out TypeSpec keySpec) || - !TryGetTypeSpec(elementType, Diagnostics.ElementTypeNotSupported, out TypeSpec elementSpec)) - { - return null; - } + INamedTypeSymbol type = (INamedTypeSymbol)typeParseInfo.TypeSymbol; - if (keySpec.SpecKind != TypeSpecKind.ParsableFromString) - { - RegisterUnsupportedType(type, Diagnostics.DictionaryKeyNotSupported); - return null; - } - - InitializationStrategy constructionStrategy; - CollectionPopulationStrategy populationStrategy; - INamedTypeSymbol? concreteType = null; - INamedTypeSymbol? populationCastType = null; - string? toEnumerableMethodCall = null; + CollectionInstantiationStrategy instantiationStrategy; + CollectionInstantiationConcreteType instantiationConcreteType; + CollectionPopulationCastType populationCastType; if (HasPublicParameterLessCtor(type)) { - constructionStrategy = InitializationStrategy.ParameterlessConstructor; + instantiationStrategy = CollectionInstantiationStrategy.ParameterlessConstructor; + instantiationConcreteType = CollectionInstantiationConcreteType.Self; - if (HasAddMethod(type, keyType, elementType)) + if (HasAddMethod(type, keyTypeSymbol, elementTypeSymbol)) { - populationStrategy = CollectionPopulationStrategy.Add; + populationCastType = CollectionPopulationCastType.NotApplicable; } - else if (GetInterface(type, _typeSymbols.GenericIDictionary_Unbound) is not null) + else if (_typeSymbols.GenericIDictionary is not null && GetInterface(type, _typeSymbols.GenericIDictionary_Unbound) is not null) { - populationCastType = _typeSymbols.GenericIDictionary; - populationStrategy = CollectionPopulationStrategy.Cast_Then_Add; + populationCastType = CollectionPopulationCastType.IDictionary; } else { - RegisterUnsupportedType(type, Diagnostics.CollectionNotSupported); - return null; + return CreateUnsupportedCollectionSpec(typeParseInfo); } } - else if (IsInterfaceMatch(type, _typeSymbols.GenericIDictionary_Unbound) || IsInterfaceMatch(type, _typeSymbols.IDictionary)) + else if (_typeSymbols.Dictionary is not null && + (IsInterfaceMatch(type, _typeSymbols.GenericIDictionary_Unbound) || IsInterfaceMatch(type, _typeSymbols.IDictionary))) { - concreteType = _typeSymbols.Dictionary; - constructionStrategy = InitializationStrategy.ParameterlessConstructor; - populationStrategy = CollectionPopulationStrategy.Add; + instantiationStrategy = CollectionInstantiationStrategy.ParameterlessConstructor; + instantiationConcreteType = CollectionInstantiationConcreteType.Dictionary; + populationCastType = CollectionPopulationCastType.NotApplicable; } - else if (IsInterfaceMatch(type, _typeSymbols.IReadOnlyDictionary_Unbound)) + else if (_typeSymbols.Dictionary is not null && IsInterfaceMatch(type, _typeSymbols.IReadOnlyDictionary_Unbound)) { - concreteType = _typeSymbols.Dictionary; - populationCastType = _typeSymbols.GenericIDictionary; - constructionStrategy = InitializationStrategy.ToEnumerableMethod; - populationStrategy = CollectionPopulationStrategy.Cast_Then_Add; - toEnumerableMethodCall = "ToDictionary(pair => pair.Key, pair => pair.Value)"; - _sourceGenSpec.TypeNamespaces.Add("System.Linq"); + instantiationStrategy = CollectionInstantiationStrategy.LinqToDictionary; + instantiationConcreteType = CollectionInstantiationConcreteType.Dictionary; + populationCastType = CollectionPopulationCastType.IDictionary; } else { - RegisterUnsupportedType(type, Diagnostics.CollectionNotSupported); - return null; + return CreateUnsupportedCollectionSpec(typeParseInfo); } - DictionarySpec spec = new(type) + TypeRef keyTypeRef = EnqueueTransitiveType(typeParseInfo, keyTypeSymbol, DiagnosticDescriptors.DictionaryKeyNotSupported); + TypeRef elementTypeRef = EnqueueTransitiveType(typeParseInfo, elementTypeSymbol, DiagnosticDescriptors.ElementTypeNotSupported); + + return new DictionarySpec(type) { - KeyType = (ParsableFromStringSpec)keySpec, - ElementType = elementSpec, - InitializationStrategy = constructionStrategy, - PopulationStrategy = populationStrategy, - ToEnumerableMethodCall = toEnumerableMethodCall, + KeyTypeRef = keyTypeRef, + ElementTypeRef = elementTypeRef, + InstantiationStrategy = instantiationStrategy, + InstantiationConcreteType = instantiationConcreteType, + PopulationCastType = populationCastType, }; - - Debug.Assert(!(populationStrategy is CollectionPopulationStrategy.Cast_Then_Add && populationCastType is null)); - spec.ConcreteType = ConstructGenericCollectionSpecIfRequired(concreteType, keyType, elementType); - spec.PopulationCastType = ConstructGenericCollectionSpecIfRequired(populationCastType, keyType, elementType); - - return spec; } - private EnumerableSpec? CreateEnumerableSpec(INamedTypeSymbol type) + private TypeSpec CreateEnumerableSpec(TypeParseInfo typeParseInfo) { - if (!TryGetElementType(type, out ITypeSymbol? elementType) || - !TryGetTypeSpec(elementType, Diagnostics.ElementTypeNotSupported, out TypeSpec elementSpec)) + INamedTypeSymbol type = (INamedTypeSymbol)typeParseInfo.TypeSymbol; + + if (!TryGetElementType(type, out ITypeSymbol? elementType)) { - return null; + return CreateUnsupportedCollectionSpec(typeParseInfo); } - InitializationStrategy constructionStrategy; - CollectionPopulationStrategy populationStrategy; - INamedTypeSymbol? concreteType = null; - INamedTypeSymbol? populationCastType = null; + CollectionInstantiationStrategy instantiationStrategy; + CollectionInstantiationConcreteType instantiationConcreteType; + CollectionPopulationCastType populationCastType; if (HasPublicParameterLessCtor(type)) { - constructionStrategy = InitializationStrategy.ParameterlessConstructor; + instantiationStrategy = CollectionInstantiationStrategy.ParameterlessConstructor; + instantiationConcreteType = CollectionInstantiationConcreteType.Self; if (HasAddMethod(type, elementType)) { - populationStrategy = CollectionPopulationStrategy.Add; + populationCastType = CollectionPopulationCastType.NotApplicable; } - else if (GetInterface(type, _typeSymbols.GenericICollection_Unbound) is not null) + else if (_typeSymbols.GenericICollection is not null && GetInterface(type, _typeSymbols.GenericICollection_Unbound) is not null) { - populationCastType = _typeSymbols.GenericICollection; - populationStrategy = CollectionPopulationStrategy.Cast_Then_Add; + populationCastType = CollectionPopulationCastType.ICollection; } else { - RegisterUnsupportedType(type, Diagnostics.CollectionNotSupported); - return null; + return CreateUnsupportedCollectionSpec(typeParseInfo); } } - else if (IsInterfaceMatch(type, _typeSymbols.GenericICollection_Unbound) || - IsInterfaceMatch(type, _typeSymbols.GenericIList_Unbound)) + else if ((IsInterfaceMatch(type, _typeSymbols.GenericICollection_Unbound) || IsInterfaceMatch(type, _typeSymbols.GenericIList_Unbound))) { - concreteType = _typeSymbols.List; - constructionStrategy = InitializationStrategy.ParameterlessConstructor; - populationStrategy = CollectionPopulationStrategy.Add; + instantiationStrategy = CollectionInstantiationStrategy.ParameterlessConstructor; + instantiationConcreteType = CollectionInstantiationConcreteType.List; + populationCastType = CollectionPopulationCastType.NotApplicable; } else if (IsInterfaceMatch(type, _typeSymbols.GenericIEnumerable_Unbound)) { - concreteType = _typeSymbols.List; - populationCastType = _typeSymbols.GenericICollection; - constructionStrategy = InitializationStrategy.ParameterizedConstructor; - populationStrategy = CollectionPopulationStrategy.Cast_Then_Add; + instantiationStrategy = CollectionInstantiationStrategy.CopyConstructor; + instantiationConcreteType = CollectionInstantiationConcreteType.List; + populationCastType = CollectionPopulationCastType.ICollection; } else if (IsInterfaceMatch(type, _typeSymbols.ISet_Unbound)) { - concreteType = _typeSymbols.HashSet; - constructionStrategy = InitializationStrategy.ParameterlessConstructor; - populationStrategy = CollectionPopulationStrategy.Add; + instantiationStrategy = CollectionInstantiationStrategy.ParameterlessConstructor; + instantiationConcreteType = CollectionInstantiationConcreteType.HashSet; + populationCastType = CollectionPopulationCastType.NotApplicable; } else if (IsInterfaceMatch(type, _typeSymbols.IReadOnlySet_Unbound)) { - concreteType = _typeSymbols.HashSet; - populationCastType = _typeSymbols.ISet; - constructionStrategy = InitializationStrategy.ParameterizedConstructor; - populationStrategy = CollectionPopulationStrategy.Cast_Then_Add; + instantiationStrategy = CollectionInstantiationStrategy.CopyConstructor; + instantiationConcreteType = CollectionInstantiationConcreteType.HashSet; + populationCastType = CollectionPopulationCastType.ISet; } else if (IsInterfaceMatch(type, _typeSymbols.IReadOnlyList_Unbound) || IsInterfaceMatch(type, _typeSymbols.IReadOnlyCollection_Unbound)) { - concreteType = _typeSymbols.List; - populationCastType = _typeSymbols.GenericICollection; - constructionStrategy = InitializationStrategy.ParameterizedConstructor; - populationStrategy = CollectionPopulationStrategy.Cast_Then_Add; + instantiationStrategy = CollectionInstantiationStrategy.CopyConstructor; + instantiationConcreteType = CollectionInstantiationConcreteType.List; + populationCastType = CollectionPopulationCastType.ICollection; } else { - RegisterUnsupportedType(type, Diagnostics.CollectionNotSupported); - return null; + return CreateUnsupportedCollectionSpec(typeParseInfo); } - Register_AsConfigWithChildren_HelperForGen_IfRequired(elementSpec); + TypeRef elementTypeRef = EnqueueTransitiveType(typeParseInfo, elementType, DiagnosticDescriptors.ElementTypeNotSupported); - EnumerableSpec spec = new(type) + return new EnumerableSpec(type) { - ElementType = elementSpec, - InitializationStrategy = constructionStrategy, - PopulationStrategy = populationStrategy, - ToEnumerableMethodCall = null, + ElementTypeRef = elementTypeRef, + InstantiationStrategy = instantiationStrategy, + InstantiationConcreteType = instantiationConcreteType, + PopulationCastType = populationCastType, }; - - Debug.Assert(!(populationStrategy is CollectionPopulationStrategy.Cast_Then_Add && populationCastType is null)); - spec.ConcreteType = ConstructGenericCollectionSpecIfRequired(concreteType, elementType); - spec.PopulationCastType = ConstructGenericCollectionSpecIfRequired(populationCastType, elementType); - - return spec; } - private ObjectSpec? CreateObjectSpec(INamedTypeSymbol type) + private ObjectSpec CreateObjectSpec(TypeParseInfo typeParseInfo) { - // Add spec to cache before traversing properties to avoid stack overflow. - ObjectSpec objectSpec = new(type); - _createdSpecs.Add(type, objectSpec); + INamedTypeSymbol typeSymbol = (INamedTypeSymbol)typeParseInfo.TypeSymbol; + string typeName = typeSymbol.GetTypeName().Name; + + ObjectInstantiationStrategy initializationStrategy = ObjectInstantiationStrategy.None; + DiagnosticDescriptor? initDiagDescriptor = null; + string? initExceptionMessage = null; - string typeName = objectSpec.Name; IMethodSymbol? ctor = null; - DiagnosticDescriptor? diagnosticDescriptor = null; - if (!(type.IsAbstract || type.TypeKind is TypeKind.Interface)) + if (!(typeSymbol.IsAbstract || typeSymbol.TypeKind is TypeKind.Interface)) { IMethodSymbol? parameterlessCtor = null; IMethodSymbol? parameterizedCtor = null; bool hasMultipleParameterizedCtors = false; - foreach (IMethodSymbol candidate in type.InstanceConstructors) + foreach (IMethodSymbol candidate in typeSymbol.InstanceConstructors) { if (candidate.DeclaredAccessibility is not Accessibility.Public) { @@ -596,35 +538,38 @@ private DictionarySpec CreateDictionarySpec(INamedTypeSymbol type, ITypeSymbol k } } - bool hasPublicParameterlessCtor = type.IsValueType || parameterlessCtor is not null; + bool hasPublicParameterlessCtor = typeSymbol.IsValueType || parameterlessCtor is not null; if (!hasPublicParameterlessCtor && hasMultipleParameterizedCtors) { - diagnosticDescriptor = Diagnostics.MultipleParameterizedConstructors; - objectSpec.InitExceptionMessage = string.Format(Emitter.ExceptionMessages.MultipleParameterizedConstructors, typeName); + initDiagDescriptor = DiagnosticDescriptors.MultipleParameterizedConstructors; + initExceptionMessage = string.Format(Emitter.ExceptionMessages.MultipleParameterizedConstructors, typeName); } - ctor = type.IsValueType + ctor = typeSymbol.IsValueType // Roslyn ctor fetching APIs include paramerterless ctors for structs, unlike System.Reflection. ? parameterizedCtor ?? parameterlessCtor : parameterlessCtor ?? parameterizedCtor; } - objectSpec.InitializationStrategy = ctor?.Parameters.Length is 0 ? InitializationStrategy.ParameterlessConstructor : InitializationStrategy.ParameterizedConstructor; - if (ctor is null) { - diagnosticDescriptor = Diagnostics.MissingPublicInstanceConstructor; - objectSpec.InitExceptionMessage = string.Format(Emitter.ExceptionMessages.MissingPublicInstanceConstructor, typeName); + initDiagDescriptor = DiagnosticDescriptors.MissingPublicInstanceConstructor; + initExceptionMessage = string.Format(Emitter.ExceptionMessages.MissingPublicInstanceConstructor, typeName); + } + else + { + initializationStrategy = ctor.Parameters.Length is 0 ? ObjectInstantiationStrategy.ParameterlessConstructor : ObjectInstantiationStrategy.ParameterizedConstructor; } - if (diagnosticDescriptor is not null) + if (initDiagDescriptor is not null) { - Debug.Assert(objectSpec.InitExceptionMessage is not null); - RegisterUnsupportedType(type, diagnosticDescriptor); - return objectSpec; + Debug.Assert(initExceptionMessage is not null); + RecordTypeDiagnostic(typeParseInfo, initDiagDescriptor); } - INamedTypeSymbol current = type; + Dictionary? properties = null; + + INamedTypeSymbol? current = typeSymbol; while (current is not null) { ImmutableArray members = current.GetMembers(); @@ -633,106 +578,90 @@ private DictionarySpec CreateDictionarySpec(INamedTypeSymbol type, ITypeSymbol k if (member is IPropertySymbol { IsIndexer: false, IsImplicitlyDeclared: false } property) { string propertyName = property.Name; - TypeSpec? propertyTypeSpec = GetOrCreateTypeSpec(property.Type); + TypeRef propertyTypeRef = EnqueueTransitiveType(typeParseInfo, property.Type, DiagnosticDescriptors.PropertyNotSupported, propertyName); - if (propertyTypeSpec?.CanInitialize is not true) - { - InvocationDiagnosticInfo propertyDiagnostic = new InvocationDiagnosticInfo(Diagnostics.PropertyNotSupported, new string[] { propertyName, type.ToDisplayString() }); - RegisterTypeDiagnostic(causingType: type, propertyDiagnostic); - _invocationTargetTypeDiags.Add(propertyDiagnostic); - } + AttributeData? attributeData = property.GetAttributes().FirstOrDefault(a => SymbolEqualityComparer.Default.Equals(a.AttributeClass, _typeSymbols.ConfigurationKeyNameAttribute)); + string configKeyName = attributeData?.ConstructorArguments.FirstOrDefault().Value as string ?? propertyName; - if (propertyTypeSpec is not null) + PropertySpec spec = new(property) { - AttributeData? attributeData = property.GetAttributes().FirstOrDefault(a => SymbolEqualityComparer.Default.Equals(a.AttributeClass, _typeSymbols.ConfigurationKeyNameAttribute)); - string configKeyName = attributeData?.ConstructorArguments.FirstOrDefault().Value as string ?? propertyName; + TypeRef = propertyTypeRef, + ConfigurationKeyName = configKeyName + }; - PropertySpec spec = new(property) { Type = propertyTypeSpec, ConfigurationKeyName = configKeyName }; - objectSpec.Properties[propertyName] = spec; - Register_AsConfigWithChildren_HelperForGen_IfRequired(propertyTypeSpec); - } + (properties ??= new(StringComparer.OrdinalIgnoreCase))[propertyName] = spec; } } current = current.BaseType; } - if (objectSpec.InitializationStrategy is InitializationStrategy.ParameterizedConstructor) + List? ctorParams = null; + + if (initializationStrategy is ObjectInstantiationStrategy.ParameterizedConstructor) { - List missingParameters = new(); - List invalidParameters = new(); + Debug.Assert(ctor is not null); + List? missingParameters = null; + List? invalidParameters = null; foreach (IParameterSymbol parameter in ctor.Parameters) { string parameterName = parameter.Name; - if (!objectSpec.Properties.TryGetValue(parameterName, out PropertySpec? propertySpec)) + if (properties?.TryGetValue(parameterName, out PropertySpec? propertySpec) is not true) { - missingParameters.Add(parameterName); + (missingParameters ??= new()).Add(parameterName); } else if (parameter.RefKind is not RefKind.None) { - invalidParameters.Add(parameterName); + (invalidParameters ??= new()).Add(parameterName); } else { ParameterSpec paramSpec = new ParameterSpec(parameter) { - Type = propertySpec.Type, + TypeRef = propertySpec.TypeRef, ConfigurationKeyName = propertySpec.ConfigurationKeyName, }; propertySpec.MatchingCtorParam = paramSpec; - objectSpec.ConstructorParameters.Add(paramSpec); + (ctorParams ??= new()).Add(paramSpec); } } - if (invalidParameters.Count > 0) + if (invalidParameters?.Count > 0) { - objectSpec.InitExceptionMessage = string.Format(Emitter.ExceptionMessages.CannotBindToConstructorParameter, typeName, FormatParams(invalidParameters)); + initExceptionMessage = string.Format(Emitter.ExceptionMessages.CannotBindToConstructorParameter, typeName, FormatParams(invalidParameters)); } - else if (missingParameters.Count > 0) + else if (missingParameters?.Count > 0) { - if (type.IsValueType) + if (typeSymbol.IsValueType) { - objectSpec.InitializationStrategy = InitializationStrategy.ParameterlessConstructor; + initializationStrategy = ObjectInstantiationStrategy.ParameterlessConstructor; } else { - objectSpec.InitExceptionMessage = string.Format(Emitter.ExceptionMessages.ConstructorParametersDoNotMatchProperties, typeName, FormatParams(missingParameters)); + initExceptionMessage = string.Format(Emitter.ExceptionMessages.ConstructorParametersDoNotMatchProperties, typeName, FormatParams(missingParameters)); } } - if (objectSpec.CanInitialize) - { - RegisterTypeForMethodGen(MethodsToGen_CoreBindingHelper.Initialize, objectSpec); - } - static string FormatParams(List names) => string.Join(",", names); } - Debug.Assert((objectSpec.CanInitialize && objectSpec.InitExceptionMessage is null) || - (!objectSpec.CanInitialize && objectSpec.InitExceptionMessage is not null)); - - if (objectSpec.NeedsMemberBinding) - { - RegisterTypeForMethodGen(MethodsToGen_CoreBindingHelper.BindCore, objectSpec); - } - - return objectSpec; + return new ObjectSpec( + typeSymbol, + initializationStrategy, + properties: properties?.Values.ToImmutableEquatableArray(), + constructorParameters: ctorParams?.ToImmutableEquatableArray(), + initExceptionMessage); } - private void Register_AsConfigWithChildren_HelperForGen_IfRequired(TypeSpec type) - { - if (type.SpecKind is TypeSpecKind.Object or - TypeSpecKind.Enumerable or - TypeSpecKind.Dictionary) - { + private static UnsupportedTypeSpec CreateUnsupportedCollectionSpec(TypeParseInfo typeParseInfo) + => CreateUnsupportedTypeSpec(typeParseInfo, NotSupportedReason.CollectionNotSupported); - _sourceGenSpec.MethodsToGen_CoreBindingHelper |= MethodsToGen_CoreBindingHelper.AsConfigWithChildren; - } - } + private static UnsupportedTypeSpec CreateUnsupportedTypeSpec(TypeParseInfo typeParseInfo, NotSupportedReason reason) => + new(typeParseInfo.TypeSymbol) { NotSupportedReason = reason }; - private bool TryGetElementType(INamedTypeSymbol type, out ITypeSymbol? elementType) + private bool TryGetElementType(INamedTypeSymbol type, [NotNullWhen(true)] out ITypeSymbol? elementType) { INamedTypeSymbol? candidate = GetInterface(type, _typeSymbols.GenericIEnumerable_Unbound); @@ -746,7 +675,7 @@ private bool TryGetElementType(INamedTypeSymbol type, out ITypeSymbol? elementTy return false; } - private bool IsCandidateDictionary(INamedTypeSymbol type, out ITypeSymbol? keyType, out ITypeSymbol? elementType) + private bool IsCandidateDictionary(INamedTypeSymbol type, [NotNullWhen(true)] out ITypeSymbol? keyType, [NotNullWhen(true)] out ITypeSymbol? elementType) { INamedTypeSymbol? candidate = GetInterface(type, _typeSymbols.GenericIDictionary_Unbound) ?? GetInterface(type, _typeSymbols.IReadOnlyDictionary_Unbound); @@ -772,8 +701,13 @@ private bool IsCandidateDictionary(INamedTypeSymbol type, out ITypeSymbol? keyTy private bool IsCollection(ITypeSymbol type) => type is INamedTypeSymbol namedType && GetInterface(namedType, _typeSymbols.IEnumerable) is not null; - private static INamedTypeSymbol? GetInterface(INamedTypeSymbol type, INamedTypeSymbol @interface) + private static INamedTypeSymbol? GetInterface(INamedTypeSymbol type, INamedTypeSymbol? @interface) { + if (@interface is null) + { + return null; + } + if (IsInterfaceMatch(type, @interface)) { return type; @@ -790,8 +724,13 @@ private bool IsCollection(ITypeSymbol type) => return type.AllInterfaces.FirstOrDefault(candidate => SymbolEqualityComparer.Default.Equals(candidate, @interface)); } - private static bool IsInterfaceMatch(INamedTypeSymbol type, INamedTypeSymbol @interface) + private static bool IsInterfaceMatch(INamedTypeSymbol type, INamedTypeSymbol? @interface) { + if (@interface is null) + { + return false; + } + if (type.IsGenericType) { INamedTypeSymbol unbound = type.ConstructUnboundGenericType(); @@ -825,8 +764,8 @@ private static bool HasPublicParameterLessCtor(INamedTypeSymbol type) => private static bool HasAddMethod(INamedTypeSymbol type, ITypeSymbol element) { - INamedTypeSymbol current = type; - while (current != null) + INamedTypeSymbol? current = type; + while (current is not null) { if (current.GetMembers("Add").Any(member => member is IMethodSymbol { Parameters.Length: 1 } method && @@ -841,8 +780,8 @@ private static bool HasAddMethod(INamedTypeSymbol type, ITypeSymbol element) private static bool HasAddMethod(INamedTypeSymbol type, ITypeSymbol key, ITypeSymbol element) { - INamedTypeSymbol current = type; - while (current != null) + INamedTypeSymbol? current = type; + while (current is not null) { if (current.GetMembers("Add").Any(member => member is IMethodSymbol { Parameters.Length: 2 } method && @@ -858,40 +797,51 @@ private static bool HasAddMethod(INamedTypeSymbol type, ITypeSymbol key, ITypeSy private static bool IsEnum(ITypeSymbol type) => type is INamedTypeSymbol { EnumUnderlyingType: INamedTypeSymbol { } }; - private CollectionSpec? ConstructGenericCollectionSpecIfRequired(INamedTypeSymbol? collectionType, params ITypeSymbol[] parameters) => - (collectionType is not null ? ConstructGenericCollectionSpec(collectionType, parameters) : null); - - private CollectionSpec? ConstructGenericCollectionSpec(INamedTypeSymbol type, params ITypeSymbol[] parameters) - { - Debug.Assert(type.IsGenericType); - INamedTypeSymbol constructedType = type.Construct(parameters); - return CreateCollectionSpec(constructedType); - } - - private void RegisterUnsupportedType(ITypeSymbol type, DiagnosticDescriptor descriptor = null) + private void RecordTypeDiagnosticIfRequired(TypeParseInfo typeParseInfo, TypeSpec typeSpec) { - InvocationDiagnosticInfo diagInfo = new(descriptor, new string[] { type.ToDisplayString() }); + ContainingTypeDiagnosticInfo? containingTypeDiagInfo = typeParseInfo.ContainingTypeDiagnosticInfo; - if (!_unsupportedTypes.Contains(type)) + if (typeSpec is UnsupportedTypeSpec unsupportedTypeSpec) + { + DiagnosticDescriptor descriptor = DiagnosticDescriptors.GetNotSupportedDescriptor(unsupportedTypeSpec.NotSupportedReason); + RecordTypeDiagnostic(typeParseInfo, descriptor); + } + else if (containingTypeDiagInfo?.Descriptor == DiagnosticDescriptors.DictionaryKeyNotSupported && + typeSpec is not ParsableFromStringSpec) { - RegisterTypeDiagnostic(type, diagInfo); - _unsupportedTypes.Add(type); + ReportContainingTypeDiagnosticIfRequired(typeParseInfo); } + } - _invocationTargetTypeDiags.Add(diagInfo); + private void RecordTypeDiagnostic(TypeParseInfo typeParseInfo, DiagnosticDescriptor descriptor) + { + RecordDiagnostic(descriptor, typeParseInfo.BinderInvocation.Location, new object?[] { typeParseInfo.TypeName }); + ReportContainingTypeDiagnosticIfRequired(typeParseInfo); } - private void RegisterTypeDiagnostic(ITypeSymbol causingType, InvocationDiagnosticInfo info) + private void ReportContainingTypeDiagnosticIfRequired(TypeParseInfo typeParseInfo) { - bool typeHadDiags = _typeDiagnostics.TryGetValue(causingType, out HashSet? typeDiags); - typeDiags ??= new HashSet(); - typeDiags.Add(info); + ContainingTypeDiagnosticInfo? containingTypeDiagInfo = typeParseInfo.ContainingTypeDiagnosticInfo; - if (!typeHadDiags) + while (containingTypeDiagInfo is not null) { - _typeDiagnostics[causingType] = typeDiags; + string containingTypeName = containingTypeDiagInfo.TypeName; + + object[] messageArgs = containingTypeDiagInfo.MemberName is string memberName + ? new[] { memberName, containingTypeName } + : new[] { containingTypeName }; + + RecordDiagnostic(containingTypeDiagInfo.Descriptor, typeParseInfo.BinderInvocation.Location, messageArgs); + + containingTypeDiagInfo = containingTypeDiagInfo.ContainingTypeInfo; } } + + private void RecordDiagnostic(DiagnosticDescriptor descriptor, Location trimmedLocation, params object?[]? messageArgs) + { + Diagnostics ??= new List(); + Diagnostics.Add(DiagnosticInfo.Create(descriptor, trimmedLocation, messageArgs)); + } } } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/ConfigurationBindingGenerator.Suppressor.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/ConfigurationBindingGenerator.Suppressor.cs new file mode 100644 index 00000000000000..13158753c3f075 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/ConfigurationBindingGenerator.Suppressor.cs @@ -0,0 +1,70 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Immutable; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Diagnostics; +using Microsoft.CodeAnalysis.Operations; + +namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration +{ + public sealed partial class ConfigurationBindingGenerator + { + /// + /// Supresses false-positive diagnostics emitted by the linker + /// when analyzing binding invocations that we have intercepted. + /// Workaround for https://github.com/dotnet/roslyn/issues/68669. + /// + [DiagnosticAnalyzer(LanguageNames.CSharp)] + public sealed class Suppressor : DiagnosticSuppressor + { + private const string Justification = "The target method has been intercepted by a generated static variant."; + + /// + /// Suppression descriptor for IL2026: Members attributed with RequiresUnreferencedCode may break when trimming. + /// + private static readonly SuppressionDescriptor RUCDiagnostic = new(id: "SYSLIBSUPPRESS0002", suppressedDiagnosticId: "IL2026", Justification); + + /// + /// Suppression descriptor for IL3050: Avoid calling members annotated with 'RequiresDynamicCodeAttribute' when publishing as native AOT. + /// + private static readonly SuppressionDescriptor RDCDiagnostic = new(id: "SYSLIBSUPPRESS0003", suppressedDiagnosticId: "IL3050", Justification); + + public override ImmutableArray SupportedSuppressions => ImmutableArray.Create(RUCDiagnostic, RDCDiagnostic); + + public override void ReportSuppressions(SuppressionAnalysisContext context) + { + foreach (Diagnostic diagnostic in context.ReportedDiagnostics) + { + string diagnosticId = diagnostic.Id; + + if (diagnosticId != RDCDiagnostic.SuppressedDiagnosticId && diagnosticId != RUCDiagnostic.SuppressedDiagnosticId) + { + continue; + } + + Location location = diagnostic.AdditionalLocations.Count > 0 + ? diagnostic.AdditionalLocations[0] + : diagnostic.Location; + + bool shouldSuppressDiagnostic = + location.SourceTree is SyntaxTree sourceTree && + sourceTree.GetRoot().FindNode(location.SourceSpan) is SyntaxNode syntaxNode && + BinderInvocation.IsCandidateSyntaxNode(syntaxNode) && + context.GetSemanticModel(sourceTree) + .GetOperation((InvocationExpressionSyntax)syntaxNode, context.CancellationToken) is IInvocationOperation operation && + BinderInvocation.IsBindingOperation(operation); + + if (shouldSuppressDiagnostic) + { + SuppressionDescriptor targetSuppression = diagnosticId == RUCDiagnostic.SuppressedDiagnosticId + ? RUCDiagnostic + : RDCDiagnostic; + context.ReportSuppression(Suppression.Create(targetSuppression, diagnostic)); + } + } + } + } + } +} diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/ConfigurationBindingGenerator.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/ConfigurationBindingGenerator.cs index 5d7a830c729542..ec4b234a61045c 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/ConfigurationBindingGenerator.cs +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/ConfigurationBindingGenerator.cs @@ -2,10 +2,10 @@ // The .NET Foundation licenses this file to you under the MIT license. //#define LAUNCH_DEBUGGER -using System.Collections.Immutable; +using System; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; -using Microsoft.CodeAnalysis.CSharp.Syntax; +using SourceGenerators; namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration { @@ -15,7 +15,9 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration [Generator] public sealed partial class ConfigurationBindingGenerator : IIncrementalGenerator { - internal const string ProjectName = "Microsoft.Extensions.Configuration.Binder.SourceGeneration"; + private static readonly string ProjectName = Emitter.s_assemblyName.Name!; + + public const string GenSpecTrackingName = nameof(SourceGenerationSpec); public void Initialize(IncrementalGeneratorInitializationContext context) { @@ -31,49 +33,72 @@ public void Initialize(IncrementalGeneratorInitializationContext context) ? new CompilationData((CSharpCompilation)compilation) : null); - IncrementalValuesProvider inputCalls = context.SyntaxProvider + IncrementalValueProvider<(SourceGenerationSpec?, ImmutableEquatableArray?)> genSpec = context.SyntaxProvider .CreateSyntaxProvider( - (node, _) => node is InvocationExpressionSyntax invocation, + (node, _) => BinderInvocation.IsCandidateSyntaxNode(node), BinderInvocation.Create) - .Where(operation => operation is not null); + .Where(invocation => invocation is not null) + .Collect() + .Combine(compilationData) + .Select((tuple, cancellationToken) => + { + if (tuple.Right is not CompilationData compilationData) + { + return (null, null); + } - IncrementalValueProvider<(CompilationData?, ImmutableArray)> inputData = compilationData.Combine(inputCalls.Collect()); + try + { + Parser parser = new(compilationData); + SourceGenerationSpec? spec = parser.GetSourceGenerationSpec(tuple.Left, cancellationToken); + ImmutableEquatableArray? diagnostics = parser.Diagnostics?.ToImmutableEquatableArray(); + return (spec, diagnostics); + } + catch (Exception ex) + { + throw ex; + } + }) + .WithTrackingName(GenSpecTrackingName); - context.RegisterSourceOutput(inputData, (spc, source) => Execute(source.Item1, source.Item2, spc)); + context.RegisterSourceOutput(genSpec, ReportDiagnosticsAndEmitSource); } /// - /// Generates source code to optimize binding with ConfigurationBinder. + /// Instrumentation helper for unit tests. /// - private static void Execute(CompilationData compilationData, ImmutableArray inputCalls, SourceProductionContext context) - { - if (inputCalls.IsDefaultOrEmpty) - { - return; - } + public Action? OnSourceEmitting { get; init; } - if (compilationData?.LanguageVersionIsSupported is not true) + private void ReportDiagnosticsAndEmitSource(SourceProductionContext sourceProductionContext, (SourceGenerationSpec? SourceGenerationSpec, ImmutableEquatableArray? Diagnostics) input) + { + if (input.Diagnostics is ImmutableEquatableArray diagnostics) { - context.ReportDiagnostic(Diagnostic.Create(Parser.Diagnostics.LanguageVersionNotSupported, location: null)); - return; + foreach (DiagnosticInfo diagnostic in diagnostics) + { + sourceProductionContext.ReportDiagnostic(diagnostic.CreateDiagnostic()); + } } - Parser parser = new(context, compilationData.TypeSymbols!, inputCalls); - if (parser.GetSourceGenerationSpec() is SourceGenerationSpec { } spec) + if (input.SourceGenerationSpec is SourceGenerationSpec spec) { - Emitter emitter = new(context, spec); - emitter.Emit(); + OnSourceEmitting?.Invoke(spec); + Emitter emitter = new(spec); + emitter.Emit(sourceProductionContext); } } - private sealed record CompilationData + internal sealed class CompilationData { public bool LanguageVersionIsSupported { get; } public KnownTypeSymbols? TypeSymbols { get; } public CompilationData(CSharpCompilation compilation) { - LanguageVersionIsSupported = compilation.LanguageVersion >= LanguageVersion.CSharp11; + // We don't have a CSharp21 value available yet. Polyfill the value here for forward compat, rather than use the LangugeVersion.Preview enum value. + // https://github.com/dotnet/roslyn/blob/168689931cb4e3150641ec2fb188a64ce4b3b790/src/Compilers/CSharp/Portable/LanguageVersion.cs#L218-L232 + const int LangVersion_CSharp12 = 1200; + LanguageVersionIsSupported = (int)compilation.LanguageVersion >= LangVersion_CSharp12; + if (LanguageVersionIsSupported) { TypeSymbols = new KnownTypeSymbols(compilation); diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Emitter/ConfigurationBinder.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Emitter/ConfigurationBinder.cs new file mode 100644 index 00000000000000..7d723139bde3e4 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Emitter/ConfigurationBinder.cs @@ -0,0 +1,168 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics; +using SourceGenerators; + +namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration +{ + public sealed partial class ConfigurationBindingGenerator + { + private sealed partial class Emitter + { + private void EmitBindingExtensions_IConfiguration() + { + if (!ShouldEmitMethods(MethodsToGen.ConfigBinder_Any)) + { + return; + } + + EmitBindingExtStartRegion(Identifier.IConfiguration); + EmitGetMethods(); + EmitGetValueMethods(); + EmitBindMethods_ConfigurationBinder(); + EmitBindingExtEndRegion(); + } + + private void EmitGetMethods() + { + const string expressionForGetCore = nameof(MethodsToGen_CoreBindingHelper.GetCore); + const string documentation = "Attempts to bind the configuration instance to a new instance of type T."; + + if (ShouldEmitMethods(MethodsToGen.ConfigBinder_Get_T)) + { + EmitStartDefinition_Get_Or_GetValue_Overload(MethodsToGen.ConfigBinder_Get_T, documentation); + _writer.WriteLine($"public static T? {Identifier.Get}(this {Identifier.IConfiguration} {Identifier.configuration}) => " + + $"(T?)({expressionForGetCore}({Identifier.configuration}, typeof(T), {Identifier.configureOptions}: null) ?? default(T));"); + } + + if (ShouldEmitMethods(MethodsToGen.ConfigBinder_Get_T_BinderOptions)) + { + EmitStartDefinition_Get_Or_GetValue_Overload(MethodsToGen.ConfigBinder_Get_T_BinderOptions, documentation); + _writer.WriteLine($"public static T? {Identifier.Get}(this {Identifier.IConfiguration} {Identifier.configuration}, {TypeDisplayString.NullableActionOfBinderOptions} {Identifier.configureOptions}) => " + + $"(T?)({expressionForGetCore}({Identifier.configuration}, typeof(T), {Identifier.configureOptions}) ?? default(T));"); + } + + if (ShouldEmitMethods(MethodsToGen.ConfigBinder_Get_TypeOf)) + { + EmitStartDefinition_Get_Or_GetValue_Overload(MethodsToGen.ConfigBinder_Get_TypeOf, documentation); + _writer.WriteLine($"public static object? {Identifier.Get}(this {Identifier.IConfiguration} {Identifier.configuration}, Type {Identifier.type}) => " + + $"{expressionForGetCore}({Identifier.configuration}, {Identifier.type}, {Identifier.configureOptions}: null);"); + } + + if (ShouldEmitMethods(MethodsToGen.ConfigBinder_Get_TypeOf_BinderOptions)) + { + EmitStartDefinition_Get_Or_GetValue_Overload(MethodsToGen.ConfigBinder_Get_TypeOf_BinderOptions, documentation); + _writer.WriteLine($"public static object? {Identifier.Get}(this {Identifier.IConfiguration} {Identifier.configuration}, Type {Identifier.type}, {TypeDisplayString.NullableActionOfBinderOptions} {Identifier.configureOptions}) => " + + $"{expressionForGetCore}({Identifier.configuration}, {Identifier.type}, {Identifier.configureOptions});"); + } + } + + private void EmitGetValueMethods() + { + const string expressionForGetValueCore = $"{Identifier.BindingExtensions}.{nameof(MethodsToGen_CoreBindingHelper.GetValueCore)}"; + const string documentation = "Extracts the value with the specified key and converts it to the specified type."; + + if (ShouldEmitMethods(MethodsToGen.ConfigBinder_GetValue_T_key)) + { + EmitStartDefinition_Get_Or_GetValue_Overload(MethodsToGen.ConfigBinder_GetValue_T_key, documentation); + _writer.WriteLine($"public static T? {Identifier.GetValue}(this {Identifier.IConfiguration} {Identifier.configuration}, string {Identifier.key}) => " + + $"(T?)({expressionForGetValueCore}({Identifier.configuration}, typeof(T), {Identifier.key}) ?? default(T));"); + } + + if (ShouldEmitMethods(MethodsToGen.ConfigBinder_GetValue_T_key_defaultValue)) + { + EmitStartDefinition_Get_Or_GetValue_Overload(MethodsToGen.ConfigBinder_GetValue_T_key_defaultValue, documentation); + _writer.WriteLine($"public static T? {Identifier.GetValue}(this {Identifier.IConfiguration} {Identifier.configuration}, string {Identifier.key}, T {Identifier.defaultValue}) => " + + $"(T?)({expressionForGetValueCore}({Identifier.configuration}, typeof(T), {Identifier.key}) ?? {Identifier.defaultValue});"); + } + + if (ShouldEmitMethods(MethodsToGen.ConfigBinder_GetValue_TypeOf_key)) + { + EmitStartDefinition_Get_Or_GetValue_Overload(MethodsToGen.ConfigBinder_GetValue_TypeOf_key, documentation); + _writer.WriteLine($"public static object? {Identifier.GetValue}(this {Identifier.IConfiguration} {Identifier.configuration}, Type {Identifier.type}, string {Identifier.key}) => " + + $"{expressionForGetValueCore}({Identifier.configuration}, {Identifier.type}, {Identifier.key});"); + } + + if (ShouldEmitMethods(MethodsToGen.ConfigBinder_GetValue_TypeOf_key_defaultValue)) + { + EmitStartDefinition_Get_Or_GetValue_Overload(MethodsToGen.ConfigBinder_GetValue_TypeOf_key_defaultValue, documentation); + _writer.WriteLine($"public static object? {Identifier.GetValue}(this {Identifier.IConfiguration} {Identifier.configuration}, Type {Identifier.type}, string {Identifier.key}, object? {Identifier.defaultValue}) => " + + $"{expressionForGetValueCore}({Identifier.configuration}, {Identifier.type}, {Identifier.key}) ?? {Identifier.defaultValue};"); + } + } + + private void EmitBindMethods_ConfigurationBinder() + { + if (!ShouldEmitMethods(MethodsToGen.ConfigBinder_Bind)) + { + return; + } + + string instanceParamExpr = $"object? {Identifier.instance}"; + + if (ShouldEmitMethods(MethodsToGen.ConfigBinder_Bind_instance)) + { + EmitMethods( + _interceptorInfo.ConfigBinder_Bind_instance, + additionalParams: instanceParamExpr, + configExpression: Identifier.configuration, + configureOptions: false); + } + + if (ShouldEmitMethods(MethodsToGen.ConfigBinder_Bind_instance_BinderOptions)) + { + EmitMethods( + _interceptorInfo.ConfigBinder_Bind_instance_BinderOptions, + additionalParams: $"{instanceParamExpr}, {TypeDisplayString.NullableActionOfBinderOptions} {Identifier.configureOptions}", + configExpression: Identifier.configuration, + configureOptions: true); + } + + if (ShouldEmitMethods(MethodsToGen.ConfigBinder_Bind_key_instance)) + { + EmitMethods( + _interceptorInfo.ConfigBinder_Bind_key_instance, + additionalParams: $"string {Identifier.key}, {instanceParamExpr}", + configExpression: $"{Expression.configurationGetSection}({Identifier.key})", + configureOptions: false); + } + + void EmitMethods(ImmutableEquatableArray? interceptorInfo, string additionalParams, string configExpression, bool configureOptions) + { + Debug.Assert(interceptorInfo is not null); + + foreach ((ComplexTypeSpec type, ImmutableEquatableArray locations) in interceptorInfo) + { + EmitBlankLineIfRequired(); + _writer.WriteLine($"/// Attempts to bind the given object instance to configuration values by matching property names against configuration keys recursively."); + EmitInterceptsLocationAnnotations(locations); + EmitStartBlock($"public static void {Identifier.Bind}_{type.IdentifierCompatibleSubstring}(this {Identifier.IConfiguration} {Identifier.configuration}, {additionalParams})"); + + if (_typeIndex.HasBindableMembers(type)) + { + Debug.Assert(!type.IsValueType); + string binderOptionsArg = configureOptions ? $"{Identifier.GetBinderOptions}({Identifier.configureOptions})" : $"{Identifier.binderOptions}: null"; + + EmitCheckForNullArgument_WithBlankLine(Identifier.configuration); + EmitCheckForNullArgument_WithBlankLine(Identifier.instance, voidReturn: true); + _writer.WriteLine($$""" + var {{Identifier.typedObj}} = ({{type.DisplayString}}){{Identifier.instance}}; + {{nameof(MethodsToGen_CoreBindingHelper.BindCore)}}({{configExpression}}, ref {{Identifier.typedObj}}, defaultValueIfNotFound: false, {{binderOptionsArg}}); + """); + } + + EmitEndBlock(); + } + } + } + + private void EmitStartDefinition_Get_Or_GetValue_Overload(MethodsToGen overload, string documentation) + { + EmitBlankLineIfRequired(); + _writer.WriteLine($"/// {documentation}"); + EmitInterceptsLocationAnnotations(overload); + } + } + } +} diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Emitter/CoreBindingHelpers.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Emitter/CoreBindingHelpers.cs new file mode 100644 index 00000000000000..499d4085bbd362 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Emitter/CoreBindingHelpers.cs @@ -0,0 +1,1265 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Text.RegularExpressions; +using Microsoft.CodeAnalysis; +using SourceGenerators; + +namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration +{ + public sealed partial class ConfigurationBindingGenerator + { + private sealed partial class Emitter + { + private int _valueSuffixIndex; + private bool _emitBlankLineBeforeNextStatement; + private static readonly Regex s_arrayBracketsRegex = new(Regex.Escape("[]")); + + private bool ShouldEmitMethods(MethodsToGen_CoreBindingHelper methods) => (_bindingHelperInfo.MethodsToGen & methods) != 0; + + private void EmitCoreBindingHelpers() + { + Debug.Assert(_emitBlankLineBeforeNextStatement); + EmitBindingExtStartRegion("Core binding"); + EmitConfigurationKeyCaches(); + EmitGetCoreMethod(); + EmitGetValueCoreMethod(); + EmitBindCoreMainMethod(); + EmitBindCoreMethods(); + EmitInitializeMethods(); + EmitHelperMethods(); + EmitBindingExtEndRegion(); + } + + private void EmitConfigurationKeyCaches() + { + if (_bindingHelperInfo.TypesForGen_BindCore is not { Count: not 0 } types) + { + return; + } + + EmitBlankLineIfRequired(); + + foreach (TypeSpec type in types) + { + if (type is not ObjectSpec objectType) + { + continue; + } + + Debug.Assert(_typeIndex.HasBindableMembers(objectType)); + + HashSet? keys = null; + static string GetCacheElement(MemberSpec member) => $@"""{member.ConfigurationKeyName}"""; + + if (objectType.ConstructorParameters?.Select(m => GetCacheElement(m)) is IEnumerable paramNames) + { + keys = new(paramNames); + } + + if (objectType.Properties?.Select(m => GetCacheElement(m)) is IEnumerable propNames) + { + if (keys is null) + { + keys = new(propNames); + } + else + { + keys.UnionWith(propNames); + } + } + + // Type has bindable members. + Debug.Assert(keys is not null); + + string configKeysSource = string.Join(", ", keys); + string fieldName = TypeIndex.GetConfigKeyCacheFieldName(objectType); + _writer.WriteLine($@"private readonly static Lazy<{TypeDisplayString.HashSetOfString}> {fieldName} = new(() => new {TypeDisplayString.HashSetOfString}(StringComparer.OrdinalIgnoreCase) {{ {configKeysSource} }});"); + } + } + + private void EmitGetCoreMethod() + { + if (_bindingHelperInfo.TypesForGen_GetCore is not { Count: not 0 } targetTypes) + { + return; + } + + EmitBlankLineIfRequired(); + EmitStartBlock($"public static object? {nameof(MethodsToGen_CoreBindingHelper.GetCore)}(this {Identifier.IConfiguration} {Identifier.configuration}, Type {Identifier.type}, Action<{Identifier.BinderOptions}>? {Identifier.configureOptions})"); + + EmitCheckForNullArgument_WithBlankLine(Identifier.configuration); + + _writer.WriteLine($"{Identifier.BinderOptions}? {Identifier.binderOptions} = {Identifier.GetBinderOptions}({Identifier.configureOptions});"); + _writer.WriteLine(); + + EmitIConfigurationHasValueOrChildrenCheck(voidReturn: false); + + bool isFirstType = true; + foreach (TypeSpec type in targetTypes) + { + Debug.Assert(_typeIndex.CanBindTo(type.TypeRef)); + + TypeSpec effectiveType = _typeIndex.GetEffectiveTypeSpec(type); + string conditionKindExpr = GetConditionKindExpr(ref isFirstType); + + EmitStartBlock($"{conditionKindExpr} ({Identifier.type} == typeof({type.DisplayString}))"); + + switch (effectiveType) + { + case ParsableFromStringSpec stringParsableType: + { + EmitCastToIConfigurationSection(); + EmitBindingLogic( + stringParsableType, + Expression.sectionValue, + Expression.sectionPath, + writeOnSuccess: parsedValueExpr => _writer.WriteLine($"return {parsedValueExpr};"), + checkForNullSectionValue: stringParsableType.StringParsableTypeKind is not StringParsableTypeKind.AssignFromSectionValue, + useDefaultValueIfSectionValueIsNull: false, + useIncrementalStringValueIdentifier: false); + } + break; + case ConfigurationSectionSpec: + { + EmitCastToIConfigurationSection(); + _writer.WriteLine($"return {Identifier.section};"); + } + break; + case ComplexTypeSpec complexType: + { + if (_typeIndex.CanInstantiate(complexType)) + { + EmitBindingLogic(complexType, Identifier.instance, Identifier.configuration, InitializationKind.Declaration, ValueDefaulting.CallSetter); + _writer.WriteLine($"return {Identifier.instance};"); + } + else if (type is ObjectSpec { InitExceptionMessage: string exMsg }) + { + _writer.WriteLine($@"throw new {Identifier.InvalidOperationException}(""{exMsg}"");"); + } +#if DEBUG + else + { + Debug.Fail($"Complex should not be included for GetCore gen: {complexType.DisplayString}"); + } +#endif + } + break; + } + + EmitEndBlock(); // End if-check for input type. + } + + _writer.WriteLine(); + Emit_NotSupportedException_TypeNotDetectedAsInput(); + EmitEndBlock(); + _emitBlankLineBeforeNextStatement = true; + + void EmitCastToIConfigurationSection() => + _writer.WriteLine($$""" + if ({{Identifier.configuration}} is not {{Identifier.IConfigurationSection}} {{Identifier.section}}) + { + throw new {{Identifier.InvalidOperationException}}(); + } + """); + } + + private void EmitGetValueCoreMethod() + { + if (_bindingHelperInfo.TypesForGen_GetValueCore is not { Count: not 0 } targetTypes) + { + return; + } + + EmitBlankLineIfRequired(); + EmitStartBlock($"public static object? {nameof(MethodsToGen_CoreBindingHelper.GetValueCore)}(this {Identifier.IConfiguration} {Identifier.configuration}, Type {Identifier.type}, string {Identifier.key})"); + + EmitCheckForNullArgument_WithBlankLine(Identifier.configuration); + _writer.WriteLine($@"{Identifier.IConfigurationSection} {Identifier.section} = {GetSectionFromConfigurationExpression(Identifier.key, addQuotes: false)};"); + _writer.WriteLine(); + + _writer.WriteLine($$""" + if ({{Expression.sectionValue}} is not string {{Identifier.value}}) + { + return null; + } + """); + + _writer.WriteLine(); + + bool isFirstType = true; + foreach (TypeSpec type in targetTypes) + { + string conditionKindExpr = GetConditionKindExpr(ref isFirstType); + EmitStartBlock($"{conditionKindExpr} ({Identifier.type} == typeof({type.DisplayString}))"); + + EmitBindingLogic( + (ParsableFromStringSpec)_typeIndex.GetEffectiveTypeSpec(type), + Identifier.value, + Expression.sectionPath, + writeOnSuccess: (parsedValueExpr) => _writer.WriteLine($"return {parsedValueExpr};"), + checkForNullSectionValue: false, + useDefaultValueIfSectionValueIsNull: false, + useIncrementalStringValueIdentifier: false); + + EmitEndBlock(); + } + + _writer.WriteLine(); + _writer.WriteLine("return null;"); + EmitEndBlock(); + _emitBlankLineBeforeNextStatement = true; + } + + private void EmitBindCoreMainMethod() + { + if (_bindingHelperInfo.TypesForGen_BindCoreMain is not { Count: not 0 } targetTypes) + { + return; + } + + EmitBlankLineIfRequired(); + EmitStartBlock($"public static void {nameof(MethodsToGen_CoreBindingHelper.BindCoreMain)}({Identifier.IConfiguration} {Identifier.configuration}, object {Identifier.instance}, Type {Identifier.type}, {TypeDisplayString.NullableActionOfBinderOptions} {Identifier.configureOptions})"); + EmitCheckForNullArgument_WithBlankLine(Identifier.instance, voidReturn: true); + EmitIConfigurationHasValueOrChildrenCheck(voidReturn: true); + _writer.WriteLine($"{Identifier.BinderOptions}? {Identifier.binderOptions} = {Identifier.GetBinderOptions}({Identifier.configureOptions});"); + _writer.WriteLine(); + + bool isFirstType = true; + foreach (ComplexTypeSpec type in targetTypes) + { + ComplexTypeSpec effectiveType = (ComplexTypeSpec)_typeIndex.GetEffectiveTypeSpec(type); + Debug.Assert(_typeIndex.HasBindableMembers(effectiveType)); + string conditionKindExpr = GetConditionKindExpr(ref isFirstType); + + EmitStartBlock($"{conditionKindExpr} ({Identifier.type} == typeof({type.DisplayString}))"); + _writer.WriteLine($"var {Identifier.temp} = ({effectiveType.DisplayString}){Identifier.instance};"); + EmitBindingLogic(type, Identifier.temp, Identifier.configuration, InitializationKind.None, ValueDefaulting.None); + _writer.WriteLine($"return;"); + EmitEndBlock(); + } + + _writer.WriteLine(); + Emit_NotSupportedException_TypeNotDetectedAsInput(); + EmitEndBlock(); + } + + private void EmitBindCoreMethods() + { + if (_bindingHelperInfo.TypesForGen_BindCore is not ImmutableEquatableArray types) + { + return; + } + + foreach (ComplexTypeSpec type in types) + { + Debug.Assert(_typeIndex.HasBindableMembers(type)); + EmitBlankLineIfRequired(); + EmitBindCoreMethod(type); + } + } + + private void EmitBindCoreMethod(ComplexTypeSpec type) + { + string objParameterExpression = $"ref {type.DisplayString} {Identifier.instance}"; + EmitStartBlock(@$"public static void {nameof(MethodsToGen_CoreBindingHelper.BindCore)}({Identifier.IConfiguration} {Identifier.configuration}, {objParameterExpression}, bool defaultValueIfNotFound, {Identifier.BinderOptions}? {Identifier.binderOptions})"); + + ComplexTypeSpec effectiveType = (ComplexTypeSpec)_typeIndex.GetEffectiveTypeSpec(type); + + switch (effectiveType) + { + case ArraySpec arrayType: + { + EmitBindCoreImplForArray(arrayType); + } + break; + case EnumerableSpec enumerableType: + { + EmitBindCoreImplForEnumerableWithAdd(enumerableType); + } + break; + case DictionarySpec dictionaryType: + { + EmitBindCoreImplForDictionary(dictionaryType); + } + break; + case ObjectSpec objectType: + { + EmitBindCoreImplForObject(objectType); + } + break; + default: + { + Debug.Fail($"Unsupported spec for bind core gen: {effectiveType.GetType()}"); + } + break; + } + + EmitEndBlock(); + } + + private void EmitInitializeMethods() + { + if (_bindingHelperInfo.TypesForGen_Initialize is not ImmutableEquatableArray types) + { + return; + } + + foreach (ObjectSpec type in types) + { + EmitBlankLineIfRequired(); + EmitInitializeMethod(type); + } + } + + private void EmitInitializeMethod(ObjectSpec type) + { + Debug.Assert(type.InstantiationStrategy is ObjectInstantiationStrategy.ParameterizedConstructor); + Debug.Assert(_typeIndex.CanInstantiate(type)); + Debug.Assert( + type is { Properties: not null, ConstructorParameters: not null }, + $"Expecting type for init method, {type.DisplayString}, to have both properties and ctor params."); + + IEnumerable initOnlyProps = type.Properties.Where(prop => prop is { SetOnInit: true }); + List ctorArgList = new(); + string displayString = type.DisplayString; + + EmitStartBlock($"public static {type.DisplayString} {GetInitalizeMethodDisplayString(type)}({Identifier.IConfiguration} {Identifier.configuration}, {Identifier.BinderOptions}? {Identifier.binderOptions})"); + _emitBlankLineBeforeNextStatement = false; + + foreach (ParameterSpec parameter in type.ConstructorParameters) + { + string name = parameter.Name; + string argExpr = parameter.RefKind switch + { + RefKind.None => name, + RefKind.Ref => $"ref {name}", + RefKind.Out => "out _", + RefKind.In => $"in {name}", + _ => throw new InvalidOperationException() + }; + + ctorArgList.Add(argExpr); + EmitBindImplForMember(parameter); + } + + foreach (PropertySpec property in initOnlyProps) + { + if (_typeIndex.ShouldBindTo(property) && property.MatchingCtorParam is null) + { + EmitBindImplForMember(property); + } + } + + string returnExpression = $"return new {displayString}({string.Join(", ", ctorArgList)})"; + if (!initOnlyProps.Any()) + { + _writer.WriteLine($"{returnExpression};"); + } + else + { + EmitStartBlock(returnExpression); + foreach (PropertySpec property in initOnlyProps) + { + string propertyName = property.Name; + _writer.WriteLine($@"{propertyName} = {propertyName},"); + } + EmitEndBlock(endBraceTrailingSource: ";"); + } + + // End method. + EmitEndBlock(); + _emitBlankLineBeforeNextStatement = true; + + void EmitBindImplForMember(MemberSpec member) + { + TypeSpec memberType = _typeIndex.GetTypeSpec(member.TypeRef); + string parsedMemberDeclarationLhs = $"{memberType.DisplayString} {member.Name}"; + string configKeyName = member.ConfigurationKeyName; + string parsedMemberAssignmentLhsExpr; + + switch (memberType) + { + case ParsableFromStringSpec { StringParsableTypeKind: StringParsableTypeKind.AssignFromSectionValue }: + { + if (member is ParameterSpec parameter && parameter.ErrorOnFailedBinding) + { + string condition = $@"if ({Identifier.configuration}[""{configKeyName}""] is not {parsedMemberDeclarationLhs})"; + EmitThrowBlock(condition); + _writer.WriteLine(); + return; + } + + parsedMemberAssignmentLhsExpr = parsedMemberDeclarationLhs; + } + break; + case ConfigurationSectionSpec: + { + _writer.WriteLine($"{parsedMemberDeclarationLhs} = {GetSectionFromConfigurationExpression(configKeyName)};"); + return; + } + default: + { + string bangExpr = memberType.IsValueType ? string.Empty : "!"; + string parsedMemberIdentifierDeclaration = $"{parsedMemberDeclarationLhs} = {member.DefaultValueExpr}{bangExpr};"; + + _writer.WriteLine(parsedMemberIdentifierDeclaration); + _emitBlankLineBeforeNextStatement = false; + + parsedMemberAssignmentLhsExpr = member.Name; + } + break; + } + + bool canBindToMember = this.EmitBindImplForMember( + member, + parsedMemberAssignmentLhsExpr, + sectionPathExpr: GetSectionPathFromConfigurationExpression(configKeyName), + canSet: true, + InitializationKind.None); + + if (canBindToMember) + { + if (member is ParameterSpec parameter && parameter.ErrorOnFailedBinding) + { + // Add exception logic for parameter ctors; must be present in configuration object. + EmitThrowBlock(condition: "else"); + } + + _writer.WriteLine(); + } + + void EmitThrowBlock(string condition) => + _writer.WriteLine($$""" + {{condition}} + { + throw new {{Identifier.InvalidOperationException}}("{{string.Format(ExceptionMessages.ParameterHasNoMatchingConfig, type.Name, member.Name)}}"); + } + """); + } + } + + private void EmitHelperMethods() + { + // Emitted if we are to bind objects with complex members, or if we're emitting BindCoreMain or GetCore methods. + bool emitAsConfigWithChildren = ShouldEmitMethods(MethodsToGen_CoreBindingHelper.AsConfigWithChildren); + + if (ShouldEmitMethods(MethodsToGen_CoreBindingHelper.BindCore)) + { + EmitBlankLineIfRequired(); + EmitValidateConfigurationKeysMethod(); + } + + if (ShouldEmitMethods(MethodsToGen_CoreBindingHelper.BindCoreMain | MethodsToGen_CoreBindingHelper.GetCore)) + { + // HasValueOrChildren references this method. + Debug.Assert(emitAsConfigWithChildren); + EmitBlankLineIfRequired(); + EmitHasValueOrChildrenMethod(); + } + + if (emitAsConfigWithChildren) + { + EmitBlankLineIfRequired(); + EmitAsConfigWithChildrenMethod(); + } + + if (ShouldEmitMethods(MethodsToGen_CoreBindingHelper.BindCoreMain | MethodsToGen_CoreBindingHelper.GetCore) || + ShouldEmitMethods(MethodsToGen.ConfigBinder_Bind_instance_BinderOptions)) + { + EmitBlankLineIfRequired(); + EmitGetBinderOptionsHelper(); + } + + if (_bindingHelperInfo.TypesForGen_ParsePrimitive is { Count: not 0 } stringParsableTypes) + { + bool enumTypeExists = false; + + foreach (ParsableFromStringSpec type in stringParsableTypes) + { + EmitBlankLineIfRequired(); + + if (type.StringParsableTypeKind == StringParsableTypeKind.Enum) + { + if (!enumTypeExists) + { + EmitEnumParseMethod(); + enumTypeExists = true; + } + } + else + { + EmitPrimitiveParseMethod(type); + } + } + } + } + + private void EmitValidateConfigurationKeysMethod() + { + const string keysIdentifier = "keys"; + string exceptionMessage = string.Format(ExceptionMessages.MissingConfig, Identifier.ErrorOnUnknownConfiguration, Identifier.BinderOptions, $"{{{Identifier.type}}}", $@"{{string.Join("", "", {Identifier.temp})}}"); + + EmitBlankLineIfRequired(); + _writer.WriteLine($$""" + /// If required by the binder options, validates that there are no unknown keys in the input configuration object. + public static void {{Identifier.ValidateConfigurationKeys}}(Type {{Identifier.type}}, {{TypeDisplayString.LazyHashSetOfString}} {{keysIdentifier}}, {{Identifier.IConfiguration}} {{Identifier.configuration}}, {{Identifier.BinderOptions}}? {{Identifier.binderOptions}}) + { + if ({{Identifier.binderOptions}}?.{{Identifier.ErrorOnUnknownConfiguration}} is true) + { + {{TypeDisplayString.ListOfString}}? {{Identifier.temp}} = null; + + foreach ({{Identifier.IConfigurationSection}} {{Identifier.section}} in {{Identifier.configuration}}.{{Identifier.GetChildren}}()) + { + if (!{{keysIdentifier}}.Value.Contains({{Expression.sectionKey}})) + { + ({{Identifier.temp}} ??= new {{TypeDisplayString.ListOfString}}()).Add($"'{{{Expression.sectionKey}}}'"); + } + } + + if ({{Identifier.temp}} is not null) + { + throw new InvalidOperationException($"{{exceptionMessage}}"); + } + } + } + """); + } + + private void EmitHasValueOrChildrenMethod() + { + _writer.WriteLine($$""" + public static bool {{Identifier.HasValueOrChildren}}({{Identifier.IConfiguration}} {{Identifier.configuration}}) + { + if (({{Identifier.configuration}} as {{Identifier.IConfigurationSection}})?.{{Identifier.Value}} is not null) + { + return true; + } + return {{MethodsToGen_CoreBindingHelper.AsConfigWithChildren}}({{Identifier.configuration}}) is not null; + } + """); + } + + private void EmitAsConfigWithChildrenMethod() + { + _writer.WriteLine($$""" + public static {{Identifier.IConfiguration}}? {{MethodsToGen_CoreBindingHelper.AsConfigWithChildren}}({{Identifier.IConfiguration}} {{Identifier.configuration}}) + { + foreach ({{Identifier.IConfigurationSection}} _ in {{Identifier.configuration}}.{{Identifier.GetChildren}}()) + { + return {{Identifier.configuration}}; + } + return null; + } + """); + } + + private void EmitGetBinderOptionsHelper() + { + _writer.WriteLine($$""" + public static {{Identifier.BinderOptions}}? {{Identifier.GetBinderOptions}}({{TypeDisplayString.NullableActionOfBinderOptions}} {{Identifier.configureOptions}}) + { + if ({{Identifier.configureOptions}} is null) + { + return null; + } + + {{Identifier.BinderOptions}} {{Identifier.binderOptions}} = new(); + {{Identifier.configureOptions}}({{Identifier.binderOptions}}); + + if ({{Identifier.binderOptions}}.BindNonPublicProperties) + { + throw new NotSupportedException($"{{string.Format(ExceptionMessages.CannotSpecifyBindNonPublicProperties)}}"); + } + + return {{Identifier.binderOptions}}; + } + """); + } + + private void EmitEnumParseMethod() + { + string exceptionArg1 = string.Format(ExceptionMessages.FailedBinding, $"{{{Identifier.getPath}()}}", $"{{typeof(T)}}"); + + _writer.WriteLine($$""" + public static T ParseEnum(string value, Func getPath) where T : struct + { + try + { + #if NETFRAMEWORK || NETSTANDARD2_0 + return (T)Enum.Parse(typeof(T), value, ignoreCase: true); + #else + return Enum.Parse(value, ignoreCase: true); + #endif + } + catch ({{Identifier.Exception}} {{Identifier.exception}}) + { + throw new {{Identifier.InvalidOperationException}}($"{{exceptionArg1}}", {{Identifier.exception}}); + } + } + """); + } + + private void EmitPrimitiveParseMethod(ParsableFromStringSpec type) + { + StringParsableTypeKind typeKind = type.StringParsableTypeKind; + string typeDisplayString = type.DisplayString; + + string invariantCultureExpression = $"{Identifier.CultureInfo}.InvariantCulture"; + string parsedValueExpr; + + switch (typeKind) + { + case StringParsableTypeKind.Enum: + return; + case StringParsableTypeKind.ByteArray: + { + parsedValueExpr = $"Convert.FromBase64String({Identifier.value})"; + } + break; + case StringParsableTypeKind.Integer: + { + parsedValueExpr = $"{typeDisplayString}.{Identifier.Parse}({Identifier.value}, {Identifier.NumberStyles}.Integer, {invariantCultureExpression})"; + } + break; + case StringParsableTypeKind.Float: + { + parsedValueExpr = $"{typeDisplayString}.{Identifier.Parse}({Identifier.value}, {Identifier.NumberStyles}.Float, {invariantCultureExpression})"; + } + break; + case StringParsableTypeKind.Parse: + { + parsedValueExpr = $"{typeDisplayString}.{Identifier.Parse}({Identifier.value})"; + } + break; + case StringParsableTypeKind.ParseInvariant: + { + parsedValueExpr = $"{typeDisplayString}.{Identifier.Parse}({Identifier.value}, {invariantCultureExpression})"; ; + } + break; + case StringParsableTypeKind.CultureInfo: + { + parsedValueExpr = $"{Identifier.CultureInfo}.GetCultureInfo({Identifier.value})"; + } + break; + case StringParsableTypeKind.Uri: + { + parsedValueExpr = $"new Uri({Identifier.value}, UriKind.RelativeOrAbsolute)"; + } + break; + default: + { + Debug.Fail($"Invalid string parsable kind: {typeKind}"); + return; + } + } + + string exceptionArg1 = string.Format(ExceptionMessages.FailedBinding, $"{{{Identifier.getPath}()}}", $"{{typeof({typeDisplayString})}}"); + + EmitStartBlock($"public static {typeDisplayString} {TypeIndex.GetParseMethodName(type)}(string {Identifier.value}, Func {Identifier.getPath})"); + EmitEndBlock($$""" + try + { + return {{parsedValueExpr}}; + } + catch ({{Identifier.Exception}} {{Identifier.exception}}) + { + throw new {{Identifier.InvalidOperationException}}($"{{exceptionArg1}}", {{Identifier.exception}}); + } + """); + } + + private void EmitBindCoreImplForArray(ArraySpec type) + { + TypeRef elementTypeRef = type.ElementTypeRef; + string elementTypeDisplayString = _typeIndex.GetTypeSpec(elementTypeRef).DisplayString; + string tempIdentifier = GetIncrementalIdentifier(Identifier.temp); + + // Create temp list. + _writer.WriteLine($"var {tempIdentifier} = new List<{elementTypeDisplayString}>();"); + _writer.WriteLine(); + + // Bind elements to temp list. + EmitBindingLogicForEnumerableWithAdd(elementTypeRef, tempIdentifier); + _writer.WriteLine(); + + // Resize array and add binded elements. + _writer.WriteLine($$""" + {{Identifier.Int32}} {{Identifier.originalCount}} = {{Identifier.instance}}.{{Identifier.Length}}; + {{Identifier.Array}}.{{Identifier.Resize}}(ref {{Identifier.instance}}, {{Identifier.originalCount}} + {{tempIdentifier}}.{{Identifier.Count}}); + {{tempIdentifier}}.{{Identifier.CopyTo}}({{Identifier.instance}}, {{Identifier.originalCount}}); + """); + } + + private void EmitBindCoreImplForEnumerableWithAdd(EnumerableSpec type) + { + EmitCollectionCastIfRequired(type, out string instanceIdentifier); + EmitBindingLogicForEnumerableWithAdd(type.ElementTypeRef, instanceIdentifier); + } + + private void EmitBindingLogicForEnumerableWithAdd(TypeRef elementTypeRef, string enumerableIdentifier) + { + Emit_Foreach_Section_In_ConfigChildren_StartBlock(); + + string addExpr = $"{enumerableIdentifier}.{Identifier.Add}"; + + switch (_typeIndex.GetEffectiveTypeSpec(elementTypeRef)) + { + case ParsableFromStringSpec stringParsableType: + { + EmitBindingLogic( + stringParsableType, + Expression.sectionValue, + Expression.sectionPath, + (parsedValueExpr) => _writer.WriteLine($"{addExpr}({parsedValueExpr});"), + checkForNullSectionValue: true, + useDefaultValueIfSectionValueIsNull: false, + useIncrementalStringValueIdentifier: false); + } + break; + case ConfigurationSectionSpec: + { + _writer.WriteLine($"{addExpr}({Identifier.section});"); + } + break; + case ComplexTypeSpec complexType when _typeIndex.CanInstantiate(complexType): + { + EmitBindingLogic(complexType, Identifier.value, Identifier.section, InitializationKind.Declaration, ValueDefaulting.None); + _writer.WriteLine($"{addExpr}({Identifier.value});"); + } + break; + } + + EmitEndBlock(); + } + + private void EmitBindCoreImplForDictionary(DictionarySpec type) + { + EmitCollectionCastIfRequired(type, out string instanceIdentifier); + + Emit_Foreach_Section_In_ConfigChildren_StartBlock(); + + ParsableFromStringSpec keyType = (ParsableFromStringSpec)_typeIndex.GetEffectiveTypeSpec(type.KeyTypeRef); + TypeSpec elementType = _typeIndex.GetTypeSpec(type.ElementTypeRef); + + // Parse key + EmitBindingLogic( + keyType, + Expression.sectionKey, + Expression.sectionPath, + Emit_BindAndAddLogic_ForElement, + checkForNullSectionValue: false, + useDefaultValueIfSectionValueIsNull: false, + useIncrementalStringValueIdentifier: false); + + void Emit_BindAndAddLogic_ForElement(string parsedKeyExpr) + { + switch (elementType) + { + case ParsableFromStringSpec stringParsableElementType: + { + EmitBindingLogic( + stringParsableElementType, + Expression.sectionValue, + Expression.sectionPath, + writeOnSuccess: parsedValueExpr => _writer.WriteLine($"{instanceIdentifier}[{parsedKeyExpr}] = {parsedValueExpr};"), + checkForNullSectionValue: true, + useDefaultValueIfSectionValueIsNull: false, + useIncrementalStringValueIdentifier: false); + } + break; + case ConfigurationSectionSpec: + { + _writer.WriteLine($"{instanceIdentifier}[{parsedKeyExpr}] = {Identifier.section};"); + } + break; + case ComplexTypeSpec complexElementType: + { + if (keyType.StringParsableTypeKind is not StringParsableTypeKind.AssignFromSectionValue) + { + // Save value to local to avoid parsing twice - during look-up and during add. + _writer.WriteLine($"{keyType.DisplayString} {Identifier.key} = {parsedKeyExpr};"); + parsedKeyExpr = Identifier.key; + } + + bool isValueType = complexElementType.IsValueType; + string expressionForElementIsNotNull = $"{Identifier.element} is not null"; + string elementTypeDisplayString = complexElementType.DisplayString + (complexElementType.IsValueType ? string.Empty : "?"); + + string expressionForElementExists = $"{instanceIdentifier}.{Identifier.TryGetValue}({parsedKeyExpr}, out {elementTypeDisplayString} {Identifier.element})"; + string conditionToUseExistingElement = expressionForElementExists; + + // If key already exists, bind to existing element instance if not null (for ref types). + if (!isValueType) + { + conditionToUseExistingElement += $" && {expressionForElementIsNotNull}"; + } + + if (_typeIndex.CanInstantiate(complexElementType)) + { + EmitStartBlock($"if (!({conditionToUseExistingElement}))"); + EmitObjectInit(complexElementType, Identifier.element, InitializationKind.SimpleAssignment, Identifier.section); + EmitEndBlock(); + + EmitBindingLogic(); + } + else + { + EmitStartBlock($"if ({conditionToUseExistingElement})"); + EmitBindingLogic(); + EmitEndBlock(); + } + + void EmitBindingLogic() + { + this.EmitBindingLogic( + complexElementType, + Identifier.element, + Identifier.section, + InitializationKind.None, + ValueDefaulting.None); + + _writer.WriteLine($"{instanceIdentifier}[{parsedKeyExpr}] = {Identifier.element};"); + } + } + break; + } + } + + EmitEndBlock(); + } + + private void EmitBindCoreImplForObject(ObjectSpec type) + { + Debug.Assert(_typeIndex.HasBindableMembers(type)); + + string keyCacheFieldName = TypeIndex.GetConfigKeyCacheFieldName(type); + string validateMethodCallExpr = $"{Identifier.ValidateConfigurationKeys}(typeof({type.DisplayString}), {keyCacheFieldName}, {Identifier.configuration}, {Identifier.binderOptions});"; + _writer.WriteLine(validateMethodCallExpr); + + foreach (PropertySpec property in type.Properties!) + { + if (_typeIndex.ShouldBindTo(property)) + { + string containingTypeRef = property.IsStatic ? type.DisplayString : Identifier.instance; + EmitBindImplForMember( + property, + memberAccessExpr: $"{containingTypeRef}.{property.Name}", + GetSectionPathFromConfigurationExpression(property.ConfigurationKeyName), + canSet: property.CanSet, + InitializationKind.Declaration); + } + } + } + + private bool EmitBindImplForMember( + MemberSpec member, + string memberAccessExpr, + string sectionPathExpr, + bool canSet, + InitializationKind initializationKind) + { + string sectionParseExpr = GetSectionFromConfigurationExpression(member.ConfigurationKeyName); + + switch (_typeIndex.GetEffectiveTypeSpec(member.TypeRef)) + { + case ParsableFromStringSpec stringParsableType: + { + if (canSet) + { + bool useDefaultValueIfSectionValueIsNull = + initializationKind == InitializationKind.Declaration && + member is PropertySpec && + member.TypeRef.IsValueType && + _typeIndex.GetTypeSpec(member.TypeRef) is not NullableSpec; + + EmitBlankLineIfRequired(); + EmitBindingLogic( + stringParsableType, + $@"{Identifier.configuration}[""{member.ConfigurationKeyName}""]", + sectionPathExpr, + writeOnSuccess: parsedValueExpr => _writer.WriteLine($"{memberAccessExpr} = {parsedValueExpr};"), + checkForNullSectionValue: true, + useDefaultValueIfSectionValueIsNull, + useIncrementalStringValueIdentifier: true); + } + + return true; + } + case ConfigurationSectionSpec: + { + if (canSet) + { + EmitBlankLineIfRequired(); + _writer.WriteLine($"{memberAccessExpr} = {sectionParseExpr};"); + } + + return true; + } + case ComplexTypeSpec complexType: + { + string sectionValidationCall = $"{MethodsToGen_CoreBindingHelper.AsConfigWithChildren}({sectionParseExpr})"; + string sectionIdentifier = GetIncrementalIdentifier(Identifier.section); + + EmitBlankLineIfRequired(); + EmitStartBlock($"if ({sectionValidationCall} is {Identifier.IConfigurationSection} {sectionIdentifier})"); + EmitBindingLogicForComplexMember(member, memberAccessExpr, sectionIdentifier, canSet); + EmitEndBlock(); + + return _typeIndex.CanInstantiate(complexType); + } + default: + return false; + } + } + + private void EmitBindingLogicForComplexMember( + MemberSpec member, + string memberAccessExpr, + string configArgExpr, + bool canSet) + { + + TypeSpec memberType = _typeIndex.GetTypeSpec(member.TypeRef); + ComplexTypeSpec effectiveMemberType = (ComplexTypeSpec)_typeIndex.GetEffectiveTypeSpec(memberType); + + string tempIdentifier = GetIncrementalIdentifier(Identifier.temp); + InitializationKind initKind; + string targetObjAccessExpr; + + if (effectiveMemberType.IsValueType) + { + if (!canSet) + { + return; + } + + Debug.Assert(canSet); + string effectiveMemberTypeDisplayString = effectiveMemberType.DisplayString; + initKind = InitializationKind.None; + + if (memberType is NullableSpec) + { + string nullableTempIdentifier = GetIncrementalIdentifier(Identifier.temp); + + _writer.WriteLine($"{memberType.DisplayString} {nullableTempIdentifier} = {memberAccessExpr};"); + + _writer.WriteLine( + $"{effectiveMemberTypeDisplayString} {tempIdentifier} = {nullableTempIdentifier}.{Identifier.HasValue} ? {nullableTempIdentifier}.{Identifier.Value} : new {effectiveMemberTypeDisplayString}();"); + } + else + { + _writer.WriteLine($"{effectiveMemberTypeDisplayString} {tempIdentifier} = {memberAccessExpr};"); + } + + targetObjAccessExpr = tempIdentifier; + } + else if (member.CanGet) + { + targetObjAccessExpr = memberAccessExpr; + initKind = InitializationKind.AssignmentWithNullCheck; + } + else + { + targetObjAccessExpr = memberAccessExpr; + initKind = InitializationKind.SimpleAssignment; + } + + Action? writeOnSuccess = !canSet + ? null + : bindedValueIdentifier => + { + if (memberAccessExpr != bindedValueIdentifier) + { + _writer.WriteLine($"{memberAccessExpr} = {bindedValueIdentifier};"); + } + }; + + EmitBindingLogic( + effectiveMemberType, + targetObjAccessExpr, + configArgExpr, + initKind, + ValueDefaulting.None, + writeOnSuccess + ); + } + + private void EmitBindingLogic( + ComplexTypeSpec type, + string memberAccessExpr, + string configArgExpr, + InitializationKind initKind, + ValueDefaulting valueDefaulting, + Action? writeOnSuccess = null) + { + if (!_typeIndex.HasBindableMembers(type)) + { + if (initKind is not InitializationKind.None) + { + if (_typeIndex.CanInstantiate(type)) + { + EmitObjectInit(type, memberAccessExpr, initKind, configArgExpr); + } + else if (type is ObjectSpec { InitExceptionMessage: string exMsg }) + { + _writer.WriteLine($@"throw new {Identifier.InvalidOperationException}(""{exMsg}"");"); + } + } + + return; + } + + string tempIdentifier = GetIncrementalIdentifier(Identifier.temp); + if (initKind is InitializationKind.AssignmentWithNullCheck) + { + Debug.Assert(!type.IsValueType); + _writer.WriteLine($"{type.DisplayString}? {tempIdentifier} = {memberAccessExpr};"); + EmitBindingLogic(tempIdentifier, InitializationKind.AssignmentWithNullCheck); + } + else if (initKind is InitializationKind.None && type.IsValueType) + { + EmitBindingLogic(tempIdentifier, InitializationKind.Declaration); + _writer.WriteLine($"{memberAccessExpr} = {tempIdentifier};"); + } + else + { + EmitBindingLogic(memberAccessExpr, initKind); + } + + void EmitBindingLogic(string instanceToBindExpr, InitializationKind initKind) + { + string bindCoreCall = $@"{nameof(MethodsToGen_CoreBindingHelper.BindCore)}({configArgExpr}, ref {instanceToBindExpr}, defaultValueIfNotFound: {FormatDefaultValueIfNotFound()}, {Identifier.binderOptions});"; + + if (_typeIndex.CanInstantiate(type)) + { + if (initKind is not InitializationKind.None) + { + EmitObjectInit(type, instanceToBindExpr, initKind, configArgExpr); + } + + EmitBindCoreCall(); + } + else + { + Debug.Assert(!type.IsValueType); + EmitStartBlock($"if ({instanceToBindExpr} is not null)"); + EmitBindCoreCall(); + EmitEndBlock(); + if (type is ObjectSpec { InitExceptionMessage: string exMsg }) + { + EmitStartBlock("else"); + _writer.WriteLine($@"throw new {Identifier.InvalidOperationException}(""{exMsg}"");"); + EmitEndBlock(); + } + } + + void EmitBindCoreCall() + { + _writer.WriteLine(bindCoreCall); + writeOnSuccess?.Invoke(instanceToBindExpr); + } + + string FormatDefaultValueIfNotFound() => valueDefaulting == ValueDefaulting.CallSetter ? "true" : "false"; + } + } + + private void EmitBindingLogic( + ParsableFromStringSpec type, + string sectionValueExpr, + string sectionPathExpr, + Action? writeOnSuccess, + bool checkForNullSectionValue, + bool useDefaultValueIfSectionValueIsNull, + bool useIncrementalStringValueIdentifier) + { + StringParsableTypeKind typeKind = type.StringParsableTypeKind; + Debug.Assert(typeKind is not StringParsableTypeKind.None); + + string nonNull_StringValue_Identifier = useIncrementalStringValueIdentifier ? GetIncrementalIdentifier(Identifier.value) : Identifier.value; + string stringValueToParse_Expr = checkForNullSectionValue ? nonNull_StringValue_Identifier : sectionValueExpr; + string parsedValueExpr = typeKind switch + { + StringParsableTypeKind.AssignFromSectionValue => stringValueToParse_Expr, + StringParsableTypeKind.Enum => $"ParseEnum<{type.DisplayString}>({stringValueToParse_Expr}, () => {sectionPathExpr})", + _ => $"{TypeIndex.GetParseMethodName(type)}({stringValueToParse_Expr}, () => {sectionPathExpr})", + }; + + if (!checkForNullSectionValue) + { + InvokeWriteOnSuccess(); + } + else + { + EmitStartBlock($"if ({sectionValueExpr} is string {nonNull_StringValue_Identifier})"); + InvokeWriteOnSuccess(); + EmitEndBlock(); + } + + if (useDefaultValueIfSectionValueIsNull) + { + parsedValueExpr = $"default"; + EmitStartBlock($"else if (defaultValueIfNotFound)"); + InvokeWriteOnSuccess(); + EmitEndBlock(); + } + + void InvokeWriteOnSuccess() => writeOnSuccess?.Invoke(parsedValueExpr); + } + + private bool EmitObjectInit(ComplexTypeSpec type, string memberAccessExpr, InitializationKind initKind, string configArgExpr) + { + CollectionSpec? collectionType = type as CollectionSpec; + ObjectSpec? objectType = type as ObjectSpec; + + string? castExpr = null; + string initExpr; + + string effectiveDisplayString = type.DisplayString; + if (collectionType is not null) + { + if (collectionType is ArraySpec) + { + initExpr = $"new {s_arrayBracketsRegex.Replace(effectiveDisplayString, "[0]", 1)}"; + } + else + { + CollectionWithCtorInitSpec collectionWithCtorInitType = (CollectionWithCtorInitSpec)collectionType; + + if (collectionWithCtorInitType.InstantiationConcreteType is not CollectionInstantiationConcreteType.Self) + { + castExpr = $"({collectionWithCtorInitType.DisplayString})"; + } + + effectiveDisplayString = _typeIndex.GetInstantiationTypeDisplayString(collectionWithCtorInitType); + initExpr = $"{castExpr}new {effectiveDisplayString}()"; + } + } + else + { + Debug.Assert(objectType is not null); + ObjectInstantiationStrategy strategy = objectType.InstantiationStrategy; + + if (strategy is ObjectInstantiationStrategy.ParameterlessConstructor) + { + initExpr = $"new {effectiveDisplayString}()"; + } + else + { + Debug.Assert(strategy is ObjectInstantiationStrategy.ParameterizedConstructor); + string initMethodIdentifier = GetInitalizeMethodDisplayString(((ObjectSpec)type)); + initExpr = $"{initMethodIdentifier}({configArgExpr}, {Identifier.binderOptions})"; + } + } + + switch (initKind) + { + case InitializationKind.Declaration: + { + Debug.Assert(!memberAccessExpr.Contains('.')); + _writer.WriteLine($"var {memberAccessExpr} = {initExpr};"); + } + break; + case InitializationKind.AssignmentWithNullCheck: + { + + if (collectionType is CollectionWithCtorInitSpec + { + InstantiationStrategy: CollectionInstantiationStrategy.CopyConstructor or CollectionInstantiationStrategy.LinqToDictionary + } collectionWithCtorInitType) + { + string assignmentValueIfMemberNull = collectionWithCtorInitType.InstantiationStrategy is CollectionInstantiationStrategy.CopyConstructor + ? $"new {effectiveDisplayString}({memberAccessExpr})" + : $"{memberAccessExpr}.ToDictionary(pair => pair.Key, pair => pair.Value)"; + + Debug.Assert(castExpr is not null || collectionWithCtorInitType.InstantiationConcreteType is CollectionInstantiationConcreteType.Self); + assignmentValueIfMemberNull = $"{castExpr}{assignmentValueIfMemberNull}"; + + _writer.WriteLine($"{memberAccessExpr} = {memberAccessExpr} is null ? {initExpr} : {assignmentValueIfMemberNull};"); + } + else + { + _writer.WriteLine($"{memberAccessExpr} ??= {initExpr};"); + } + } + break; + case InitializationKind.SimpleAssignment: + { + _writer.WriteLine($"{memberAccessExpr} = {initExpr};"); + } + break; + default: + { + Debug.Fail($"Invaild initialization kind: {initKind}"); + } + break; + } + + return true; + } + + private void EmitIConfigurationHasValueOrChildrenCheck(bool voidReturn) + { + string returnPostfix = voidReturn ? string.Empty : " null"; + _writer.WriteLine($$""" + if (!{{Identifier.HasValueOrChildren}}({{Identifier.configuration}})) + { + return{{returnPostfix}}; + } + """); + _writer.WriteLine(); + } + + private void EmitCollectionCastIfRequired(CollectionWithCtorInitSpec type, out string instanceIdentifier) + { + if (type.PopulationCastType is CollectionPopulationCastType.NotApplicable) + { + instanceIdentifier = Identifier.instance; + return; + } + + string castTypeDisplayString = _typeIndex.GetPopulationCastTypeDisplayString(type); + instanceIdentifier = Identifier.temp; + + _writer.WriteLine($$""" + if ({{Identifier.instance}} is not {{castTypeDisplayString}} {{instanceIdentifier}}) + { + return; + } + """); + _writer.WriteLine(); + + } + + private void Emit_Foreach_Section_In_ConfigChildren_StartBlock() => + EmitStartBlock($"foreach ({Identifier.IConfigurationSection} {Identifier.section} in {Identifier.configuration}.{Identifier.GetChildren}())"); + + private void Emit_NotSupportedException_TypeNotDetectedAsInput() => + _writer.WriteLine(@$"throw new NotSupportedException($""{string.Format(ExceptionMessages.TypeNotDetectedAsInput, "{type}")}"");"); + + private static string GetSectionPathFromConfigurationExpression(string configurationKeyName) + => $@"{GetSectionFromConfigurationExpression(configurationKeyName)}.{Identifier.Path}"; + + private static string GetSectionFromConfigurationExpression(string configurationKeyName, bool addQuotes = true) + { + string argExpr = addQuotes ? $@"""{configurationKeyName}""" : configurationKeyName; + return $@"{Identifier.configuration}.{Identifier.GetSection}({argExpr})"; + } + + private static string GetConditionKindExpr(ref bool isFirstType) + { + if (isFirstType) + { + isFirstType = false; + return "if"; + } + + return "else if"; + } + } + } +} diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Helpers/Emitter/ExceptionMessages.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Emitter/ExceptionMessages.cs similarity index 100% rename from src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Helpers/Emitter/ExceptionMessages.cs rename to src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Emitter/ExceptionMessages.cs diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Helpers/Emitter/Helpers.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Emitter/Helpers.cs similarity index 60% rename from src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Helpers/Emitter/Helpers.cs rename to src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Emitter/Helpers.cs index e0e6a36aabaa7c..34a97d3c64c76c 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Helpers/Emitter/Helpers.cs +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Emitter/Helpers.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Collections.Generic; using System.Diagnostics; using System.Reflection; @@ -10,7 +11,9 @@ public sealed partial class ConfigurationBindingGenerator { private sealed partial class Emitter { - private static readonly AssemblyName s_assemblyName = typeof(Emitter).Assembly.GetName(); + internal static readonly AssemblyName s_assemblyName = typeof(ConfigurationBindingGenerator).Assembly.GetName(); + + private string? _emittedExtsTargetType; private enum InitializationKind { @@ -19,6 +22,24 @@ private enum InitializationKind AssignmentWithNullCheck = 2, Declaration = 3, } + + /// + /// The type of defaulting for a property if it does not have a config entry. + /// This should only be applied for "Get" cases, not "Bind" and is also conditioned + /// on the source generated for a particular property as to whether it uses this value. + /// Note this is different than "InitializationKind.Declaration" since it only applied to + /// complex types and not arrays\enumerables. + /// + private enum ValueDefaulting + { + None = 0, + + /// + /// Call the setter with the default value for the property's Type. + /// + CallSetter = 1, + } + private static class Expression { public const string configurationGetSection = "configuration.GetSection"; @@ -26,29 +47,13 @@ private static class Expression public const string sectionPath = "section.Path"; public const string sectionValue = "section.Value"; - public const string GetBinderOptions = $"{FullyQualifiedDisplayString.CoreBindingHelper}.{Identifier.GetBinderOptions}"; + public static string GeneratedCodeAnnotation = $@"[GeneratedCode(""{s_assemblyName.Name}"", ""{s_assemblyName.Version}"")]"; } - private static class FullyQualifiedDisplayString - { - public const string ActionOfBinderOptions = $"global::System.Action"; - public const string AddSingleton = $"{ServiceCollectionServiceExtensions}.AddSingleton"; - public const string ConfigurationChangeTokenSource = "global::Microsoft.Extensions.Options.ConfigurationChangeTokenSource"; - public const string CoreBindingHelper = $"global::{ProjectName}.{Identifier.CoreBindingHelper}"; - public const string IConfiguration = "global::Microsoft.Extensions.Configuration.IConfiguration"; - public const string IConfigurationSection = IConfiguration + "Section"; - public const string IOptionsChangeTokenSource = "global::Microsoft.Extensions.Options.IOptionsChangeTokenSource"; - public const string InvalidOperationException = "global::System.InvalidOperationException"; - public const string IServiceCollection = "global::Microsoft.Extensions.DependencyInjection.IServiceCollection"; - public const string NotSupportedException = "global::System.NotSupportedException"; - public const string OptionsBuilderOfTOptions = $"global::Microsoft.Extensions.Options.OptionsBuilder<{Identifier.TOptions}>"; - public const string ServiceCollectionServiceExtensions = "global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions"; - public const string Type = $"global::System.Type"; - } - - private static class MinimalDisplayString + private static class TypeDisplayString { public const string NullableActionOfBinderOptions = "Action?"; + public const string OptionsBuilderOfTOptions = $"OptionsBuilder<{Identifier.TOptions}>"; public const string HashSetOfString = "HashSet"; public const string LazyHashSetOfString = "Lazy>"; public const string ListOfString = "List"; @@ -57,6 +62,8 @@ private static class MinimalDisplayString private static class Identifier { public const string binderOptions = nameof(binderOptions); + public const string config = nameof(config); + public const string configureBinder = nameof(configureBinder); public const string configureOptions = nameof(configureOptions); public const string configuration = nameof(configuration); public const string configSectionPath = nameof(configSectionPath); @@ -67,7 +74,7 @@ private static class Identifier public const string getPath = nameof(getPath); public const string key = nameof(key); public const string name = nameof(name); - public const string obj = nameof(obj); + public const string instance = nameof(instance); public const string optionsBuilder = nameof(optionsBuilder); public const string originalCount = nameof(originalCount); public const string section = nameof(section); @@ -75,6 +82,7 @@ private static class Identifier public const string services = nameof(services); public const string temp = nameof(temp); public const string type = nameof(type); + public const string typedObj = nameof(typedObj); public const string validateKeys = nameof(validateKeys); public const string value = nameof(value); @@ -82,21 +90,19 @@ private static class Identifier public const string AddSingleton = nameof(AddSingleton); public const string Any = nameof(Any); public const string Array = nameof(Array); - public const string AsConfigWithChildren = nameof(AsConfigWithChildren); public const string Bind = nameof(Bind); public const string BinderOptions = nameof(BinderOptions); + public const string BindingExtensions = nameof(BindingExtensions); + public const string ConfigurationChangeTokenSource = nameof(ConfigurationChangeTokenSource); public const string Configure = nameof(Configure); public const string CopyTo = nameof(CopyTo); public const string ContainsKey = nameof(ContainsKey); - public const string CoreBindingHelper = nameof(CoreBindingHelper); public const string Count = nameof(Count); public const string CultureInfo = nameof(CultureInfo); public const string CultureNotFoundException = nameof(CultureNotFoundException); public const string Enum = nameof(Enum); public const string ErrorOnUnknownConfiguration = nameof(ErrorOnUnknownConfiguration); - public const string GeneratedConfigurationBinder = nameof(GeneratedConfigurationBinder); - public const string GeneratedOptionsBuilderBinder = nameof(GeneratedOptionsBuilderBinder); - public const string GeneratedServiceCollectionBinder = nameof(GeneratedServiceCollectionBinder); + public const string Exception = nameof(Exception); public const string Get = nameof(Get); public const string GetBinderOptions = nameof(GetBinderOptions); public const string GetChildren = nameof(GetChildren); @@ -108,9 +114,13 @@ private static class Identifier public const string IConfiguration = nameof(IConfiguration); public const string IConfigurationSection = nameof(IConfigurationSection); public const string Int32 = "int"; + public const string InterceptsLocation = nameof(InterceptsLocation); public const string InvalidOperationException = nameof(InvalidOperationException); public const string InvariantCulture = nameof(InvariantCulture); + public const string IOptionsChangeTokenSource = nameof(IOptionsChangeTokenSource); + public const string IServiceCollection = nameof(IServiceCollection); public const string Length = nameof(Length); + public const string NumberStyles = nameof(NumberStyles); public const string Parse = nameof(Parse); public const string Path = nameof(Path); public const string Resize = nameof(Resize); @@ -118,16 +128,64 @@ private static class Identifier public const string TOptions = nameof(TOptions); public const string TryCreate = nameof(TryCreate); public const string TryGetValue = nameof(TryGetValue); - public const string TryParse = nameof(TryParse); + public const string Type = nameof(Type); public const string Uri = nameof(Uri); public const string ValidateConfigurationKeys = nameof(ValidateConfigurationKeys); public const string Value = nameof(Value); } - private bool ShouldEmitBinders() => - ShouldEmitMethods(MethodsToGen_ConfigurationBinder.Any) || - ShouldEmitMethods(MethodsToGen_Extensions_OptionsBuilder.Any) || - ShouldEmitMethods(MethodsToGen_Extensions_ServiceCollection.Any); + private bool ShouldEmitMethods(MethodsToGen methods) => (_interceptorInfo.MethodsToGen & methods) != 0; + + private void EmitInterceptsLocationAnnotations(MethodsToGen overload) + { + IEnumerable? infoList = _interceptorInfo.GetInfo(overload); + bool interceptsCalls = infoList is not null; + + // The only time a generated binding method won't have any locations to + // intercept is when either of these methods are used as helpers for + // other generated OptionsBuilder or ServiceCollection binding extensions. + Debug.Assert(interceptsCalls || + overload is MethodsToGen.ServiceCollectionExt_Configure_T_name_BinderOptions || + overload is MethodsToGen.OptionsBuilderExt_Bind_T_BinderOptions); + + if (interceptsCalls) + { + EmitInterceptsLocationAnnotations(infoList!); + } + } + + private void EmitInterceptsLocationAnnotations(IEnumerable infoList) + { + foreach (InvocationLocationInfo info in infoList) + { + _writer.WriteLine($@"[{Identifier.InterceptsLocation}(@""{info.FilePath}"", {info.LineNumber}, {info.CharacterNumber})]"); + } + } + + private void EmitBindingExtStartRegion(string targetType) + { + Debug.Assert(_emittedExtsTargetType is null); + + EmitBlankLineIfRequired(); + _emittedExtsTargetType = targetType; + EmitBindingExtRegionText(isStart: true); + _emitBlankLineBeforeNextStatement = false; + } + + private void EmitBindingExtEndRegion() + { + Debug.Assert(_emittedExtsTargetType is not null); + + EmitBindingExtRegionText(isStart: false); + _emittedExtsTargetType = null; + _emitBlankLineBeforeNextStatement = true; + } + + private void EmitBindingExtRegionText(bool isStart) + { + string endSource = isStart ? string.Empty : "end"; + _writer.WriteLine($"#{endSource}region {_emittedExtsTargetType} extensions."); + } /// /// Starts a block of source code. @@ -171,81 +229,26 @@ private void EmitBlankLineIfRequired() _emitBlankLineBeforeNextStatement = true; } - private void EmitCheckForNullArgument_WithBlankLine_IfRequired(bool isValueType) + private void EmitCheckForNullArgument_WithBlankLine(string paramName, bool voidReturn = false) { - if (!isValueType) - { - EmitCheckForNullArgument_WithBlankLine(Identifier.obj); - } - } - - private void EmitCheckForNullArgument_WithBlankLine(string paramName) - { - string exceptionTypeDisplayString = _useFullyQualifiedNames - ? "global::System.ArgumentNullException" - : "ArgumentNullException"; + string returnExpr = voidReturn + ? "return" + : $"throw new ArgumentNullException(nameof({paramName}))"; _writer.WriteLine($$""" if ({{paramName}} is null) { - throw new {{exceptionTypeDisplayString}}(nameof({{paramName}})); + {{returnExpr}}; } """); _writer.WriteLine(); } - private bool EmitInitException(TypeSpec type) - { - Debug.Assert(type.InitializationStrategy is not InitializationStrategy.None); - - if (!type.CanInitialize) - { - _writer.WriteLine(GetInitException(type.InitExceptionMessage) + ";"); - return true; - } - - return false; - } - - private void EmitRootBindingClassStartBlock(string className) - { - EmitBlankLineIfRequired(); - EmitStartBlock($$""" - /// Generated helper providing an AOT and linking compatible implementation for configuration binding. - {{GetGeneratedCodeAttributeSrc()}} - internal static class {{className}} - """); - - _emitBlankLineBeforeNextStatement = false; - } - - private string GetGeneratedCodeAttributeSrc() - { - string attributeRefExpr = _useFullyQualifiedNames ? $"global::System.CodeDom.Compiler.GeneratedCodeAttribute" : "GeneratedCode"; - return $@"[{attributeRefExpr}(""{s_assemblyName.Name}"", ""{s_assemblyName.Version}"")]"; - } - - private string GetInitException(string message) => $@"throw new {GetInvalidOperationDisplayName()}(""{message}"")"; - private string GetIncrementalIdentifier(string prefix) => $"{prefix}{_valueSuffixIndex++}"; - private string GetInitalizeMethodDisplayString(ObjectSpec type) => - GetHelperMethodDisplayString($"{nameof(MethodsToGen_CoreBindingHelper.Initialize)}{type.DisplayStringWithoutSpecialCharacters}"); - - private string GetTypeDisplayString(TypeSpec type) => _useFullyQualifiedNames ? type.FullyQualifiedDisplayString : type.MinimalDisplayString; - - private string GetHelperMethodDisplayString(string methodName) - { - if (_useFullyQualifiedNames) - { - methodName = FullyQualifiedDisplayString.CoreBindingHelper + "." + methodName; - } - - return methodName; - } - - private string GetInvalidOperationDisplayName() => _useFullyQualifiedNames ? FullyQualifiedDisplayString.InvalidOperationException : Identifier.InvalidOperationException; + private static string GetInitalizeMethodDisplayString(ObjectSpec type) => + $"{nameof(MethodsToGen_CoreBindingHelper.Initialize)}{type.IdentifierCompatibleSubstring}"; } } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Emitter/OptionsBuilderConfigurationExtensions.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Emitter/OptionsBuilderConfigurationExtensions.cs new file mode 100644 index 00000000000000..fdc4286e34c559 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Emitter/OptionsBuilderConfigurationExtensions.cs @@ -0,0 +1,100 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration +{ + public sealed partial class ConfigurationBindingGenerator + { + private sealed partial class Emitter + { + private void EmitBindingExtensions_OptionsBuilder() + { + if (!ShouldEmitMethods(MethodsToGen.OptionsBuilderExt_Any)) + { + return; + } + + EmitBindingExtStartRegion(TypeDisplayString.OptionsBuilderOfTOptions); + EmitBindMethods_Extensions_OptionsBuilder(); + EmitBindConfigurationMethod(); + EmitBindingExtEndRegion(); + } + + private void EmitBindMethods_Extensions_OptionsBuilder() + { + if (!ShouldEmitMethods(MethodsToGen.OptionsBuilderExt_Bind)) + { + return; + } + + const string documentation = @"/// Registers a configuration instance which will bind against."; + const string paramList = $"{Identifier.IConfiguration} {Identifier.config}"; + + if (ShouldEmitMethods(MethodsToGen.OptionsBuilderExt_Bind_T)) + { + EmitMethodStartBlock(MethodsToGen.OptionsBuilderExt_Bind_T, "Bind", paramList, documentation); + _writer.WriteLine($"return Bind({Identifier.optionsBuilder}, {Identifier.config}, {Identifier.configureBinder}: null);"); + EmitEndBlock(); + } + + EmitMethodStartBlock( + MethodsToGen.OptionsBuilderExt_Bind_T_BinderOptions, + "Bind", + paramList + $", {TypeDisplayString.NullableActionOfBinderOptions} {Identifier.configureBinder}", + documentation); + + EmitCheckForNullArgument_WithBlankLine(Identifier.optionsBuilder); + + _writer.WriteLine($$""" + {{Identifier.Configure}}<{{Identifier.TOptions}}>({{Identifier.optionsBuilder}}.{{Identifier.Services}}, {{Identifier.optionsBuilder}}.Name, {{Identifier.config}}, {{Identifier.configureBinder}}); + return {{Identifier.optionsBuilder}}; + """); + + EmitEndBlock(); + } + + private void EmitBindConfigurationMethod() + { + if (!ShouldEmitMethods(MethodsToGen.OptionsBuilderExt_BindConfiguration_T_path_BinderOptions)) + { + return; + } + + const string documentation = $@"/// Registers the dependency injection container to bind against the obtained from the DI service provider."; + string paramList = $"string {Identifier.configSectionPath}, {TypeDisplayString.NullableActionOfBinderOptions} {Identifier.configureBinder} = null"; + + EmitMethodStartBlock(MethodsToGen.OptionsBuilderExt_BindConfiguration, "BindConfiguration", paramList, documentation); + + EmitCheckForNullArgument_WithBlankLine(Identifier.optionsBuilder); + EmitCheckForNullArgument_WithBlankLine(Identifier.configSectionPath); + + EmitStartBlock($"{Identifier.optionsBuilder}.{Identifier.Configure}<{Identifier.IConfiguration}>(({Identifier.instance}, {Identifier.config}) =>"); + EmitCheckForNullArgument_WithBlankLine(Identifier.config); + _writer.WriteLine($$""" + {{Identifier.IConfiguration}} {{Identifier.section}} = string.Equals(string.Empty, {{Identifier.configSectionPath}}, StringComparison.OrdinalIgnoreCase) ? {{Identifier.config}} : {{Identifier.config}}.{{Identifier.GetSection}}({{Identifier.configSectionPath}}); + {{nameof(MethodsToGen_CoreBindingHelper.BindCoreMain)}}({{Identifier.section}}, {{Identifier.instance}}, typeof({{Identifier.TOptions}}), {{Identifier.configureBinder}}); + """); + + EmitEndBlock(endBraceTrailingSource: ");"); + + _writer.WriteLine(); + + _writer.WriteLine($$""" + {{Identifier.optionsBuilder}}.{{Identifier.Services}}.{{Identifier.AddSingleton}}<{{Identifier.IOptionsChangeTokenSource}}<{{Identifier.TOptions}}>, {{Identifier.ConfigurationChangeTokenSource}}<{{Identifier.TOptions}}>>(); + return {{Identifier.optionsBuilder}}; + """); + + EmitEndBlock(); + } + + private void EmitMethodStartBlock(MethodsToGen method, string methodName, string paramList, string documentation) + { + paramList = $"this {TypeDisplayString.OptionsBuilderOfTOptions} {Identifier.optionsBuilder}, {paramList}"; + EmitBlankLineIfRequired(); + _writer.WriteLine(documentation); + EmitInterceptsLocationAnnotations(method); + EmitStartBlock($"public static {TypeDisplayString.OptionsBuilderOfTOptions} {methodName}<{Identifier.TOptions}>({paramList}) where {Identifier.TOptions} : class"); + } + } + } +} diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Emitter/OptionsConfigurationServiceCollectionExtensions.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Emitter/OptionsConfigurationServiceCollectionExtensions.cs new file mode 100644 index 00000000000000..daa3b79db8abc4 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Emitter/OptionsConfigurationServiceCollectionExtensions.cs @@ -0,0 +1,77 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration +{ + public sealed partial class ConfigurationBindingGenerator + { + private sealed partial class Emitter + { + private void EmitBindingExtensions_IServiceCollection() + { + if (!ShouldEmitMethods(MethodsToGen.ServiceCollectionExt_Any)) + { + return; + } + + EmitBindingExtStartRegion(Identifier.IServiceCollection); + EmitConfigureMethods(); + EmitBindingExtEndRegion(); + } + + private void EmitConfigureMethods() + { + const string defaultNameExpr = "string.Empty"; + string configParam = $"{Identifier.IConfiguration} {Identifier.config}"; + + if (ShouldEmitMethods(MethodsToGen.ServiceCollectionExt_Configure_T)) + { + EmitStartMethod(MethodsToGen.ServiceCollectionExt_Configure_T, configParam); + _writer.WriteLine($"return {Identifier.Configure}<{Identifier.TOptions}>({Identifier.services}, {defaultNameExpr}, {Identifier.config}, {Identifier.configureOptions}: null);"); + EmitEndBlock(); + } + + if (ShouldEmitMethods(MethodsToGen.ServiceCollectionExt_Configure_T_name)) + { + EmitStartMethod( + MethodsToGen.ServiceCollectionExt_Configure_T_name, + paramList: $"string? {Identifier.name}, " + configParam); + _writer.WriteLine($"return {Identifier.Configure}<{Identifier.TOptions}>({Identifier.services}, {Identifier.name}, {Identifier.config}, {Identifier.configureOptions}: null);"); + EmitEndBlock(); + } + + if (ShouldEmitMethods(MethodsToGen.ServiceCollectionExt_Configure_T_BinderOptions)) + { + EmitStartMethod( + MethodsToGen.ServiceCollectionExt_Configure_T_BinderOptions, + paramList: configParam + $", {TypeDisplayString.NullableActionOfBinderOptions} {Identifier.configureOptions}"); + _writer.WriteLine($"return {Identifier.Configure}<{Identifier.TOptions}>({Identifier.services}, {defaultNameExpr}, {Identifier.config}, {Identifier.configureOptions});"); + EmitEndBlock(); + } + + // Core Configure method that the other overloads call. + // Like the others, it is public API that could be called directly by users. + // So, it is always generated whenever a Configure overload is called. + EmitStartMethod(MethodsToGen.ServiceCollectionExt_Configure_T_name_BinderOptions, paramList: $"string? {Identifier.name}, " + configParam + $", {TypeDisplayString.NullableActionOfBinderOptions} {Identifier.configureOptions}"); + EmitCheckForNullArgument_WithBlankLine(Identifier.services); + EmitCheckForNullArgument_WithBlankLine(Identifier.config); + _writer.WriteLine($$""" + OptionsServiceCollectionExtensions.AddOptions({{Identifier.services}}); + {{Identifier.services}}.{{Identifier.AddSingleton}}<{{Identifier.IOptionsChangeTokenSource}}<{{Identifier.TOptions}}>>(new {{Identifier.ConfigurationChangeTokenSource}}<{{Identifier.TOptions}}>({{Identifier.name}}, {{Identifier.config}})); + return {{Identifier.services}}.{{Identifier.AddSingleton}}>(new ConfigureNamedOptions<{{Identifier.TOptions}}>({{Identifier.name}}, {{Identifier.instance}} => {{nameof(MethodsToGen_CoreBindingHelper.BindCoreMain)}}({{Identifier.config}}, {{Identifier.instance}}, typeof({{Identifier.TOptions}}), {{Identifier.configureOptions}}))); + """); + EmitEndBlock(); + } + + private void EmitStartMethod(MethodsToGen overload, string paramList) + { + paramList = $"this {Identifier.IServiceCollection} {Identifier.services}, {paramList}"; + + EmitBlankLineIfRequired(); + _writer.WriteLine("/// Registers a configuration instance which TOptions will bind against."); + EmitInterceptsLocationAnnotations(overload); + EmitStartBlock($"public static {Identifier.IServiceCollection} {Identifier.Configure}<{Identifier.TOptions}>({paramList}) where {Identifier.TOptions} : class"); + } + } + } +} diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Helpers/Emitter/ConfigurationBinder.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Helpers/Emitter/ConfigurationBinder.cs deleted file mode 100644 index c10e607df75d0a..00000000000000 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Helpers/Emitter/ConfigurationBinder.cs +++ /dev/null @@ -1,182 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.Collections.Generic; -using System.Diagnostics; -using System.Linq; -using Microsoft.CodeAnalysis; - -namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration -{ - public sealed partial class ConfigurationBindingGenerator - { - private sealed partial class Emitter - { - private bool ShouldEmitMethods(MethodsToGen_ConfigurationBinder methods) => (_sourceGenSpec.MethodsToGen_ConfigurationBinder & methods) != 0; - - private void EmitBinder_Extensions_IConfiguration() - { - Debug.Assert(_sourceGenSpec.TypesForGen_ConfigurationBinder_BindMethods.Count <= 3 && - !_sourceGenSpec.TypesForGen_ConfigurationBinder_BindMethods.Keys.Any(overload => (overload & MethodsToGen_ConfigurationBinder.Bind) is 0)); - - if (!ShouldEmitMethods(MethodsToGen_ConfigurationBinder.Any)) - { - return; - } - - _emitBlankLineBeforeNextStatement = false; - EmitRootBindingClassStartBlock(Identifier.GeneratedConfigurationBinder); - - EmitGetMethods(); - EmitGetValueMethods(); - EmitBindMethods_ConfigurationBinder(); - - EmitEndBlock(); - _emitBlankLineBeforeNextStatement = true; - } - - private void EmitGetMethods() - { - const string expressionForGetCore = $"{FullyQualifiedDisplayString.CoreBindingHelper}.{nameof(MethodsToGen_CoreBindingHelper.GetCore)}"; - const string documentation = "Attempts to bind the configuration instance to a new instance of type T."; - - if (ShouldEmitMethods(MethodsToGen_ConfigurationBinder.Get_T)) - { - StartMethodDefinition(documentation); - _writer.WriteLine($"public static T? {Identifier.Get}(this {FullyQualifiedDisplayString.IConfiguration} {Identifier.configuration}) => " + - $"(T?)({expressionForGetCore}({Identifier.configuration}, typeof(T), {Identifier.configureOptions}: null) ?? default(T));"); - } - - if (ShouldEmitMethods(MethodsToGen_ConfigurationBinder.Get_T_BinderOptions)) - { - StartMethodDefinition(documentation); - _writer.WriteLine($"public static T? {Identifier.Get}(this {FullyQualifiedDisplayString.IConfiguration} {Identifier.configuration}, {FullyQualifiedDisplayString.ActionOfBinderOptions}? {Identifier.configureOptions}) => " + - $"(T?)({expressionForGetCore}({Identifier.configuration}, typeof(T), {Identifier.configureOptions}) ?? default(T));"); - } - - if (ShouldEmitMethods(MethodsToGen_ConfigurationBinder.Get_TypeOf)) - { - StartMethodDefinition(documentation); - _writer.WriteLine($"public static object? {Identifier.Get}(this {FullyQualifiedDisplayString.IConfiguration} {Identifier.configuration}, {FullyQualifiedDisplayString.Type} {Identifier.type}) => " + - $"{expressionForGetCore}({Identifier.configuration}, {Identifier.type}, {Identifier.configureOptions}: null);"); - } - - if (ShouldEmitMethods(MethodsToGen_ConfigurationBinder.Get_TypeOf_BinderOptions)) - { - StartMethodDefinition(documentation); - _writer.WriteLine($"public static object? {Identifier.Get}(this {FullyQualifiedDisplayString.IConfiguration} {Identifier.configuration}, {FullyQualifiedDisplayString.Type} {Identifier.type}, {FullyQualifiedDisplayString.ActionOfBinderOptions}? {Identifier.configureOptions}) => " + - $"{expressionForGetCore}({Identifier.configuration}, {Identifier.type}, {Identifier.configureOptions});"); - } - } - - private void EmitGetValueMethods() - { - const string expressionForGetValueCore = $"{FullyQualifiedDisplayString.CoreBindingHelper}.{nameof(MethodsToGen_CoreBindingHelper.GetValueCore)}"; - const string documentation = "Extracts the value with the specified key and converts it to the specified type."; - - if (ShouldEmitMethods(MethodsToGen_ConfigurationBinder.GetValue_T_key)) - { - StartMethodDefinition(documentation); - _writer.WriteLine($"public static T? {Identifier.GetValue}(this {FullyQualifiedDisplayString.IConfiguration} {Identifier.configuration}, string {Identifier.key}) => " + - $"(T?)({expressionForGetValueCore}({Identifier.configuration}, typeof(T), {Identifier.key}) ?? default(T));"); - } - - if (ShouldEmitMethods(MethodsToGen_ConfigurationBinder.GetValue_T_key_defaultValue)) - { - StartMethodDefinition(documentation); - _writer.WriteLine($"public static T? {Identifier.GetValue}(this {FullyQualifiedDisplayString.IConfiguration} {Identifier.configuration}, string {Identifier.key}, T {Identifier.defaultValue}) => " + - $"(T?)({expressionForGetValueCore}({Identifier.configuration}, typeof(T), {Identifier.key}) ?? {Identifier.defaultValue});"); - } - - if (ShouldEmitMethods(MethodsToGen_ConfigurationBinder.GetValue_TypeOf_key)) - { - StartMethodDefinition(documentation); - _writer.WriteLine($"public static object? {Identifier.GetValue}(this {FullyQualifiedDisplayString.IConfiguration} {Identifier.configuration}, {FullyQualifiedDisplayString.Type} {Identifier.type}, string {Identifier.key}) => " + - $"{expressionForGetValueCore}({Identifier.configuration}, {Identifier.type}, {Identifier.key});"); - } - - if (ShouldEmitMethods(MethodsToGen_ConfigurationBinder.GetValue_TypeOf_key_defaultValue)) - { - StartMethodDefinition(documentation); - _writer.WriteLine($"public static object? {Identifier.GetValue}(this {FullyQualifiedDisplayString.IConfiguration} {Identifier.configuration}, {FullyQualifiedDisplayString.Type} {Identifier.type}, string {Identifier.key}, object? {Identifier.defaultValue}) => " + - $"{expressionForGetValueCore}({Identifier.configuration}, {Identifier.type}, {Identifier.key}) ?? {Identifier.defaultValue};"); - } - } - - private void EmitBindMethods_ConfigurationBinder() - { - if (!ShouldEmitMethods(MethodsToGen_ConfigurationBinder.Bind)) - { - return; - } - - Dictionary> types = _sourceGenSpec.TypesForGen_ConfigurationBinder_BindMethods; - - if (types.TryGetValue(MethodsToGen_ConfigurationBinder.Bind_instance, out HashSet? typeSpecs)) - { - foreach (TypeSpec type in typeSpecs) - { - EmitMethodImplementation( - type, - additionalParams: GetObjParameter(type), - configExpression: Identifier.configuration, - configureOptions: false); - } - } - - if (types.TryGetValue(MethodsToGen_ConfigurationBinder.Bind_instance_BinderOptions, out typeSpecs)) - { - foreach (TypeSpec type in typeSpecs) - { - EmitMethodImplementation( - type, - additionalParams: $"{GetObjParameter(type)}, {FullyQualifiedDisplayString.ActionOfBinderOptions}? {Identifier.configureOptions}", - configExpression: Identifier.configuration, - configureOptions: true); - } - } - - if (types.TryGetValue(MethodsToGen_ConfigurationBinder.Bind_key_instance, out typeSpecs)) - { - foreach (TypeSpec type in typeSpecs) - { - EmitMethodImplementation( - type, - additionalParams: $"string {Identifier.key}, {GetObjParameter(type)}", - configExpression: $"{Expression.configurationGetSection}({Identifier.key})", - configureOptions: false); - } - } - - void EmitMethodImplementation(TypeSpec type, string additionalParams, string configExpression, bool configureOptions) - { - string binderOptionsArg = configureOptions ? $"{Expression.GetBinderOptions}({Identifier.configureOptions})" : $"{Identifier.binderOptions}: null"; - - string returnExpression; - if (type.CanInitialize) - { - returnExpression = type.NeedsMemberBinding - ? $"{FullyQualifiedDisplayString.CoreBindingHelper}.{nameof(MethodsToGen_CoreBindingHelper.BindCore)}({configExpression}, ref {Identifier.obj}, {binderOptionsArg})" - : "{ }"; - } - else - { - returnExpression = GetInitException(type.InitExceptionMessage); - } - - StartMethodDefinition("Attempts to bind the given object instance to configuration values by matching property names against configuration keys recursively."); - _writer.WriteLine($"public static void {Identifier.Bind}(this {FullyQualifiedDisplayString.IConfiguration} {Identifier.configuration}, {additionalParams}) => " - + $"{returnExpression};"); - } - - string GetObjParameter(TypeSpec type) => $"{type.FullyQualifiedDisplayString} {Identifier.obj}"; - } - - private void StartMethodDefinition(string documentation) - { - EmitBlankLineIfRequired(); - _writer.WriteLine($"/// {documentation}"); - } - } - } -} diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Helpers/Emitter/CoreBindingHelper.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Helpers/Emitter/CoreBindingHelper.cs deleted file mode 100644 index a7b42b1329804e..00000000000000 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Helpers/Emitter/CoreBindingHelper.cs +++ /dev/null @@ -1,955 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System; -using System.Collections.Generic; -using System.Collections.Immutable; -using System.Diagnostics; -using System.Linq; -using Microsoft.CodeAnalysis; - -namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration -{ - public sealed partial class ConfigurationBindingGenerator - { - private sealed partial class Emitter - { - private bool ShouldEmitMethods(MethodsToGen_CoreBindingHelper methods) => (_sourceGenSpec.MethodsToGen_CoreBindingHelper & methods) != 0; - - private void Emit_CoreBindingHelper() - { - Debug.Assert(_emitBlankLineBeforeNextStatement); - _writer.WriteLine(); - _emitBlankLineBeforeNextStatement = false; - - EmitStartBlock($"namespace {ProjectName}"); - EmitHelperUsingStatements(); - - _writer.WriteLine(); - - EmitStartBlock($$""" - /// Provide core binding logic. - {{GetGeneratedCodeAttributeSrc()}} - file static class {{Identifier.CoreBindingHelper}} - """); - - EmitConfigurationKeyCaches(); - EmitGetCoreMethod(); - EmitGetValueCoreMethod(); - EmitBindCoreUntypedMethod(); - EmitBindCoreMethods(); - EmitInitializeMethods(); - EmitHelperMethods(); - - EmitEndBlock(); // End helper class. - EmitEndBlock(); // End namespace. - } - - private void EmitHelperUsingStatements() - { - foreach (string @namespace in _sourceGenSpec.TypeNamespaces.ToImmutableSortedSet()) - { - _writer.WriteLine($"using {@namespace};"); - } - } - - private void EmitConfigurationKeyCaches() - { - if (!_sourceGenSpec.TypesForGen_CoreBindingHelper_Methods.TryGetValue(MethodsToGen_CoreBindingHelper.BindCore, out HashSet targetTypes)) - { - return; - } - - foreach (TypeSpec type in targetTypes) - { - if (type is not ObjectSpec objectType) - { - continue; - } - - HashSet keys = new(objectType.ConstructorParameters.Select(m => GetCacheElement(m))); - keys.UnionWith(objectType.Properties.Values.Select(m => GetCacheElement(m))); - static string GetCacheElement(MemberSpec member) => $@"""{member.ConfigurationKeyName}"""; - - string configKeysSource = string.Join(", ", keys); - string fieldName = GetConfigKeyCacheFieldName(objectType); - _writer.WriteLine($@"private readonly static Lazy<{MinimalDisplayString.HashSetOfString}> {fieldName} = new(() => new {MinimalDisplayString.HashSetOfString}(StringComparer.OrdinalIgnoreCase) {{ {configKeysSource} }});"); - } - - _emitBlankLineBeforeNextStatement = true; - } - - private void EmitGetCoreMethod() - { - if (!_sourceGenSpec.TypesForGen_CoreBindingHelper_Methods.TryGetValue(MethodsToGen_CoreBindingHelper.GetCore, out HashSet? types)) - { - return; - } - - EmitBlankLineIfRequired(); - EmitStartBlock($"public static object? {nameof(MethodsToGen_CoreBindingHelper.GetCore)}(this {Identifier.IConfiguration} {Identifier.configuration}, Type {Identifier.type}, Action<{Identifier.BinderOptions}>? {Identifier.configureOptions})"); - - EmitCheckForNullArgument_WithBlankLine(Identifier.configuration); - - _writer.WriteLine($"{Identifier.BinderOptions}? {Identifier.binderOptions} = {Identifier.GetBinderOptions}({Identifier.configureOptions});"); - _writer.WriteLine(); - - EmitIConfigurationHasValueOrChildrenCheck(voidReturn: false); - - foreach (TypeSpec type in types) - { - TypeSpec effectiveType = type.EffectiveType; - TypeSpecKind kind = effectiveType.SpecKind; - - EmitStartBlock($"if (type == typeof({type.MinimalDisplayString}))"); - - if (effectiveType is ParsableFromStringSpec stringParsableType) - { - EmitCastToIConfigurationSection(); - EmitBindLogicFromString( - stringParsableType, - Expression.sectionValue, - Expression.sectionPath, - writeOnSuccess: parsedValueExpr => _writer.WriteLine($"return {parsedValueExpr};"), - checkForNullSectionValue: stringParsableType.StringParsableTypeKind is not StringParsableTypeKind.AssignFromSectionValue, - useIncrementalStringValueIdentifier: false); - } - else if (!EmitInitException(effectiveType)) - { - EmitBindCoreCall(effectiveType, Identifier.obj, Identifier.configuration, InitializationKind.Declaration); - _writer.WriteLine($"return {Identifier.obj};"); - } - - EmitEndBlock(); - _writer.WriteLine(); - } - - Emit_NotSupportedException_TypeNotDetectedAsInput(); - EmitEndBlock(); - _emitBlankLineBeforeNextStatement = true; - } - - private void EmitGetValueCoreMethod() - { - if (!_sourceGenSpec.TypesForGen_CoreBindingHelper_Methods.TryGetValue(MethodsToGen_CoreBindingHelper.GetValueCore, out HashSet? targetTypes)) - { - return; - } - - EmitBlankLineIfRequired(); - EmitStartBlock($"public static object? {nameof(MethodsToGen_CoreBindingHelper.GetValueCore)}(this {Identifier.IConfiguration} {Identifier.configuration}, Type {Identifier.type}, string {Identifier.key})"); - - EmitCheckForNullArgument_WithBlankLine(Identifier.configuration); - _writer.WriteLine($@"{Identifier.IConfigurationSection} {Identifier.section} = {GetSectionFromConfigurationExpression(Identifier.key, addQuotes: false)};"); - _writer.WriteLine(); - - _writer.WriteLine($$""" - if ({{Expression.sectionValue}} is not string {{Identifier.value}}) - { - return null; - } - """); - - _writer.WriteLine(); - - foreach (TypeSpec type in targetTypes) - { - EmitStartBlock($"if ({Identifier.type} == typeof({type.MinimalDisplayString}))"); - - EmitBindLogicFromString( - (ParsableFromStringSpec)type.EffectiveType, - Identifier.value, - Expression.sectionPath, - writeOnSuccess: (parsedValueExpr) => _writer.WriteLine($"return {parsedValueExpr};"), - checkForNullSectionValue: false, - useIncrementalStringValueIdentifier: false); - - EmitEndBlock(); - _writer.WriteLine(); - } - - _writer.WriteLine("return null;"); - EmitEndBlock(); - _emitBlankLineBeforeNextStatement = true; - } - - private void EmitBindCoreUntypedMethod() - { - if (!_sourceGenSpec.TypesForGen_CoreBindingHelper_Methods.TryGetValue(MethodsToGen_CoreBindingHelper.BindCoreUntyped, out HashSet? targetTypes)) - { - return; - } - - EmitBlankLineIfRequired(); - - EmitStartBlock($"public static void {nameof(MethodsToGen_CoreBindingHelper.BindCoreUntyped)}(this {Identifier.IConfiguration} {Identifier.configuration}, object {Identifier.obj}, Type {Identifier.type}, {MinimalDisplayString.NullableActionOfBinderOptions} {Identifier.configureOptions})"); - - EmitCheckForNullArgument_WithBlankLine(Identifier.configuration); - - _writer.WriteLine($"{Identifier.BinderOptions}? {Identifier.binderOptions} = {Identifier.GetBinderOptions}({Identifier.configureOptions});"); - _writer.WriteLine(); - - EmitIConfigurationHasValueOrChildrenCheck(voidReturn: true); - - foreach (TypeSpec type in targetTypes) - { - EmitStartBlock($"if (type == typeof({type.MinimalDisplayString}))"); - - TypeSpec effectiveType = type.EffectiveType; - if (!EmitInitException(effectiveType)) - { - _writer.WriteLine($"var {Identifier.temp} = ({effectiveType.MinimalDisplayString}){Identifier.obj};"); - EmitBindCoreCall(type, Identifier.temp, Identifier.configuration, InitializationKind.None); - _writer.WriteLine($"return;"); - } - - EmitEndBlock(); - _writer.WriteLine(); - } - - Emit_NotSupportedException_TypeNotDetectedAsInput(); - EmitEndBlock(); - _emitBlankLineBeforeNextStatement = true; - } - - private void EmitBindCoreMethods() - { - if (!_sourceGenSpec.TypesForGen_CoreBindingHelper_Methods.TryGetValue(MethodsToGen_CoreBindingHelper.BindCore, out HashSet? targetTypes)) - { - return; - } - - foreach (TypeSpec type in targetTypes) - { - Debug.Assert(type.NeedsMemberBinding); - EmitBlankLineIfRequired(); - EmitBindCoreMethod(type); - } - } - - private void EmitBindCoreMethod(TypeSpec type) - { - Debug.Assert(type.CanInitialize); - - string objParameterExpression = $"ref {type.MinimalDisplayString} {Identifier.obj}"; - EmitStartBlock(@$"public static void {nameof(MethodsToGen_CoreBindingHelper.BindCore)}({Identifier.IConfiguration} {Identifier.configuration}, {objParameterExpression}, {Identifier.BinderOptions}? {Identifier.binderOptions})"); - - EmitCheckForNullArgument_WithBlankLine_IfRequired(type.IsValueType); - - TypeSpec effectiveType = type.EffectiveType; - if (effectiveType is EnumerableSpec enumerable) - { - if (effectiveType.InitializationStrategy is InitializationStrategy.Array) - { - Debug.Assert(type == effectiveType); - EmitPopulationImplForArray((EnumerableSpec)type); - } - else - { - EmitPopulationImplForEnumerableWithAdd(enumerable); - } - } - else if (effectiveType is DictionarySpec dictionary) - { - EmitBindCoreImplForDictionary(dictionary); - } - else - { - EmitBindCoreImplForObject((ObjectSpec)effectiveType); - } - - EmitEndBlock(); - } - - private void EmitInitializeMethods() - { - if (!_sourceGenSpec.TypesForGen_CoreBindingHelper_Methods.TryGetValue(MethodsToGen_CoreBindingHelper.Initialize, out HashSet? targetTypes)) - { - return; - } - - foreach (ObjectSpec type in targetTypes) - { - EmitBlankLineIfRequired(); - EmitInitializeMethod(type); - } - } - - private void EmitInitializeMethod(ObjectSpec type) - { - Debug.Assert(type.CanInitialize); - List ctorParams = type.ConstructorParameters; - IEnumerable initOnlyProps = type.Properties.Values.Where(prop => prop is { SetOnInit: true }); - List ctorArgList = new(); - string displayString = type.MinimalDisplayString; - - EmitStartBlock($"public static {type.MinimalDisplayString} {GetInitalizeMethodDisplayString(type)}({Identifier.IConfiguration} {Identifier.configuration}, {Identifier.BinderOptions}? {Identifier.binderOptions})"); - _emitBlankLineBeforeNextStatement = false; - - foreach (ParameterSpec parameter in ctorParams) - { - string name = parameter.Name; - string argExpr = parameter.RefKind switch - { - RefKind.None => name, - RefKind.Ref => $"ref {name}", - RefKind.Out => "out _", - RefKind.In => $"in {name}", - _ => throw new InvalidOperationException() - }; - - ctorArgList.Add(argExpr); - EmitBindImplForMember(parameter); - } - - foreach (PropertySpec property in initOnlyProps) - { - if (property.ShouldBind() && property.MatchingCtorParam is null) - { - EmitBindImplForMember(property); - } - } - - string returnExpression = $"return new {displayString}({string.Join(", ", ctorArgList)})"; - if (!initOnlyProps.Any()) - { - _writer.WriteLine($"{returnExpression};"); - } - else - { - EmitStartBlock(returnExpression); - foreach (PropertySpec property in initOnlyProps) - { - string propertyName = property.Name; - _writer.WriteLine($@"{propertyName} = {propertyName},"); - } - EmitEndBlock(endBraceTrailingSource: ";"); - } - - // End method. - EmitEndBlock(); - _emitBlankLineBeforeNextStatement = true; - - void EmitBindImplForMember(MemberSpec member) - { - TypeSpec memberType = member.Type; - bool errorOnFailedBinding = member.ErrorOnFailedBinding; - - string parsedMemberIdentifierDeclarationPrefix = $"{memberType.MinimalDisplayString} {member.Name}"; - string parsedMemberIdentifier; - - if (memberType is ParsableFromStringSpec { StringParsableTypeKind: StringParsableTypeKind.AssignFromSectionValue }) - { - parsedMemberIdentifier = parsedMemberIdentifierDeclarationPrefix; - - if (errorOnFailedBinding) - { - string condition = $@" if ({Identifier.configuration}[""{member.ConfigurationKeyName}""] is not {memberType.MinimalDisplayString} {member.Name})"; - EmitThrowBlock(condition); - _writer.WriteLine(); - return; - } - } - else - { - parsedMemberIdentifier = member.Name; - - string declarationSuffix; - if (errorOnFailedBinding) - { - declarationSuffix = ";"; - } - else - { - string bangExpr = memberType.IsValueType ? string.Empty : "!"; - declarationSuffix = memberType.CanInitialize - ? $" = {member.DefaultValueExpr}{bangExpr};" - : ";"; - } - - string parsedMemberIdentifierDeclaration = $"{parsedMemberIdentifierDeclarationPrefix}{declarationSuffix}"; - _writer.WriteLine(parsedMemberIdentifierDeclaration); - _emitBlankLineBeforeNextStatement = false; - } - - bool canBindToMember = this.EmitBindImplForMember( - member, - parsedMemberIdentifier, - sectionPathExpr: GetSectionPathFromConfigurationExpression(member.ConfigurationKeyName), - canSet: true); - - if (canBindToMember) - { - if (errorOnFailedBinding) - { - // Add exception logic for parameter ctors; must be present in configuration object. - EmitThrowBlock(condition: "else"); - } - - _writer.WriteLine(); - } - - void EmitThrowBlock(string condition) => - _writer.WriteLine($$""" - {{condition}} - { - throw new {{GetInvalidOperationDisplayName()}}("{{string.Format(ExceptionMessages.ParameterHasNoMatchingConfig, type.Name, member.Name)}}"); - } - """); - } - } - private void EmitHelperMethods() - { - if (ShouldEmitMethods(MethodsToGen_CoreBindingHelper.BindCore)) - { - EmitValidateConfigurationKeysMethod(); - } - - if (ShouldEmitMethods(MethodsToGen_CoreBindingHelper.BindCoreUntyped | MethodsToGen_CoreBindingHelper.GetCore)) - { - _writer.WriteLine(); - EmitHasValueOrChildrenMethod(); - _writer.WriteLine(); - EmitAsConfigWithChildrenMethod(); - _emitBlankLineBeforeNextStatement = true; - } - else if (ShouldEmitMethods(MethodsToGen_CoreBindingHelper.AsConfigWithChildren)) - { - _writer.WriteLine(); - EmitAsConfigWithChildrenMethod(); - _emitBlankLineBeforeNextStatement = true; - } - - if (ShouldEmitMethods( - MethodsToGen_CoreBindingHelper.BindCoreUntyped | MethodsToGen_CoreBindingHelper.GetCore) || - ShouldEmitMethods(MethodsToGen_ConfigurationBinder.Bind_instance_BinderOptions)) - { - _writer.WriteLine(); - EmitGetBinderOptionsHelper(); - _emitBlankLineBeforeNextStatement = true; - } - - bool enumTypeExists = false; - - foreach (ParsableFromStringSpec type in _sourceGenSpec.PrimitivesForHelperGen) - { - EmitBlankLineIfRequired(); - - if (type.StringParsableTypeKind == StringParsableTypeKind.Enum) - { - if (!enumTypeExists) - { - EmitEnumParseMethod(); - enumTypeExists = true; - } - } - else - { - EmitPrimitiveParseMethod(type); - } - } - } - - private void EmitValidateConfigurationKeysMethod() - { - const string keysIdentifier = "keys"; - string exceptionMessage = string.Format(ExceptionMessages.MissingConfig, Identifier.ErrorOnUnknownConfiguration, Identifier.BinderOptions, $"{{{Identifier.type}}}", $@"{{string.Join("", "", {Identifier.temp})}}"); - - EmitBlankLineIfRequired(); - _writer.WriteLine($$""" - /// If required by the binder options, validates that there are no unknown keys in the input configuration object. - public static void {{Identifier.ValidateConfigurationKeys}}(Type {{Identifier.type}}, {{MinimalDisplayString.LazyHashSetOfString}} {{keysIdentifier}}, {{Identifier.IConfiguration}} {{Identifier.configuration}}, {{Identifier.BinderOptions}}? {{Identifier.binderOptions}}) - { - if ({{Identifier.binderOptions}}?.{{Identifier.ErrorOnUnknownConfiguration}} is true) - { - {{MinimalDisplayString.ListOfString}}? {{Identifier.temp}} = null; - - foreach ({{Identifier.IConfigurationSection}} {{Identifier.section}} in {{Identifier.configuration}}.{{Identifier.GetChildren}}()) - { - if (!{{keysIdentifier}}.Value.Contains({{Expression.sectionKey}})) - { - ({{Identifier.temp}} ??= new {{MinimalDisplayString.ListOfString}}()).Add($"'{{{Expression.sectionKey}}}'"); - } - } - - if ({{Identifier.temp}} is not null) - { - throw new InvalidOperationException($"{{exceptionMessage}}"); - } - } - } - """); - } - - private void EmitHasValueOrChildrenMethod() - { - _writer.WriteLine($$""" - public static bool {{Identifier.HasValueOrChildren}}({{Identifier.IConfiguration}} {{Identifier.configuration}}) - { - if (({{Identifier.configuration}} as {{Identifier.IConfigurationSection}})?.{{Identifier.Value}} is not null) - { - return true; - } - return {{Identifier.AsConfigWithChildren}}({{Identifier.configuration}}) is not null; - } - """); - } - - private void EmitAsConfigWithChildrenMethod() - { - _writer.WriteLine($$""" - public static {{Identifier.IConfiguration}}? {{Identifier.AsConfigWithChildren}}({{Identifier.IConfiguration}} {{Identifier.configuration}}) - { - foreach ({{Identifier.IConfigurationSection}} _ in {{Identifier.configuration}}.{{Identifier.GetChildren}}()) - { - return {{Identifier.configuration}}; - } - return null; - } - """); - } - - private void EmitGetBinderOptionsHelper() - { - _writer.WriteLine($$""" - public static {{Identifier.BinderOptions}}? {{Identifier.GetBinderOptions}}({{MinimalDisplayString.NullableActionOfBinderOptions}} {{Identifier.configureOptions}}) - { - if ({{Identifier.configureOptions}} is null) - { - return null; - } - - {{Identifier.BinderOptions}} {{Identifier.binderOptions}} = new(); - {{Identifier.configureOptions}}({{Identifier.binderOptions}}); - - if ({{Identifier.binderOptions}}.BindNonPublicProperties) - { - throw new global::System.NotSupportedException($"{{string.Format(ExceptionMessages.CannotSpecifyBindNonPublicProperties)}}"); - } - - return {{Identifier.binderOptions}}; - } - """); - } - - private void EmitEnumParseMethod() - { - string innerExceptionTypeDisplayString = _useFullyQualifiedNames ? "global::System.Exception" : "Exception"; - string exceptionArg1 = string.Format(ExceptionMessages.FailedBinding, $"{{{Identifier.getPath}()}}", $"{{typeof(T)}}"); - - _writer.WriteLine($$""" - public static T ParseEnum(string value, Func getPath) where T : struct - { - try - { - #if NETFRAMEWORK || NETSTANDARD2_0 - return (T)Enum.Parse(typeof(T), value, ignoreCase: true); - #else - return Enum.Parse(value, ignoreCase: true); - #endif - } - catch ({{innerExceptionTypeDisplayString}} {{Identifier.exception}}) - { - throw new {{GetInvalidOperationDisplayName()}}($"{{exceptionArg1}}", {{Identifier.exception}}); - } - } - """); - } - - private void EmitPrimitiveParseMethod(ParsableFromStringSpec type) - { - string innerExceptionTypeDisplayString; - string cultureInfoTypeDisplayString; - string numberStylesTypeDisplayString; - - if (_useFullyQualifiedNames) - { - innerExceptionTypeDisplayString = "global::System.Exception"; - cultureInfoTypeDisplayString = "global::System.Globalization.CultureInfo"; - numberStylesTypeDisplayString = "global::System.Globalization.NumberStyles"; - } - else - { - innerExceptionTypeDisplayString = "Exception"; - cultureInfoTypeDisplayString = "CultureInfo"; - numberStylesTypeDisplayString = "NumberStyles"; - } - - StringParsableTypeKind typeKind = type.StringParsableTypeKind; - string typeDisplayString = type.MinimalDisplayString; - - string invariantCultureExpression = $"{cultureInfoTypeDisplayString}.InvariantCulture"; - - string parsedValueExpr; - switch (typeKind) - { - case StringParsableTypeKind.Enum: - return; - case StringParsableTypeKind.ByteArray: - { - parsedValueExpr = $"Convert.FromBase64String({Identifier.value})"; - } - break; - case StringParsableTypeKind.Integer: - { - parsedValueExpr = $"{typeDisplayString}.{Identifier.Parse}({Identifier.value}, {numberStylesTypeDisplayString}.Integer, {invariantCultureExpression})"; - } - break; - case StringParsableTypeKind.Float: - { - parsedValueExpr = $"{typeDisplayString}.{Identifier.Parse}({Identifier.value}, {numberStylesTypeDisplayString}.Float, {invariantCultureExpression})"; - } - break; - case StringParsableTypeKind.Parse: - { - parsedValueExpr = $"{typeDisplayString}.{Identifier.Parse}({Identifier.value})"; - } - break; - case StringParsableTypeKind.ParseInvariant: - { - parsedValueExpr = $"{typeDisplayString}.{Identifier.Parse}({Identifier.value}, {invariantCultureExpression})"; ; - } - break; - case StringParsableTypeKind.CultureInfo: - { - parsedValueExpr = $"{cultureInfoTypeDisplayString}.GetCultureInfo({Identifier.value})"; - } - break; - case StringParsableTypeKind.Uri: - { - parsedValueExpr = $"new Uri({Identifier.value}, UriKind.RelativeOrAbsolute)"; - } - break; - default: - { - Debug.Fail($"Invalid string parsable kind: {typeKind}"); - return; - } - } - - string exceptionArg1 = string.Format(ExceptionMessages.FailedBinding, $"{{{Identifier.getPath}()}}", $"{{typeof({typeDisplayString})}}"); - - EmitStartBlock($"public static {typeDisplayString} {type.ParseMethodName}(string {Identifier.value}, Func {Identifier.getPath})"); - EmitEndBlock($$""" - try - { - return {{parsedValueExpr}}; - } - catch ({{innerExceptionTypeDisplayString}} {{Identifier.exception}}) - { - throw new {{GetInvalidOperationDisplayName()}}($"{{exceptionArg1}}", {{Identifier.exception}}); - } - """); - } - - private void EmitPopulationImplForArray(EnumerableSpec type) - { - EnumerableSpec concreteType = (EnumerableSpec)type.ConcreteType; - - // Create list and bind elements. - string tempIdentifier = GetIncrementalIdentifier(Identifier.temp); - EmitBindCoreCall(concreteType, tempIdentifier, Identifier.configuration, InitializationKind.Declaration); - - // Resize array and add binded elements. - _writer.WriteLine($$""" - {{Identifier.Int32}} {{Identifier.originalCount}} = {{Identifier.obj}}.{{Identifier.Length}}; - {{Identifier.Array}}.{{Identifier.Resize}}(ref {{Identifier.obj}}, {{Identifier.originalCount}} + {{tempIdentifier}}.{{Identifier.Count}}); - {{tempIdentifier}}.{{Identifier.CopyTo}}({{Identifier.obj}}, {{Identifier.originalCount}}); - """); - } - - private void EmitPopulationImplForEnumerableWithAdd(EnumerableSpec type) - { - EmitCollectionCastIfRequired(type, out string objIdentifier); - - Emit_Foreach_Section_In_ConfigChildren_StartBlock(); - - TypeSpec elementType = type.ElementType; - - if (elementType is ParsableFromStringSpec stringParsableType) - { - EmitBindLogicFromString( - stringParsableType, - Expression.sectionValue, - Expression.sectionPath, - (parsedValueExpr) => _writer.WriteLine($"{objIdentifier}.{Identifier.Add}({parsedValueExpr});"), - checkForNullSectionValue: true, - useIncrementalStringValueIdentifier: false); - } - else - { - EmitBindCoreCall(elementType, Identifier.value, Identifier.section, InitializationKind.Declaration); - _writer.WriteLine($"{objIdentifier}.{Identifier.Add}({Identifier.value});"); - } - - EmitEndBlock(); - } - - private void EmitBindCoreImplForDictionary(DictionarySpec type) - { - EmitCollectionCastIfRequired(type, out string objIdentifier); - - Emit_Foreach_Section_In_ConfigChildren_StartBlock(); - - ParsableFromStringSpec keyType = type.KeyType; - TypeSpec elementType = type.ElementType; - - // Parse key - EmitBindLogicFromString( - keyType, - Expression.sectionKey, - Expression.sectionPath, - Emit_BindAndAddLogic_ForElement, - checkForNullSectionValue: false, - useIncrementalStringValueIdentifier: false); - - void Emit_BindAndAddLogic_ForElement(string parsedKeyExpr) - { - if (elementType is ParsableFromStringSpec stringParsableElementType) - { - EmitBindLogicFromString( - stringParsableElementType, - Expression.sectionValue, - Expression.sectionPath, - writeOnSuccess: parsedValueExpr => _writer.WriteLine($"{objIdentifier}[{parsedKeyExpr}] = {parsedValueExpr};"), - checkForNullSectionValue: true, - useIncrementalStringValueIdentifier: false); - } - else // For complex types: - { - Debug.Assert(elementType.CanInitialize); - - if (keyType.StringParsableTypeKind is not StringParsableTypeKind.AssignFromSectionValue) - { - // Save value to local to avoid parsing twice - during look-up and during add. - _writer.WriteLine($"{keyType.MinimalDisplayString} {Identifier.key} = {parsedKeyExpr};"); - parsedKeyExpr = Identifier.key; - } - - bool isValueType = elementType.IsValueType; - string expressionForElementIsNotNull = $"{Identifier.element} is not null"; - string elementTypeDisplayString = elementType.MinimalDisplayString + (elementType.IsValueType ? string.Empty : "?"); - - string expressionForElementExists = $"{objIdentifier}.{Identifier.TryGetValue}({parsedKeyExpr}, out {elementTypeDisplayString} {Identifier.element})"; - string conditionToUseExistingElement = expressionForElementExists; - - // If key already exists, bind to existing element instance if not null (for ref types). - if (!isValueType) - { - conditionToUseExistingElement += $" && {expressionForElementIsNotNull}"; - } - - EmitStartBlock($"if (!({conditionToUseExistingElement}))"); - EmitObjectInit(elementType, Identifier.element, InitializationKind.SimpleAssignment, Identifier.section); - EmitEndBlock(); - - if (elementType is CollectionSpec { InitializationStrategy: InitializationStrategy.ParameterizedConstructor or InitializationStrategy.ToEnumerableMethod } collectionSpec) - { - // This is a read-only collection. If the element exists and is not null, - // we need to copy its contents into a new instance & then append/bind to that. - - string initExpression = collectionSpec.InitializationStrategy is InitializationStrategy.ParameterizedConstructor - ? $"new {collectionSpec.ConcreteType.MinimalDisplayString}({Identifier.element})" - : $"{Identifier.element}.{collectionSpec.ToEnumerableMethodCall!}"; - - _writer.WriteLine($$""" - else - { - {{Identifier.element}} = {{initExpression}}; - } - """); - } - - EmitBindCoreCall(elementType, Identifier.element, Identifier.section, InitializationKind.None); - _writer.WriteLine($"{objIdentifier}[{parsedKeyExpr}] = {Identifier.element};"); - } - } - - EmitEndBlock(); - } - - private void EmitBindCoreImplForObject(ObjectSpec type) - { - Debug.Assert(type.NeedsMemberBinding); - - string keyCacheFieldName = GetConfigKeyCacheFieldName(type); - string validateMethodCallExpr = $"{Identifier.ValidateConfigurationKeys}(typeof({type.MinimalDisplayString}), {keyCacheFieldName}, {Identifier.configuration}, {Identifier.binderOptions});"; - _writer.WriteLine(validateMethodCallExpr); - - foreach (PropertySpec property in type.Properties.Values) - { - bool noSetter_And_IsReadonly = !property.CanSet && property.Type is CollectionSpec { InitializationStrategy: InitializationStrategy.ParameterizedConstructor }; - if (property.ShouldBind() && !noSetter_And_IsReadonly) - { - string containingTypeRef = property.IsStatic ? type.MinimalDisplayString : Identifier.obj; - EmitBindImplForMember( - property, - memberAccessExpr: $"{containingTypeRef}.{property.Name}", - GetSectionPathFromConfigurationExpression(property.ConfigurationKeyName), - canSet: property.CanSet); - } - } - } - - private bool EmitBindImplForMember( - MemberSpec member, - string memberAccessExpr, - string sectionPathExpr, - bool canSet) - { - TypeSpec effectiveMemberType = member.Type.EffectiveType; - - if (effectiveMemberType is ParsableFromStringSpec stringParsableType) - { - if (canSet) - { - bool checkForNullSectionValue = member is ParameterSpec - ? true - : stringParsableType.StringParsableTypeKind is not StringParsableTypeKind.AssignFromSectionValue; - - string nullBangExpr = checkForNullSectionValue ? string.Empty : "!"; - - EmitBlankLineIfRequired(); - EmitBindLogicFromString( - stringParsableType, - $@"{Identifier.configuration}[""{member.ConfigurationKeyName}""]", - sectionPathExpr, - writeOnSuccess: parsedValueExpr => _writer.WriteLine($"{memberAccessExpr} = {parsedValueExpr}{nullBangExpr};"), - checkForNullSectionValue, - useIncrementalStringValueIdentifier: true); - } - - return true; - } - - string sectionParseExpr = GetSectionFromConfigurationExpression(member.ConfigurationKeyName); - - EmitBlankLineIfRequired(); - - if (effectiveMemberType.SpecKind is TypeSpecKind.IConfigurationSection) - { - _writer.WriteLine($"{memberAccessExpr} = {sectionParseExpr};"); - return true; - } - - string sectionValidationCall = $"{Identifier.AsConfigWithChildren}({sectionParseExpr})"; - string sectionIdentifier = GetIncrementalIdentifier(Identifier.section); - - EmitStartBlock($"if ({sectionValidationCall} is {Identifier.IConfigurationSection} {sectionIdentifier})"); - - bool canInit = !EmitInitException(effectiveMemberType); - if (canInit) - { - EmitBindCoreCallForMember(member, memberAccessExpr, sectionIdentifier, canSet); - } - - EmitEndBlock(); - return canInit; - } - - private void EmitBindCoreCallForMember( - MemberSpec member, - string memberAccessExpr, - string configArgExpr, - bool canSet) - { - - TypeSpec memberType = member.Type; - TypeSpec effectiveMemberType = memberType.EffectiveType; - - string tempIdentifier = GetIncrementalIdentifier(Identifier.temp); - InitializationKind initKind; - string targetObjAccessExpr; - - if (effectiveMemberType.IsValueType) - { - if (!canSet) - { - return; - } - - Debug.Assert(canSet); - string effectiveMemberTypeDisplayString = effectiveMemberType.MinimalDisplayString; - initKind = InitializationKind.None; - - if (memberType.SpecKind is TypeSpecKind.Nullable) - { - string nullableTempIdentifier = GetIncrementalIdentifier(Identifier.temp); - - _writer.WriteLine($"{memberType.MinimalDisplayString} {nullableTempIdentifier} = {memberAccessExpr};"); - - _writer.WriteLine( - $"{effectiveMemberTypeDisplayString} {tempIdentifier} = {nullableTempIdentifier}.{Identifier.HasValue} ? {nullableTempIdentifier}.{Identifier.Value} : new {effectiveMemberTypeDisplayString}();"); - } - else - { - _writer.WriteLine($"{effectiveMemberTypeDisplayString} {tempIdentifier} = {memberAccessExpr};"); - } - - targetObjAccessExpr = tempIdentifier; - } - else if (member.CanGet) - { - targetObjAccessExpr = memberAccessExpr; - initKind = InitializationKind.AssignmentWithNullCheck; - } - else - { - targetObjAccessExpr = memberAccessExpr; - initKind = InitializationKind.SimpleAssignment; - } - - Action? writeOnSuccess = !canSet - ? null - : bindedValueIdentifier => - { - if (memberAccessExpr != bindedValueIdentifier) - { - _writer.WriteLine($"{memberAccessExpr} = {bindedValueIdentifier};"); - } - }; - - EmitBindCoreCall( - effectiveMemberType, - targetObjAccessExpr, - configArgExpr, - initKind, - writeOnSuccess); - } - - private void EmitCollectionCastIfRequired(CollectionSpec type, out string objIdentifier) - { - objIdentifier = Identifier.obj; - if (type.PopulationStrategy is CollectionPopulationStrategy.Cast_Then_Add) - { - objIdentifier = Identifier.temp; - _writer.WriteLine($$""" - if ({{Identifier.obj}} is not {{type.PopulationCastType!.MinimalDisplayString}} {{objIdentifier}}) - { - return; - } - """); - _writer.WriteLine(); - } - } - - private void Emit_Foreach_Section_In_ConfigChildren_StartBlock() => - EmitStartBlock($"foreach ({Identifier.IConfigurationSection} {Identifier.section} in {Identifier.configuration}.{Identifier.GetChildren}())"); - - private static string GetSectionPathFromConfigurationExpression(string configurationKeyName) - => $@"{GetSectionFromConfigurationExpression(configurationKeyName)}.{Identifier.Path}"; - - private static string GetSectionFromConfigurationExpression(string configurationKeyName, bool addQuotes = true) - { - string argExpr = addQuotes ? $@"""{configurationKeyName}""" : configurationKeyName; - return $@"{Expression.configurationGetSection}({argExpr})"; - } - - private static string GetConfigKeyCacheFieldName(ObjectSpec type) => - $"s_configKeys_{type.DisplayStringWithoutSpecialCharacters}"; - - private void Emit_NotSupportedException_TypeNotDetectedAsInput() => - _writer.WriteLine(@$"throw new global::System.NotSupportedException($""{string.Format(ExceptionMessages.TypeNotDetectedAsInput, "{type}")}"");"); - } - } -} diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Helpers/Emitter/OptionsBuilderConfigurationExtensions.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Helpers/Emitter/OptionsBuilderConfigurationExtensions.cs deleted file mode 100644 index 71d0b6989dd970..00000000000000 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Helpers/Emitter/OptionsBuilderConfigurationExtensions.cs +++ /dev/null @@ -1,103 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration -{ - public sealed partial class ConfigurationBindingGenerator - { - private sealed partial class Emitter - { - private bool ShouldEmitMethods(MethodsToGen_Extensions_OptionsBuilder methods) => (_sourceGenSpec.MethodsToGen_OptionsBuilderExt & methods) != 0; - - private void EmitBinder_Extensions_OptionsBuilder() - { - if (!ShouldEmitMethods(MethodsToGen_Extensions_OptionsBuilder.Any)) - { - return; - } - - EmitRootBindingClassStartBlock(Identifier.GeneratedOptionsBuilderBinder); - - EmitBindMethods_Extensions_OptionsBuilder(); - EmitBindConfigurationMethod(); - - EmitEndBlock(); - } - - private void EmitBindMethods_Extensions_OptionsBuilder() - { - if (!ShouldEmitMethods(MethodsToGen_Extensions_OptionsBuilder.Bind)) - { - return; - } - - const string documentation = @"/// Registers a configuration instance which will bind against."; - const string paramList = $"{FullyQualifiedDisplayString.IConfiguration} {Identifier.configuration}"; - - if (ShouldEmitMethods(MethodsToGen_Extensions_OptionsBuilder.Bind_T)) - { - EmitMethodStartBlock("Bind", paramList, documentation); - _writer.WriteLine($"return global::{Identifier.GeneratedOptionsBuilderBinder}.Bind({Identifier.optionsBuilder}, {Identifier.configuration}, {Identifier.configureOptions}: null);"); - EmitEndBlock(); - } - - EmitMethodStartBlock( - "Bind", - paramList + $", {FullyQualifiedDisplayString.ActionOfBinderOptions}? {Identifier.configureOptions}", - documentation); - - EmitCheckForNullArgument_WithBlankLine(Identifier.optionsBuilder); - - _writer.WriteLine($$""" - global::{{Identifier.GeneratedServiceCollectionBinder}}.{{Identifier.Configure}}<{{Identifier.TOptions}}>({{Identifier.optionsBuilder}}.{{Identifier.Services}}, {{Identifier.optionsBuilder}}.Name, {{Identifier.configuration}}, {{Identifier.configureOptions}}); - return {{Identifier.optionsBuilder}}; - """); - - EmitEndBlock(); - } - - private void EmitBindConfigurationMethod() - { - if (!ShouldEmitMethods(MethodsToGen_Extensions_OptionsBuilder.BindConfiguration_T_path_BinderOptions)) - { - return; - } - - const string documentation = $@"/// Registers the dependency injection container to bind against the obtained from the DI service provider."; - string paramList = $"string {Identifier.configSectionPath}, {FullyQualifiedDisplayString.ActionOfBinderOptions}? {Identifier.configureOptions} = null"; - - EmitMethodStartBlock("BindConfiguration", paramList, documentation); - - EmitCheckForNullArgument_WithBlankLine(Identifier.optionsBuilder); - EmitCheckForNullArgument_WithBlankLine(Identifier.configSectionPath); - - EmitStartBlock($"{Identifier.optionsBuilder}.{Identifier.Configure}<{FullyQualifiedDisplayString.IConfiguration}>(({Identifier.obj}, {Identifier.configuration}) =>"); - - _writer.WriteLine($$""" - {{FullyQualifiedDisplayString.IConfiguration}} {{Identifier.section}} = string.Equals(string.Empty, {{Identifier.configSectionPath}}, global::System.StringComparison.OrdinalIgnoreCase) ? {{Identifier.configuration}} : {{Identifier.configuration}}.{{Identifier.GetSection}}({{Identifier.configSectionPath}}); - {{FullyQualifiedDisplayString.CoreBindingHelper}}.{{nameof(MethodsToGen_CoreBindingHelper.BindCoreUntyped)}}({{Identifier.section}}, {{Identifier.obj}}, typeof({{Identifier.TOptions}}), {{Identifier.configureOptions}}); - """); - - EmitEndBlock(endBraceTrailingSource: ");"); - - _writer.WriteLine(); - - _writer.WriteLine($$""" - {{FullyQualifiedDisplayString.AddSingleton}}<{{FullyQualifiedDisplayString.IOptionsChangeTokenSource}}<{{Identifier.TOptions}}>, {{FullyQualifiedDisplayString.ConfigurationChangeTokenSource}}<{{Identifier.TOptions}}>>({{Identifier.optionsBuilder}}.{{Identifier.Services}}); - return {{Identifier.optionsBuilder}}; - """); - - EmitEndBlock(); - } - - private void EmitMethodStartBlock(string methodName, string paramList, string documentation) - { - paramList = $"this {FullyQualifiedDisplayString.OptionsBuilderOfTOptions} {Identifier.optionsBuilder}, {paramList}"; - - EmitBlankLineIfRequired(); - _writer.WriteLine(documentation); - EmitStartBlock($"public static {FullyQualifiedDisplayString.OptionsBuilderOfTOptions} {methodName}<{Identifier.TOptions}>({paramList}) where {Identifier.TOptions} : class"); - } - } - } -} diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Helpers/Emitter/OptionsConfigurationServiceCollectionExtensions.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Helpers/Emitter/OptionsConfigurationServiceCollectionExtensions.cs deleted file mode 100644 index f4cd4800df1230..00000000000000 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Helpers/Emitter/OptionsConfigurationServiceCollectionExtensions.cs +++ /dev/null @@ -1,80 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration -{ - public sealed partial class ConfigurationBindingGenerator - { - private sealed partial class Emitter - { - private bool ShouldEmitMethods(MethodsToGen_Extensions_ServiceCollection methods) => (_sourceGenSpec.MethodsToGen_ServiceCollectionExt & methods) != 0; - - private void EmitBinder_Extensions_IServiceCollection() - { - if (!ShouldEmitMethods(MethodsToGen_Extensions_ServiceCollection.Any)) - { - return; - } - - EmitRootBindingClassStartBlock(Identifier.GeneratedServiceCollectionBinder); - - const string defaultNameExpr = "string.Empty"; - const string configureMethodString = $"global::{Identifier.GeneratedServiceCollectionBinder}.{Identifier.Configure}"; - string configParam = $"{FullyQualifiedDisplayString.IConfiguration} {Identifier.configuration}"; - - if (ShouldEmitMethods(MethodsToGen_Extensions_ServiceCollection.Configure_T)) - { - EmitStartMethod(configParam); - _writer.WriteLine($"return {configureMethodString}<{Identifier.TOptions}>({Identifier.services}, {defaultNameExpr}, {Identifier.configuration}, {Identifier.configureOptions}: null);"); - EmitEndBlock(); - } - - if (ShouldEmitMethods(MethodsToGen_Extensions_ServiceCollection.Configure_T_name)) - { - EmitStartMethod( - paramList: $"string? {Identifier.name}, " + configParam); - _writer.WriteLine($"return {configureMethodString}<{Identifier.TOptions}>({Identifier.services}, {Identifier.name}, {Identifier.configuration}, {Identifier.configureOptions}: null);"); - EmitEndBlock(); - } - - if (ShouldEmitMethods(MethodsToGen_Extensions_ServiceCollection.Configure_T_BinderOptions)) - { - EmitStartMethod( - paramList: configParam + $", {FullyQualifiedDisplayString.ActionOfBinderOptions}? {Identifier.configureOptions}"); - _writer.WriteLine($"return {configureMethodString}<{Identifier.TOptions}>({Identifier.services}, {defaultNameExpr}, {Identifier.configuration}, {Identifier.configureOptions});"); - EmitEndBlock(); - } - - // Core Configure method that the other overloads call. - // Like the others, it is public API that could be called directly by users. - // So, it is always generated whenever a Configure overload is called. - string optionsNamespaceName = "global::Microsoft.Extensions.Options"; - string bindCoreUntypedDisplayString = GetHelperMethodDisplayString(nameof(MethodsToGen_CoreBindingHelper.BindCoreUntyped)); - - EmitStartMethod(paramList: $"string? {Identifier.name}, " + configParam + $", {FullyQualifiedDisplayString.ActionOfBinderOptions}? {Identifier.configureOptions}"); - EmitCheckForNullArgument_WithBlankLine(Identifier.services); - EmitCheckForNullArgument_WithBlankLine(Identifier.configuration); - _writer.WriteLine($$""" - global::Microsoft.Extensions.DependencyInjection.OptionsServiceCollectionExtensions.AddOptions({{Identifier.services}}); - {{FullyQualifiedDisplayString.AddSingleton}}<{{FullyQualifiedDisplayString.IOptionsChangeTokenSource}}<{{Identifier.TOptions}}>>({{Identifier.services}}, new {{FullyQualifiedDisplayString.ConfigurationChangeTokenSource}}<{{Identifier.TOptions}}>({{Identifier.name}}, {{Identifier.configuration}})); - return {{FullyQualifiedDisplayString.AddSingleton}}<{{optionsNamespaceName}}.IConfigureOptions<{{Identifier.TOptions}}>>({{Identifier.services}}, new {{optionsNamespaceName}}.ConfigureNamedOptions<{{Identifier.TOptions}}>({{Identifier.name}}, {{Identifier.obj}} => {{bindCoreUntypedDisplayString}}({{Identifier.configuration}}, {{Identifier.obj}}, typeof({{Identifier.TOptions}}), {{Identifier.configureOptions}}))); - """); - EmitEndBlock(); - - EmitEndBlock(); - _emitBlankLineBeforeNextStatement = true; - } - - private void EmitStartMethod(string paramList) - { - paramList = $"this {FullyQualifiedDisplayString.IServiceCollection} {Identifier.services}, {paramList}"; - - EmitBlankLineIfRequired(); - EmitStartBlock($$""" - /// Registers a configuration instance which TOptions will bind against. - public static {{FullyQualifiedDisplayString.IServiceCollection}} {{Identifier.Configure}}<{{Identifier.TOptions}}>({{paramList}}) where {{Identifier.TOptions}} : class - """); - } - } - } -} diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Helpers/MethodsToGen.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Helpers/MethodsToGen.cs deleted file mode 100644 index 7b40f198e08f23..00000000000000 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Helpers/MethodsToGen.cs +++ /dev/null @@ -1,151 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System; - -namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration -{ - [Flags] - public enum MethodsToGen_CoreBindingHelper - { - None = 0x0, - BindCore = 0x1, - BindCoreUntyped = 0x2, - GetCore = 0x4, - GetValueCore = 0x8, - Initialize = 0x10, - AsConfigWithChildren = 0x20, - } - - /// - /// Methods on Microsoft.Extensions.Configuration.ConfigurationBinder - /// - [Flags] - internal enum MethodsToGen_ConfigurationBinder - { - None = 0x0, - - /// - /// Bind(IConfiguration, object). - /// - Bind_instance = 0x1, - - /// - /// Bind(IConfiguration, object, Action). - /// - Bind_instance_BinderOptions = 0x2, - - /// - /// Bind(IConfiguration, string, object). - /// - Bind_key_instance = 0x4, - - /// - /// Get(IConfiguration). - /// - Get_T = 0x8, - - /// - /// Get(IConfiguration, Action). - /// - Get_T_BinderOptions = 0x10, - - /// - /// Get(IConfiguration, Type). - /// - Get_TypeOf = 0x20, - - /// - /// Get(IConfiguration, Type, Action). - /// - Get_TypeOf_BinderOptions = 0x40, - - /// - /// GetValue(IConfiguration, Type, string). - /// - GetValue_TypeOf_key = 0x80, - - /// - /// GetValue(IConfiguration, Type, object). - /// - GetValue_TypeOf_key_defaultValue = 0x100, - - /// - /// GetValue(IConfiguration, string). - /// - GetValue_T_key = 0x200, - - /// - /// GetValue(IConfiguration, string, T). - /// - GetValue_T_key_defaultValue = 0x400, - - // Method groups - Bind = Bind_instance | Bind_instance_BinderOptions | Bind_key_instance, - Get = Get_T | Get_T_BinderOptions | Get_TypeOf | Get_TypeOf_BinderOptions, - GetValue = GetValue_T_key | GetValue_T_key_defaultValue | GetValue_TypeOf_key | GetValue_TypeOf_key_defaultValue, - - Any = Bind | Get | GetValue, - } - - [Flags] - internal enum MethodsToGen_Extensions_OptionsBuilder - { - None = 0x0, - - /// - /// Bind(OptionsBuilder, IConfiguration). - /// - Bind_T = 0x1, - - /// - /// Bind(OptionsBuilder, IConfiguration, Action?). - /// - Bind_T_BinderOptions = 0x2, - - /// - /// BindConfiguration(OptionsBuilder, string, Action?). - /// - BindConfiguration_T_path_BinderOptions = 0x4, - - // Method group. BindConfiguration_T is its own method group. - Bind = Bind_T | Bind_T_BinderOptions, - - BindConfiguration = BindConfiguration_T_path_BinderOptions, - - Any = Bind | BindConfiguration, - } - - /// - /// Methods on Microsoft.Extensions.DependencyInjection.OptionsConfigurationServiceCollectionExtensions - /// - [Flags] - public enum MethodsToGen_Extensions_ServiceCollection - { - None = 0x0, - - /// - /// Configure(IServiceCollection, IConfiguration). - /// - Configure_T = 0x1, - - /// - /// Configure(IServiceCollection, string, IConfiguration). - /// - Configure_T_name = 0x2, - - /// - /// Configure(IServiceCollection, IConfiguration, Action?). - /// - Configure_T_BinderOptions = 0x4, - - /// - /// Configure(IServiceCollection, string, IConfiguration, Action?). - /// - Configure_T_name_BinderOptions = 0x8, - - Configure = Configure_T | Configure_T_name | Configure_T_BinderOptions | Configure_T_name_BinderOptions, - - Any = Configure, - } -} diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Helpers/Parser/BinderInvocation.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Helpers/Parser/BinderInvocation.cs deleted file mode 100644 index 3029a8de34f9e8..00000000000000 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Helpers/Parser/BinderInvocation.cs +++ /dev/null @@ -1,99 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.Threading; -using Microsoft.CodeAnalysis.CSharp.Syntax; -using Microsoft.CodeAnalysis.Operations; -using Microsoft.CodeAnalysis; - -namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration -{ - internal sealed record BinderInvocation - { - public IInvocationOperation Operation { get; private set; } - public Location? Location { get; private set; } - - public static BinderInvocation? Create(GeneratorSyntaxContext context, CancellationToken cancellationToken) - { - if (!IsCandidateInvocationExpressionSyntax(context.Node, out InvocationExpressionSyntax? invocationSyntax) || - context.SemanticModel.GetOperation(invocationSyntax, cancellationToken) is not IInvocationOperation operation || - !IsCandidateInvocation(operation)) - { - return null; - } - - return new BinderInvocation() - { - Operation = operation, - Location = invocationSyntax.GetLocation() - }; - } - - private static bool IsCandidateInvocationExpressionSyntax(SyntaxNode node, out InvocationExpressionSyntax? invocationSyntax) - { - if (node is InvocationExpressionSyntax - { - Expression: MemberAccessExpressionSyntax - { - Name.Identifier.ValueText: string memberName - } - } syntax && IsCandidateBindingMethodName(memberName)) - { - invocationSyntax = syntax; - return true; - } - - invocationSyntax = null; - return false; - - static bool IsCandidateBindingMethodName(string name) => - IsCandidateMethodName_ConfigurationBinder(name) || - IsCandidateMethodName_OptionsBuilderConfigurationExtensions(name) || - IsValidMethodName_OptionsConfigurationServiceCollectionExtensions(name); - } - - private static bool IsCandidateInvocation(IInvocationOperation operation) - { - if (operation.TargetMethod is not IMethodSymbol - { - IsExtensionMethod: true, - Name: string methodName, - ContainingType: ITypeSymbol - { - Name: string containingTypeName, - ContainingNamespace: INamespaceSymbol { } containingNamespace, - } containingType - } method || - containingNamespace.ToDisplayString() is not string containingNamespaceName) - { - return false; - } - - return (containingTypeName) switch - { - "ConfigurationBinder" => - containingNamespaceName is "Microsoft.Extensions.Configuration" && - IsCandidateMethodName_ConfigurationBinder(methodName), - "OptionsBuilderConfigurationExtensions" => - containingNamespaceName is "Microsoft.Extensions.DependencyInjection" && - IsCandidateMethodName_OptionsBuilderConfigurationExtensions(methodName), - "OptionsConfigurationServiceCollectionExtensions" => - containingNamespaceName is "Microsoft.Extensions.DependencyInjection" && - IsValidMethodName_OptionsConfigurationServiceCollectionExtensions(methodName), - _ => false, - }; - } - - private static bool IsCandidateMethodName_ConfigurationBinder(string name) => name is - nameof(MethodsToGen_ConfigurationBinder.Bind) or - nameof(MethodsToGen_ConfigurationBinder.Get) or - nameof(MethodsToGen_ConfigurationBinder.GetValue); - - private static bool IsCandidateMethodName_OptionsBuilderConfigurationExtensions(string name) => name is - nameof(MethodsToGen_Extensions_OptionsBuilder.Bind) or - nameof(MethodsToGen_Extensions_OptionsBuilder.BindConfiguration); - - private static bool IsValidMethodName_OptionsConfigurationServiceCollectionExtensions(string name) => name is - nameof(MethodsToGen_Extensions_ServiceCollection.Configure); - } -} diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Helpers/Parser/OptionsBuilderConfigurationExtensions.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Helpers/Parser/OptionsBuilderConfigurationExtensions.cs deleted file mode 100644 index d01e80d708ca11..00000000000000 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Helpers/Parser/OptionsBuilderConfigurationExtensions.cs +++ /dev/null @@ -1,96 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.Collections.Immutable; -using System.Diagnostics; -using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.Operations; - -namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration -{ - public sealed partial class ConfigurationBindingGenerator - { - private sealed partial class Parser - { - private void RegisterMethodInvocation_OptionsBuilderExt(BinderInvocation invocation) - { - IMethodSymbol targetMethod = invocation.Operation.TargetMethod; - ImmutableArray @params = targetMethod.Parameters; - - if (!targetMethod.IsGenericMethod || - @params.Length < 2 || - @params[0].Type is not INamedTypeSymbol { IsGenericType: true } genericType || - !SymbolEqualityComparer.Default.Equals(_typeSymbols.OptionsBuilderOfT_Unbound, genericType.ConstructUnboundGenericType())) - { - return; - } - - TypeSpec typeSpec = GetTargetTypeForRootInvocation( - type: targetMethod.TypeArguments[0].WithNullableAnnotation(NullableAnnotation.None), - invocation.Location); - - if (typeSpec is null) - { - return; - } - - // We are going to emit calls to APIs on IServiceCollection. - _sourceGenSpec.TypeNamespaces.Add("Microsoft.Extensions.DependencyInjection"); - - if (targetMethod.Name is "Bind") - { - RegisterBindInvocation(invocation, typeSpec); - } - else if (targetMethod.Name is "BindConfiguration") - { - ParseBindConfigurationInvocation(invocation, typeSpec); - } - } - - private void RegisterBindInvocation(BinderInvocation invocation, TypeSpec typeSpec) - { - IInvocationOperation operation = invocation.Operation!; - IMethodSymbol targetMethod = operation.TargetMethod; - ImmutableArray @params = targetMethod.Parameters; - int paramCount = @params.Length; - - Debug.Assert(paramCount >= 2); - - if (!SymbolEqualityComparer.Default.Equals(_typeSymbols.IConfiguration, @params[1].Type)) - { - return; - } - - if (paramCount is 2) - { - _sourceGenSpec.MethodsToGen_OptionsBuilderExt |= MethodsToGen_Extensions_OptionsBuilder.Bind_T; - } - else if (paramCount is 3 && SymbolEqualityComparer.Default.Equals(_typeSymbols.ActionOfBinderOptions, @params[2].Type)) - { - _sourceGenSpec.MethodsToGen_OptionsBuilderExt |= MethodsToGen_Extensions_OptionsBuilder.Bind_T_BinderOptions; - } - else - { - return; - } - - RegisterTypeForMethodGen(MethodsToGen_Extensions_ServiceCollection.Configure_T_name_BinderOptions, typeSpec); - } - - private void ParseBindConfigurationInvocation(BinderInvocation invocation, TypeSpec typeSpec) - { - IMethodSymbol targetMethod = invocation.Operation.TargetMethod; - ImmutableArray @params = targetMethod.Parameters; - - int paramCount = @params.Length; - Debug.Assert(paramCount >= 2); - - if (paramCount is 3 && @params[1].Type.SpecialType is SpecialType.System_String && SymbolEqualityComparer.Default.Equals(_typeSymbols.ActionOfBinderOptions, @params[2].Type)) - { - _sourceGenSpec.MethodsToGen_OptionsBuilderExt |= MethodsToGen_Extensions_OptionsBuilder.BindConfiguration_T_path_BinderOptions; - RegisterTypeForBindCoreUntypedGen(typeSpec); - } - } - } - } -} diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Microsoft.Extensions.Configuration.Binder.SourceGeneration.csproj b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Microsoft.Extensions.Configuration.Binder.SourceGeneration.csproj index 785a18c5c0978e..764682b43daa86 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Microsoft.Extensions.Configuration.Binder.SourceGeneration.csproj +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Microsoft.Extensions.Configuration.Binder.SourceGeneration.csproj @@ -9,46 +9,60 @@ $(DefineConstants);LAUNCH_DEBUGGER + + + $(NetCoreAppToolCurrent);netstandard2.0 + + + + - - - + + + + + + + + - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Model/CollectionSpec.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Model/CollectionSpec.cs deleted file mode 100644 index 280ecc4c536482..00000000000000 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Model/CollectionSpec.cs +++ /dev/null @@ -1,51 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using Microsoft.CodeAnalysis; - -namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration -{ - internal abstract record CollectionSpec : TypeSpec - { - public CollectionSpec(ITypeSymbol type) : base(type) { } - - public required TypeSpec ElementType { get; init; } - - public CollectionSpec? ConcreteType { get; set; } - - public CollectionSpec? PopulationCastType { get; set; } - - public required CollectionPopulationStrategy PopulationStrategy { get; init; } - - public override bool CanInitialize => ConcreteType?.CanInitialize ?? CanInitComplexObject(); - - public override required InitializationStrategy InitializationStrategy { get; set; } - - public required string? ToEnumerableMethodCall { get; init; } - - public sealed override bool NeedsMemberBinding => true; - } - - internal sealed record EnumerableSpec : CollectionSpec - { - public EnumerableSpec(ITypeSymbol type) : base(type) { } - - public override TypeSpecKind SpecKind => TypeSpecKind.Enumerable; - } - - internal sealed record DictionarySpec : CollectionSpec - { - public DictionarySpec(INamedTypeSymbol type) : base(type) { } - - public override TypeSpecKind SpecKind => TypeSpecKind.Dictionary; - - public required ParsableFromStringSpec KeyType { get; init; } - } - - internal enum CollectionPopulationStrategy - { - Unknown = 0, - Add = 1, - Cast_Then_Add = 2, - } -} diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Model/ConfigurationSectionSpec.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Model/ConfigurationSectionSpec.cs deleted file mode 100644 index ed1fcac7636ba7..00000000000000 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Model/ConfigurationSectionSpec.cs +++ /dev/null @@ -1,14 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using Microsoft.CodeAnalysis; - -namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration -{ - internal sealed record ConfigurationSectionSpec : TypeSpec - { - public ConfigurationSectionSpec(ITypeSymbol type) : base(type) { } - - public override TypeSpecKind SpecKind => TypeSpecKind.IConfigurationSection; - } -} diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Model/InitializationStrategy.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Model/InitializationStrategy.cs deleted file mode 100644 index 866dd254e0181e..00000000000000 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Model/InitializationStrategy.cs +++ /dev/null @@ -1,14 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration -{ - internal enum InitializationStrategy - { - None = 0, - ParameterlessConstructor = 1, - ParameterizedConstructor = 2, - ToEnumerableMethod = 3, - Array = 4, - } -} diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Model/NullableSpec.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Model/NullableSpec.cs deleted file mode 100644 index 9dcca27596e7ee..00000000000000 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Model/NullableSpec.cs +++ /dev/null @@ -1,21 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using Microsoft.CodeAnalysis; - -namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration -{ - internal sealed record NullableSpec : TypeSpec - { - private readonly TypeSpec _underlyingType; - - public NullableSpec(ITypeSymbol type, TypeSpec underlyingType) : base(type) - { - _underlyingType = underlyingType; - } - - public override TypeSpecKind SpecKind => TypeSpecKind.Nullable; - - public override TypeSpec EffectiveType => _underlyingType; - } -} diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Model/ObjectSpec.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Model/ObjectSpec.cs deleted file mode 100644 index 1696ee099fe469..00000000000000 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Model/ObjectSpec.cs +++ /dev/null @@ -1,33 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System; -using System.Collections.Generic; -using System.Linq; -using Microsoft.CodeAnalysis; - -namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration -{ - internal sealed record ObjectSpec : TypeSpec - { - public ObjectSpec(INamedTypeSymbol type) : base(type) { } - - public override TypeSpecKind SpecKind => TypeSpecKind.Object; - - public override InitializationStrategy InitializationStrategy { get; set; } - - public override bool CanInitialize => CanInitComplexObject(); - - public Dictionary Properties { get; } = new(StringComparer.OrdinalIgnoreCase); - - public List ConstructorParameters { get; } = new(); - - private string _displayStringWithoutSpecialCharacters; - public string DisplayStringWithoutSpecialCharacters => - _displayStringWithoutSpecialCharacters ??= $"{MinimalDisplayString.Replace(".", string.Empty).Replace("<", string.Empty).Replace(">", string.Empty)}"; - - public override bool NeedsMemberBinding => CanInitialize && - Properties.Values.Count > 0 && - Properties.Values.Any(p => p.ShouldBind()); - } -} diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Model/ParsableFromStringSpec.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Model/ParsableFromStringSpec.cs deleted file mode 100644 index 6b5bb5b61ea371..00000000000000 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Model/ParsableFromStringSpec.cs +++ /dev/null @@ -1,51 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.Diagnostics; -using Microsoft.CodeAnalysis; - -namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration -{ - internal sealed record ParsableFromStringSpec : TypeSpec - { - public ParsableFromStringSpec(ITypeSymbol type) : base(type) { } - - public override TypeSpecKind SpecKind => TypeSpecKind.ParsableFromString; - - public required StringParsableTypeKind StringParsableTypeKind { get; init; } - - private string? _parseMethodName; - public string ParseMethodName - { - get - { - Debug.Assert(StringParsableTypeKind is not StringParsableTypeKind.AssignFromSectionValue); - - _parseMethodName ??= StringParsableTypeKind is StringParsableTypeKind.ByteArray - ? "ParseByteArray" - // MinimalDisplayString.Length is certainly > 2. - : $"Parse{(char.ToUpper(MinimalDisplayString[0]) + MinimalDisplayString.Substring(1)).Replace(".", "")}"; - - return _parseMethodName; - } - } - } - - internal enum StringParsableTypeKind - { - None = 0, - - /// - /// Declared types that can be assigned directly from IConfigurationSection.Value, i.e. string and tyepof(object). - /// - AssignFromSectionValue = 1, - Enum = 2, - ByteArray = 3, - Integer = 4, - Float = 5, - Parse = 6, - ParseInvariant = 7, - CultureInfo = 8, - Uri = 9, - } -} diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Model/SourceGenerationSpec.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Model/SourceGenerationSpec.cs deleted file mode 100644 index 88c4b24f57a5ea..00000000000000 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Model/SourceGenerationSpec.cs +++ /dev/null @@ -1,27 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.Collections.Generic; - -namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration -{ - internal sealed record SourceGenerationSpec - { - public Dictionary> TypesForGen_CoreBindingHelper_Methods { get; } = new(); - public Dictionary> TypesForGen_ConfigurationBinder_BindMethods { get; } = new(); - - public HashSet PrimitivesForHelperGen { get; } = new(); - public HashSet TypeNamespaces { get; } = new() - { - "System", - "System.CodeDom.Compiler", - "System.Globalization", - "Microsoft.Extensions.Configuration", - }; - - public MethodsToGen_CoreBindingHelper MethodsToGen_CoreBindingHelper { get; set; } - public MethodsToGen_ConfigurationBinder MethodsToGen_ConfigurationBinder { get; set; } - public MethodsToGen_Extensions_OptionsBuilder MethodsToGen_OptionsBuilderExt { get; set; } - public MethodsToGen_Extensions_ServiceCollection MethodsToGen_ServiceCollectionExt { get; set; } - } -} diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Model/TypeSpec.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Model/TypeSpec.cs deleted file mode 100644 index 6a6292b7ebd0b4..00000000000000 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Model/TypeSpec.cs +++ /dev/null @@ -1,63 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using Microsoft.CodeAnalysis; - -namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration -{ - internal abstract record TypeSpec - { - private static readonly SymbolDisplayFormat s_minimalDisplayFormat = new SymbolDisplayFormat( - globalNamespaceStyle: SymbolDisplayGlobalNamespaceStyle.Omitted, - typeQualificationStyle: SymbolDisplayTypeQualificationStyle.NameAndContainingTypes, - genericsOptions: SymbolDisplayGenericsOptions.IncludeTypeParameters, - miscellaneousOptions: SymbolDisplayMiscellaneousOptions.UseSpecialTypes); - - public TypeSpec(ITypeSymbol type) - { - IsValueType = type.IsValueType; - Namespace = type.ContainingNamespace?.ToDisplayString(); - FullyQualifiedDisplayString = type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - MinimalDisplayString = type.ToDisplayString(s_minimalDisplayFormat); - Name = Namespace + "." + MinimalDisplayString.Replace(".", "+"); - IsInterface = type.TypeKind is TypeKind.Interface; - } - - public string Name { get; } - - public string FullyQualifiedDisplayString { get; } - - public string MinimalDisplayString { get; } - - public string? Namespace { get; } - - public bool IsValueType { get; } - - public abstract TypeSpecKind SpecKind { get; } - - public virtual InitializationStrategy InitializationStrategy { get; set; } - - public virtual string? InitExceptionMessage { get; set; } - - public virtual bool CanInitialize => true; - - public virtual bool NeedsMemberBinding { get; } - - public virtual TypeSpec EffectiveType => this; - - public bool IsInterface { get; } - - protected bool CanInitComplexObject() => InitializationStrategy is not InitializationStrategy.None && InitExceptionMessage is null; - } - - internal enum TypeSpecKind - { - Unknown = 0, - ParsableFromString = 1, - Object = 2, - Enumerable = 3, - Dictionary = 4, - IConfigurationSection = 5, - Nullable = 6, - } -} diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Parser/BinderInvocation.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Parser/BinderInvocation.cs new file mode 100644 index 00000000000000..b1cf51acb3b4a6 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Parser/BinderInvocation.cs @@ -0,0 +1,91 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics; +using System.Threading; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Operations; +using Microsoft.CodeAnalysis; + +namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration +{ + internal sealed class BinderInvocation + { + private BinderInvocation(IInvocationOperation operation, Location location) + { + Operation = operation; + Location = location; + } + + public IInvocationOperation Operation { get; } + public Location Location { get; } + + public static BinderInvocation? Create(GeneratorSyntaxContext context, CancellationToken cancellationToken) + { + Debug.Assert(IsCandidateSyntaxNode(context.Node)); + InvocationExpressionSyntax invocationSyntax = (InvocationExpressionSyntax)context.Node; + + return context.SemanticModel.GetOperation(invocationSyntax, cancellationToken) is IInvocationOperation operation && + IsBindingOperation(operation) + ? new BinderInvocation(operation, invocationSyntax.GetLocation()) + : null; + } + + public static bool IsCandidateSyntaxNode(SyntaxNode node) + { + return node is InvocationExpressionSyntax + { + // TODO: drill further into this evaluation for a declaring-type name check. + // https://github.com/dotnet/runtime/issues/90687. + Expression: MemberAccessExpressionSyntax + { + Name.Identifier.ValueText: string memberName, + } + } && IsCandidateBindingMethodName(memberName); + + static bool IsCandidateBindingMethodName(string name) => + IsValidMethodName_ConfigurationBinder(name) || + IsValidMethodName_OptionsBuilderConfigurationExtensions(name) || + IsValidMethodName_OptionsConfigurationServiceCollectionExtensions(name); + } + + public static bool IsBindingOperation(IInvocationOperation operation) + { + if (operation.TargetMethod is not IMethodSymbol + { + IsExtensionMethod: true, + Name: string methodName, + ContainingType: INamedTypeSymbol + { + Name: string containingTypeName, + ContainingNamespace: INamespaceSymbol containingNamespace, + } + }) + { + return false; + } + + string containingNamespaceName = containingNamespace.ToDisplayString(); + + return (containingTypeName) switch + { + "ConfigurationBinder" => + containingNamespaceName is "Microsoft.Extensions.Configuration" && + IsValidMethodName_ConfigurationBinder(methodName), + "OptionsBuilderConfigurationExtensions" => + containingNamespaceName is "Microsoft.Extensions.DependencyInjection" && + IsValidMethodName_OptionsBuilderConfigurationExtensions(methodName), + "OptionsConfigurationServiceCollectionExtensions" => + containingNamespaceName is "Microsoft.Extensions.DependencyInjection" && + IsValidMethodName_OptionsConfigurationServiceCollectionExtensions(methodName), + _ => false, + }; + } + + private static bool IsValidMethodName_ConfigurationBinder(string name) => name is "Bind" or "Get" or "GetValue"; + + private static bool IsValidMethodName_OptionsBuilderConfigurationExtensions(string name) => name is "Bind" or "BindConfiguration"; + + private static bool IsValidMethodName_OptionsConfigurationServiceCollectionExtensions(string name) => name is "Configure"; + } +} diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Helpers/Parser/ConfigurationBinder.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Parser/ConfigurationBinder.cs similarity index 54% rename from src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Helpers/Parser/ConfigurationBinder.cs rename to src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Parser/ConfigurationBinder.cs index a663c441c55ce3..645786e35c1c55 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Helpers/Parser/ConfigurationBinder.cs +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Parser/ConfigurationBinder.cs @@ -2,43 +2,41 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; -using System.Collections.Generic; using System.Collections.Immutable; using System.Linq; using Microsoft.CodeAnalysis.Operations; using Microsoft.CodeAnalysis; +using System.Diagnostics; namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration { public sealed partial class ConfigurationBindingGenerator { - private sealed partial class Parser + internal sealed partial class Parser { - private void RegisterMethodInvocation_ConfigurationBinder(BinderInvocation invocation) + private void ParseInvocation_ConfigurationBinder(BinderInvocation invocation) { switch (invocation.Operation.TargetMethod.Name) { - case nameof(MethodsToGen_ConfigurationBinder.Bind): + case "Bind": { - RegisterBindInvocation(invocation); + ParseBindInvocation_ConfigurationBinder(invocation); } break; - case nameof(MethodsToGen_ConfigurationBinder.Get): + case "Get": { - RegisterGetInvocation(invocation); + ParseGetInvocation(invocation); } break; - case nameof(MethodsToGen_ConfigurationBinder.GetValue): + case "GetValue": { - RegisterGetValueInvocation(invocation); + ParseGetValueInvocation(invocation); } break; - default: - return; } } - private void RegisterBindInvocation(BinderInvocation invocation) + private void ParseBindInvocation_ConfigurationBinder(BinderInvocation invocation) { IInvocationOperation operation = invocation.Operation!; ImmutableArray @params = operation.TargetMethod.Parameters; @@ -49,69 +47,58 @@ private void RegisterBindInvocation(BinderInvocation invocation) return; } - MethodsToGen_ConfigurationBinder overload = MethodsToGen_ConfigurationBinder.None; + MethodsToGen overload = MethodsToGen.None; if (paramCount is 2) { - overload = MethodsToGen_ConfigurationBinder.Bind_instance; + overload = MethodsToGen.ConfigBinder_Bind_instance; } else if (paramCount is 3) { if (@params[1].Type.SpecialType is SpecialType.System_String) { - overload = MethodsToGen_ConfigurationBinder.Bind_key_instance; + overload = MethodsToGen.ConfigBinder_Bind_key_instance; } else if (SymbolEqualityComparer.Default.Equals(@params[2].Type, _typeSymbols.ActionOfBinderOptions)) { - overload = MethodsToGen_ConfigurationBinder.Bind_instance_BinderOptions; + overload = MethodsToGen.ConfigBinder_Bind_instance_BinderOptions; } } - if (overload is MethodsToGen_ConfigurationBinder.None) + if (overload is MethodsToGen.None) { return; } - int objectIndex = overload switch + int instanceIndex = overload switch { - MethodsToGen_ConfigurationBinder.Bind_instance => 1, - MethodsToGen_ConfigurationBinder.Bind_instance_BinderOptions => 1, - MethodsToGen_ConfigurationBinder.Bind_key_instance => 2, + MethodsToGen.ConfigBinder_Bind_instance => 1, + MethodsToGen.ConfigBinder_Bind_instance_BinderOptions => 1, + MethodsToGen.ConfigBinder_Bind_key_instance => 2, _ => throw new InvalidOperationException() }; - IArgumentOperation objectArg = operation.Arguments[objectIndex]; - if (objectArg.Parameter.Type.SpecialType != SpecialType.System_Object) + IArgumentOperation instanceArg = GetArgumentForParameterAtIndex(operation.Arguments, instanceIndex); + if (instanceArg.Parameter?.Type.SpecialType is not SpecialType.System_Object) { return; } - ITypeSymbol? type = ResolveType(objectArg.Value)?.WithNullableAnnotation(NullableAnnotation.None); + ITypeSymbol? type = ResolveType(instanceArg.Value)?.WithNullableAnnotation(NullableAnnotation.None); if (!IsValidRootConfigType(type)) { - _context.ReportDiagnostic(Diagnostic.Create(Diagnostics.CouldNotDetermineTypeInfo, invocation.Location)); + RecordDiagnostic(DiagnosticDescriptors.CouldNotDetermineTypeInfo, invocation.Location); return; } - if (type!.IsValueType) + if (type.IsValueType) { - _context.ReportDiagnostic(Diagnostic.Create(Diagnostics.ValueTypesInvalidForBind, invocation.Location, type)); + RecordDiagnostic(DiagnosticDescriptors.ValueTypesInvalidForBind, invocation.Location, messageArgs: new object[] { type }); return; } - if (GetTargetTypeForRootInvocationCore(type, invocation.Location) is TypeSpec typeSpec) - { - Dictionary> types = _sourceGenSpec.TypesForGen_ConfigurationBinder_BindMethods; - - if (!types.TryGetValue(overload, out HashSet? typeSpecs)) - { - types[overload] = typeSpecs = new HashSet(); - } - - _sourceGenSpec.MethodsToGen_ConfigurationBinder |= overload; - typeSpecs.Add(typeSpec); - } + EnqueueTargetTypeForRootInvocation(type, overload, invocation); static ITypeSymbol? ResolveType(IOperation conversionOperation) => conversionOperation switch @@ -130,7 +117,20 @@ private void RegisterBindInvocation(BinderInvocation invocation) }; } - private void RegisterGetInvocation(BinderInvocation invocation) + private static IArgumentOperation GetArgumentForParameterAtIndex(ImmutableArray arguments, int parameterIndex) + { + foreach (var argument in arguments) + { + if (argument.Parameter?.Ordinal == parameterIndex) + { + return argument; + } + } + + throw new InvalidOperationException(); + } + + private void ParseGetInvocation(BinderInvocation invocation) { IInvocationOperation operation = invocation.Operation!; IMethodSymbol targetMethod = operation.TargetMethod; @@ -142,7 +142,7 @@ private void RegisterGetInvocation(BinderInvocation invocation) return; } - MethodsToGen_ConfigurationBinder overload = MethodsToGen_ConfigurationBinder.None; + MethodsToGen overload = MethodsToGen.None; ITypeSymbol? type; if (targetMethod.IsGenericMethod) @@ -156,11 +156,11 @@ private void RegisterGetInvocation(BinderInvocation invocation) if (paramCount is 1) { - overload = MethodsToGen_ConfigurationBinder.Get_T; + overload = MethodsToGen.ConfigBinder_Get_T; } else if (paramCount is 2 && SymbolEqualityComparer.Default.Equals(@params[1].Type, _typeSymbols.ActionOfBinderOptions)) { - overload = MethodsToGen_ConfigurationBinder.Get_T_BinderOptions; + overload = MethodsToGen.ConfigBinder_Get_T_BinderOptions; } } else if (paramCount > 3) @@ -169,35 +169,30 @@ private void RegisterGetInvocation(BinderInvocation invocation) } else { - ITypeOfOperation? typeOfOperation = operation.Arguments[1].ChildOperations.FirstOrDefault() as ITypeOfOperation; + ITypeOfOperation? typeOfOperation = GetArgumentForParameterAtIndex(operation.Arguments, 1).ChildOperations.FirstOrDefault() as ITypeOfOperation; type = typeOfOperation?.TypeOperand; if (paramCount is 2) { - overload = MethodsToGen_ConfigurationBinder.Get_TypeOf; + overload = MethodsToGen.ConfigBinder_Get_TypeOf; } else if (paramCount is 3 && SymbolEqualityComparer.Default.Equals(@params[2].Type, _typeSymbols.ActionOfBinderOptions)) { - overload = MethodsToGen_ConfigurationBinder.Get_TypeOf_BinderOptions; + overload = MethodsToGen.ConfigBinder_Get_TypeOf_BinderOptions; } } - if (GetTargetTypeForRootInvocation(type, invocation.Location) is TypeSpec typeSpec) - { - _sourceGenSpec.MethodsToGen_ConfigurationBinder |= overload; - RegisterTypeForMethodGen(MethodsToGen_CoreBindingHelper.GetCore, typeSpec); - } - + EnqueueTargetTypeForRootInvocation(type, overload, invocation); } - private void RegisterGetValueInvocation(BinderInvocation invocation) + private void ParseGetValueInvocation(BinderInvocation invocation) { IInvocationOperation operation = invocation.Operation!; IMethodSymbol targetMethod = operation.TargetMethod; ImmutableArray @params = targetMethod.Parameters; int paramCount = @params.Length; - MethodsToGen_ConfigurationBinder overload = MethodsToGen_ConfigurationBinder.None; + MethodsToGen overload = MethodsToGen.None; ITypeSymbol? type; if (targetMethod.IsGenericMethod) @@ -211,11 +206,11 @@ private void RegisterGetValueInvocation(BinderInvocation invocation) if (paramCount is 2) { - overload = MethodsToGen_ConfigurationBinder.GetValue_T_key; + overload = MethodsToGen.ConfigBinder_GetValue_T_key; } else if (paramCount is 3 && SymbolEqualityComparer.Default.Equals(@params[2].Type, type)) { - overload = MethodsToGen_ConfigurationBinder.GetValue_T_key_defaultValue; + overload = MethodsToGen.ConfigBinder_GetValue_T_key_defaultValue; } } else if (paramCount > 4) @@ -229,32 +224,60 @@ private void RegisterGetValueInvocation(BinderInvocation invocation) return; } - ITypeOfOperation? typeOfOperation = operation.Arguments[1].ChildOperations.FirstOrDefault() as ITypeOfOperation; + ITypeOfOperation? typeOfOperation = GetArgumentForParameterAtIndex(operation.Arguments, 1).ChildOperations.FirstOrDefault() as ITypeOfOperation; type = typeOfOperation?.TypeOperand; if (paramCount is 3) { - overload = MethodsToGen_ConfigurationBinder.GetValue_TypeOf_key; + overload = MethodsToGen.ConfigBinder_GetValue_TypeOf_key; } else if (paramCount is 4 && @params[3].Type.SpecialType is SpecialType.System_Object) { - overload = MethodsToGen_ConfigurationBinder.GetValue_TypeOf_key_defaultValue; + overload = MethodsToGen.ConfigBinder_GetValue_TypeOf_key_defaultValue; } } - ITypeSymbol effectiveType = (IsNullable(type, out ITypeSymbol? underlyingType) ? underlyingType : type)!; - if (!IsValidRootConfigType(type)) { - _context.ReportDiagnostic(Diagnostic.Create(Diagnostics.CouldNotDetermineTypeInfo, invocation.Location)); + RecordDiagnostic(DiagnosticDescriptors.CouldNotDetermineTypeInfo, invocation.Location); return; } - if (IsParsableFromString(effectiveType, out _) && - GetTargetTypeForRootInvocationCore(type, invocation.Location) is TypeSpec typeSpec) + ITypeSymbol effectiveType = IsNullable(type, out ITypeSymbol? underlyingType) ? underlyingType : type; + + if (IsParsableFromString(effectiveType, out _)) { - _sourceGenSpec.MethodsToGen_ConfigurationBinder |= overload; - RegisterTypeForMethodGen(MethodsToGen_CoreBindingHelper.GetValueCore, typeSpec); + EnqueueTargetTypeForRootInvocation(type, overload, invocation); + } + } + + private void RegisterInterceptor_ConfigurationBinder(TypeParseInfo typeParseInfo, TypeSpec typeSpec) + { + MethodsToGen overload = typeParseInfo.BindingOverload; + IInvocationOperation invocationOperation = typeParseInfo.BinderInvocation!.Operation; + Debug.Assert((MethodsToGen.ConfigBinder_Any & overload) is not 0); + + if ((MethodsToGen.ConfigBinder_Bind & overload) is not 0) + { + if (typeSpec is ComplexTypeSpec complexTypeSpec && + _helperInfoBuilder!.TryRegisterTransitiveTypesForMethodGen(complexTypeSpec.TypeRef)) + { + _interceptorInfoBuilder.RegisterInterceptor_ConfigBinder_Bind(overload, complexTypeSpec, invocationOperation); + } + } + else + { + Debug.Assert((MethodsToGen.ConfigBinder_Get & overload) is not 0 || + (MethodsToGen.ConfigBinder_GetValue & overload) is not 0); + + bool registered = (MethodsToGen.ConfigBinder_Get & overload) is not 0 + ? _helperInfoBuilder!.TryRegisterTypeForGetGen(typeSpec) + : _helperInfoBuilder!.TryRegisterTypeForGetValueGen(typeSpec); + + if (registered) + { + _interceptorInfoBuilder.RegisterInterceptor(overload, invocationOperation); + } } } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Helpers/Parser/Diagnostics.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Parser/DiagnosticDescriptors.cs similarity index 82% rename from src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Helpers/Parser/Diagnostics.cs rename to src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Parser/DiagnosticDescriptors.cs index d6d816545bcd0a..3f694c78be8309 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Helpers/Parser/Diagnostics.cs +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Parser/DiagnosticDescriptors.cs @@ -9,9 +9,9 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration { public sealed partial class ConfigurationBindingGenerator { - private sealed partial class Parser + internal sealed partial class Parser { - internal static class Diagnostics + private static class DiagnosticDescriptors { public static DiagnosticDescriptor TypeNotSupported { get; } = CreateTypeNotSupportedDescriptor(nameof(SR.TypeNotSupported)); public static DiagnosticDescriptor MissingPublicInstanceConstructor { get; } = CreateTypeNotSupportedDescriptor(nameof(SR.MissingPublicInstanceConstructor)); @@ -62,6 +62,20 @@ private static DiagnosticDescriptor CreateTypeNotSupportedDescriptor(string name category: ProjectName, defaultSeverity: DiagnosticSeverity.Warning, isEnabledByDefault: true); + + public static DiagnosticDescriptor GetNotSupportedDescriptor(NotSupportedReason reason) => + reason switch + { + NotSupportedReason.UnknownType => TypeNotSupported, + NotSupportedReason.MissingPublicInstanceConstructor => MissingPublicInstanceConstructor, + NotSupportedReason.CollectionNotSupported => CollectionNotSupported, + NotSupportedReason.DictionaryKeyNotSupported => DictionaryKeyNotSupported, + NotSupportedReason.ElementTypeNotSupported => ElementTypeNotSupported, + NotSupportedReason.MultipleParameterizedConstructors => MultipleParameterizedConstructors, + NotSupportedReason.MultiDimArraysNotSupported => MultiDimArraysNotSupported, + NotSupportedReason.NullableUnderlyingTypeNotSupported => NullableUnderlyingTypeNotSupported, + _ => throw new InvalidOperationException() + }; } } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Parser/Extensions.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Parser/Extensions.cs new file mode 100644 index 00000000000000..f685842639966a --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Parser/Extensions.cs @@ -0,0 +1,126 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Text; +using Microsoft.CodeAnalysis; +using SourceGenerators; + +namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration +{ + public sealed partial class ConfigurationBindingGenerator + { + internal sealed partial class Parser + { + private readonly struct TypeParseInfo + { + public ITypeSymbol TypeSymbol { get; private init; } + public string TypeName { get; private init; } + public MethodsToGen BindingOverload { get; private init; } + public BinderInvocation BinderInvocation { get; private init; } + public ContainingTypeDiagnosticInfo? ContainingTypeDiagnosticInfo { get; private init; } + + public static TypeParseInfo Create(ITypeSymbol typeSymbol, MethodsToGen overload, BinderInvocation invocation, ContainingTypeDiagnosticInfo? containingTypeDiagInfo = null) => + new TypeParseInfo + { + TypeSymbol = typeSymbol, + TypeName = typeSymbol.GetName(), + BindingOverload = overload, + BinderInvocation = invocation, + ContainingTypeDiagnosticInfo = containingTypeDiagInfo, + }; + + public TypeParseInfo ToTransitiveTypeParseInfo(ITypeSymbol memberType, DiagnosticDescriptor? diagDescriptor = null, string? memberName = null) + { + ContainingTypeDiagnosticInfo? diagnosticInfo = diagDescriptor is null + ? null + : new() + { + TypeName = TypeName, + Descriptor = diagDescriptor, + MemberName = memberName, + ContainingTypeInfo = ContainingTypeDiagnosticInfo, + }; + + return Create(memberType, BindingOverload, BinderInvocation, diagnosticInfo); + } + } + + private sealed class ContainingTypeDiagnosticInfo + { + public required string TypeName { get; init; } + public required string? MemberName { get; init; } + public required DiagnosticDescriptor Descriptor { get; init; } + public required ContainingTypeDiagnosticInfo? ContainingTypeInfo { get; init; } + } + } + } + + internal static class ParserExtensions + { + private static readonly SymbolDisplayFormat s_identifierCompatibleFormat = new SymbolDisplayFormat( + globalNamespaceStyle: SymbolDisplayGlobalNamespaceStyle.Omitted, + typeQualificationStyle: SymbolDisplayTypeQualificationStyle.NameAndContainingTypes, + genericsOptions: SymbolDisplayGenericsOptions.None, + miscellaneousOptions: SymbolDisplayMiscellaneousOptions.UseSpecialTypes); + + private static readonly SymbolDisplayFormat s_minimalDisplayFormat = new SymbolDisplayFormat( + globalNamespaceStyle: SymbolDisplayGlobalNamespaceStyle.Omitted, + typeQualificationStyle: SymbolDisplayTypeQualificationStyle.NameAndContainingTypes, + genericsOptions: SymbolDisplayGenericsOptions.IncludeTypeParameters, + miscellaneousOptions: SymbolDisplayMiscellaneousOptions.UseSpecialTypes); + + public static void RegisterCacheEntry(this Dictionary cache, TKey key, TEntry entry) + where TKey : notnull + where TValue : ICollection, new() + { + if (!cache.TryGetValue(key, out TValue? entryCollection)) + { + cache[key] = entryCollection = new TValue(); + } + + entryCollection.Add(entry); + } + + public static string ToIdentifierCompatibleSubstring(this ITypeSymbol type) + { + if (type is IArrayTypeSymbol arrayType) + { + int rank = arrayType.Rank; + string suffix = rank == 1 ? "Array" : $"Array{rank}D"; // Array, Array2D, Array3D, ... + return ToIdentifierCompatibleSubstring(arrayType.ElementType) + suffix; + } + + string displayString = type.ContainingType is null + ? type.Name + : type.ToDisplayString(s_identifierCompatibleFormat).Replace(".", string.Empty); + + if (type is not INamedTypeSymbol { IsGenericType: true } namedType) + { + return displayString; + } + + StringBuilder sb = new(displayString); + + if (namedType.GetAllTypeArgumentsInScope() is List typeArgsInScope) + { + foreach (ITypeSymbol genericArg in typeArgsInScope) + { + sb.Append(ToIdentifierCompatibleSubstring(genericArg)); + } + } + + return sb.ToString(); + } + + public static (string? Namespace, string DisplayString, string Name) GetTypeName(this ITypeSymbol type) + { + string? @namespace = type.ContainingNamespace is { IsGlobalNamespace: false } containingNamespace ? containingNamespace.ToDisplayString() : null; + string displayString = type.ToDisplayString(s_minimalDisplayFormat); + string name = (@namespace is null ? string.Empty : @namespace + ".") + displayString.Replace(".", "+"); + return (@namespace, displayString, name); + } + + public static string GetName(this ITypeSymbol type) => GetTypeName(type).Name; + } +} diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Model/KnownTypeSymbols.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Parser/KnownTypeSymbols.cs similarity index 94% rename from src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Model/KnownTypeSymbols.cs rename to src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Parser/KnownTypeSymbols.cs index e3a4f67ed396b4..07dae8689782e4 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Model/KnownTypeSymbols.cs +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Parser/KnownTypeSymbols.cs @@ -11,8 +11,10 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration { - internal sealed record KnownTypeSymbols + internal sealed class KnownTypeSymbols { + public CSharpCompilation Compilation { get; } + public INamedTypeSymbol String { get; } public INamedTypeSymbol? CultureInfo { get; } public INamedTypeSymbol? DateOnly { get; } @@ -35,7 +37,7 @@ internal sealed record KnownTypeSymbols public INamedTypeSymbol? OptionsConfigurationServiceCollectionExtensions { get; } public INamedTypeSymbol GenericIList_Unbound { get; } - public INamedTypeSymbol GenericICollection_Unbound { get; } + public INamedTypeSymbol? GenericICollection_Unbound { get; } public INamedTypeSymbol GenericICollection { get; } public INamedTypeSymbol GenericIEnumerable_Unbound { get; } public INamedTypeSymbol IEnumerable { get; } @@ -57,7 +59,10 @@ internal sealed record KnownTypeSymbols public KnownTypeSymbols(CSharpCompilation compilation) { - // Primitives (needed because they are Microsoft.CodeAnalysis.SpecialType.None) + Compilation = compilation; + + // Primitives + String = compilation.GetSpecialType(SpecialType.System_String); CultureInfo = compilation.GetBestTypeByMetadataName(typeof(CultureInfo)); DateOnly = compilation.GetBestTypeByMetadataName("System.DateOnly"); DateTimeOffset = compilation.GetBestTypeByMetadataName(typeof(DateTimeOffset)); @@ -99,7 +104,7 @@ public KnownTypeSymbols(CSharpCompilation compilation) // Used for type equivalency checks for unbound generics. The parameters of the types // retured by the Roslyn Get*Type* APIs are not unbound, so we construct unbound // generics to equal those corresponding to generic types in the input type graphs. - GenericICollection_Unbound = GenericICollection?.ConstructUnboundGenericType(); + GenericICollection_Unbound = GenericICollection.ConstructUnboundGenericType(); GenericIDictionary_Unbound = GenericIDictionary?.ConstructUnboundGenericType(); GenericIEnumerable_Unbound = compilation.GetSpecialType(SpecialType.System_Collections_Generic_IEnumerable_T).ConstructUnboundGenericType(); GenericIList_Unbound = compilation.GetSpecialType(SpecialType.System_Collections_Generic_IList_T).ConstructUnboundGenericType(); diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Parser/OptionsBuilderConfigurationExtensions.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Parser/OptionsBuilderConfigurationExtensions.cs new file mode 100644 index 00000000000000..eb0ab086bcd588 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Parser/OptionsBuilderConfigurationExtensions.cs @@ -0,0 +1,118 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Immutable; +using System.Diagnostics; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.Operations; + +namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration +{ + public sealed partial class ConfigurationBindingGenerator + { + internal sealed partial class Parser + { + private void ParseInvocation_OptionsBuilderExt(BinderInvocation invocation) + { + IMethodSymbol targetMethod = invocation.Operation.TargetMethod; + ImmutableArray @params = targetMethod.Parameters; + + if (!targetMethod.IsGenericMethod || + @params.Length < 2 || + @params[0].Type is not INamedTypeSymbol { IsGenericType: true } genericType || + !SymbolEqualityComparer.Default.Equals(_typeSymbols.OptionsBuilderOfT_Unbound, genericType.ConstructUnboundGenericType())) + { + return; + } + + ITypeSymbol? typeSymbol = targetMethod.TypeArguments[0].WithNullableAnnotation(NullableAnnotation.None); + // This would violate generic type constraint; any such invocation could not have been included in the initial parser. + Debug.Assert(typeSymbol?.IsValueType is not true); + + if (targetMethod.Name is "Bind") + { + ParseBindInvocation_OptionsBuilderExt(invocation, typeSymbol); + } + else if (targetMethod.Name is "BindConfiguration") + { + ParseBindConfigurationInvocation(invocation, typeSymbol); + } + } + + private void ParseBindInvocation_OptionsBuilderExt(BinderInvocation invocation, ITypeSymbol? type) + { + IInvocationOperation operation = invocation.Operation!; + IMethodSymbol targetMethod = operation.TargetMethod; + ImmutableArray @params = targetMethod.Parameters; + int paramCount = @params.Length; + + Debug.Assert(paramCount >= 2); + + if (!SymbolEqualityComparer.Default.Equals(_typeSymbols.IConfiguration, @params[1].Type)) + { + return; + } + + MethodsToGen overload = paramCount switch + { + 2 => MethodsToGen.OptionsBuilderExt_Bind_T, + 3 when SymbolEqualityComparer.Default.Equals(_typeSymbols.ActionOfBinderOptions, @params[2].Type) => + MethodsToGen.OptionsBuilderExt_Bind_T_BinderOptions, + _ => MethodsToGen.None + }; + + if (overload is not MethodsToGen.None) + { + EnqueueTargetTypeForRootInvocation(type, overload, invocation); + } + } + + private void ParseBindConfigurationInvocation(BinderInvocation invocation, ITypeSymbol? type) + { + IMethodSymbol targetMethod = invocation.Operation.TargetMethod; + ImmutableArray @params = targetMethod.Parameters; + + int paramCount = @params.Length; + Debug.Assert(paramCount >= 2); + + if (paramCount is 3 && + @params[1].Type.SpecialType is SpecialType.System_String && + SymbolEqualityComparer.Default.Equals(_typeSymbols.ActionOfBinderOptions, @params[2].Type)) + { + EnqueueTargetTypeForRootInvocation(type, MethodsToGen.OptionsBuilderExt_BindConfiguration_T_path_BinderOptions, invocation); + } + } + + private void RegisterInterceptor_OptionsBuilderExt(TypeParseInfo typeParseInfo, TypeSpec typeSpec) + { + MethodsToGen overload = typeParseInfo.BindingOverload; + Debug.Assert((MethodsToGen.OptionsBuilderExt_Any & overload) is not 0); + + if (typeSpec is not ComplexTypeSpec complexTypeSpec) + { + return; + } + + if ((MethodsToGen.OptionsBuilderExt_Bind & overload) is not 0) + { + if (!TryRegisterTypeForOverloadGen_ServiceCollectionExt(MethodsToGen.ServiceCollectionExt_Configure_T_name_BinderOptions, complexTypeSpec)) + { + return; + } + } + else if (!_helperInfoBuilder!.TryRegisterTypeForBindCoreMainGen(complexTypeSpec)) + { + return; + } + + _interceptorInfoBuilder.RegisterInterceptor(typeParseInfo.BindingOverload, typeParseInfo.BinderInvocation.Operation); + + // Emitting refs to IOptionsChangeTokenSource, ConfigurationChangeTokenSource. + _helperInfoBuilder!.RegisterNamespace("Microsoft.Extensions.Options"); + + // Emitting refs to OptionsBuilder. + _helperInfoBuilder!.RegisterNamespace("Microsoft.Extensions.DependencyInjection"); + } + } + } +} diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Helpers/Parser/OptionsConfigurationServiceCollectionExtensions.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Parser/OptionsConfigurationServiceCollectionExtensions.cs similarity index 54% rename from src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Helpers/Parser/OptionsConfigurationServiceCollectionExtensions.cs rename to src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Parser/OptionsConfigurationServiceCollectionExtensions.cs index 02c75b4ab653b3..1ccef24bc6b71f 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Helpers/Parser/OptionsConfigurationServiceCollectionExtensions.cs +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Parser/OptionsConfigurationServiceCollectionExtensions.cs @@ -10,9 +10,9 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration { public sealed partial class ConfigurationBindingGenerator { - private sealed partial class Parser + internal sealed partial class Parser { - private void RegisterMethodInvocation_ServiceCollectionExt(BinderInvocation invocation) + private void ParseInvocation_ServiceCollectionExt(BinderInvocation invocation) { IInvocationOperation operation = invocation.Operation!; IMethodSymbol targetMethod = operation.TargetMethod; @@ -30,11 +30,11 @@ private void RegisterMethodInvocation_ServiceCollectionExt(BinderInvocation invo return; } - MethodsToGen_Extensions_ServiceCollection overload; + MethodsToGen overload; if (paramCount is 2 && SymbolEqualityComparer.Default.Equals(_typeSymbols.IConfiguration, @params[1].Type)) { - overload = MethodsToGen_Extensions_ServiceCollection.Configure_T; + overload = MethodsToGen.ServiceCollectionExt_Configure_T; } else if (paramCount is 3) { @@ -44,12 +44,12 @@ private void RegisterMethodInvocation_ServiceCollectionExt(BinderInvocation invo if (secondParamType.SpecialType is SpecialType.System_String && SymbolEqualityComparer.Default.Equals(_typeSymbols.IConfiguration, thirdParamType)) { - overload = MethodsToGen_Extensions_ServiceCollection.Configure_T_name; + overload = MethodsToGen.ServiceCollectionExt_Configure_T_name; } else if (SymbolEqualityComparer.Default.Equals(_typeSymbols.IConfiguration, secondParamType) && SymbolEqualityComparer.Default.Equals(_typeSymbols.ActionOfBinderOptions, thirdParamType)) { - overload = MethodsToGen_Extensions_ServiceCollection.Configure_T_BinderOptions; + overload = MethodsToGen.ServiceCollectionExt_Configure_T_BinderOptions; } else { @@ -61,7 +61,7 @@ @params[1].Type.SpecialType is SpecialType.System_String && SymbolEqualityComparer.Default.Equals(_typeSymbols.IConfiguration, @params[2].Type) && SymbolEqualityComparer.Default.Equals(_typeSymbols.ActionOfBinderOptions, @params[3].Type)) { - overload = MethodsToGen_Extensions_ServiceCollection.Configure_T_name_BinderOptions; + overload = MethodsToGen.ServiceCollectionExt_Configure_T_name_BinderOptions; } else { @@ -69,22 +69,38 @@ @params[1].Type.SpecialType is SpecialType.System_String && return; } - TypeSpec typeSpec = GetTargetTypeForRootInvocation( - type: targetMethod.TypeArguments[0].WithNullableAnnotation(NullableAnnotation.None), - invocation.Location); + ITypeSymbol? typeSymbol = targetMethod.TypeArguments[0].WithNullableAnnotation(NullableAnnotation.None); + // This would violate generic type constraint; any such invocation could not have been included in the initial parser. + Debug.Assert(typeSymbol?.IsValueType is not true); - if (typeSpec is null) + EnqueueTargetTypeForRootInvocation(typeSymbol, overload, invocation); + } + + private void RegisterInterceptor_ServiceCollectionExt(TypeParseInfo typeParseInfo, TypeSpec typeSpec) + { + MethodsToGen overload = typeParseInfo.BindingOverload; + + if (typeSpec is ComplexTypeSpec complexTypeSpec && + TryRegisterTypeForOverloadGen_ServiceCollectionExt(overload, complexTypeSpec)) { - return; + _interceptorInfoBuilder.RegisterInterceptor(overload, typeParseInfo.BinderInvocation.Operation); } - - RegisterTypeForMethodGen(overload, typeSpec); } - private void RegisterTypeForMethodGen(MethodsToGen_Extensions_ServiceCollection overload, TypeSpec typeSpec) + private bool TryRegisterTypeForOverloadGen_ServiceCollectionExt(MethodsToGen overload, ComplexTypeSpec typeSpec) { - _sourceGenSpec.MethodsToGen_ServiceCollectionExt |= overload; - RegisterTypeForBindCoreUntypedGen(typeSpec); + Debug.Assert((MethodsToGen.ServiceCollectionExt_Any & overload) is not 0); + + if (!_helperInfoBuilder!.TryRegisterTypeForBindCoreMainGen(typeSpec)) + { + return false; + } + + _interceptorInfoBuilder.MethodsToGen |= overload; + _helperInfoBuilder!.RegisterNamespace("Microsoft.Extensions.DependencyInjection"); + // Emitting refs to IOptionsChangeTokenSource, ConfigurationChangeTokenSource, IConfigureOptions<>, ConfigureNamedOptions<>. + _helperInfoBuilder!.RegisterNamespace("Microsoft.Extensions.Options"); + return true; } } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/Strings.resx b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/Strings.resx index 301913987d7c7e..be66da59c6b5a7 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/Strings.resx +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/Strings.resx @@ -121,7 +121,7 @@ The collection type is not supported: '{0}'. - Binding logic was not generated for a binder call. Unsupported input patterns include generic calls and passing boxed objects. + Binding logic was not generated for a binder call. Unsupported input patterns include generic calls, passing boxed objects, and passing types that are not 'public' or 'internal'. The target type for a binder call could not be determined @@ -133,10 +133,10 @@ The collection element type is not supported: '{0}'. - The project's language version has to be at least 'C# 11'. + The project's language version has to be at least 'C# 12'. - Language version is required to be at least C# 11 + Language version is required to be at least C# 12 Cannot create instance of type '{0}' because it is missing a public instance constructor. diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/xlf/Strings.cs.xlf b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/xlf/Strings.cs.xlf index e248c54626865f..8af274c8cb6562 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/xlf/Strings.cs.xlf +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/xlf/Strings.cs.xlf @@ -8,8 +8,8 @@ - Binding logic was not generated for a binder call. Unsupported input patterns include generic calls and passing boxed objects. - Pro volání vazače se nevygenerovala logika vazby. Nepodporované vstupní vzory zahrnují obecná volání a předávání zabalených objektů. + Binding logic was not generated for a binder call. Unsupported input patterns include generic calls, passing boxed objects, and passing types that are not 'public' or 'internal'. + Pro volání vazby se nevygenerovala logika vazby. Mezi nepodporované vstupní vzory patří obecná volání, předávání zabalených objektů a předávání typů, které nejsou public ani internal. @@ -28,13 +28,13 @@ - The project's language version has to be at least 'C# 11'. - Jazyková verze projektu musí být alespoň C# 11 + The project's language version has to be at least 'C# 12'. + Jazyková verze projektu musí být alespoň C# 12 - Language version is required to be at least C# 11 - Verze jazyka musí být alespoň C# 11 + Language version is required to be at least C# 12 + Verze jazyka musí být alespoň C# 12 diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/xlf/Strings.de.xlf b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/xlf/Strings.de.xlf index 1fa847592bd02e..5f98b4d4a396f0 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/xlf/Strings.de.xlf +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/xlf/Strings.de.xlf @@ -8,8 +8,8 @@ - Binding logic was not generated for a binder call. Unsupported input patterns include generic calls and passing boxed objects. - Für einen Binderaufruf wurde keine Bindungslogik generiert. Nicht unterstützte Eingabemuster umfassen generische Aufrufe und übergeben geschachtelte Objekte. + Binding logic was not generated for a binder call. Unsupported input patterns include generic calls, passing boxed objects, and passing types that are not 'public' or 'internal'. + Für einen Binderaufruf wurde keine Bindungslogik generiert. Nicht unterstützte Eingabemuster umfassen generische Aufrufe, vorübergehende geschachtelte Objekte und das Übergeben von Typen, die nicht "public" oder "internal" sind. @@ -28,13 +28,13 @@ - The project's language version has to be at least 'C# 11'. - Die Sprachversion des Projekts muss mindestens „C# 11“ sein + The project's language version has to be at least 'C# 12'. + Die Sprachversion des Projekts muss mindestens "C# 12" sein. - Language version is required to be at least C# 11 - Die Sprachversion muss mindestens C# 11 sein + Language version is required to be at least C# 12 + Die Sprachversion muss mindestens C# 12 sein. diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/xlf/Strings.es.xlf b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/xlf/Strings.es.xlf index c52b2317ceaded..446d2f0f7edb8f 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/xlf/Strings.es.xlf +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/xlf/Strings.es.xlf @@ -8,8 +8,8 @@ - Binding logic was not generated for a binder call. Unsupported input patterns include generic calls and passing boxed objects. - No se generó la lógica de enlace para una llamada de enlazador. Los patrones de entrada no admitidos incluyen llamadas genéricas y pasar objetos en cuadros. + Binding logic was not generated for a binder call. Unsupported input patterns include generic calls, passing boxed objects, and passing types that are not 'public' or 'internal'. + No se ha generado la lógica de enlace para una llamada de enlace. Entre los patrones de entrada no admitidos se incluyen las llamadas genéricas, el paso de objetos en caja y el paso de tipos que no son "públicos" o "internos". @@ -28,13 +28,13 @@ - The project's language version has to be at least 'C# 11'. - La versión del lenguaje del proyecto debe ser al menos "C# 11". + The project's language version has to be at least 'C# 12'. + La versión del lenguaje del proyecto debe ser al menos "C# 12". - Language version is required to be at least C# 11 - La versión del lenguaje debe ser al menos C# 11 + Language version is required to be at least C# 12 + La versión del lenguaje debe ser al menos C# 12 diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/xlf/Strings.fr.xlf b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/xlf/Strings.fr.xlf index 19362d7336208f..517d8ffad0e5cc 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/xlf/Strings.fr.xlf +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/xlf/Strings.fr.xlf @@ -8,8 +8,8 @@ - Binding logic was not generated for a binder call. Unsupported input patterns include generic calls and passing boxed objects. - La logique de liaison n’a pas été générée pour un appel de classeur. Les modèles d’entrée non pris en charge incluent les appels génériques et les objets boxed de passage. + Binding logic was not generated for a binder call. Unsupported input patterns include generic calls, passing boxed objects, and passing types that are not 'public' or 'internal'. + La logique de liaison n’a pas été générée pour un appel de classeur. Les modèles d’entrée non pris en charge incluent les appels génériques, le passage d’objets box et les types de passage qui ne sont pas 'public' ou 'internal'. @@ -28,13 +28,13 @@ - The project's language version has to be at least 'C# 11'. - La version de langage du projet doit être au moins « C# 11 ». + The project's language version has to be at least 'C# 12'. + La version de langage du projet doit être au moins 'C# 12'. - Language version is required to be at least C# 11 - La version du langage doit être au moins C# 11 + Language version is required to be at least C# 12 + La version du langage doit être au moins C# 12. diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/xlf/Strings.it.xlf b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/xlf/Strings.it.xlf index f418a83d0d422e..787d46e9e7339f 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/xlf/Strings.it.xlf +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/xlf/Strings.it.xlf @@ -8,8 +8,8 @@ - Binding logic was not generated for a binder call. Unsupported input patterns include generic calls and passing boxed objects. - La logica di binding non è stata generata per una chiamata binder. I modelli di input non supportati includono chiamate generiche e il passaggio di oggetti in caselle. + Binding logic was not generated for a binder call. Unsupported input patterns include generic calls, passing boxed objects, and passing types that are not 'public' or 'internal'. + La logica di associazione non è stata generata per una chiamata del gestore di associazione. I modelli di input non supportati includono chiamate generici, il passaggio di oggetti boxed e il passaggio di tipi non 'public' o 'internal'. @@ -28,13 +28,13 @@ - The project's language version has to be at least 'C# 11'. - La versione del linguaggio del progetto deve essere almeno 'C# 11'. + The project's language version has to be at least 'C# 12'. + La versione del linguaggio del progetto deve essere almeno 'C# 12'. - Language version is required to be at least C# 11 - La versione del linguaggio deve essere almeno C# 11 + Language version is required to be at least C# 12 + La versione del linguaggio deve essere almeno C# 12 diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/xlf/Strings.ja.xlf b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/xlf/Strings.ja.xlf index ba59cfba40a89b..57308cabd0da29 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/xlf/Strings.ja.xlf +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/xlf/Strings.ja.xlf @@ -8,8 +8,8 @@ - Binding logic was not generated for a binder call. Unsupported input patterns include generic calls and passing boxed objects. - バインダー呼び出しのバインド ロジックが生成されませんでした。サポートされていない入力パターンとしては、ジェネリック呼び出し、ボックス化されたオブジェクトの受け渡しなどがあります。 + Binding logic was not generated for a binder call. Unsupported input patterns include generic calls, passing boxed objects, and passing types that are not 'public' or 'internal'. + バインダー呼び出しのバインド ロジックが生成されませんでした。サポートされていない入力パターンには、ジェネリック呼び出し、ボックス化されたオブジェクトの受け渡し、および 'public' または 'internal' ではない型の受け渡しが含まれます。 @@ -28,13 +28,13 @@ - The project's language version has to be at least 'C# 11'. - プロジェクトの言語バージョンは少なくとも 'C# 11' である必要があります。 + The project's language version has to be at least 'C# 12'. + プロジェクトの言語バージョンは少なくとも 'C# 12' である必要があります。 - Language version is required to be at least C# 11 - 言語バージョンは少なくとも C# 11 である必要があります + Language version is required to be at least C# 12 + 言語バージョンは少なくとも C# 12 である必要があります diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/xlf/Strings.ko.xlf b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/xlf/Strings.ko.xlf index 10b9b107c4aade..e97e99e1ac7cd0 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/xlf/Strings.ko.xlf +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/xlf/Strings.ko.xlf @@ -8,8 +8,8 @@ - Binding logic was not generated for a binder call. Unsupported input patterns include generic calls and passing boxed objects. - 바인더 호출에 대한 바인딩 논리가 생성되지 않았습니다. 지원되지 않는 입력 패턴에는 제네릭 호출 및 boxed 개체 전달이 포함됩니다. + Binding logic was not generated for a binder call. Unsupported input patterns include generic calls, passing boxed objects, and passing types that are not 'public' or 'internal'. + 바인더 호출에 대한 바인딩 논리가 생성되지 않았습니다. 지원되지 않는 입력 패턴에는 제네릭 호출, boxed 개체 전달 및 'public' 또는 'internal'이 아닌 전달 형식이 포함됩니다. @@ -28,13 +28,13 @@ - The project's language version has to be at least 'C# 11'. - 프로젝트의 언어 버전은 'C# 11' 이상이어야 합니다. + The project's language version has to be at least 'C# 12'. + 프로젝트의 언어 버전은 'C# 12' 이상이어야 합니다. - Language version is required to be at least C# 11 - 언어 버전은 C# 11 이상이어야 합니다. + Language version is required to be at least C# 12 + 언어 버전은 C# 12 이상이어야 합니다. diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/xlf/Strings.pl.xlf b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/xlf/Strings.pl.xlf index 2b558c588ebfb9..d0a4985215c6aa 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/xlf/Strings.pl.xlf +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/xlf/Strings.pl.xlf @@ -8,8 +8,8 @@ - Binding logic was not generated for a binder call. Unsupported input patterns include generic calls and passing boxed objects. - Nie wygenerowano logiki powiązania dla wywołania integratora. Nieobsługiwane wzorce wejściowe obejmują wywołania ogólne i przekazywanie obiektów w ramce. + Binding logic was not generated for a binder call. Unsupported input patterns include generic calls, passing boxed objects, and passing types that are not 'public' or 'internal'. + Logika powiązania nie została wygenerowana dla wywołania integratora. Nieobsługiwane wzorce wejściowe obejmują ogólne wywołania, przekazywanie obiektów w pudełkach i przekazywanie typów, które nie są „publiczne” lub „wewnętrzne”. @@ -28,13 +28,13 @@ - The project's language version has to be at least 'C# 11'. - Wersja językowa projektu musi mieć wartość co najmniej „C# 11”. + The project's language version has to be at least 'C# 12'. + Wersja językowa projektu musi mieć wartość co najmniej „C# 12”. - Language version is required to be at least C# 11 - Wymagana jest wersja językowa co najmniej C# 11 + Language version is required to be at least C# 12 + Wymagana jest wersja językowa o wartości co najmniej C# 12 diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/xlf/Strings.pt-BR.xlf b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/xlf/Strings.pt-BR.xlf index 9d2a51c6aa9c96..965a1181aacd34 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/xlf/Strings.pt-BR.xlf +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/xlf/Strings.pt-BR.xlf @@ -8,8 +8,8 @@ - Binding logic was not generated for a binder call. Unsupported input patterns include generic calls and passing boxed objects. - A lógica de associação não foi gerada para uma chamada de associador. Os padrões de entrada sem suporte incluem chamadas genéricas e passagem de objetos em caixa. + Binding logic was not generated for a binder call. Unsupported input patterns include generic calls, passing boxed objects, and passing types that are not 'public' or 'internal'. + A lógica de associação não foi gerada para uma chamada de fichário. Padrões de entrada sem suporte incluem chamadas genéricas, passagem de objetos em caixa e passagem de tipos que não são 'públicos' ou 'internos'. @@ -28,13 +28,13 @@ - The project's language version has to be at least 'C# 11'. - A versão do idioma do projeto deve ser no mínimo 'C# 11'. + The project's language version has to be at least 'C# 12'. + A versão da linguagem do projeto deve ser no mínimo 'C# 12'. - Language version is required to be at least C# 11 - A versão do idioma deve ser pelo menos C# 11 + Language version is required to be at least C# 12 + A versão do idioma deve ser pelo menos C# 12 diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/xlf/Strings.ru.xlf b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/xlf/Strings.ru.xlf index 1ed03c55891a9e..c230a1edd07d1d 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/xlf/Strings.ru.xlf +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/xlf/Strings.ru.xlf @@ -8,8 +8,8 @@ - Binding logic was not generated for a binder call. Unsupported input patterns include generic calls and passing boxed objects. - Логика привязки не была создана для вызова модуля привязки. К неподдерживаемым шаблонам ввода относятся универсальные вызовы и передача упакованных объектов. + Binding logic was not generated for a binder call. Unsupported input patterns include generic calls, passing boxed objects, and passing types that are not 'public' or 'internal'. + Логика привязки не создана для вызова привязки. Неподдерживаемые шаблоны ввода включают общие вызовы, передачу упакованных объектов и передачу типов, не являющихся "общедоступными" или "внутренними". @@ -28,13 +28,13 @@ - The project's language version has to be at least 'C# 11'. - Версия языка проекта должна быть не ниже "C# 11". + The project's language version has to be at least 'C# 12'. + Версия языка проекта должна быть не ниже "C# 12". - Language version is required to be at least C# 11 - Версия языка должна быть не ниже C# 11 + Language version is required to be at least C# 12 + Версия языка должна быть не ниже C# 12 diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/xlf/Strings.tr.xlf b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/xlf/Strings.tr.xlf index 8a6dbf76bab7f7..1ffbaa22a96a85 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/xlf/Strings.tr.xlf +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/xlf/Strings.tr.xlf @@ -8,8 +8,8 @@ - Binding logic was not generated for a binder call. Unsupported input patterns include generic calls and passing boxed objects. - Bir bağlayıcı çağrısı için bağlama mantığı oluşturulmadı. Desteklenmeyen giriş desenleri genel çağrılar ve geçici kutulu nesneler içeriyor. + Binding logic was not generated for a binder call. Unsupported input patterns include generic calls, passing boxed objects, and passing types that are not 'public' or 'internal'. + Bağlama mantığı bir bağlayıcı çağrısı için oluşturulmadı. Desteklenmeyen giriş düzenleri şunları içerir: genel çağrılar, geçirilen kutulu nesneler ve ‘genel’ veya ‘iç’ olmayan geçirme türleri. @@ -28,13 +28,13 @@ - The project's language version has to be at least 'C# 11'. - Projenin dil sürümü en az 'C# 11' olmalıdır. + The project's language version has to be at least 'C# 12'. + Projenin dil sürümü en az 'C# 12' olmalıdır. - Language version is required to be at least C# 11 - Dil sürümünün en az C# 11 olması gerekir + Language version is required to be at least C# 12 + Dil sürümünün en az C# 12 olması gerekir diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/xlf/Strings.zh-Hans.xlf b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/xlf/Strings.zh-Hans.xlf index 9d0c0eb3a5d6dd..dd89336534060a 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/xlf/Strings.zh-Hans.xlf +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/xlf/Strings.zh-Hans.xlf @@ -8,8 +8,8 @@ - Binding logic was not generated for a binder call. Unsupported input patterns include generic calls and passing boxed objects. - 未为联编程序调用生成绑定逻辑。不支持的输入模式包括泛型调用和传递装箱对象。 + Binding logic was not generated for a binder call. Unsupported input patterns include generic calls, passing boxed objects, and passing types that are not 'public' or 'internal'. + 没有为绑定器调用生成绑定逻辑。不支持的输入模式包括泛型调用、传递装箱对象和传递不是“public”或“internal”的类型。 @@ -28,13 +28,13 @@ - The project's language version has to be at least 'C# 11'. - 项目的语言版本必须至少为 "C# 11"。 + The project's language version has to be at least 'C# 12'. + 项目的语言版本必须至少为 "C# 12"。 - Language version is required to be at least C# 11 - 语言版本必须至少为 C# 11 + Language version is required to be at least C# 12 + 语言版本必须至少为 C# 12 diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/xlf/Strings.zh-Hant.xlf b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/xlf/Strings.zh-Hant.xlf index dc6ded618c8e94..a30c193b7f0115 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/xlf/Strings.zh-Hant.xlf +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Resources/xlf/Strings.zh-Hant.xlf @@ -8,8 +8,8 @@ - Binding logic was not generated for a binder call. Unsupported input patterns include generic calls and passing boxed objects. - 未產生文件夾呼叫的繫結邏輯。不支援的輸入模式包括一般呼叫和傳遞方塊物件。 + Binding logic was not generated for a binder call. Unsupported input patterns include generic calls, passing boxed objects, and passing types that are not 'public' or 'internal'. + 未產生文件夾呼叫的繫結邏輯。不支援的輸入模式包括一般呼叫、傳遞方塊物件,以及非 'public' 或 'internal' 的傳遞類型。。 @@ -28,13 +28,13 @@ - The project's language version has to be at least 'C# 11'. - 專案的語言版本必須至少為 'C# 11'。 + The project's language version has to be at least 'C# 12'. + 專案的語言版本必須至少為 'C# 12'。 - Language version is required to be at least C# 11 - 語言版本要求至少為 C#11 + Language version is required to be at least C# 12 + 語言版本要求至少為 C#12 diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/BindingHelperInfo.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/BindingHelperInfo.cs new file mode 100644 index 00000000000000..096c8410717ae7 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/BindingHelperInfo.cs @@ -0,0 +1,237 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using SourceGenerators; + +namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration +{ + public sealed record BindingHelperInfo + { + public required ImmutableEquatableArray Namespaces { get; init; } + public required bool EmitConfigurationKeyCaches { get; init; } + + public required MethodsToGen_CoreBindingHelper MethodsToGen { get; init; } + public required ImmutableEquatableArray? TypesForGen_BindCoreMain { get; init; } + public required ImmutableEquatableArray? TypesForGen_GetCore { get; init; } + public required ImmutableEquatableArray? TypesForGen_GetValueCore { get; init; } + public required ImmutableEquatableArray? TypesForGen_BindCore { get; init; } + public required ImmutableEquatableArray? TypesForGen_Initialize { get; init; } + public required ImmutableEquatableArray? TypesForGen_ParsePrimitive { get; init; } + + internal sealed class Builder(TypeIndex _typeIndex) + { + private readonly Dictionary _seenTransitiveTypes = new(); + + private MethodsToGen_CoreBindingHelper _methodsToGen; + private bool _emitConfigurationKeyCaches; + + private readonly Dictionary> _typesForGen = new(); + + private readonly SortedSet _namespaces = new() + { + "System", + "System.CodeDom.Compiler", + "System.Globalization", + "System.Runtime.CompilerServices", + "Microsoft.Extensions.Configuration", + }; + + public BindingHelperInfo ToIncrementalValue() + { + return new BindingHelperInfo + { + Namespaces = _namespaces.ToImmutableEquatableArray(), + EmitConfigurationKeyCaches = _emitConfigurationKeyCaches, + + MethodsToGen = _methodsToGen, + TypesForGen_GetCore = GetTypesForGen_CoreBindingHelper(MethodsToGen_CoreBindingHelper.GetCore), + TypesForGen_BindCoreMain = GetTypesForGen_CoreBindingHelper(MethodsToGen_CoreBindingHelper.BindCoreMain), + TypesForGen_GetValueCore = GetTypesForGen_CoreBindingHelper(MethodsToGen_CoreBindingHelper.GetValueCore), + TypesForGen_BindCore = GetTypesForGen_CoreBindingHelper(MethodsToGen_CoreBindingHelper.BindCore), + TypesForGen_Initialize = GetTypesForGen_CoreBindingHelper(MethodsToGen_CoreBindingHelper.Initialize), + TypesForGen_ParsePrimitive = GetTypesForGen_CoreBindingHelper(MethodsToGen_CoreBindingHelper.ParsePrimitive) + }; + + ImmutableEquatableArray? GetTypesForGen_CoreBindingHelper(MethodsToGen_CoreBindingHelper overload) + where TSpec : TypeSpec, IEquatable + { + _typesForGen.TryGetValue(overload, out HashSet? typesAsBase); + + if (typesAsBase is null) + { + return null; + } + + IEnumerable types = typeof(TSpec) == typeof(TypeSpec) + ? (HashSet)(object)typesAsBase + : typesAsBase.Select(t => (TSpec)t); + + return GetTypesForGen(types); + } + + static ImmutableEquatableArray GetTypesForGen(IEnumerable types) + where TSpec : TypeSpec, IEquatable => + types.ToImmutableEquatableArray(); + } + + public bool TryRegisterTypeForGetGen(TypeSpec type) + { + if (TryRegisterTransitiveTypesForMethodGen(type.TypeRef)) + { + RegisterTypeForMethodGen(MethodsToGen_CoreBindingHelper.GetCore, type); + RegisterForGen_AsConfigWithChildrenHelper(); + return true; + } + + return false; + } + + public bool TryRegisterTypeForGetValueGen(TypeSpec typeSpec) + { + ParsableFromStringSpec effectiveType = (ParsableFromStringSpec)_typeIndex.GetEffectiveTypeSpec(typeSpec); + RegisterTypeForMethodGen(MethodsToGen_CoreBindingHelper.GetValueCore, typeSpec); + RegisterStringParsableTypeIfApplicable(effectiveType); + return true; + } + + public bool TryRegisterTypeForBindCoreMainGen(ComplexTypeSpec type) + { + if (TryRegisterTransitiveTypesForMethodGen(type.TypeRef)) + { + RegisterTypeForMethodGen(MethodsToGen_CoreBindingHelper.BindCoreMain, type); + RegisterForGen_AsConfigWithChildrenHelper(); + return true; + } + + return false; + } + + public bool TryRegisterTransitiveTypesForMethodGen(TypeRef typeRef) + { + return _seenTransitiveTypes.TryGetValue(typeRef, out bool isValid) + ? isValid + : (_seenTransitiveTypes[typeRef] = TryRegisterCore()); + + bool TryRegisterCore() + { + switch (_typeIndex.GetTypeSpec(typeRef)) + { + case NullableSpec nullableSpec: + { + return TryRegisterTransitiveTypesForMethodGen(nullableSpec.EffectiveTypeRef); + } + case ParsableFromStringSpec stringParsableSpec: + { + RegisterStringParsableTypeIfApplicable(stringParsableSpec); + return true; + } + case DictionarySpec dictionarySpec: + { + bool shouldRegister = _typeIndex.CanBindTo(typeRef) && + TryRegisterTransitiveTypesForMethodGen(dictionarySpec.KeyTypeRef) && + TryRegisterTransitiveTypesForMethodGen(dictionarySpec.ElementTypeRef) && + TryRegisterTypeForBindCoreGen(dictionarySpec); + + if (shouldRegister && dictionarySpec.InstantiationStrategy is CollectionInstantiationStrategy.LinqToDictionary) + { + _namespaces.Add("System.Linq"); + } + + return shouldRegister; + } + case CollectionSpec collectionSpec: + { + return TryRegisterTransitiveTypesForMethodGen(collectionSpec.ElementTypeRef) && + TryRegisterTypeForBindCoreGen(collectionSpec); + } + case ObjectSpec objectSpec: + { + // Base case to avoid stack overflow for recursive object graphs. + // Register all object types for gen; we need to throw runtime exceptions in some cases. + bool shouldRegister = true; + _seenTransitiveTypes.Add(typeRef, shouldRegister); + + // List is used in generated code as a temp holder for formatting + // an error for config properties that don't map to object properties. + _namespaces.Add("System.Collections.Generic"); + + if (_typeIndex.HasBindableMembers(objectSpec)) + { + foreach (PropertySpec property in objectSpec.Properties!) + { + TryRegisterTransitiveTypesForMethodGen(property.TypeRef); + + if (_typeIndex.GetTypeSpec(property.TypeRef) is ComplexTypeSpec) + { + RegisterForGen_AsConfigWithChildrenHelper(); + } + } + + bool registeredForBindCore = TryRegisterTypeForBindCoreGen(objectSpec); + Debug.Assert(registeredForBindCore); + + if (objectSpec is { InstantiationStrategy: ObjectInstantiationStrategy.ParameterizedConstructor, InitExceptionMessage: null }) + { + RegisterTypeForMethodGen(MethodsToGen_CoreBindingHelper.Initialize, objectSpec); + } + } + + return true; + } + default: + { + return true; + } + } + } + } + + public void RegisterNamespace(string @namespace) => _namespaces.Add(@namespace); + + private bool TryRegisterTypeForBindCoreGen(ComplexTypeSpec type) + { + if (_typeIndex.HasBindableMembers(type)) + { + RegisterTypeForMethodGen(MethodsToGen_CoreBindingHelper.BindCore, type); + _emitConfigurationKeyCaches = true; + return true; + } + + return false; + } + + private void RegisterTypeForMethodGen(MethodsToGen_CoreBindingHelper method, TypeSpec type) + { + if (!_typesForGen.TryGetValue(method, out HashSet? types)) + { + _typesForGen[method] = types = new HashSet(); + } + + if (types.Add(type)) + { + _methodsToGen |= method; + + if (type is { Namespace: string @namespace }) + { + _namespaces.Add(@namespace); + } + } + } + + private void RegisterStringParsableTypeIfApplicable(ParsableFromStringSpec type) + { + if (type.StringParsableTypeKind is not StringParsableTypeKind.AssignFromSectionValue) + { + _methodsToGen |= MethodsToGen_CoreBindingHelper.ParsePrimitive; + RegisterTypeForMethodGen(MethodsToGen_CoreBindingHelper.ParsePrimitive, type); + } + } + + private void RegisterForGen_AsConfigWithChildrenHelper() => _methodsToGen |= MethodsToGen_CoreBindingHelper.AsConfigWithChildren; + } + } +} diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/InterceptorInfo.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/InterceptorInfo.cs new file mode 100644 index 00000000000000..999ed6514f99d7 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/InterceptorInfo.cs @@ -0,0 +1,202 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Operations; +using Microsoft.CodeAnalysis.Text; +using SourceGenerators; + +namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration +{ + public sealed record InterceptorInfo + { + public required MethodsToGen MethodsToGen { get; init; } + + public required ImmutableEquatableArray? ConfigBinder_Bind_instance { get; init; } + public required ImmutableEquatableArray? ConfigBinder_Bind_instance_BinderOptions { get; init; } + public required ImmutableEquatableArray? ConfigBinder_Bind_key_instance { get; init; } + + + public required ImmutableEquatableArray? ConfigBinder { get; init; } + public required ImmutableEquatableArray? OptionsBuilderExt { get; init; } + public required ImmutableEquatableArray? ServiceCollectionExt { get; init; } + + public IEnumerable? GetInfo(MethodsToGen interceptor) + { + Debug.Assert((MethodsToGen.ConfigBinder_Bind & interceptor) is 0); + + ImmutableEquatableArray? infoList; + if ((MethodsToGen.ConfigBinder_Any ^ MethodsToGen.ConfigBinder_Bind & interceptor) is not 0) + { + infoList = ConfigBinder; + } + else if ((MethodsToGen.OptionsBuilderExt_Any & interceptor) is not 0) + { + infoList = OptionsBuilderExt; + } + else + { + Debug.Assert((MethodsToGen.ServiceCollectionExt_Any & interceptor) is not 0); + infoList = ServiceCollectionExt; + } + + return infoList?.Where(i => i.Interceptor == interceptor); + } + + internal sealed class Builder + { + private TypedInterceptorInfoBuildler? _configBinder_InfoBuilder_Bind_instance; + private TypedInterceptorInfoBuildler? _configBinder_InfoBuilder_Bind_instance_BinderOptions; + private TypedInterceptorInfoBuildler? _configBinder_InfoBuilder_Bind_key_instance; + + private List? _interceptors_configBinder; + private List? _interceptors_OptionsBuilderExt; + private List? _interceptors_serviceCollectionExt; + + public MethodsToGen MethodsToGen { get; set; } + + public void RegisterInterceptor_ConfigBinder_Bind(MethodsToGen overload, ComplexTypeSpec type, IInvocationOperation invocation) + { + Debug.Assert((MethodsToGen.ConfigBinder_Bind & overload) is not 0); + + switch (overload) + { + case MethodsToGen.ConfigBinder_Bind_instance: + RegisterInterceptor(ref _configBinder_InfoBuilder_Bind_instance); + break; + case MethodsToGen.ConfigBinder_Bind_instance_BinderOptions: + RegisterInterceptor(ref _configBinder_InfoBuilder_Bind_instance_BinderOptions); + break; + case MethodsToGen.ConfigBinder_Bind_key_instance: + RegisterInterceptor(ref _configBinder_InfoBuilder_Bind_key_instance); + break; + } + + MethodsToGen |= overload; + + void RegisterInterceptor(ref TypedInterceptorInfoBuildler? infoBuilder) + { + infoBuilder ??= new TypedInterceptorInfoBuildler(); + infoBuilder.RegisterInterceptor(overload, type, invocation); + } + } + + public void RegisterInterceptor(MethodsToGen overload, IInvocationOperation operation) + { + Debug.Assert((MethodsToGen.ConfigBinder_Bind & overload) is 0); + + if ((MethodsToGen.ConfigBinder_Any ^ MethodsToGen.ConfigBinder_Bind & overload) is not 0) + { + RegisterInterceptor(ref _interceptors_configBinder); + } + else if ((MethodsToGen.OptionsBuilderExt_Any & overload) is not 0) + { + RegisterInterceptor(ref _interceptors_OptionsBuilderExt); + } + else + { + Debug.Assert((MethodsToGen.ServiceCollectionExt_Any & overload) is not 0); + RegisterInterceptor(ref _interceptors_serviceCollectionExt); + } + + MethodsToGen |= overload; + + void RegisterInterceptor(ref List? infoList) + { + infoList ??= new List(); + infoList.Add(new InvocationLocationInfo(overload, operation)); + } + } + + public InterceptorInfo ToIncrementalValue() => + new InterceptorInfo + { + MethodsToGen = MethodsToGen, + + ConfigBinder = _interceptors_configBinder?.ToImmutableEquatableArray(), + OptionsBuilderExt = _interceptors_OptionsBuilderExt?.ToImmutableEquatableArray(), + ServiceCollectionExt = _interceptors_serviceCollectionExt?.ToImmutableEquatableArray(), + + ConfigBinder_Bind_instance = _configBinder_InfoBuilder_Bind_instance?.ToIncrementalValue(), + ConfigBinder_Bind_instance_BinderOptions = _configBinder_InfoBuilder_Bind_instance_BinderOptions?.ToIncrementalValue(), + ConfigBinder_Bind_key_instance = _configBinder_InfoBuilder_Bind_key_instance?.ToIncrementalValue(), + }; + } + } + + internal sealed class TypedInterceptorInfoBuildler + { + private readonly Dictionary _invocationInfoBuilderCache = new(); + + public void RegisterInterceptor(MethodsToGen overload, ComplexTypeSpec type, IInvocationOperation invocation) + { + if (!_invocationInfoBuilderCache.TryGetValue(type, out TypedInterceptorInvocationInfo.Builder? invocationInfoBuilder)) + { + _invocationInfoBuilderCache[type] = invocationInfoBuilder = new TypedInterceptorInvocationInfo.Builder(overload, type); + } + + invocationInfoBuilder.RegisterInvocation(invocation); + } + + public ImmutableEquatableArray? ToIncrementalValue() => + _invocationInfoBuilderCache.Values + .Select(b => b.ToIncrementalValue()) + .ToImmutableEquatableArray(); + } + + public sealed record TypedInterceptorInvocationInfo(ComplexTypeSpec TargetType, ImmutableEquatableArray Locations) + { + public sealed class Builder(MethodsToGen Overload, ComplexTypeSpec TargetType) + { + private readonly List _infoList = new(); + + public void RegisterInvocation(IInvocationOperation invocation) => + _infoList.Add(new InvocationLocationInfo(Overload, invocation)); + + public TypedInterceptorInvocationInfo ToIncrementalValue() => new( + TargetType, + Locations: _infoList.ToImmutableEquatableArray()); + } + } + + public sealed record InvocationLocationInfo + { + public InvocationLocationInfo(MethodsToGen interceptor, IInvocationOperation invocation) + { + Debug.Assert(BinderInvocation.IsBindingOperation(invocation)); + + if (invocation.Syntax is not InvocationExpressionSyntax { Expression: MemberAccessExpressionSyntax memberAccessExprSyntax }) + { + const string InvalidInvocationErrMsg = "The invocation should have been validated upstream when selecting invocations to emit interceptors for."; + throw new ArgumentException(InvalidInvocationErrMsg, nameof(invocation)); + } + + SyntaxTree operationSyntaxTree = invocation.Syntax.SyntaxTree; + TextSpan memberNameSpan = memberAccessExprSyntax.Name.Span; + FileLinePositionSpan linePosSpan = operationSyntaxTree.GetLineSpan(memberNameSpan); + + Interceptor = interceptor; + LineNumber = linePosSpan.StartLinePosition.Line + 1; + CharacterNumber = linePosSpan.StartLinePosition.Character + 1; + FilePath = GetInterceptorFilePath(); + + // Use the same logic used by the interceptors API for resolving the source mapped value of a path. + // https://github.com/dotnet/roslyn/blob/f290437fcc75dad50a38c09e0977cce13a64f5ba/src/Compilers/CSharp/Portable/Compilation/CSharpCompilation.cs#L1063-L1064 + string GetInterceptorFilePath() + { + SourceReferenceResolver? sourceReferenceResolver = invocation.SemanticModel?.Compilation.Options.SourceReferenceResolver; + return sourceReferenceResolver?.NormalizePath(operationSyntaxTree.FilePath, baseFilePath: null) ?? operationSyntaxTree.FilePath; + } + } + + public MethodsToGen Interceptor { get; } + public string FilePath { get; } + public int LineNumber { get; } + public int CharacterNumber { get; } + } +} diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Model/MemberSpec.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/Members/MemberSpec.cs similarity index 82% rename from src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Model/MemberSpec.cs rename to src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/Members/MemberSpec.cs index 4bf674f597502a..dc5b03087ac87a 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Model/MemberSpec.cs +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/Members/MemberSpec.cs @@ -3,10 +3,11 @@ using System.Diagnostics; using Microsoft.CodeAnalysis; +using SourceGenerators; namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration { - internal abstract record MemberSpec + public abstract record MemberSpec { public MemberSpec(ISymbol member) { @@ -16,10 +17,9 @@ public MemberSpec(ISymbol member) } public string Name { get; } - public bool ErrorOnFailedBinding { get; protected set; } public string DefaultValueExpr { get; protected set; } - public required TypeSpec Type { get; init; } + public required TypeRef TypeRef { get; init; } public required string ConfigurationKeyName { get; init; } public abstract bool CanGet { get; } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Model/ParameterSpec.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/Members/ParameterSpec.cs similarity index 82% rename from src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Model/ParameterSpec.cs rename to src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/Members/ParameterSpec.cs index 9b5e4360c11169..62c781e1f1631f 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Model/ParameterSpec.cs +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/Members/ParameterSpec.cs @@ -6,7 +6,7 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration { - internal sealed record ParameterSpec : MemberSpec + public sealed record ParameterSpec : MemberSpec { public ParameterSpec(IParameterSymbol parameter) : base(parameter) { @@ -14,7 +14,7 @@ public ParameterSpec(IParameterSymbol parameter) : base(parameter) if (parameter.HasExplicitDefaultValue) { - string formatted = SymbolDisplay.FormatPrimitive(parameter.ExplicitDefaultValue, quoteStrings: true, useHexadecimalNumbers: false); + string formatted = SymbolDisplay.FormatPrimitive(parameter.ExplicitDefaultValue!, quoteStrings: true, useHexadecimalNumbers: false); if (formatted is not "null") { DefaultValueExpr = formatted; @@ -26,6 +26,8 @@ public ParameterSpec(IParameterSymbol parameter) : base(parameter) } } + public bool ErrorOnFailedBinding { get; private set; } + public RefKind RefKind { get; } public override bool CanGet => false; diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Model/PropertySpec.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/Members/PropertySpec.cs similarity index 90% rename from src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Model/PropertySpec.cs rename to src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/Members/PropertySpec.cs index 584e8d570b8a9f..443e39d32e4933 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Model/PropertySpec.cs +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/Members/PropertySpec.cs @@ -5,7 +5,7 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration { - internal sealed record PropertySpec : MemberSpec + public sealed record PropertySpec : MemberSpec { public PropertySpec(IPropertySymbol property) : base(property) { @@ -28,7 +28,5 @@ public PropertySpec(IPropertySymbol property) : base(property) public override bool CanGet { get; } public override bool CanSet { get; } - - public bool ShouldBind() => CanGet || CanSet; } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/MethodsToGen.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/MethodsToGen.cs new file mode 100644 index 00000000000000..af2a33fa6c2f80 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/MethodsToGen.cs @@ -0,0 +1,145 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; + +namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration +{ + [Flags] + public enum MethodsToGen_CoreBindingHelper + { + None = 0x0, + BindCore = 0x1, + GetCore = 0x2, + GetValueCore = 0x4, + BindCoreMain = 0x8, + Initialize = 0x10, + HasValueOrChildren = 0x20, + AsConfigWithChildren = 0x40, + ParsePrimitive = 0x80, + } + + /// + /// Methods on Microsoft.Extensions.Configuration.ConfigurationBinder + /// + [Flags] + public enum MethodsToGen + { + None = 0x0, + Any = ConfigBinder_Any | OptionsBuilderExt_Any | ServiceCollectionExt_Any, + + #region IConfiguration ext. method overloads: 0x1 - 0x400 + /// + /// Bind(IConfiguration, object?). + /// + ConfigBinder_Bind_instance = 0x1, + + /// + /// Bind(IConfiguration, object?, Action?). + /// + ConfigBinder_Bind_instance_BinderOptions = 0x2, + + /// + /// Bind(IConfiguration, string, object?). + /// + ConfigBinder_Bind_key_instance = 0x4, + + /// + /// Get(IConfiguration). + /// + ConfigBinder_Get_T = 0x8, + + /// + /// Get(IConfiguration, Action?). + /// + ConfigBinder_Get_T_BinderOptions = 0x10, + + /// + /// Get(IConfiguration, Type). + /// + ConfigBinder_Get_TypeOf = 0x20, + + /// + /// Get(IConfiguration, Type, Action?). + /// + ConfigBinder_Get_TypeOf_BinderOptions = 0x40, + + /// + /// GetValue(IConfiguration, Type, string). + /// + ConfigBinder_GetValue_TypeOf_key = 0x80, + + /// + /// GetValue(IConfiguration, Type, object?). + /// + ConfigBinder_GetValue_TypeOf_key_defaultValue = 0x100, + + /// + /// GetValue(IConfiguration, string). + /// + ConfigBinder_GetValue_T_key = 0x200, + + /// + /// GetValue(IConfiguration, string, T). + /// + ConfigBinder_GetValue_T_key_defaultValue = 0x400, + + // Method groups + ConfigBinder_Bind = ConfigBinder_Bind_instance | ConfigBinder_Bind_instance_BinderOptions | ConfigBinder_Bind_key_instance, + ConfigBinder_Get = ConfigBinder_Get_T | ConfigBinder_Get_T_BinderOptions | ConfigBinder_Get_TypeOf | ConfigBinder_Get_TypeOf_BinderOptions, + ConfigBinder_GetValue = ConfigBinder_GetValue_T_key | ConfigBinder_GetValue_T_key_defaultValue | ConfigBinder_GetValue_TypeOf_key | ConfigBinder_GetValue_TypeOf_key_defaultValue, + + ConfigBinder_Any = ConfigBinder_Bind | ConfigBinder_Get | ConfigBinder_GetValue, + #endregion ConfigurationBinder ext. method overloads. + + #region OptionsBuilder ext. method overloads: 0x800 - 0x2000 + /// + /// Bind(OptionsBuilder, IConfiguration). + /// + OptionsBuilderExt_Bind_T = 0x800, + + /// + /// Bind(OptionsBuilder, IConfiguration, Action?). + /// + OptionsBuilderExt_Bind_T_BinderOptions = 0x1000, + + /// + /// BindConfiguration(OptionsBuilder, string, Action?). + /// + OptionsBuilderExt_BindConfiguration_T_path_BinderOptions = 0x2000, + + // Method group. BindConfiguration_T is its own method group. + OptionsBuilderExt_Bind = OptionsBuilderExt_Bind_T | OptionsBuilderExt_Bind_T_BinderOptions, + + OptionsBuilderExt_BindConfiguration = OptionsBuilderExt_BindConfiguration_T_path_BinderOptions, + + OptionsBuilderExt_Any = OptionsBuilderExt_Bind | OptionsBuilderExt_BindConfiguration, + #endregion OptionsBuilder ext. method overloads. + + #region IServiceCollection ext. method overloads: 0x4000 - 0x20000 + /// + /// Configure(IServiceCollection, IConfiguration). + /// + ServiceCollectionExt_Configure_T = 0x4000, + + /// + /// Configure(IServiceCollection, string, IConfiguration). + /// + ServiceCollectionExt_Configure_T_name = 0x8000, + + /// + /// Configure(IServiceCollection, IConfiguration, Action?). + /// + ServiceCollectionExt_Configure_T_BinderOptions = 0x10000, + + /// + /// Configure(IServiceCollection, string, IConfiguration, Action?). + /// + ServiceCollectionExt_Configure_T_name_BinderOptions = 0x20000, + + ServiceCollectionExt_Configure = ServiceCollectionExt_Configure_T | ServiceCollectionExt_Configure_T_name | ServiceCollectionExt_Configure_T_BinderOptions | ServiceCollectionExt_Configure_T_name_BinderOptions, + + ServiceCollectionExt_Any = ServiceCollectionExt_Configure, + #endregion IServiceCollection ext. method overloads: 0x4000 - 0x20000 + } +} diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/SourceGenerationSpec.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/SourceGenerationSpec.cs new file mode 100644 index 00000000000000..4f57316429e2b1 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/SourceGenerationSpec.cs @@ -0,0 +1,14 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using SourceGenerators; + +namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration +{ + public sealed record SourceGenerationSpec + { + public required InterceptorInfo InterceptorInfo { get; init; } + public required BindingHelperInfo BindingHelperInfo { get; init; } + public required ImmutableEquatableArray ConfigTypes { get; init; } + } +} diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/TypeIndex.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/TypeIndex.cs new file mode 100644 index 00000000000000..5b59577b392921 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/TypeIndex.cs @@ -0,0 +1,122 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using SourceGenerators; + +namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration +{ + internal sealed class TypeIndex(IEnumerable typeSpecs) + { + private readonly Dictionary _index = typeSpecs.ToDictionary(spec => spec.TypeRef); + + public bool CanBindTo(TypeRef typeRef) => GetEffectiveTypeSpec(typeRef) switch + { + SimpleTypeSpec => true, + ComplexTypeSpec complexTypeSpec => CanInstantiate(complexTypeSpec) || HasBindableMembers(complexTypeSpec), + _ => throw new InvalidOperationException(), + }; + + public bool CanInstantiate(ComplexTypeSpec typeSpec) => typeSpec switch + { + ObjectSpec objectSpec => objectSpec is { InstantiationStrategy: not ObjectInstantiationStrategy.None, InitExceptionMessage: null }, + DictionarySpec dictionarySpec => KeyIsSupported(dictionarySpec), + CollectionSpec collectionSpec => CanBindTo(collectionSpec.ElementTypeRef), + _ => throw new InvalidOperationException(), + }; + + public bool HasBindableMembers(ComplexTypeSpec typeSpec) => + typeSpec switch + { + ObjectSpec objectSpec => objectSpec.Properties?.Any(ShouldBindTo) is true, + DictionarySpec dictSpec => KeyIsSupported(dictSpec) && CanBindTo(dictSpec.ElementTypeRef), + CollectionSpec collectionSpec => CanBindTo(collectionSpec.ElementTypeRef), + _ => throw new InvalidOperationException(), + }; + + public bool ShouldBindTo(PropertySpec property) + { + TypeSpec propTypeSpec = GetEffectiveTypeSpec(property.TypeRef); + return IsAccessible() && !IsCollectionAndCannotOverride() && !IsDictWithUnsupportedKey(); + + bool IsAccessible() => property.CanGet || property.CanSet; + + bool IsDictWithUnsupportedKey() => propTypeSpec is DictionarySpec dictionarySpec && !KeyIsSupported(dictionarySpec); + + bool IsCollectionAndCannotOverride() => !property.CanSet && + propTypeSpec is CollectionWithCtorInitSpec + { + InstantiationStrategy: CollectionInstantiationStrategy.CopyConstructor or CollectionInstantiationStrategy.LinqToDictionary + }; + } + + public TypeSpec GetEffectiveTypeSpec(TypeRef typeRef) + { + TypeSpec typeSpec = GetTypeSpec(typeRef); + return GetEffectiveTypeSpec(typeSpec); + } + + public TypeSpec GetEffectiveTypeSpec(TypeSpec typeSpec) + { + TypeRef effectiveRef = typeSpec.EffectiveTypeRef; + TypeSpec effectiveSpec = effectiveRef == typeSpec.TypeRef ? typeSpec : _index[effectiveRef]; + return effectiveSpec; + } + + public TypeSpec GetTypeSpec(TypeRef typeRef) => _index[typeRef]; + + public string GetInstantiationTypeDisplayString(CollectionWithCtorInitSpec type) + { + CollectionInstantiationConcreteType concreteType = type.InstantiationConcreteType; + return concreteType is CollectionInstantiationConcreteType.Self + ? type.DisplayString + : GetGenericTypeDisplayString(type, concreteType); + } + + public string GetPopulationCastTypeDisplayString(CollectionWithCtorInitSpec type) + { + CollectionPopulationCastType castType = type.PopulationCastType; + Debug.Assert(castType is not CollectionPopulationCastType.NotApplicable); + return GetGenericTypeDisplayString(type, castType); + } + + public string GetGenericTypeDisplayString(CollectionWithCtorInitSpec type, Enum genericProxyTypeName) + { + string proxyTypeNameStr = genericProxyTypeName.ToString(); + string elementTypeDisplayString = GetTypeSpec(type.ElementTypeRef).DisplayString; + + if (type is EnumerableSpec) + { + return $"{proxyTypeNameStr}<{elementTypeDisplayString}>"; + } + + string keyTypeDisplayString = GetTypeSpec(((DictionarySpec)type).KeyTypeRef).DisplayString; + return $"{proxyTypeNameStr}<{keyTypeDisplayString}, {elementTypeDisplayString}>"; + } + + public bool KeyIsSupported(DictionarySpec typeSpec) => + // Only types that are parsable from string are supported. + // Nullable keys not allowed; that would cause us to emit + // code that violates dictionary key notnull constraint. + GetTypeSpec(typeSpec.KeyTypeRef) is ParsableFromStringSpec; + + public static string GetConfigKeyCacheFieldName(ObjectSpec type) => $"s_configKeys_{type.IdentifierCompatibleSubstring}"; + + public static string GetParseMethodName(ParsableFromStringSpec type) + { + Debug.Assert(type.StringParsableTypeKind is not StringParsableTypeKind.AssignFromSectionValue); + + string displayString = type.DisplayString; + + string parseMethod = type.StringParsableTypeKind is StringParsableTypeKind.ByteArray + ? "ParseByteArray" + // MinimalDisplayString.Length is certainly > 2. + : $"Parse{(char.ToUpper(displayString[0]) + displayString.Substring(1)).Replace(".", "")}"; + + return parseMethod; + } + } +} diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/Types/CollectionSpec.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/Types/CollectionSpec.cs new file mode 100644 index 00000000000000..f891328f77af7c --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/Types/CollectionSpec.cs @@ -0,0 +1,68 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.CodeAnalysis; +using SourceGenerators; + +namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration +{ + internal abstract record CollectionSpec : ComplexTypeSpec + { + protected CollectionSpec(ITypeSymbol type) : base(type) { } + + public required TypeRef ElementTypeRef { get; init; } + + } + + internal abstract record CollectionWithCtorInitSpec : CollectionSpec + { + protected CollectionWithCtorInitSpec(ITypeSymbol type) : base(type) { } + + public required CollectionInstantiationStrategy InstantiationStrategy { get; init; } + + public required CollectionInstantiationConcreteType InstantiationConcreteType { get; init; } + + public required CollectionPopulationCastType PopulationCastType { get; init; } + } + + internal sealed record ArraySpec : CollectionSpec + { + public ArraySpec(ITypeSymbol type) : base(type) { } + } + + internal sealed record EnumerableSpec : CollectionWithCtorInitSpec + { + public EnumerableSpec(ITypeSymbol type) : base(type) { } + } + + internal sealed record DictionarySpec : CollectionWithCtorInitSpec + { + public DictionarySpec(INamedTypeSymbol type) : base(type) { } + + public required TypeRef KeyTypeRef { get; init; } + } + + internal enum CollectionInstantiationStrategy + { + NotApplicable = 0, + ParameterlessConstructor = 1, + CopyConstructor = 2, + LinqToDictionary = 3, + } + + internal enum CollectionInstantiationConcreteType + { + Self = 0, + Dictionary = 1, + List = 2, + HashSet = 3, + } + + internal enum CollectionPopulationCastType + { + NotApplicable = 0, + IDictionary = 1, + ICollection = 2, + ISet = 3, + } +} diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/Types/ObjectSpec.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/Types/ObjectSpec.cs new file mode 100644 index 00000000000000..abc01258d4190c --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/Types/ObjectSpec.cs @@ -0,0 +1,39 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.CodeAnalysis; +using SourceGenerators; + +namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration +{ + public sealed record ObjectSpec : ComplexTypeSpec + { + public ObjectSpec( + INamedTypeSymbol type, + ObjectInstantiationStrategy instantiationStrategy, + ImmutableEquatableArray? properties, + ImmutableEquatableArray? constructorParameters, + string? initExceptionMessage) : base(type) + { + InstantiationStrategy = instantiationStrategy; + Properties = properties; + ConstructorParameters = constructorParameters; + InitExceptionMessage = initExceptionMessage; + } + + public ObjectInstantiationStrategy InstantiationStrategy { get; } + + public ImmutableEquatableArray? Properties { get; } + + public ImmutableEquatableArray? ConstructorParameters { get; } + + public string? InitExceptionMessage { get; } + } + + public enum ObjectInstantiationStrategy + { + None = 0, + ParameterlessConstructor = 1, + ParameterizedConstructor = 2, + } +} diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/Types/SimpleTypeSpec.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/Types/SimpleTypeSpec.cs new file mode 100644 index 00000000000000..70c7a8042e0359 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/Types/SimpleTypeSpec.cs @@ -0,0 +1,42 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.CodeAnalysis; + +namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration +{ + public abstract record SimpleTypeSpec : TypeSpec + { + public SimpleTypeSpec(ITypeSymbol type) : base(type) { } + } + + internal sealed record ConfigurationSectionSpec : SimpleTypeSpec + { + public ConfigurationSectionSpec(ITypeSymbol type) : base(type) { } + } + + public sealed record ParsableFromStringSpec : SimpleTypeSpec + { + public ParsableFromStringSpec(ITypeSymbol type) : base(type) { } + + public required StringParsableTypeKind StringParsableTypeKind { get; init; } + } + + public enum StringParsableTypeKind + { + None = 0, + + /// + /// Declared types that can be assigned directly from IConfigurationSection.Value, i.e. string and typeof(object). + /// + AssignFromSectionValue = 1, + Enum = 2, + ByteArray = 3, + Integer = 4, + Float = 5, + Parse = 6, + ParseInvariant = 7, + CultureInfo = 8, + Uri = 9, + } +} diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/Types/TypeSpec.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/Types/TypeSpec.cs new file mode 100644 index 00000000000000..1c243ae1cdc7c1 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/Types/TypeSpec.cs @@ -0,0 +1,66 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics; +using Microsoft.CodeAnalysis; +using SourceGenerators; + +namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration +{ + [DebuggerDisplay("Name={DisplayString}, Kind={SpecKind}")] + public abstract record TypeSpec + { + public TypeSpec(ITypeSymbol type) + { + TypeRef = new TypeRef(type); + EffectiveTypeRef = TypeRef; // Overriden by NullableSpec. + (Namespace, DisplayString, Name) = type.GetTypeName(); + IdentifierCompatibleSubstring = type.ToIdentifierCompatibleSubstring(); + IsValueType = type.IsValueType; + } + + public TypeRef TypeRef { get; } + + public TypeRef EffectiveTypeRef { get; protected init; } + + public string Name { get; } + + public string DisplayString { get; } + + public string IdentifierCompatibleSubstring { get; } + + public string? Namespace { get; } + + public bool IsValueType { get; } + } + + public abstract record ComplexTypeSpec : TypeSpec + { + protected ComplexTypeSpec(ITypeSymbol type) : base(type) { } + } + + internal sealed record NullableSpec : TypeSpec + { + public NullableSpec(ITypeSymbol type, TypeRef underlyingTypeRef) : base(type) => + EffectiveTypeRef = underlyingTypeRef; + } + + internal sealed record UnsupportedTypeSpec : TypeSpec + { + public UnsupportedTypeSpec(ITypeSymbol type) : base(type) { } + + public required NotSupportedReason NotSupportedReason { get; init; } + } + + public enum NotSupportedReason + { + UnknownType = 1, + MissingPublicInstanceConstructor = 2, + CollectionNotSupported = 3, + DictionaryKeyNotSupported = 4, + ElementTypeNotSupported = 5, + MultipleParameterizedConstructors = 6, + MultiDimArraysNotSupported = 7, + NullableUnderlyingTypeNotSupported = 8, + } +} diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/src/ConfigurationBinder.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/src/ConfigurationBinder.cs index 8651e4922e0d72..dfc35d80208553 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/src/ConfigurationBinder.cs +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/src/ConfigurationBinder.cs @@ -91,6 +91,7 @@ public static class ConfigurationBinder Action? configureOptions) { ThrowHelper.ThrowIfNull(configuration); + ThrowHelper.ThrowIfNull(type); var options = new BinderOptions(); configureOptions?.Invoke(options); @@ -108,7 +109,10 @@ public static class ConfigurationBinder [RequiresDynamicCode(DynamicCodeWarningMessage)] [RequiresUnreferencedCode(InstanceGetTypeTrimmingWarningMessage)] public static void Bind(this IConfiguration configuration, string key, object? instance) - => configuration.GetSection(key).Bind(instance); + { + ThrowHelper.ThrowIfNull(configuration); + configuration.GetSection(key).Bind(instance); + } /// /// Attempts to bind the given object instance to configuration values by matching property names against configuration keys recursively. @@ -200,6 +204,9 @@ public static void Bind(this IConfiguration configuration, object? instance, Act Type type, string key, object? defaultValue) { + ThrowHelper.ThrowIfNull(configuration); + ThrowHelper.ThrowIfNull(type); + IConfigurationSection section = configuration.GetSection(key); string? value = section.Value; if (value != null) @@ -291,6 +298,11 @@ private static void BindInstance( return; } + if (config is null) + { + return; + } + var section = config as IConfigurationSection; string? configValue = section?.Value; if (configValue != null && TryConvertValue(type, configValue, section?.Path, out object? convertedValue, out Exception? error)) @@ -305,123 +317,120 @@ private static void BindInstance( return; } - if (config != null) + if (config.GetChildren().Any()) { - if (config.GetChildren().Any()) + // for arrays and read-only list-like interfaces, we concatenate on to what is already there, if we can + if (type.IsArray || IsImmutableArrayCompatibleInterface(type)) { - // for arrays and read-only list-like interfaces, we concatenate on to what is already there, if we can - if (type.IsArray || IsImmutableArrayCompatibleInterface(type)) + if (!bindingPoint.IsReadOnly) { - if (!bindingPoint.IsReadOnly) - { - bindingPoint.SetValue(BindArray(type, (IEnumerable?)bindingPoint.Value, config, options)); - } - - // for getter-only collection properties that we can't add to, nothing more we can do - return; + bindingPoint.SetValue(BindArray(type, (IEnumerable?)bindingPoint.Value, config, options)); } - // ----------------------------------------------------------------------------------------------------------------------------- - // | bindingPoint | bindingPoint | - // Interface | Value | IsReadOnly | Behavior - // ----------------------------------------------------------------------------------------------------------------------------- - // ISet | not null | true/false | Use the Value instance to populate the configuration - // ISet | null | false | Create HashSet instance to populate the configuration - // ISet | null | true | nothing - // IReadOnlySet | null/not null | false | Create HashSet instance, copy over existing values, and populate the configuration - // IReadOnlySet | null/not null | true | nothing - // ----------------------------------------------------------------------------------------------------------------------------- - if (TypeIsASetInterface(type)) - { - if (!bindingPoint.IsReadOnly || bindingPoint.Value is not null) - { - object? newValue = BindSet(type, (IEnumerable?)bindingPoint.Value, config, options); - if (!bindingPoint.IsReadOnly && newValue != null) - { - bindingPoint.SetValue(newValue); - } - } - - return; - } + // for getter-only collection properties that we can't add to, nothing more we can do + return; + } - // ----------------------------------------------------------------------------------------------------------------------------- - // | bindingPoint | bindingPoint | - // Interface | Value | IsReadOnly | Behavior - // ----------------------------------------------------------------------------------------------------------------------------- - // IDictionary | not null | true/false | Use the Value instance to populate the configuration - // IDictionary | null | false | Create Dictionary instance to populate the configuration - // IDictionary | null | true | nothing - // IReadOnlyDictionary | null/not null | false | Create Dictionary instance, copy over existing values, and populate the configuration - // IReadOnlyDictionary | null/not null | true | nothing - // ----------------------------------------------------------------------------------------------------------------------------- - if (TypeIsADictionaryInterface(type)) + // ----------------------------------------------------------------------------------------------------------------------------- + // | bindingPoint | bindingPoint | + // Interface | Value | IsReadOnly | Behavior + // ----------------------------------------------------------------------------------------------------------------------------- + // ISet | not null | true/false | Use the Value instance to populate the configuration + // ISet | null | false | Create HashSet instance to populate the configuration + // ISet | null | true | nothing + // IReadOnlySet | null/not null | false | Create HashSet instance, copy over existing values, and populate the configuration + // IReadOnlySet | null/not null | true | nothing + // ----------------------------------------------------------------------------------------------------------------------------- + if (TypeIsASetInterface(type)) + { + if (!bindingPoint.IsReadOnly || bindingPoint.Value is not null) { - if (!bindingPoint.IsReadOnly || bindingPoint.Value is not null) + object? newValue = BindSet(type, (IEnumerable?)bindingPoint.Value, config, options); + if (!bindingPoint.IsReadOnly && newValue != null) { - object? newValue = BindDictionaryInterface(bindingPoint.Value, type, config, options); - if (!bindingPoint.IsReadOnly && newValue != null) - { - bindingPoint.SetValue(newValue); - } + bindingPoint.SetValue(newValue); } - - return; } - // If we don't have an instance, try to create one - if (bindingPoint.Value is null) + return; + } + + // ----------------------------------------------------------------------------------------------------------------------------- + // | bindingPoint | bindingPoint | + // Interface | Value | IsReadOnly | Behavior + // ----------------------------------------------------------------------------------------------------------------------------- + // IDictionary | not null | true/false | Use the Value instance to populate the configuration + // IDictionary | null | false | Create Dictionary instance to populate the configuration + // IDictionary | null | true | nothing + // IReadOnlyDictionary | null/not null | false | Create Dictionary instance, copy over existing values, and populate the configuration + // IReadOnlyDictionary | null/not null | true | nothing + // ----------------------------------------------------------------------------------------------------------------------------- + if (TypeIsADictionaryInterface(type)) + { + if (!bindingPoint.IsReadOnly || bindingPoint.Value is not null) { - // if the binding point doesn't let us set a new instance, there's nothing more we can do - if (bindingPoint.IsReadOnly) + object? newValue = BindDictionaryInterface(bindingPoint.Value, type, config, options); + if (!bindingPoint.IsReadOnly && newValue != null) { - return; + bindingPoint.SetValue(newValue); } + } - Type? interfaceGenericType = type.IsInterface && type.IsConstructedGenericType ? type.GetGenericTypeDefinition() : null; + return; + } - if (interfaceGenericType is not null && - (interfaceGenericType == typeof(ICollection<>) || interfaceGenericType == typeof(IList<>))) - { - // For ICollection and IList we bind them to mutable List type. - Type genericType = typeof(List<>).MakeGenericType(type.GenericTypeArguments); - bindingPoint.SetValue(Activator.CreateInstance(genericType)); - } - else - { - bindingPoint.SetValue(CreateInstance(type, config, options)); - } + // If we don't have an instance, try to create one + if (bindingPoint.Value is null) + { + // if the binding point doesn't let us set a new instance, there's nothing more we can do + if (bindingPoint.IsReadOnly) + { + return; } - Debug.Assert(bindingPoint.Value is not null); - - // At this point we know that we have a non-null bindingPoint.Value, we just have to populate the items - // using the IDictionary<> or ICollection<> interfaces, or properties using reflection. - Type? dictionaryInterface = FindOpenGenericInterface(typeof(IDictionary<,>), type); + Type? interfaceGenericType = type.IsInterface && type.IsConstructedGenericType ? type.GetGenericTypeDefinition() : null; - if (dictionaryInterface != null) + if (interfaceGenericType is not null && + (interfaceGenericType == typeof(ICollection<>) || interfaceGenericType == typeof(IList<>))) { - BindDictionary(bindingPoint.Value, dictionaryInterface, config, options); + // For ICollection and IList we bind them to mutable List type. + Type genericType = typeof(List<>).MakeGenericType(type.GenericTypeArguments); + bindingPoint.SetValue(Activator.CreateInstance(genericType)); } else { - Type? collectionInterface = FindOpenGenericInterface(typeof(ICollection<>), type); - if (collectionInterface != null) - { - BindCollection(bindingPoint.Value, collectionInterface, config, options); - } - else - { - BindProperties(bindingPoint.Value, config, options); - } + bindingPoint.SetValue(CreateInstance(type, config, options)); } } + + Debug.Assert(bindingPoint.Value is not null); + + // At this point we know that we have a non-null bindingPoint.Value, we just have to populate the items + // using the IDictionary<> or ICollection<> interfaces, or properties using reflection. + Type? dictionaryInterface = FindOpenGenericInterface(typeof(IDictionary<,>), type); + + if (dictionaryInterface != null) + { + BindDictionary(bindingPoint.Value, dictionaryInterface, config, options); + } else { - if (isParentCollection) + Type? collectionInterface = FindOpenGenericInterface(typeof(ICollection<>), type); + if (collectionInterface != null) { - bindingPoint.TrySetValue(CreateInstance(type, config, options)); + BindCollection(bindingPoint.Value, collectionInterface, config, options); } + else + { + BindProperties(bindingPoint.Value, config, options); + } + } + } + else + { + if (isParentCollection) + { + bindingPoint.TrySetValue(CreateInstance(type, config, options)); } } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/src/PACKAGE.md b/src/libraries/Microsoft.Extensions.Configuration.Binder/src/PACKAGE.md index ffb402bece35f0..861b0687843ea0 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/src/PACKAGE.md +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/src/PACKAGE.md @@ -1,10 +1,23 @@ ## About + + Provides the functionality to bind an object to data in configuration providers for [Microsoft.Extensions.Configuration](https://www.nuget.org/packages/Microsoft.Extensions.Configuration/). This package enables you to represent the configuration data as strongly-typed classes defined in the application code. To bind a configuration, use the [Microsoft.Extensions.Configuration.ConfigurationBinder.Get](https://learn.microsoft.com/dotnet/api/microsoft.extensions.configuration.configurationbinder.get) extension method on the `IConfiguration` object. To use this package, you also need to install a package for the [configuration provider](https://learn.microsoft.com/dotnet/core/extensions/configuration#configuration-providers), for example, [Microsoft.Extensions.Configuration.Json](https://www.nuget.org/packages/Microsoft.Extensions.Configuration.Json/) for the JSON provider. -For more information, see the documentation: [Configuration in .NET](https://learn.microsoft.com/dotnet/core/extensions/configuration). +The types contained in this assembly use Reflection at runtime which is not friendly with linking or AOT. To better support linking and AOT as well as provide more efficient strongly-typed binding methods - this package also provides a source generator. This generator is enabled by default when a project sets `PublishAot` but can also be enabled using `true`. + +## Key Features + + + +* Configuring existing type instances from a configuration section (Bind) +* Constructing new configured type instances from a configuration section (Get & GetValue) +* Generating source to bind objects from a configuration section without a runtime reflection dependency. + +## How to Use + + -## Example The following example shows how to bind a JSON configuration section to .NET objects. ```cs @@ -42,7 +55,7 @@ class Program // Read nested objects Console.WriteLine("Endpoints: "); - + foreach (Endpoint endpoint in settings.Endpoints) { Console.WriteLine($"{endpoint.IPAddress}:{endpoint.Port}"); @@ -81,3 +94,45 @@ You can include a configuration file using a code like this in your `.csproj` fi ``` + +You can add the following property to enable the source generator. This requires a .NET 8.0 SDK or later. +```xml + + true + +``` + +## Main Types + + + +The main types provided by this library are: + +* `Microsoft.Extensions.Configuration.ConfigurationBinder` +* `Microsoft.Extensions.Configuration.BinderOptions` + +## Additional Documentation + + + +* [Configuration in .NET](https://learn.microsoft.com/dotnet/core/extensions/configuration) +* [API documentation](https://learn.microsoft.com/dotnet/api/microsoft.extensions.configuration) + +## Related Packages + + +* [Microsoft.Extensions.Configuration](https://www.nuget.org/packages/Microsoft.Extensions.Configuration) +* [Microsoft.Extensions.Configuration.Abstractions](https://www.nuget.org/packages/Microsoft.Extensions.Configuration.Abstractions) +* [Microsoft.Extensions.Configuration.CommandLine](https://www.nuget.org/packages/Microsoft.Extensions.Configuration.CommandLine) +* [Microsoft.Extensions.Configuration.EnvironmentVariables](https://www.nuget.org/packages/Microsoft.Extensions.Configuration.EnvironmentVariables) +* [Microsoft.Extensions.Configuration.FileExtensions](https://www.nuget.org/packages/Microsoft.Extensions.Configuration.FileExtensions) +* [Microsoft.Extensions.Configuration.Ini](https://www.nuget.org/packages/Microsoft.Extensions.Configuration.Ini) +* [Microsoft.Extensions.Configuration.Json](https://www.nuget.org/packages/Microsoft.Extensions.Configuration.Json) +* [Microsoft.Extensions.Configuration.UserSecrets](https://www.nuget.org/packages/Microsoft.Extensions.Configuration.UserSecrets) +* [Microsoft.Extensions.Configuration.Xml](https://www.nuget.org/packages/Microsoft.Extensions.Configuration.Xml) + +## Feedback & Contributing + + + +Microsoft.Extensions.Configuration.Binder is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/src/buildTransitive/Microsoft.Extensions.Configuration.Binder.targets b/src/libraries/Microsoft.Extensions.Configuration.Binder/src/buildTransitive/Microsoft.Extensions.Configuration.Binder.targets index f091c7a57b23ae..fdfd48d12a75ad 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/src/buildTransitive/Microsoft.Extensions.Configuration.Binder.targets +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/src/buildTransitive/Microsoft.Extensions.Configuration.Binder.targets @@ -1,5 +1,10 @@ - + + $(InterceptorsPreviewNamespaces);Microsoft.Extensions.Configuration.Binder.SourceGeneration + + + @@ -8,8 +13,8 @@ - + <_Microsoft_Extensions_Configuration_Binder_Compatible_TargetFramework Condition="$([MSBuild]::IsTargetFrameworkCompatible('$(TargetFramework)', 'netcoreapp2.0')) AND diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/Common/ConfigurationBinderTests.Helpers.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/Common/ConfigurationBinderTests.Helpers.cs index a1d1a72ffab20f..d6521ed86dfdec 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/Common/ConfigurationBinderTests.Helpers.cs +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/Common/ConfigurationBinderTests.Helpers.cs @@ -160,5 +160,17 @@ public class CustomICollectionWithoutAnAddMethod : ICollection public int Count => _items.Count; public bool IsReadOnly => false; } + + public interface IGeolocation + { + public double Latitude { get; set; } + public double Longitude { get; set; } + } + + public sealed record GeolocationRecord : IGeolocation + { + public double Latitude { get; set; } + public double Longitude { get; set; } + } #endregion } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/Common/ConfigurationBinderTests.TestClasses.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/Common/ConfigurationBinderTests.TestClasses.cs index 1e537b407963d9..7d10f66c822fc0 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/Common/ConfigurationBinderTests.TestClasses.cs +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/Common/ConfigurationBinderTests.TestClasses.cs @@ -3,10 +3,12 @@ using System; using System.Collections.Generic; +using System.Collections.ObjectModel; using System.ComponentModel; using System.Globalization; -using System.Text.Json; using System.Linq; +using System.Net; +using System.Text.Json; using Microsoft.Extensions.Configuration; using Xunit; @@ -441,6 +443,10 @@ public record NestedConfig(string MyProp); public class OptionWithCollectionProperties { private int _otherCode; + private int _otherCodeNullable; + private string _otherCodeString = "default"; + private object _otherCodeNull; + private Uri _otherCodeUri; private ICollection blacklist = new HashSet(); public ICollection Blacklist @@ -458,12 +464,37 @@ public ICollection Blacklist // ParsedBlacklist initialized using the setter of Blacklist. public ICollection ParsedBlacklist { get; private set; } = new HashSet(); - // This property not having any match in the configuration. Still the setter need to be called during the binding. + // This does not have a match in the configuration, however the setter should be called during the binding: public int OtherCode { get => _otherCode; set => _otherCode = value == 0 ? 2 : value; } + + // These do not have any match in the configuration, and the setters should not be called during the binding: + public int? OtherCodeNullable + { + get => _otherCodeNullable; + set => _otherCodeNullable = !value.HasValue ? 3 : value.Value; + } + + public string OtherCodeString + { + get => _otherCodeString; + set => _otherCodeString = value; + } + + public object? OtherCodeNull + { + get => _otherCodeNull; + set => _otherCodeNull = value is null ? 4 : value; + } + + public Uri OtherCodeUri + { + get => _otherCodeUri; + set => _otherCodeUri = value is null ? new Uri("hello") : value; + } } public interface ISomeInterface @@ -543,6 +574,47 @@ public struct DeeplyNested } } + public struct StructWithNestedStructAndSetterLogic + { + private string _string; + private int _int32; + + public string String + { + get => _string; + // Setter should not be called for missing values. + set { _string = string.IsNullOrEmpty(value) ? "Hello" : value; } + } + + public int Int32 + { + get => _int32; + set { _int32 = value == 0 ? 42 : value; } + } + + public Nested NestedStruct; + public Nested[] NestedStructs; + + public struct Nested + { + private string _string; + private int _int32; + + public string String + { + get => _string; + // Setter should not be called for missing values. + set { _string = string.IsNullOrEmpty(value) ? "Hello2" : value; } + } + + public int Int32 + { + get => _int32; + set { _int32 = value == 0 ? 43 : value; } + } + } + } + public class BaseClassWithVirtualProperty { private string? PrivateProperty { get; set; } @@ -667,12 +739,6 @@ public struct StructWithParameterlessAndParameterizedCtor public int MyInt { get; } } - public interface IGeolocation - { - public double Latitude { get; set; } - public double Longitude { get; set; } - } - [TypeConverter(typeof(GeolocationTypeConverter))] public struct Geolocation : IGeolocation { @@ -704,12 +770,6 @@ public sealed class GeolocationClass : IGeolocation public double Longitude { get; set; } } - public sealed record GeolocationRecord : IGeolocation - { - public double Latitude { get; set; } - public double Longitude { get; set; } - } - public class GeolocationWrapper { public Geolocation Location { get; set; } @@ -742,5 +802,102 @@ public record OidcProviderOptions { public string? Authority { get; set; } } + + public class AClass + { + public EndPointCollection EndPoints { get; init; } = new EndPointCollection(); + + public bool Property { get; set; } = false; + } + + public sealed class EndPointCollection : Collection, IEnumerable + { + public EndPointCollection() { } + + public void Add(string hostAndPort) + { + EndPoint? endpoint; + + if (IPAddress.TryParse(hostAndPort, out IPAddress? address)) + { + endpoint = new IPEndPoint(address, 0); + } + else + { + endpoint = new DnsEndPoint(hostAndPort, 0); + } + + Add(endpoint); + } + } + + internal abstract class AbstractBase + { + public int Value { get; set; } + } + + internal sealed class Derived : AbstractBase { } + + internal sealed class DerivedWithAnotherProp : AbstractBase + { + public int Value2 { get; set; } + } + + internal class ClassWithAbstractProp + { + public AbstractBase AbstractProp { get; set; } + } + + internal class ClassWithAbstractCtorParam + { + public AbstractBase AbstractProp { get; } + + public ClassWithAbstractCtorParam(AbstractBase abstractProp) => AbstractProp = abstractProp; + } + + internal class ClassWithOptionalAbstractCtorParam + { + public AbstractBase AbstractProp { get; } + + public ClassWithOptionalAbstractCtorParam(AbstractBase? abstractProp = null) => AbstractProp = abstractProp; + } + + internal class ClassWith_DirectlyAssignable_CtorParams + { + public IConfigurationSection MySection { get; } + public object MyObject { get; } + public string MyString { get; } + + public ClassWith_DirectlyAssignable_CtorParams(IConfigurationSection mySection, object myObject, string myString) => + (MySection, MyObject, MyString) = (mySection, myObject, myString); + } + + public class SharedChildInstance_Class + { + public string? ConnectionString { get; set; } + } + + public class ClassThatThrowsOnSetters + { + private int _myIntProperty; + + public ClassThatThrowsOnSetters() + { + _myIntProperty = 42; + } + + public int MyIntProperty + { + get => _myIntProperty; + set => throw new InvalidOperationException("Not expected"); + } + } + + public class SimplePoco + { + public string A { get; set; } + public string B { get; set; } + } + } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/Common/ConfigurationBinderTests.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/Common/ConfigurationBinderTests.cs index 56dddc6f8bc83b..296cb790c22ba5 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/Common/ConfigurationBinderTests.cs +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/Common/ConfigurationBinderTests.cs @@ -8,10 +8,10 @@ using System.Globalization; using System.Linq; using System.Reflection; -using System.Text; #if BUILDING_SOURCE_GENERATOR_TESTS using Microsoft.Extensions.Configuration; #endif +using Microsoft.Extensions.Configuration.Memory; using Microsoft.Extensions.Configuration.Test; using Xunit; @@ -31,7 +31,7 @@ public ConfigurationBinderTestsBase() } } - public sealed partial class ConfigurationBinderTests : ConfigurationBinderTestsBase + public partial class ConfigurationBinderTests : ConfigurationBinderTestsBase { [Fact] public void BindWithNestedTypesWithReadOnlyProperties() @@ -1581,6 +1581,53 @@ public void CanBindNestedStructProperties() Assert.True(bound.ReadWriteNestedStruct.DeeplyNested.Boolean); } + [Fact] + public void CanBindNestedStructProperties_SetterCalledWithMissingConfigEntry() + { + ConfigurationBuilder configurationBuilder = new(); + configurationBuilder.AddInMemoryCollection(new Dictionary + { + { "dmy", "dmy" }, + }); + + IConfiguration config = configurationBuilder.Build(); + + var bound = config.Get(); + Assert.Null(bound.String); + Assert.Null(bound.NestedStruct.String); + Assert.Equal(42, bound.Int32); + Assert.Equal(0, bound.NestedStruct.Int32); + } + + [Fact] + public void CanBindNestedStructProperties_SetterNotCalledWithMissingConfigSection() + { + ConfigurationBuilder configurationBuilder = new(); + configurationBuilder.AddInMemoryCollection(new Dictionary + { + // An empty value will not trigger defaulting. + }); + + IConfiguration config = configurationBuilder.Build(); + + var bound = config.Get(); + Assert.Null(bound.String); + Assert.Null(bound.NestedStruct.String); + Assert.Equal(0, bound.Int32); + Assert.Equal(0, bound.NestedStruct.Int32); + } + + [Fact] + public void CanBindNestedStructProperties_SetterCalledWithMissingConfig_Array() + { + var config = TestHelpers.GetConfigurationFromJsonString( + """{"value": [{ }]}"""); + + var bound = config.GetSection("value").Get(); + Assert.Null(bound[0].String); + Assert.Equal(0, bound[0].Int32); + } + [Fact] public void IgnoresReadOnlyNestedStructProperties() { @@ -1713,13 +1760,29 @@ public void EnsureCallingThePropertySetter() Assert.Equal(2, options.ParsedBlacklist.Count); // should be initialized when calling the options.Blacklist setter. Assert.Equal(401, options.HttpStatusCode); // exists in configuration and properly sets the property -#if BUILDING_SOURCE_GENERATOR_TESTS - // Setter not called if there's no matching configuration value. - Assert.Equal(0, options.OtherCode); -#else - // doesn't exist in configuration. the setter sets default value '2' + + // This doesn't exist in configuration but the setter should be called which defaults the to '2' from input of '0'. Assert.Equal(2, options.OtherCode); -#endif + + // These don't exist in configuration and setters are not called since they are nullable. + Assert.Equal(0, options.OtherCodeNullable); + Assert.Equal("default", options.OtherCodeString); + Assert.Null(options.OtherCodeNull); + Assert.Null(options.OtherCodeUri); + } + + [Fact] + public void EnsureNotCallingSettersWhenGivenExistingInstanceNotInConfig() + { + var builder = new ConfigurationBuilder(); + builder.AddInMemoryCollection(new KeyValuePair[] { }); + var config = builder.Build(); + + ClassThatThrowsOnSetters instance = new(); + + // The setter for MyIntProperty throws, so this verifies that the setter is not called. + config.GetSection("Dmy").Bind(instance); + Assert.Equal(42, instance.MyIntProperty); } [Fact] @@ -1941,12 +2004,17 @@ public void BindRootStructIsNoOp() } """); - StructWithNestedStructs.DeeplyNested obj = new(); #pragma warning disable SYSLIB1103 + StructWithNestedStructs.DeeplyNested obj = new(); configuration.Bind(obj); -#pragma warning restore SYSLIB1103 Assert.Equal(0, obj.Int32); Assert.False(obj.Boolean); + + StructWithNestedStructs.DeeplyNested? nullableObj = new(); + configuration.Bind(nullableObj); + Assert.Equal(0, obj.Int32); + Assert.False(obj.Boolean); +#pragma warning restore SYSLIB1103 } [Fact] @@ -2024,7 +2092,7 @@ public void ComplexObj_As_Dictionary_Element() [Fact] public void ComplexObj_As_Enumerable_Element() { - var configuration = TestHelpers.GetConfigurationFromJsonString("""{ "Enumerable": [{ "Latitude": 3, "Longitude": 4 }] }""") + var configuration = TestHelpers.GetConfigurationFromJsonString("""{ "Enumerable": [{ "Latitude": 3, "Longitude": 4 }] }""") .GetSection("Enumerable"); Geolocation obj = configuration.Get>()[0]; @@ -2037,6 +2105,7 @@ public void ComplexObj_As_Enumerable_Element() ValidateGeolocation(obj); } +#if !BUILDING_SOURCE_GENERATOR_TESTS [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsNotBrowser))] public void TraceSwitchTest() { @@ -2056,6 +2125,7 @@ public void TraceSwitchTest() Assert.Equal("Info", ts.Value); #endif // NETCOREAPP } +#endif private void ValidateGeolocation(IGeolocation location) { @@ -2097,14 +2167,9 @@ public void CanBindToObjectMembers() TestBind(options => config.GetSection("Local").Bind(RemoteAuthenticationOptions.s_NonGenericField), obj => RemoteAuthenticationOptions.s_NonGenericField); // No null refs. -#if BUILDING_SOURCE_GENERATOR_TESTS - - Assert.Throws(() => config.GetSection("Local").Bind(new RemoteAuthenticationOptions().NullGenericProp)); - Assert.Throws(() => config.GetSection("Local").Bind(RemoteAuthenticationOptions.s_NullNonGenericField)); -#else config.GetSection("Local").Bind(new RemoteAuthenticationOptions().NullGenericProp); config.GetSection("Local").Bind(RemoteAuthenticationOptions.s_NullNonGenericField); -#endif + static void TestBind(Action> configure, Func, OidcProviderOptions> getBindedProp) { var obj = new RemoteAuthenticationOptions(); @@ -2120,5 +2185,277 @@ public void BinderSupportsObjCreationInput() // No diagnostic warning SYSLIB1104. configuration.Bind(new GraphWithUnsupportedMember()); } + + [Fact] + public void TestNullHandling_Get() + { + // Null configuration. + IConfiguration? configuration = null; + + Assert.Throws(() => configuration.Get()); + Assert.Throws(() => configuration.Get(_ => { })); + Assert.Throws(() => configuration.Get()); + Assert.Throws(() => configuration.Get(_ => { })); + + // Null Type. + configuration = TestHelpers.GetConfigurationFromJsonString(@"{""Longitude"":1,""Latitude"":2}"); +#pragma warning disable SYSLIB1104 // The target type for a binder call could not be determined + Assert.Throws(() => configuration.Get(type: null)); + Assert.Throws(() => configuration.Get(type: null, _ => { })); +#pragma warning restore SYSLIB1104 // The target type for a binder call could not be determined + } + + [Fact] + public void TestNullHandling_GetValue() + { + string key = "Longitude"; + + // Null configuration. + Test(configuration: null, key); + + // Null type. + IConfiguration configuration = TestHelpers.GetConfigurationFromJsonString(@"{""Longitude"":1,""Latitude"":2}"); +#pragma warning disable SYSLIB1104 // The target type for a binder call could not be determined + Assert.Throws(() => configuration.GetValue(type: null, key)); + Assert.Throws(() => configuration.GetValue(type: null, key, defaultValue: null)); +#pragma warning restore SYSLIB1104 // The target type for a binder call could not be determined + + // Null key. + Test(configuration: configuration, key: null); + + void Test(IConfiguration? configuration, string? key) + { + Assert.Throws(() => configuration.GetValue(key)); + Assert.Throws(() => configuration.GetValue(key, defaultValue: null)); + Assert.Throws(() => configuration.GetValue(key)); + Assert.Throws(() => configuration.GetValue(key, defaultValue: default)); + TestUntypedOverloads(configuration: null, key); + } + + void TestUntypedOverloads(IConfiguration? configuration, string? key) + { + Assert.Throws(() => configuration.GetValue(typeof(GeolocationClass), key)); + Assert.Throws(() => configuration.GetValue(typeof(GeolocationClass), key, defaultValue: null)); + Assert.Throws(() => configuration.GetValue(typeof(GeolocationClass), key, new GeolocationClass())); + Assert.Throws(() => configuration.GetValue(typeof(Geolocation), key)); + Assert.Throws(() => configuration.GetValue(typeof(Geolocation), key, defaultValue: null)); + Assert.Throws(() => configuration.GetValue(typeof(Geolocation), key, default(Geolocation))); + } + } + + [Fact] + public void TestNullHandling_Bind() + { + // Null configuration. + IConfiguration? configuration = null; + GeolocationClass? location = new(); + Assert.Throws(() => configuration.Bind(location)); + Assert.Throws(() => configuration.Bind(location, _ => { })); + Assert.Throws(() => configuration.Bind("", location)); + + // Null object. + configuration = TestHelpers.GetConfigurationFromJsonString(@"{""Longitude"":1,""Latitude"":2}"); + location = null; + // Expect no exceptions. + configuration.Bind(location); + configuration.Bind(location, _ => { }); + configuration.Bind("", location); + } + + [Fact] + public void TestAbstractTypeAsNestedMemberForBinding() + { + // Regression test for https://github.com/dotnet/runtime/issues/91324. + + IConfiguration configuration = new ConfigurationBuilder().AddInMemoryCollection( + new KeyValuePair[] + { + new KeyValuePair("ConfigBindRepro:EndPoints:0", "localhost"), + new KeyValuePair("ConfigBindRepro:Property", "true") + }) + .Build(); + + AClass settings = new(); + configuration.GetSection("ConfigBindRepro").Bind(settings); + + Assert.Empty(settings.EndPoints); // Need custom binding feature to map "localhost" string into Endpoint instance. + Assert.True(settings.Property); + } + + [Fact] + public static void TestGettingAbstractType() + { + IConfiguration configuration = TestHelpers.GetConfigurationFromJsonString(@"{""Value"":1}"); + Assert.Throws(() => configuration.Get()); + } + + [Fact] + public static void TestBindingAbstractInstance() + { + // Regression tests for https://github.com/dotnet/runtime/issues/90974. + // We only bind members on the declared binding type, i.e. AbstractBase, even + // though the actual instances are derived types that may have their own properties. + + IConfiguration configuration = TestHelpers.GetConfigurationFromJsonString(@"{""Value"":1,""Value2"":2}"); + + AbstractBase d = new Derived(); + configuration.Bind(d); + Assert.Equal(1, d.Value); + + d = new DerivedWithAnotherProp(); + configuration.Bind(d); + Assert.Equal(1, d.Value); + +#if BUILDING_SOURCE_GENERATOR_TESTS + // Divergence from reflection impl: reflection binds using instance type, + // while src-gen can only use declared type (everything has to be known AOT). + // This could change if we add an explicit API to indicate the expected runtime type(s). + Assert.Equal(0, ((DerivedWithAnotherProp)d).Value2); +#else + Assert.Equal(2, ((DerivedWithAnotherProp)d).Value2); +#endif + } + + [Fact] + public static void TestBindingAbstractMember_AsCtorParam() + { + IConfiguration configuration = TestHelpers.GetConfigurationFromJsonString(@"{ ""AbstractProp"": {""Value"":1} }"); + Assert.Throws(configuration.Get); + Assert.Throws(configuration.Get); + } + + [Fact] + public static void TestBindingInitializedAbstractMember() + { + IConfiguration configuration = TestHelpers.GetConfigurationFromJsonString(@"{ ""AbstractProp"": {""Value"":1} }"); + ClassWithAbstractProp c = new(); + c.AbstractProp = new Derived(); + configuration.Bind(c); + Assert.Equal(1, c.AbstractProp.Value); + } + + [Fact] + public static void TestBindingUninitializedAbstractMember() + { + IConfiguration configuration = TestHelpers.GetConfigurationFromJsonString(@"{ ""AbstractProp"": {""Value"":1} }"); + ClassWithAbstractProp c = new(); + c.AbstractProp = null; + Assert.Throws(() => configuration.Bind(c)); + } + + [Fact] + public void GetIConfigurationSection() + { + var configuration = TestHelpers.GetConfigurationFromJsonString(""" + { + "vaLue": "MyString", + } + """); + + var obj = configuration.GetSection("value").Get(); + Assert.Equal("MyString", obj.Value); + + configuration = TestHelpers.GetConfigurationFromJsonString(""" + { + "vaLue": [ "MyString", { "nested": "value" } ], + } + """); + + var list = configuration.GetSection("value").Get>(); + ValidateList(list); + + var dict = configuration.Get>>(); + Assert.Equal(1, dict.Count); + ValidateList(dict["vaLue"]); + + static void ValidateList(List list) + { + Assert.Equal(2, list.Count); + Assert.Equal("0", list[0].Key); + Assert.Equal("MyString", list[0].Value); + + Assert.Equal("1", list[1].Key); + var nestedSection = Assert.IsAssignableFrom(list[1].GetSection("nested")); + Assert.Equal("value", nestedSection.Value); + } + } + + [Fact] + public void NullableDictKeys() + { + var configuration = TestHelpers.GetConfigurationFromJsonString("""{ "1": "MyString" }"""); + var dict = configuration.Get>(); + Assert.Empty(dict); + } + + [Fact] + public void IConfigurationSectionAsCtorParam() + { + var configuration = TestHelpers.GetConfigurationFromJsonString(""" + { + "MySection": "MySection", + "MyObject": "MyObject", + "MyString": "MyString", + } + """); + + var obj = configuration.Get(); + Assert.Equal("MySection", obj.MySection.Value); + Assert.Equal("MyObject", obj.MyObject); + Assert.Equal("MyString", obj.MyString); + } + + [Fact] + public void SharedChildInstance() + { + var builder = new ConfigurationBuilder(); + builder.AddInMemoryCollection(new KeyValuePair[] + { + new("A:B:ConnectionString", "localhost"), + }); + + var config = builder.Build(); + + SharedChildInstance_Class instance = new(); + config.GetSection("A:B").Bind(instance); + Assert.Equal("localhost", instance.ConnectionString); + + // Binding to a new section should not set the value to null. + config.GetSection("A").Bind(instance); + Assert.Equal("localhost", instance.ConnectionString); + } + + [Fact] + public void CanBindToMockConfigurationSection() + { + const string expectedA = "hello"; + + var configSource = new MemoryConfigurationSource() + { + InitialData = new Dictionary() + { + [$":{nameof(SimplePoco.A)}"] = expectedA, + } + }; + var configRoot = new MockConfigurationRoot(new[] { configSource.Build(null) }); + var configSection = new ConfigurationSection(configRoot, string.Empty); + + SimplePoco result = new(); + configSection.Bind(result); + + Assert.Equal(expectedA, result.A); + Assert.Equal(default(string), result.B); + } + + // a mock configuration root that will return null for undefined Sections, + // as is common when Configuration interfaces are mocked + class MockConfigurationRoot : ConfigurationRoot, IConfigurationRoot + { + public MockConfigurationRoot(IList providers) : base(providers) + { } + + IConfigurationSection IConfiguration.GetSection(string key) => + this[key] is null ? null : new ConfigurationSection(this, key); + } } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/Collections.generated.txt b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/Collections.generated.txt index 2149dcaaa07141..ea4fba79cbc465 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/Collections.generated.txt +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/Collections.generated.txt @@ -2,12 +2,19 @@ #nullable enable #pragma warning disable CS0612, CS0618 // Suppress warnings about [Obsolete] member usage in generated code. -/// Generated helper providing an AOT and linking compatible implementation for configuration binding. -[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] -internal static class GeneratedConfigurationBinder +namespace System.Runtime.CompilerServices { - /// Attempts to bind the configuration instance to a new instance of type T. - public static T? Get(this global::Microsoft.Extensions.Configuration.IConfiguration configuration) => (T?)(global::Microsoft.Extensions.Configuration.Binder.SourceGeneration.CoreBindingHelper.GetCore(configuration, typeof(T), configureOptions: null) ?? default(T)); + using System; + using System.CodeDom.Compiler; + + [GeneratedCode("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] + [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] + file sealed class InterceptsLocationAttribute : Attribute + { + public InterceptsLocationAttribute(string filePath, int line, int column) + { + } + } } namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration @@ -18,12 +25,19 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration using System.Collections.Generic; using System.Globalization; using System.Linq; + using System.Runtime.CompilerServices; - /// Provide core binding logic. [GeneratedCode("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] - file static class CoreBindingHelper + file static class BindingExtensions { - private readonly static Lazy> s_configKeys_ProgramMyClassWithCustomCollections = new(() => new HashSet(StringComparer.OrdinalIgnoreCase) { "CustomDictionary", "CustomList", "IReadOnlyList", "IReadOnlyDictionary" }); + #region IConfiguration extensions. + /// Attempts to bind the configuration instance to a new instance of type T. + [InterceptsLocation(@"src-0.cs", 12, 17)] + public static T? Get(this IConfiguration configuration) => (T?)(GetCore(configuration, typeof(T), configureOptions: null) ?? default(T)); + #endregion IConfiguration extensions. + + #region Core binding extensions. + private readonly static Lazy> s_configKeys_ProgramMyClassWithCustomCollections = new(() => new HashSet(StringComparer.OrdinalIgnoreCase) { "CustomDictionary", "CustomList", "ICustomDictionary", "ICustomCollection", "IReadOnlyList", "UnsupportedIReadOnlyDictionaryUnsupported", "IReadOnlyDictionary" }); public static object? GetCore(this IConfiguration configuration, Type type, Action? configureOptions) { @@ -41,86 +55,39 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration if (type == typeof(Program.MyClassWithCustomCollections)) { - var obj = new Program.MyClassWithCustomCollections(); - BindCore(configuration, ref obj, binderOptions); - return obj; + var instance = new Program.MyClassWithCustomCollections(); + BindCore(configuration, ref instance, defaultValueIfNotFound: true, binderOptions); + return instance; } - throw new global::System.NotSupportedException($"Unable to bind to type '{type}': generator did not detect the type as input."); + throw new NotSupportedException($"Unable to bind to type '{type}': generator did not detect the type as input."); } - public static void BindCore(IConfiguration configuration, ref Program.CustomDictionary obj, BinderOptions? binderOptions) + public static void BindCore(IConfiguration configuration, ref Program.CustomDictionary instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { - if (obj is null) - { - throw new ArgumentNullException(nameof(obj)); - } - foreach (IConfigurationSection section in configuration.GetChildren()) { if (section.Value is string value) { - obj[section.Key] = ParseInt(value, () => section.Path); + instance[section.Key] = ParseInt(value, () => section.Path); } } } - public static void BindCore(IConfiguration configuration, ref Program.CustomList obj, BinderOptions? binderOptions) + public static void BindCore(IConfiguration configuration, ref Program.CustomList instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { - if (obj is null) - { - throw new ArgumentNullException(nameof(obj)); - } - foreach (IConfigurationSection section in configuration.GetChildren()) { if (section.Value is string value) { - obj.Add(value); + instance.Add(value); } } } - public static void BindCore(IConfiguration configuration, ref List obj, BinderOptions? binderOptions) + public static void BindCore(IConfiguration configuration, ref IReadOnlyList instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { - if (obj is null) - { - throw new ArgumentNullException(nameof(obj)); - } - - foreach (IConfigurationSection section in configuration.GetChildren()) - { - if (section.Value is string value) - { - obj.Add(ParseInt(value, () => section.Path)); - } - } - } - - public static void BindCore(IConfiguration configuration, ref ICollection obj, BinderOptions? binderOptions) - { - if (obj is null) - { - throw new ArgumentNullException(nameof(obj)); - } - - foreach (IConfigurationSection section in configuration.GetChildren()) - { - if (section.Value is string value) - { - obj.Add(ParseInt(value, () => section.Path)); - } - } - } - - public static void BindCore(IConfiguration configuration, ref IReadOnlyList obj, BinderOptions? binderOptions) - { - if (obj is null) - { - throw new ArgumentNullException(nameof(obj)); - } - - if (obj is not ICollection temp) + if (instance is not ICollection temp) { return; } @@ -134,46 +101,9 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration } } - public static void BindCore(IConfiguration configuration, ref Dictionary obj, BinderOptions? binderOptions) - { - if (obj is null) - { - throw new ArgumentNullException(nameof(obj)); - } - - foreach (IConfigurationSection section in configuration.GetChildren()) - { - if (section.Value is string value) - { - obj[section.Key] = ParseInt(value, () => section.Path); - } - } - } - - public static void BindCore(IConfiguration configuration, ref IDictionary obj, BinderOptions? binderOptions) - { - if (obj is null) - { - throw new ArgumentNullException(nameof(obj)); - } - - foreach (IConfigurationSection section in configuration.GetChildren()) - { - if (section.Value is string value) - { - obj[section.Key] = ParseInt(value, () => section.Path); - } - } - } - - public static void BindCore(IConfiguration configuration, ref IReadOnlyDictionary obj, BinderOptions? binderOptions) + public static void BindCore(IConfiguration configuration, ref IReadOnlyDictionary instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { - if (obj is null) - { - throw new ArgumentNullException(nameof(obj)); - } - - if (obj is not IDictionary temp) + if (instance is not IDictionary temp) { return; } @@ -187,48 +117,44 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration } } - public static void BindCore(IConfiguration configuration, ref Program.MyClassWithCustomCollections obj, BinderOptions? binderOptions) + public static void BindCore(IConfiguration configuration, ref Program.MyClassWithCustomCollections instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { - if (obj is null) - { - throw new ArgumentNullException(nameof(obj)); - } - ValidateConfigurationKeys(typeof(Program.MyClassWithCustomCollections), s_configKeys_ProgramMyClassWithCustomCollections, configuration, binderOptions); if (AsConfigWithChildren(configuration.GetSection("CustomDictionary")) is IConfigurationSection section1) { - Program.CustomDictionary? temp3 = obj.CustomDictionary; + Program.CustomDictionary? temp3 = instance.CustomDictionary; temp3 ??= new Program.CustomDictionary(); - BindCore(section1, ref temp3, binderOptions); - obj.CustomDictionary = temp3; + BindCore(section1, ref temp3, defaultValueIfNotFound: false, binderOptions); + instance.CustomDictionary = temp3; } if (AsConfigWithChildren(configuration.GetSection("CustomList")) is IConfigurationSection section4) { - Program.CustomList? temp6 = obj.CustomList; + Program.CustomList? temp6 = instance.CustomList; temp6 ??= new Program.CustomList(); - BindCore(section4, ref temp6, binderOptions); - obj.CustomList = temp6; + BindCore(section4, ref temp6, defaultValueIfNotFound: false, binderOptions); + instance.CustomList = temp6; } if (AsConfigWithChildren(configuration.GetSection("IReadOnlyList")) is IConfigurationSection section7) { - IReadOnlyList? temp9 = obj.IReadOnlyList; - temp9 = temp9 is null ? new List() : new List(temp9); - BindCore(section7, ref temp9, binderOptions); - obj.IReadOnlyList = temp9; + IReadOnlyList? temp9 = instance.IReadOnlyList; + temp9 = temp9 is null ? (IReadOnlyList)new List() : (IReadOnlyList)new List(temp9); + BindCore(section7, ref temp9, defaultValueIfNotFound: false, binderOptions); + instance.IReadOnlyList = temp9; } if (AsConfigWithChildren(configuration.GetSection("IReadOnlyDictionary")) is IConfigurationSection section10) { - IReadOnlyDictionary? temp12 = obj.IReadOnlyDictionary; - temp12 = temp12 is null ? new Dictionary() : temp12.ToDictionary(pair => pair.Key, pair => pair.Value); - BindCore(section10, ref temp12, binderOptions); - obj.IReadOnlyDictionary = temp12; + IReadOnlyDictionary? temp12 = instance.IReadOnlyDictionary; + temp12 = temp12 is null ? (IReadOnlyDictionary)new Dictionary() : (IReadOnlyDictionary)temp12.ToDictionary(pair => pair.Key, pair => pair.Value); + BindCore(section10, ref temp12, defaultValueIfNotFound: false, binderOptions); + instance.IReadOnlyDictionary = temp12; } } + /// If required by the binder options, validates that there are no unknown keys in the input configuration object. public static void ValidateConfigurationKeys(Type type, Lazy> keys, IConfiguration configuration, BinderOptions? binderOptions) { @@ -281,7 +207,7 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration if (binderOptions.BindNonPublicProperties) { - throw new global::System.NotSupportedException($"The configuration binding source generator does not support 'BinderOptions.BindNonPublicProperties'."); + throw new NotSupportedException($"The configuration binding source generator does not support 'BinderOptions.BindNonPublicProperties'."); } return binderOptions; @@ -298,5 +224,6 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration throw new InvalidOperationException($"Failed to convert configuration value at '{getPath()}' to type '{typeof(int)}'.", exception); } } + #endregion Core binding extensions. } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Bind.generated.txt b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Bind.generated.txt index 406e8db6716777..fc0dda4b5b3ae9 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Bind.generated.txt +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Bind.generated.txt @@ -2,18 +2,19 @@ #nullable enable #pragma warning disable CS0612, CS0618 // Suppress warnings about [Obsolete] member usage in generated code. -/// Generated helper providing an AOT and linking compatible implementation for configuration binding. -[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] -internal static class GeneratedConfigurationBinder +namespace System.Runtime.CompilerServices { - /// Attempts to bind the given object instance to configuration values by matching property names against configuration keys recursively. - public static void Bind(this global::Microsoft.Extensions.Configuration.IConfiguration configuration, global::Program.MyClass obj) => global::Microsoft.Extensions.Configuration.Binder.SourceGeneration.CoreBindingHelper.BindCore(configuration, ref obj, binderOptions: null); - - /// Attempts to bind the given object instance to configuration values by matching property names against configuration keys recursively. - public static void Bind(this global::Microsoft.Extensions.Configuration.IConfiguration configuration, global::Program.MyClass obj, global::System.Action? configureOptions) => global::Microsoft.Extensions.Configuration.Binder.SourceGeneration.CoreBindingHelper.BindCore(configuration, ref obj, global::Microsoft.Extensions.Configuration.Binder.SourceGeneration.CoreBindingHelper.GetBinderOptions(configureOptions)); + using System; + using System.CodeDom.Compiler; - /// Attempts to bind the given object instance to configuration values by matching property names against configuration keys recursively. - public static void Bind(this global::Microsoft.Extensions.Configuration.IConfiguration configuration, string key, global::Program.MyClass obj) => global::Microsoft.Extensions.Configuration.Binder.SourceGeneration.CoreBindingHelper.BindCore(configuration.GetSection(key), ref obj, binderOptions: null); + [GeneratedCode("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] + [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] + file sealed class InterceptsLocationAttribute : Attribute + { + public InterceptsLocationAttribute(string filePath, int line, int column) + { + } + } } namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration @@ -23,103 +24,148 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration using System.CodeDom.Compiler; using System.Collections.Generic; using System.Globalization; + using System.Runtime.CompilerServices; - /// Provide core binding logic. [GeneratedCode("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] - file static class CoreBindingHelper + file static class BindingExtensions { - private readonly static Lazy> s_configKeys_ProgramMyClass = new(() => new HashSet(StringComparer.OrdinalIgnoreCase) { "MyString", "MyInt", "MyList", "MyDictionary", "MyComplexDictionary" }); + #region IConfiguration extensions. + /// Attempts to bind the given object instance to configuration values by matching property names against configuration keys recursively. + [InterceptsLocation(@"src-0.cs", 12, 14)] + public static void Bind_ProgramMyClass(this IConfiguration configuration, object? instance) + { + if (configuration is null) + { + throw new ArgumentNullException(nameof(configuration)); + } + + if (instance is null) + { + return; + } + + var typedObj = (Program.MyClass)instance; + BindCore(configuration, ref typedObj, defaultValueIfNotFound: false, binderOptions: null); + } - public static void BindCore(IConfiguration configuration, ref List obj, BinderOptions? binderOptions) + /// Attempts to bind the given object instance to configuration values by matching property names against configuration keys recursively. + [InterceptsLocation(@"src-0.cs", 13, 20)] + public static void Bind_ProgramMyClass(this IConfiguration configuration, object? instance, Action? configureOptions) { - if (obj is null) + if (configuration is null) { - throw new ArgumentNullException(nameof(obj)); + throw new ArgumentNullException(nameof(configuration)); } - foreach (IConfigurationSection section in configuration.GetChildren()) + if (instance is null) { - if (section.Value is string value) - { - obj.Add(ParseInt(value, () => section.Path)); - } + return; } + + var typedObj = (Program.MyClass)instance; + BindCore(configuration, ref typedObj, defaultValueIfNotFound: false, GetBinderOptions(configureOptions)); } - public static void BindCore(IConfiguration configuration, ref Dictionary obj, BinderOptions? binderOptions) + /// Attempts to bind the given object instance to configuration values by matching property names against configuration keys recursively. + [InterceptsLocation(@"src-0.cs", 14, 20)] + public static void Bind_ProgramMyClass(this IConfiguration configuration, string key, object? instance) { - if (obj is null) + if (configuration is null) { - throw new ArgumentNullException(nameof(obj)); + throw new ArgumentNullException(nameof(configuration)); } + if (instance is null) + { + return; + } + + var typedObj = (Program.MyClass)instance; + BindCore(configuration.GetSection(key), ref typedObj, defaultValueIfNotFound: false, binderOptions: null); + } + #endregion IConfiguration extensions. + + #region Core binding extensions. + private readonly static Lazy> s_configKeys_ProgramMyClass = new(() => new HashSet(StringComparer.OrdinalIgnoreCase) { "MyString", "MyInt", "MyList", "MyDictionary", "MyComplexDictionary" }); + + public static void BindCore(IConfiguration configuration, ref List instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) + { foreach (IConfigurationSection section in configuration.GetChildren()) { if (section.Value is string value) { - obj[section.Key] = value; + instance.Add(ParseInt(value, () => section.Path)); } } } - public static void BindCore(IConfiguration configuration, ref Dictionary obj, BinderOptions? binderOptions) + public static void BindCore(IConfiguration configuration, ref Dictionary instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { - if (obj is null) + foreach (IConfigurationSection section in configuration.GetChildren()) { - throw new ArgumentNullException(nameof(obj)); + if (section.Value is string value) + { + instance[section.Key] = value; + } } + } + public static void BindCore(IConfiguration configuration, ref Dictionary instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) + { foreach (IConfigurationSection section in configuration.GetChildren()) { - if (!(obj.TryGetValue(section.Key, out Program.MyClass2? element) && element is not null)) + if (!(instance.TryGetValue(section.Key, out Program.MyClass2? element) && element is not null)) { element = new Program.MyClass2(); } - obj[section.Key] = element; + instance[section.Key] = element; } } - public static void BindCore(IConfiguration configuration, ref Program.MyClass obj, BinderOptions? binderOptions) + public static void BindCore(IConfiguration configuration, ref Program.MyClass instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { - if (obj is null) - { - throw new ArgumentNullException(nameof(obj)); - } - ValidateConfigurationKeys(typeof(Program.MyClass), s_configKeys_ProgramMyClass, configuration, binderOptions); - obj.MyString = configuration["MyString"]!; + if (configuration["MyString"] is string value0) + { + instance.MyString = value0; + } if (configuration["MyInt"] is string value1) { - obj.MyInt = ParseInt(value1, () => configuration.GetSection("MyInt").Path); + instance.MyInt = ParseInt(value1, () => configuration.GetSection("MyInt").Path); + } + else if (defaultValueIfNotFound) + { + instance.MyInt = default; } if (AsConfigWithChildren(configuration.GetSection("MyList")) is IConfigurationSection section2) { - List? temp4 = obj.MyList; + List? temp4 = instance.MyList; temp4 ??= new List(); - BindCore(section2, ref temp4, binderOptions); - obj.MyList = temp4; + BindCore(section2, ref temp4, defaultValueIfNotFound: false, binderOptions); + instance.MyList = temp4; } if (AsConfigWithChildren(configuration.GetSection("MyDictionary")) is IConfigurationSection section5) { - Dictionary? temp7 = obj.MyDictionary; + Dictionary? temp7 = instance.MyDictionary; temp7 ??= new Dictionary(); - BindCore(section5, ref temp7, binderOptions); - obj.MyDictionary = temp7; + BindCore(section5, ref temp7, defaultValueIfNotFound: false, binderOptions); + instance.MyDictionary = temp7; } if (AsConfigWithChildren(configuration.GetSection("MyComplexDictionary")) is IConfigurationSection section8) { - Dictionary? temp10 = obj.MyComplexDictionary; + Dictionary? temp10 = instance.MyComplexDictionary; temp10 ??= new Dictionary(); - BindCore(section8, ref temp10, binderOptions); - obj.MyComplexDictionary = temp10; + BindCore(section8, ref temp10, defaultValueIfNotFound: false, binderOptions); + instance.MyComplexDictionary = temp10; } } + /// If required by the binder options, validates that there are no unknown keys in the input configuration object. public static void ValidateConfigurationKeys(Type type, Lazy> keys, IConfiguration configuration, BinderOptions? binderOptions) { @@ -163,7 +209,7 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration if (binderOptions.BindNonPublicProperties) { - throw new global::System.NotSupportedException($"The configuration binding source generator does not support 'BinderOptions.BindNonPublicProperties'."); + throw new NotSupportedException($"The configuration binding source generator does not support 'BinderOptions.BindNonPublicProperties'."); } return binderOptions; @@ -180,5 +226,6 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration throw new InvalidOperationException($"Failed to convert configuration value at '{getPath()}' to type '{typeof(int)}'.", exception); } } + #endregion Core binding extensions. } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Bind_Instance.generated.txt b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Bind_Instance.generated.txt index 106e01795369e2..9cbc0a22c5c84d 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Bind_Instance.generated.txt +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Bind_Instance.generated.txt @@ -2,12 +2,19 @@ #nullable enable #pragma warning disable CS0612, CS0618 // Suppress warnings about [Obsolete] member usage in generated code. -/// Generated helper providing an AOT and linking compatible implementation for configuration binding. -[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] -internal static class GeneratedConfigurationBinder +namespace System.Runtime.CompilerServices { - /// Attempts to bind the given object instance to configuration values by matching property names against configuration keys recursively. - public static void Bind(this global::Microsoft.Extensions.Configuration.IConfiguration configuration, global::Program.MyClass obj) => global::Microsoft.Extensions.Configuration.Binder.SourceGeneration.CoreBindingHelper.BindCore(configuration, ref obj, binderOptions: null); + using System; + using System.CodeDom.Compiler; + + [GeneratedCode("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] + [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] + file sealed class InterceptsLocationAttribute : Attribute + { + public InterceptsLocationAttribute(string filePath, int line, int column) + { + } + } } namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration @@ -17,103 +24,112 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration using System.CodeDom.Compiler; using System.Collections.Generic; using System.Globalization; + using System.Runtime.CompilerServices; - /// Provide core binding logic. [GeneratedCode("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] - file static class CoreBindingHelper + file static class BindingExtensions { - private readonly static Lazy> s_configKeys_ProgramMyClass = new(() => new HashSet(StringComparer.OrdinalIgnoreCase) { "MyString", "MyInt", "MyList", "MyDictionary", "MyComplexDictionary" }); - - public static void BindCore(IConfiguration configuration, ref List obj, BinderOptions? binderOptions) + #region IConfiguration extensions. + /// Attempts to bind the given object instance to configuration values by matching property names against configuration keys recursively. + [InterceptsLocation(@"src-0.cs", 12, 20)] + public static void Bind_ProgramMyClass(this IConfiguration configuration, object? instance) { - if (obj is null) + if (configuration is null) + { + throw new ArgumentNullException(nameof(configuration)); + } + + if (instance is null) { - throw new ArgumentNullException(nameof(obj)); + return; } + var typedObj = (Program.MyClass)instance; + BindCore(configuration, ref typedObj, defaultValueIfNotFound: false, binderOptions: null); + } + #endregion IConfiguration extensions. + + #region Core binding extensions. + private readonly static Lazy> s_configKeys_ProgramMyClass = new(() => new HashSet(StringComparer.OrdinalIgnoreCase) { "MyString", "MyInt", "MyList", "MyDictionary", "MyComplexDictionary" }); + + public static void BindCore(IConfiguration configuration, ref List instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) + { foreach (IConfigurationSection section in configuration.GetChildren()) { if (section.Value is string value) { - obj.Add(ParseInt(value, () => section.Path)); + instance.Add(ParseInt(value, () => section.Path)); } } } - public static void BindCore(IConfiguration configuration, ref Dictionary obj, BinderOptions? binderOptions) + public static void BindCore(IConfiguration configuration, ref Dictionary instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { - if (obj is null) - { - throw new ArgumentNullException(nameof(obj)); - } - foreach (IConfigurationSection section in configuration.GetChildren()) { if (section.Value is string value) { - obj[section.Key] = value; + instance[section.Key] = value; } } } - public static void BindCore(IConfiguration configuration, ref Dictionary obj, BinderOptions? binderOptions) + public static void BindCore(IConfiguration configuration, ref Dictionary instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { - if (obj is null) - { - throw new ArgumentNullException(nameof(obj)); - } - foreach (IConfigurationSection section in configuration.GetChildren()) { - if (!(obj.TryGetValue(section.Key, out Program.MyClass2? element) && element is not null)) + if (!(instance.TryGetValue(section.Key, out Program.MyClass2? element) && element is not null)) { element = new Program.MyClass2(); } - obj[section.Key] = element; + instance[section.Key] = element; } } - public static void BindCore(IConfiguration configuration, ref Program.MyClass obj, BinderOptions? binderOptions) + public static void BindCore(IConfiguration configuration, ref Program.MyClass instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { - if (obj is null) - { - throw new ArgumentNullException(nameof(obj)); - } - ValidateConfigurationKeys(typeof(Program.MyClass), s_configKeys_ProgramMyClass, configuration, binderOptions); - obj.MyString = configuration["MyString"]!; + if (configuration["MyString"] is string value0) + { + instance.MyString = value0; + } if (configuration["MyInt"] is string value1) { - obj.MyInt = ParseInt(value1, () => configuration.GetSection("MyInt").Path); + instance.MyInt = ParseInt(value1, () => configuration.GetSection("MyInt").Path); + } + else if (defaultValueIfNotFound) + { + instance.MyInt = default; } if (AsConfigWithChildren(configuration.GetSection("MyList")) is IConfigurationSection section2) { - List? temp4 = obj.MyList; + List? temp4 = instance.MyList; temp4 ??= new List(); - BindCore(section2, ref temp4, binderOptions); - obj.MyList = temp4; + BindCore(section2, ref temp4, defaultValueIfNotFound: false, binderOptions); + instance.MyList = temp4; } if (AsConfigWithChildren(configuration.GetSection("MyDictionary")) is IConfigurationSection section5) { - Dictionary? temp7 = obj.MyDictionary; + Dictionary? temp7 = instance.MyDictionary; temp7 ??= new Dictionary(); - BindCore(section5, ref temp7, binderOptions); - obj.MyDictionary = temp7; + BindCore(section5, ref temp7, defaultValueIfNotFound: false, binderOptions); + instance.MyDictionary = temp7; } if (AsConfigWithChildren(configuration.GetSection("MyComplexDictionary")) is IConfigurationSection section8) { - Dictionary? temp10 = obj.MyComplexDictionary; + Dictionary? temp10 = instance.MyComplexDictionary; temp10 ??= new Dictionary(); - BindCore(section8, ref temp10, binderOptions); - obj.MyComplexDictionary = temp10; + BindCore(section8, ref temp10, defaultValueIfNotFound: false, binderOptions); + instance.MyComplexDictionary = temp10; } } + /// If required by the binder options, validates that there are no unknown keys in the input configuration object. public static void ValidateConfigurationKeys(Type type, Lazy> keys, IConfiguration configuration, BinderOptions? binderOptions) { @@ -156,5 +172,6 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration throw new InvalidOperationException($"Failed to convert configuration value at '{getPath()}' to type '{typeof(int)}'.", exception); } } + #endregion Core binding extensions. } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Bind_Instance_BinderOptions.generated.txt b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Bind_Instance_BinderOptions.generated.txt index a1cb7d6b93b5d0..f3efa07fc0c5c2 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Bind_Instance_BinderOptions.generated.txt +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Bind_Instance_BinderOptions.generated.txt @@ -2,12 +2,19 @@ #nullable enable #pragma warning disable CS0612, CS0618 // Suppress warnings about [Obsolete] member usage in generated code. -/// Generated helper providing an AOT and linking compatible implementation for configuration binding. -[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] -internal static class GeneratedConfigurationBinder +namespace System.Runtime.CompilerServices { - /// Attempts to bind the given object instance to configuration values by matching property names against configuration keys recursively. - public static void Bind(this global::Microsoft.Extensions.Configuration.IConfiguration configuration, global::Program.MyClass obj, global::System.Action? configureOptions) => global::Microsoft.Extensions.Configuration.Binder.SourceGeneration.CoreBindingHelper.BindCore(configuration, ref obj, global::Microsoft.Extensions.Configuration.Binder.SourceGeneration.CoreBindingHelper.GetBinderOptions(configureOptions)); + using System; + using System.CodeDom.Compiler; + + [GeneratedCode("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] + [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] + file sealed class InterceptsLocationAttribute : Attribute + { + public InterceptsLocationAttribute(string filePath, int line, int column) + { + } + } } namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration @@ -17,103 +24,112 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration using System.CodeDom.Compiler; using System.Collections.Generic; using System.Globalization; + using System.Runtime.CompilerServices; - /// Provide core binding logic. [GeneratedCode("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] - file static class CoreBindingHelper + file static class BindingExtensions { - private readonly static Lazy> s_configKeys_ProgramMyClass = new(() => new HashSet(StringComparer.OrdinalIgnoreCase) { "MyString", "MyInt", "MyList", "MyDictionary", "MyComplexDictionary" }); - - public static void BindCore(IConfiguration configuration, ref List obj, BinderOptions? binderOptions) + #region IConfiguration extensions. + /// Attempts to bind the given object instance to configuration values by matching property names against configuration keys recursively. + [InterceptsLocation(@"src-0.cs", 12, 20)] + public static void Bind_ProgramMyClass(this IConfiguration configuration, object? instance, Action? configureOptions) { - if (obj is null) + if (configuration is null) + { + throw new ArgumentNullException(nameof(configuration)); + } + + if (instance is null) { - throw new ArgumentNullException(nameof(obj)); + return; } + var typedObj = (Program.MyClass)instance; + BindCore(configuration, ref typedObj, defaultValueIfNotFound: false, GetBinderOptions(configureOptions)); + } + #endregion IConfiguration extensions. + + #region Core binding extensions. + private readonly static Lazy> s_configKeys_ProgramMyClass = new(() => new HashSet(StringComparer.OrdinalIgnoreCase) { "MyString", "MyInt", "MyList", "MyDictionary", "MyComplexDictionary" }); + + public static void BindCore(IConfiguration configuration, ref List instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) + { foreach (IConfigurationSection section in configuration.GetChildren()) { if (section.Value is string value) { - obj.Add(ParseInt(value, () => section.Path)); + instance.Add(ParseInt(value, () => section.Path)); } } } - public static void BindCore(IConfiguration configuration, ref Dictionary obj, BinderOptions? binderOptions) + public static void BindCore(IConfiguration configuration, ref Dictionary instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { - if (obj is null) - { - throw new ArgumentNullException(nameof(obj)); - } - foreach (IConfigurationSection section in configuration.GetChildren()) { if (section.Value is string value) { - obj[section.Key] = value; + instance[section.Key] = value; } } } - public static void BindCore(IConfiguration configuration, ref Dictionary obj, BinderOptions? binderOptions) + public static void BindCore(IConfiguration configuration, ref Dictionary instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { - if (obj is null) - { - throw new ArgumentNullException(nameof(obj)); - } - foreach (IConfigurationSection section in configuration.GetChildren()) { - if (!(obj.TryGetValue(section.Key, out Program.MyClass2? element) && element is not null)) + if (!(instance.TryGetValue(section.Key, out Program.MyClass2? element) && element is not null)) { element = new Program.MyClass2(); } - obj[section.Key] = element; + instance[section.Key] = element; } } - public static void BindCore(IConfiguration configuration, ref Program.MyClass obj, BinderOptions? binderOptions) + public static void BindCore(IConfiguration configuration, ref Program.MyClass instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { - if (obj is null) - { - throw new ArgumentNullException(nameof(obj)); - } - ValidateConfigurationKeys(typeof(Program.MyClass), s_configKeys_ProgramMyClass, configuration, binderOptions); - obj.MyString = configuration["MyString"]!; + if (configuration["MyString"] is string value0) + { + instance.MyString = value0; + } if (configuration["MyInt"] is string value1) { - obj.MyInt = ParseInt(value1, () => configuration.GetSection("MyInt").Path); + instance.MyInt = ParseInt(value1, () => configuration.GetSection("MyInt").Path); + } + else if (defaultValueIfNotFound) + { + instance.MyInt = default; } if (AsConfigWithChildren(configuration.GetSection("MyList")) is IConfigurationSection section2) { - List? temp4 = obj.MyList; + List? temp4 = instance.MyList; temp4 ??= new List(); - BindCore(section2, ref temp4, binderOptions); - obj.MyList = temp4; + BindCore(section2, ref temp4, defaultValueIfNotFound: false, binderOptions); + instance.MyList = temp4; } if (AsConfigWithChildren(configuration.GetSection("MyDictionary")) is IConfigurationSection section5) { - Dictionary? temp7 = obj.MyDictionary; + Dictionary? temp7 = instance.MyDictionary; temp7 ??= new Dictionary(); - BindCore(section5, ref temp7, binderOptions); - obj.MyDictionary = temp7; + BindCore(section5, ref temp7, defaultValueIfNotFound: false, binderOptions); + instance.MyDictionary = temp7; } if (AsConfigWithChildren(configuration.GetSection("MyComplexDictionary")) is IConfigurationSection section8) { - Dictionary? temp10 = obj.MyComplexDictionary; + Dictionary? temp10 = instance.MyComplexDictionary; temp10 ??= new Dictionary(); - BindCore(section8, ref temp10, binderOptions); - obj.MyComplexDictionary = temp10; + BindCore(section8, ref temp10, defaultValueIfNotFound: false, binderOptions); + instance.MyComplexDictionary = temp10; } } + /// If required by the binder options, validates that there are no unknown keys in the input configuration object. public static void ValidateConfigurationKeys(Type type, Lazy> keys, IConfiguration configuration, BinderOptions? binderOptions) { @@ -157,7 +173,7 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration if (binderOptions.BindNonPublicProperties) { - throw new global::System.NotSupportedException($"The configuration binding source generator does not support 'BinderOptions.BindNonPublicProperties'."); + throw new NotSupportedException($"The configuration binding source generator does not support 'BinderOptions.BindNonPublicProperties'."); } return binderOptions; @@ -174,5 +190,6 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration throw new InvalidOperationException($"Failed to convert configuration value at '{getPath()}' to type '{typeof(int)}'.", exception); } } + #endregion Core binding extensions. } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Bind_Key_Instance.generated.txt b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Bind_Key_Instance.generated.txt index f3ee8a9ff43840..89b82d31bc19ae 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Bind_Key_Instance.generated.txt +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Bind_Key_Instance.generated.txt @@ -2,12 +2,19 @@ #nullable enable #pragma warning disable CS0612, CS0618 // Suppress warnings about [Obsolete] member usage in generated code. -/// Generated helper providing an AOT and linking compatible implementation for configuration binding. -[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] -internal static class GeneratedConfigurationBinder +namespace System.Runtime.CompilerServices { - /// Attempts to bind the given object instance to configuration values by matching property names against configuration keys recursively. - public static void Bind(this global::Microsoft.Extensions.Configuration.IConfiguration configuration, string key, global::Program.MyClass obj) => global::Microsoft.Extensions.Configuration.Binder.SourceGeneration.CoreBindingHelper.BindCore(configuration.GetSection(key), ref obj, binderOptions: null); + using System; + using System.CodeDom.Compiler; + + [GeneratedCode("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] + [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] + file sealed class InterceptsLocationAttribute : Attribute + { + public InterceptsLocationAttribute(string filePath, int line, int column) + { + } + } } namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration @@ -17,103 +24,112 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration using System.CodeDom.Compiler; using System.Collections.Generic; using System.Globalization; + using System.Runtime.CompilerServices; - /// Provide core binding logic. [GeneratedCode("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] - file static class CoreBindingHelper + file static class BindingExtensions { - private readonly static Lazy> s_configKeys_ProgramMyClass = new(() => new HashSet(StringComparer.OrdinalIgnoreCase) { "MyString", "MyInt", "MyList", "MyDictionary", "MyComplexDictionary" }); - - public static void BindCore(IConfiguration configuration, ref List obj, BinderOptions? binderOptions) + #region IConfiguration extensions. + /// Attempts to bind the given object instance to configuration values by matching property names against configuration keys recursively. + [InterceptsLocation(@"src-0.cs", 12, 20)] + public static void Bind_ProgramMyClass(this IConfiguration configuration, string key, object? instance) { - if (obj is null) + if (configuration is null) + { + throw new ArgumentNullException(nameof(configuration)); + } + + if (instance is null) { - throw new ArgumentNullException(nameof(obj)); + return; } + var typedObj = (Program.MyClass)instance; + BindCore(configuration.GetSection(key), ref typedObj, defaultValueIfNotFound: false, binderOptions: null); + } + #endregion IConfiguration extensions. + + #region Core binding extensions. + private readonly static Lazy> s_configKeys_ProgramMyClass = new(() => new HashSet(StringComparer.OrdinalIgnoreCase) { "MyString", "MyInt", "MyList", "MyDictionary", "MyComplexDictionary" }); + + public static void BindCore(IConfiguration configuration, ref List instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) + { foreach (IConfigurationSection section in configuration.GetChildren()) { if (section.Value is string value) { - obj.Add(ParseInt(value, () => section.Path)); + instance.Add(ParseInt(value, () => section.Path)); } } } - public static void BindCore(IConfiguration configuration, ref Dictionary obj, BinderOptions? binderOptions) + public static void BindCore(IConfiguration configuration, ref Dictionary instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { - if (obj is null) - { - throw new ArgumentNullException(nameof(obj)); - } - foreach (IConfigurationSection section in configuration.GetChildren()) { if (section.Value is string value) { - obj[section.Key] = value; + instance[section.Key] = value; } } } - public static void BindCore(IConfiguration configuration, ref Dictionary obj, BinderOptions? binderOptions) + public static void BindCore(IConfiguration configuration, ref Dictionary instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { - if (obj is null) - { - throw new ArgumentNullException(nameof(obj)); - } - foreach (IConfigurationSection section in configuration.GetChildren()) { - if (!(obj.TryGetValue(section.Key, out Program.MyClass2? element) && element is not null)) + if (!(instance.TryGetValue(section.Key, out Program.MyClass2? element) && element is not null)) { element = new Program.MyClass2(); } - obj[section.Key] = element; + instance[section.Key] = element; } } - public static void BindCore(IConfiguration configuration, ref Program.MyClass obj, BinderOptions? binderOptions) + public static void BindCore(IConfiguration configuration, ref Program.MyClass instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { - if (obj is null) - { - throw new ArgumentNullException(nameof(obj)); - } - ValidateConfigurationKeys(typeof(Program.MyClass), s_configKeys_ProgramMyClass, configuration, binderOptions); - obj.MyString = configuration["MyString"]!; + if (configuration["MyString"] is string value0) + { + instance.MyString = value0; + } if (configuration["MyInt"] is string value1) { - obj.MyInt = ParseInt(value1, () => configuration.GetSection("MyInt").Path); + instance.MyInt = ParseInt(value1, () => configuration.GetSection("MyInt").Path); + } + else if (defaultValueIfNotFound) + { + instance.MyInt = default; } if (AsConfigWithChildren(configuration.GetSection("MyList")) is IConfigurationSection section2) { - List? temp4 = obj.MyList; + List? temp4 = instance.MyList; temp4 ??= new List(); - BindCore(section2, ref temp4, binderOptions); - obj.MyList = temp4; + BindCore(section2, ref temp4, defaultValueIfNotFound: false, binderOptions); + instance.MyList = temp4; } if (AsConfigWithChildren(configuration.GetSection("MyDictionary")) is IConfigurationSection section5) { - Dictionary? temp7 = obj.MyDictionary; + Dictionary? temp7 = instance.MyDictionary; temp7 ??= new Dictionary(); - BindCore(section5, ref temp7, binderOptions); - obj.MyDictionary = temp7; + BindCore(section5, ref temp7, defaultValueIfNotFound: false, binderOptions); + instance.MyDictionary = temp7; } if (AsConfigWithChildren(configuration.GetSection("MyComplexDictionary")) is IConfigurationSection section8) { - Dictionary? temp10 = obj.MyComplexDictionary; + Dictionary? temp10 = instance.MyComplexDictionary; temp10 ??= new Dictionary(); - BindCore(section8, ref temp10, binderOptions); - obj.MyComplexDictionary = temp10; + BindCore(section8, ref temp10, defaultValueIfNotFound: false, binderOptions); + instance.MyComplexDictionary = temp10; } } + /// If required by the binder options, validates that there are no unknown keys in the input configuration object. public static void ValidateConfigurationKeys(Type type, Lazy> keys, IConfiguration configuration, BinderOptions? binderOptions) { @@ -156,5 +172,6 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration throw new InvalidOperationException($"Failed to convert configuration value at '{getPath()}' to type '{typeof(int)}'.", exception); } } + #endregion Core binding extensions. } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Bind_ParseTypeFromMethodParam.generated.txt b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Bind_ParseTypeFromMethodParam.generated.txt new file mode 100644 index 00000000000000..14753a4a1f8e4b --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Bind_ParseTypeFromMethodParam.generated.txt @@ -0,0 +1,72 @@ +// +#nullable enable +#pragma warning disable CS0612, CS0618 // Suppress warnings about [Obsolete] member usage in generated code. + +namespace System.Runtime.CompilerServices +{ + using System; + using System.CodeDom.Compiler; + + [GeneratedCode("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] + [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] + file sealed class InterceptsLocationAttribute : Attribute + { + public InterceptsLocationAttribute(string filePath, int line, int column) + { + } + } +} + +namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration +{ + using Microsoft.Extensions.Configuration; + using System; + using System.CodeDom.Compiler; + using System.Collections.Generic; + using System.Globalization; + using System.Runtime.CompilerServices; + + [GeneratedCode("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] + file static class BindingExtensions + { + #region IConfiguration extensions. + /// Attempts to bind the given object instance to configuration values by matching property names against configuration keys recursively. + [InterceptsLocation(@"src-0.cs", 18, 16)] + public static void Bind_ProgramMyClass0(this IConfiguration configuration, object? instance) + { + } + + /// Attempts to bind the given object instance to configuration values by matching property names against configuration keys recursively. + [InterceptsLocation(@"src-0.cs", 23, 16)] + public static void Bind_ProgramMyClass1(this IConfiguration configuration, object? instance, Action? configureOptions) + { + } + + /// Attempts to bind the given object instance to configuration values by matching property names against configuration keys recursively. + [InterceptsLocation(@"src-0.cs", 28, 16)] + public static void Bind_ProgramMyClass2(this IConfiguration configuration, string key, object? instance) + { + } + #endregion IConfiguration extensions. + + #region Core binding extensions. + public static BinderOptions? GetBinderOptions(Action? configureOptions) + { + if (configureOptions is null) + { + return null; + } + + BinderOptions binderOptions = new(); + configureOptions(binderOptions); + + if (binderOptions.BindNonPublicProperties) + { + throw new NotSupportedException($"The configuration binding source generator does not support 'BinderOptions.BindNonPublicProperties'."); + } + + return binderOptions; + } + #endregion Core binding extensions. + } +} diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Get.generated.txt b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Get.generated.txt index fb71c70b4dd3bd..b6fb659d544d42 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Get.generated.txt +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Get.generated.txt @@ -2,21 +2,19 @@ #nullable enable #pragma warning disable CS0612, CS0618 // Suppress warnings about [Obsolete] member usage in generated code. -/// Generated helper providing an AOT and linking compatible implementation for configuration binding. -[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] -internal static class GeneratedConfigurationBinder +namespace System.Runtime.CompilerServices { - /// Attempts to bind the configuration instance to a new instance of type T. - public static T? Get(this global::Microsoft.Extensions.Configuration.IConfiguration configuration) => (T?)(global::Microsoft.Extensions.Configuration.Binder.SourceGeneration.CoreBindingHelper.GetCore(configuration, typeof(T), configureOptions: null) ?? default(T)); - - /// Attempts to bind the configuration instance to a new instance of type T. - public static T? Get(this global::Microsoft.Extensions.Configuration.IConfiguration configuration, global::System.Action? configureOptions) => (T?)(global::Microsoft.Extensions.Configuration.Binder.SourceGeneration.CoreBindingHelper.GetCore(configuration, typeof(T), configureOptions) ?? default(T)); - - /// Attempts to bind the configuration instance to a new instance of type T. - public static object? Get(this global::Microsoft.Extensions.Configuration.IConfiguration configuration, global::System.Type type) => global::Microsoft.Extensions.Configuration.Binder.SourceGeneration.CoreBindingHelper.GetCore(configuration, type, configureOptions: null); + using System; + using System.CodeDom.Compiler; - /// Attempts to bind the configuration instance to a new instance of type T. - public static object? Get(this global::Microsoft.Extensions.Configuration.IConfiguration configuration, global::System.Type type, global::System.Action? configureOptions) => global::Microsoft.Extensions.Configuration.Binder.SourceGeneration.CoreBindingHelper.GetCore(configuration, type, configureOptions); + [GeneratedCode("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] + [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] + file sealed class InterceptsLocationAttribute : Attribute + { + public InterceptsLocationAttribute(string filePath, int line, int column) + { + } + } } namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration @@ -26,11 +24,30 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration using System.CodeDom.Compiler; using System.Collections.Generic; using System.Globalization; + using System.Runtime.CompilerServices; - /// Provide core binding logic. [GeneratedCode("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] - file static class CoreBindingHelper + file static class BindingExtensions { + #region IConfiguration extensions. + /// Attempts to bind the configuration instance to a new instance of type T. + [InterceptsLocation(@"src-0.cs", 12, 38)] + public static T? Get(this IConfiguration configuration) => (T?)(GetCore(configuration, typeof(T), configureOptions: null) ?? default(T)); + + /// Attempts to bind the configuration instance to a new instance of type T. + [InterceptsLocation(@"src-0.cs", 14, 36)] + public static T? Get(this IConfiguration configuration, Action? configureOptions) => (T?)(GetCore(configuration, typeof(T), configureOptions) ?? default(T)); + + /// Attempts to bind the configuration instance to a new instance of type T. + [InterceptsLocation(@"src-0.cs", 13, 56)] + public static object? Get(this IConfiguration configuration, Type type) => GetCore(configuration, type, configureOptions: null); + + /// Attempts to bind the configuration instance to a new instance of type T. + [InterceptsLocation(@"src-0.cs", 15, 47)] + public static object? Get(this IConfiguration configuration, Type type, Action? configureOptions) => GetCore(configuration, type, configureOptions); + #endregion IConfiguration extensions. + + #region Core binding extensions. private readonly static Lazy> s_configKeys_ProgramMyClass = new(() => new HashSet(StringComparer.OrdinalIgnoreCase) { "MyString", "MyInt", "MyList", "MyArray", "MyDictionary" }); private readonly static Lazy> s_configKeys_ProgramMyClass2 = new(() => new HashSet(StringComparer.OrdinalIgnoreCase) { "MyInt" }); @@ -50,123 +67,117 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration if (type == typeof(Program.MyClass)) { - var obj = new Program.MyClass(); - BindCore(configuration, ref obj, binderOptions); - return obj; + var instance = new Program.MyClass(); + BindCore(configuration, ref instance, defaultValueIfNotFound: true, binderOptions); + return instance; } - - if (type == typeof(Program.MyClass2)) + else if (type == typeof(Program.MyClass2)) { - var obj = new Program.MyClass2(); - BindCore(configuration, ref obj, binderOptions); - return obj; + var instance = new Program.MyClass2(); + BindCore(configuration, ref instance, defaultValueIfNotFound: true, binderOptions); + return instance; } - throw new global::System.NotSupportedException($"Unable to bind to type '{type}': generator did not detect the type as input."); + throw new NotSupportedException($"Unable to bind to type '{type}': generator did not detect the type as input."); } - public static void BindCore(IConfiguration configuration, ref List obj, BinderOptions? binderOptions) + public static void BindCore(IConfiguration configuration, ref List instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { - if (obj is null) - { - throw new ArgumentNullException(nameof(obj)); - } - foreach (IConfigurationSection section in configuration.GetChildren()) { if (section.Value is string value) { - obj.Add(ParseInt(value, () => section.Path)); + instance.Add(ParseInt(value, () => section.Path)); } } } - public static void BindCore(IConfiguration configuration, ref int[] obj, BinderOptions? binderOptions) + public static void BindCore(IConfiguration configuration, ref int[] instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { - if (obj is null) + var temp2 = new List(); + + foreach (IConfigurationSection section in configuration.GetChildren()) { - throw new ArgumentNullException(nameof(obj)); + if (section.Value is string value) + { + temp2.Add(ParseInt(value, () => section.Path)); + } } - var temp2 = new List(); - BindCore(configuration, ref temp2, binderOptions); - int originalCount = obj.Length; - Array.Resize(ref obj, originalCount + temp2.Count); - temp2.CopyTo(obj, originalCount); + int originalCount = instance.Length; + Array.Resize(ref instance, originalCount + temp2.Count); + temp2.CopyTo(instance, originalCount); } - public static void BindCore(IConfiguration configuration, ref Dictionary obj, BinderOptions? binderOptions) + public static void BindCore(IConfiguration configuration, ref Dictionary instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { - if (obj is null) - { - throw new ArgumentNullException(nameof(obj)); - } - foreach (IConfigurationSection section in configuration.GetChildren()) { if (section.Value is string value) { - obj[section.Key] = value; + instance[section.Key] = value; } } } - public static void BindCore(IConfiguration configuration, ref Program.MyClass obj, BinderOptions? binderOptions) + public static void BindCore(IConfiguration configuration, ref Program.MyClass instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { - if (obj is null) - { - throw new ArgumentNullException(nameof(obj)); - } - ValidateConfigurationKeys(typeof(Program.MyClass), s_configKeys_ProgramMyClass, configuration, binderOptions); - obj.MyString = configuration["MyString"]!; - - if (configuration["MyInt"] is string value5) + if (configuration["MyString"] is string value3) { - obj.MyInt = ParseInt(value5, () => configuration.GetSection("MyInt").Path); + instance.MyString = value3; } - if (AsConfigWithChildren(configuration.GetSection("MyList")) is IConfigurationSection section6) + if (configuration["MyInt"] is string value4) + { + instance.MyInt = ParseInt(value4, () => configuration.GetSection("MyInt").Path); + } + else if (defaultValueIfNotFound) { - List? temp8 = obj.MyList; - temp8 ??= new List(); - BindCore(section6, ref temp8, binderOptions); - obj.MyList = temp8; + instance.MyInt = default; } - if (AsConfigWithChildren(configuration.GetSection("MyArray")) is IConfigurationSection section9) + if (AsConfigWithChildren(configuration.GetSection("MyList")) is IConfigurationSection section5) { - int[]? temp11 = obj.MyArray; - temp11 ??= new int[0]; - BindCore(section9, ref temp11, binderOptions); - obj.MyArray = temp11; + List? temp7 = instance.MyList; + temp7 ??= new List(); + BindCore(section5, ref temp7, defaultValueIfNotFound: false, binderOptions); + instance.MyList = temp7; } - if (AsConfigWithChildren(configuration.GetSection("MyDictionary")) is IConfigurationSection section12) + if (AsConfigWithChildren(configuration.GetSection("MyArray")) is IConfigurationSection section8) { - Dictionary? temp14 = obj.MyDictionary; - temp14 ??= new Dictionary(); - BindCore(section12, ref temp14, binderOptions); - obj.MyDictionary = temp14; + int[]? temp10 = instance.MyArray; + temp10 ??= new int[0]; + BindCore(section8, ref temp10, defaultValueIfNotFound: false, binderOptions); + instance.MyArray = temp10; } - } - public static void BindCore(IConfiguration configuration, ref Program.MyClass2 obj, BinderOptions? binderOptions) - { - if (obj is null) + if (AsConfigWithChildren(configuration.GetSection("MyDictionary")) is IConfigurationSection section11) { - throw new ArgumentNullException(nameof(obj)); + Dictionary? temp13 = instance.MyDictionary; + temp13 ??= new Dictionary(); + BindCore(section11, ref temp13, defaultValueIfNotFound: false, binderOptions); + instance.MyDictionary = temp13; } + } + public static void BindCore(IConfiguration configuration, ref Program.MyClass2 instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) + { ValidateConfigurationKeys(typeof(Program.MyClass2), s_configKeys_ProgramMyClass2, configuration, binderOptions); - if (configuration["MyInt"] is string value15) + if (configuration["MyInt"] is string value14) { - obj.MyInt = ParseInt(value15, () => configuration.GetSection("MyInt").Path); + instance.MyInt = ParseInt(value14, () => configuration.GetSection("MyInt").Path); + } + else if (defaultValueIfNotFound) + { + instance.MyInt = default; } } + /// If required by the binder options, validates that there are no unknown keys in the input configuration object. public static void ValidateConfigurationKeys(Type type, Lazy> keys, IConfiguration configuration, BinderOptions? binderOptions) { @@ -219,7 +230,7 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration if (binderOptions.BindNonPublicProperties) { - throw new global::System.NotSupportedException($"The configuration binding source generator does not support 'BinderOptions.BindNonPublicProperties'."); + throw new NotSupportedException($"The configuration binding source generator does not support 'BinderOptions.BindNonPublicProperties'."); } return binderOptions; @@ -236,5 +247,6 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration throw new InvalidOperationException($"Failed to convert configuration value at '{getPath()}' to type '{typeof(int)}'.", exception); } } + #endregion Core binding extensions. } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/GetValue.generated.txt b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/GetValue.generated.txt index c9d49faa937244..bf7e64bd31f90c 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/GetValue.generated.txt +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/GetValue.generated.txt @@ -2,21 +2,19 @@ #nullable enable #pragma warning disable CS0612, CS0618 // Suppress warnings about [Obsolete] member usage in generated code. -/// Generated helper providing an AOT and linking compatible implementation for configuration binding. -[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] -internal static class GeneratedConfigurationBinder +namespace System.Runtime.CompilerServices { - /// Extracts the value with the specified key and converts it to the specified type. - public static T? GetValue(this global::Microsoft.Extensions.Configuration.IConfiguration configuration, string key) => (T?)(global::Microsoft.Extensions.Configuration.Binder.SourceGeneration.CoreBindingHelper.GetValueCore(configuration, typeof(T), key) ?? default(T)); - - /// Extracts the value with the specified key and converts it to the specified type. - public static T? GetValue(this global::Microsoft.Extensions.Configuration.IConfiguration configuration, string key, T defaultValue) => (T?)(global::Microsoft.Extensions.Configuration.Binder.SourceGeneration.CoreBindingHelper.GetValueCore(configuration, typeof(T), key) ?? defaultValue); - - /// Extracts the value with the specified key and converts it to the specified type. - public static object? GetValue(this global::Microsoft.Extensions.Configuration.IConfiguration configuration, global::System.Type type, string key) => global::Microsoft.Extensions.Configuration.Binder.SourceGeneration.CoreBindingHelper.GetValueCore(configuration, type, key); + using System; + using System.CodeDom.Compiler; - /// Extracts the value with the specified key and converts it to the specified type. - public static object? GetValue(this global::Microsoft.Extensions.Configuration.IConfiguration configuration, global::System.Type type, string key, object? defaultValue) => global::Microsoft.Extensions.Configuration.Binder.SourceGeneration.CoreBindingHelper.GetValueCore(configuration, type, key) ?? defaultValue; + [GeneratedCode("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] + [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] + file sealed class InterceptsLocationAttribute : Attribute + { + public InterceptsLocationAttribute(string filePath, int line, int column) + { + } + } } namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration @@ -25,11 +23,30 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration using System; using System.CodeDom.Compiler; using System.Globalization; + using System.Runtime.CompilerServices; - /// Provide core binding logic. [GeneratedCode("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] - file static class CoreBindingHelper + file static class BindingExtensions { + #region IConfiguration extensions. + /// Extracts the value with the specified key and converts it to the specified type. + [InterceptsLocation(@"src-0.cs", 13, 18)] + public static T? GetValue(this IConfiguration configuration, string key) => (T?)(BindingExtensions.GetValueCore(configuration, typeof(T), key) ?? default(T)); + + /// Extracts the value with the specified key and converts it to the specified type. + [InterceptsLocation(@"src-0.cs", 16, 24)] + public static T? GetValue(this IConfiguration configuration, string key, T defaultValue) => (T?)(BindingExtensions.GetValueCore(configuration, typeof(T), key) ?? defaultValue); + + /// Extracts the value with the specified key and converts it to the specified type. + [InterceptsLocation(@"src-0.cs", 14, 24)] + public static object? GetValue(this IConfiguration configuration, Type type, string key) => BindingExtensions.GetValueCore(configuration, type, key); + + /// Extracts the value with the specified key and converts it to the specified type. + [InterceptsLocation(@"src-0.cs", 17, 24)] + public static object? GetValue(this IConfiguration configuration, Type type, string key, object? defaultValue) => BindingExtensions.GetValueCore(configuration, type, key) ?? defaultValue; + #endregion IConfiguration extensions. + + #region Core binding extensions. public static object? GetValueCore(this IConfiguration configuration, Type type, string key) { if (configuration is null) @@ -48,18 +65,15 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration { return ParseInt(value, () => section.Path); } - - if (type == typeof(bool?)) + else if (type == typeof(bool?)) { return ParseBool(value, () => section.Path); } - - if (type == typeof(byte[])) + else if (type == typeof(byte[])) { return ParseByteArray(value, () => section.Path); } - - if (type == typeof(CultureInfo)) + else if (type == typeof(CultureInfo)) { return ParseCultureInfo(value, () => section.Path); } @@ -114,5 +128,6 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration throw new InvalidOperationException($"Failed to convert configuration value at '{getPath()}' to type '{typeof(CultureInfo)}'.", exception); } } + #endregion Core binding extensions. } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/GetValue_T_Key.generated.txt b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/GetValue_T_Key.generated.txt index 17c963bd980a70..b86915b78303b2 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/GetValue_T_Key.generated.txt +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/GetValue_T_Key.generated.txt @@ -2,12 +2,19 @@ #nullable enable #pragma warning disable CS0612, CS0618 // Suppress warnings about [Obsolete] member usage in generated code. -/// Generated helper providing an AOT and linking compatible implementation for configuration binding. -[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] -internal static class GeneratedConfigurationBinder +namespace System.Runtime.CompilerServices { - /// Extracts the value with the specified key and converts it to the specified type. - public static T? GetValue(this global::Microsoft.Extensions.Configuration.IConfiguration configuration, string key) => (T?)(global::Microsoft.Extensions.Configuration.Binder.SourceGeneration.CoreBindingHelper.GetValueCore(configuration, typeof(T), key) ?? default(T)); + using System; + using System.CodeDom.Compiler; + + [GeneratedCode("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] + [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] + file sealed class InterceptsLocationAttribute : Attribute + { + public InterceptsLocationAttribute(string filePath, int line, int column) + { + } + } } namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration @@ -16,11 +23,18 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration using System; using System.CodeDom.Compiler; using System.Globalization; + using System.Runtime.CompilerServices; - /// Provide core binding logic. [GeneratedCode("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] - file static class CoreBindingHelper + file static class BindingExtensions { + #region IConfiguration extensions. + /// Extracts the value with the specified key and converts it to the specified type. + [InterceptsLocation(@"src-0.cs", 10, 20)] + public static T? GetValue(this IConfiguration configuration, string key) => (T?)(BindingExtensions.GetValueCore(configuration, typeof(T), key) ?? default(T)); + #endregion IConfiguration extensions. + + #region Core binding extensions. public static object? GetValueCore(this IConfiguration configuration, Type type, string key) { if (configuration is null) @@ -54,5 +68,6 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration throw new InvalidOperationException($"Failed to convert configuration value at '{getPath()}' to type '{typeof(int)}'.", exception); } } + #endregion Core binding extensions. } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/GetValue_T_Key_DefaultValue.generated.txt b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/GetValue_T_Key_DefaultValue.generated.txt index 1148109b9f5a81..697f710dff3027 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/GetValue_T_Key_DefaultValue.generated.txt +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/GetValue_T_Key_DefaultValue.generated.txt @@ -2,12 +2,19 @@ #nullable enable #pragma warning disable CS0612, CS0618 // Suppress warnings about [Obsolete] member usage in generated code. -/// Generated helper providing an AOT and linking compatible implementation for configuration binding. -[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] -internal static class GeneratedConfigurationBinder +namespace System.Runtime.CompilerServices { - /// Extracts the value with the specified key and converts it to the specified type. - public static T? GetValue(this global::Microsoft.Extensions.Configuration.IConfiguration configuration, string key, T defaultValue) => (T?)(global::Microsoft.Extensions.Configuration.Binder.SourceGeneration.CoreBindingHelper.GetValueCore(configuration, typeof(T), key) ?? defaultValue); + using System; + using System.CodeDom.Compiler; + + [GeneratedCode("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] + [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] + file sealed class InterceptsLocationAttribute : Attribute + { + public InterceptsLocationAttribute(string filePath, int line, int column) + { + } + } } namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration @@ -16,11 +23,18 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration using System; using System.CodeDom.Compiler; using System.Globalization; + using System.Runtime.CompilerServices; - /// Provide core binding logic. [GeneratedCode("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] - file static class CoreBindingHelper + file static class BindingExtensions { + #region IConfiguration extensions. + /// Extracts the value with the specified key and converts it to the specified type. + [InterceptsLocation(@"src-0.cs", 12, 20)] + public static T? GetValue(this IConfiguration configuration, string key, T defaultValue) => (T?)(BindingExtensions.GetValueCore(configuration, typeof(T), key) ?? defaultValue); + #endregion IConfiguration extensions. + + #region Core binding extensions. public static object? GetValueCore(this IConfiguration configuration, Type type, string key) { if (configuration is null) @@ -54,5 +68,6 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration throw new InvalidOperationException($"Failed to convert configuration value at '{getPath()}' to type '{typeof(int)}'.", exception); } } + #endregion Core binding extensions. } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/GetValue_TypeOf_Key.generated.txt b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/GetValue_TypeOf_Key.generated.txt index c833b20f18dcfe..b5a22e71ccefbf 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/GetValue_TypeOf_Key.generated.txt +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/GetValue_TypeOf_Key.generated.txt @@ -2,12 +2,19 @@ #nullable enable #pragma warning disable CS0612, CS0618 // Suppress warnings about [Obsolete] member usage in generated code. -/// Generated helper providing an AOT and linking compatible implementation for configuration binding. -[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] -internal static class GeneratedConfigurationBinder +namespace System.Runtime.CompilerServices { - /// Extracts the value with the specified key and converts it to the specified type. - public static object? GetValue(this global::Microsoft.Extensions.Configuration.IConfiguration configuration, global::System.Type type, string key) => global::Microsoft.Extensions.Configuration.Binder.SourceGeneration.CoreBindingHelper.GetValueCore(configuration, type, key); + using System; + using System.CodeDom.Compiler; + + [GeneratedCode("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] + [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] + file sealed class InterceptsLocationAttribute : Attribute + { + public InterceptsLocationAttribute(string filePath, int line, int column) + { + } + } } namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration @@ -16,11 +23,18 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration using System; using System.CodeDom.Compiler; using System.Globalization; + using System.Runtime.CompilerServices; - /// Provide core binding logic. [GeneratedCode("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] - file static class CoreBindingHelper + file static class BindingExtensions { + #region IConfiguration extensions. + /// Extracts the value with the specified key and converts it to the specified type. + [InterceptsLocation(@"src-0.cs", 10, 20)] + public static object? GetValue(this IConfiguration configuration, Type type, string key) => BindingExtensions.GetValueCore(configuration, type, key); + #endregion IConfiguration extensions. + + #region Core binding extensions. public static object? GetValueCore(this IConfiguration configuration, Type type, string key) { if (configuration is null) @@ -54,5 +68,6 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration throw new InvalidOperationException($"Failed to convert configuration value at '{getPath()}' to type '{typeof(bool)}'.", exception); } } + #endregion Core binding extensions. } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/GetValue_TypeOf_Key_DefaultValue.generated.txt b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/GetValue_TypeOf_Key_DefaultValue.generated.txt index f773f79ce6c2c0..4a4564796e562f 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/GetValue_TypeOf_Key_DefaultValue.generated.txt +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/GetValue_TypeOf_Key_DefaultValue.generated.txt @@ -2,12 +2,19 @@ #nullable enable #pragma warning disable CS0612, CS0618 // Suppress warnings about [Obsolete] member usage in generated code. -/// Generated helper providing an AOT and linking compatible implementation for configuration binding. -[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] -internal static class GeneratedConfigurationBinder +namespace System.Runtime.CompilerServices { - /// Extracts the value with the specified key and converts it to the specified type. - public static object? GetValue(this global::Microsoft.Extensions.Configuration.IConfiguration configuration, global::System.Type type, string key, object? defaultValue) => global::Microsoft.Extensions.Configuration.Binder.SourceGeneration.CoreBindingHelper.GetValueCore(configuration, type, key) ?? defaultValue; + using System; + using System.CodeDom.Compiler; + + [GeneratedCode("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] + [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] + file sealed class InterceptsLocationAttribute : Attribute + { + public InterceptsLocationAttribute(string filePath, int line, int column) + { + } + } } namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration @@ -16,11 +23,18 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration using System; using System.CodeDom.Compiler; using System.Globalization; + using System.Runtime.CompilerServices; - /// Provide core binding logic. [GeneratedCode("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] - file static class CoreBindingHelper + file static class BindingExtensions { + #region IConfiguration extensions. + /// Extracts the value with the specified key and converts it to the specified type. + [InterceptsLocation(@"src-0.cs", 11, 20)] + public static object? GetValue(this IConfiguration configuration, Type type, string key, object? defaultValue) => BindingExtensions.GetValueCore(configuration, type, key) ?? defaultValue; + #endregion IConfiguration extensions. + + #region Core binding extensions. public static object? GetValueCore(this IConfiguration configuration, Type type, string key) { if (configuration is null) @@ -54,5 +68,6 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration throw new InvalidOperationException($"Failed to convert configuration value at '{getPath()}' to type '{typeof(CultureInfo)}'.", exception); } } + #endregion Core binding extensions. } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Get_PrimitivesOnly.generated.txt b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Get_PrimitivesOnly.generated.txt new file mode 100644 index 00000000000000..b703fb5f1c864b --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Get_PrimitivesOnly.generated.txt @@ -0,0 +1,182 @@ +// +#nullable enable +#pragma warning disable CS0612, CS0618 // Suppress warnings about [Obsolete] member usage in generated code. + +namespace System.Runtime.CompilerServices +{ + using System; + using System.CodeDom.Compiler; + + [GeneratedCode("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] + [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] + file sealed class InterceptsLocationAttribute : Attribute + { + public InterceptsLocationAttribute(string filePath, int line, int column) + { + } + } +} + +namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration +{ + using Microsoft.Extensions.Configuration; + using System; + using System.CodeDom.Compiler; + using System.Globalization; + using System.Runtime.CompilerServices; + + [GeneratedCode("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] + file static class BindingExtensions + { + #region IConfiguration extensions. + /// Attempts to bind the configuration instance to a new instance of type T. + [InterceptsLocation(@"src-0.cs", 10, 16)] + public static T? Get(this IConfiguration configuration) => (T?)(GetCore(configuration, typeof(T), configureOptions: null) ?? default(T)); + + /// Attempts to bind the configuration instance to a new instance of type T. + [InterceptsLocation(@"src-0.cs", 12, 16)] + public static T? Get(this IConfiguration configuration, Action? configureOptions) => (T?)(GetCore(configuration, typeof(T), configureOptions) ?? default(T)); + + /// Attempts to bind the configuration instance to a new instance of type T. + [InterceptsLocation(@"src-0.cs", 11, 16)] + public static object? Get(this IConfiguration configuration, Type type) => GetCore(configuration, type, configureOptions: null); + + /// Attempts to bind the configuration instance to a new instance of type T. + [InterceptsLocation(@"src-0.cs", 13, 16)] + public static object? Get(this IConfiguration configuration, Type type, Action? configureOptions) => GetCore(configuration, type, configureOptions); + #endregion IConfiguration extensions. + + #region Core binding extensions. + public static object? GetCore(this IConfiguration configuration, Type type, Action? configureOptions) + { + if (configuration is null) + { + throw new ArgumentNullException(nameof(configuration)); + } + + BinderOptions? binderOptions = GetBinderOptions(configureOptions); + + if (!HasValueOrChildren(configuration)) + { + return null; + } + + if (type == typeof(int)) + { + if (configuration is not IConfigurationSection section) + { + throw new InvalidOperationException(); + } + if (section.Value is string value) + { + return ParseInt(value, () => section.Path); + } + } + else if (type == typeof(string)) + { + if (configuration is not IConfigurationSection section) + { + throw new InvalidOperationException(); + } + return section.Value; + } + else if (type == typeof(float)) + { + if (configuration is not IConfigurationSection section) + { + throw new InvalidOperationException(); + } + if (section.Value is string value) + { + return ParseFloat(value, () => section.Path); + } + } + else if (type == typeof(double)) + { + if (configuration is not IConfigurationSection section) + { + throw new InvalidOperationException(); + } + if (section.Value is string value) + { + return ParseDouble(value, () => section.Path); + } + } + + throw new NotSupportedException($"Unable to bind to type '{type}': generator did not detect the type as input."); + } + + public static bool HasValueOrChildren(IConfiguration configuration) + { + if ((configuration as IConfigurationSection)?.Value is not null) + { + return true; + } + return AsConfigWithChildren(configuration) is not null; + } + + public static IConfiguration? AsConfigWithChildren(IConfiguration configuration) + { + foreach (IConfigurationSection _ in configuration.GetChildren()) + { + return configuration; + } + return null; + } + + public static BinderOptions? GetBinderOptions(Action? configureOptions) + { + if (configureOptions is null) + { + return null; + } + + BinderOptions binderOptions = new(); + configureOptions(binderOptions); + + if (binderOptions.BindNonPublicProperties) + { + throw new NotSupportedException($"The configuration binding source generator does not support 'BinderOptions.BindNonPublicProperties'."); + } + + return binderOptions; + } + + public static int ParseInt(string value, Func getPath) + { + try + { + return int.Parse(value, NumberStyles.Integer, CultureInfo.InvariantCulture); + } + catch (Exception exception) + { + throw new InvalidOperationException($"Failed to convert configuration value at '{getPath()}' to type '{typeof(int)}'.", exception); + } + } + + public static float ParseFloat(string value, Func getPath) + { + try + { + return float.Parse(value, NumberStyles.Float, CultureInfo.InvariantCulture); + } + catch (Exception exception) + { + throw new InvalidOperationException($"Failed to convert configuration value at '{getPath()}' to type '{typeof(float)}'.", exception); + } + } + + public static double ParseDouble(string value, Func getPath) + { + try + { + return double.Parse(value, NumberStyles.Float, CultureInfo.InvariantCulture); + } + catch (Exception exception) + { + throw new InvalidOperationException($"Failed to convert configuration value at '{getPath()}' to type '{typeof(double)}'.", exception); + } + } + #endregion Core binding extensions. + } +} diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Get_T.generated.txt b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Get_T.generated.txt index de8201fe6fed2c..c2e8f167bb4750 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Get_T.generated.txt +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Get_T.generated.txt @@ -2,12 +2,19 @@ #nullable enable #pragma warning disable CS0612, CS0618 // Suppress warnings about [Obsolete] member usage in generated code. -/// Generated helper providing an AOT and linking compatible implementation for configuration binding. -[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] -internal static class GeneratedConfigurationBinder +namespace System.Runtime.CompilerServices { - /// Attempts to bind the configuration instance to a new instance of type T. - public static T? Get(this global::Microsoft.Extensions.Configuration.IConfiguration configuration) => (T?)(global::Microsoft.Extensions.Configuration.Binder.SourceGeneration.CoreBindingHelper.GetCore(configuration, typeof(T), configureOptions: null) ?? default(T)); + using System; + using System.CodeDom.Compiler; + + [GeneratedCode("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] + [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] + file sealed class InterceptsLocationAttribute : Attribute + { + public InterceptsLocationAttribute(string filePath, int line, int column) + { + } + } } namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration @@ -17,11 +24,18 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration using System.CodeDom.Compiler; using System.Collections.Generic; using System.Globalization; + using System.Runtime.CompilerServices; - /// Provide core binding logic. [GeneratedCode("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] - file static class CoreBindingHelper + file static class BindingExtensions { + #region IConfiguration extensions. + /// Attempts to bind the configuration instance to a new instance of type T. + [InterceptsLocation(@"src-0.cs", 11, 40)] + public static T? Get(this IConfiguration configuration) => (T?)(GetCore(configuration, typeof(T), configureOptions: null) ?? default(T)); + #endregion IConfiguration extensions. + + #region Core binding extensions. private readonly static Lazy> s_configKeys_ProgramMyClass = new(() => new HashSet(StringComparer.OrdinalIgnoreCase) { "MyString", "MyInt", "MyList", "MyArray", "MyDictionary" }); public static object? GetCore(this IConfiguration configuration, Type type, Action? configureOptions) @@ -40,101 +54,97 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration if (type == typeof(Program.MyClass)) { - var obj = new Program.MyClass(); - BindCore(configuration, ref obj, binderOptions); - return obj; + var instance = new Program.MyClass(); + BindCore(configuration, ref instance, defaultValueIfNotFound: true, binderOptions); + return instance; } - throw new global::System.NotSupportedException($"Unable to bind to type '{type}': generator did not detect the type as input."); + throw new NotSupportedException($"Unable to bind to type '{type}': generator did not detect the type as input."); } - public static void BindCore(IConfiguration configuration, ref List obj, BinderOptions? binderOptions) + public static void BindCore(IConfiguration configuration, ref List instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { - if (obj is null) - { - throw new ArgumentNullException(nameof(obj)); - } - foreach (IConfigurationSection section in configuration.GetChildren()) { if (section.Value is string value) { - obj.Add(ParseInt(value, () => section.Path)); + instance.Add(ParseInt(value, () => section.Path)); } } } - public static void BindCore(IConfiguration configuration, ref int[] obj, BinderOptions? binderOptions) + public static void BindCore(IConfiguration configuration, ref int[] instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { - if (obj is null) + var temp1 = new List(); + + foreach (IConfigurationSection section in configuration.GetChildren()) { - throw new ArgumentNullException(nameof(obj)); + if (section.Value is string value) + { + temp1.Add(ParseInt(value, () => section.Path)); + } } - var temp1 = new List(); - BindCore(configuration, ref temp1, binderOptions); - int originalCount = obj.Length; - Array.Resize(ref obj, originalCount + temp1.Count); - temp1.CopyTo(obj, originalCount); + int originalCount = instance.Length; + Array.Resize(ref instance, originalCount + temp1.Count); + temp1.CopyTo(instance, originalCount); } - public static void BindCore(IConfiguration configuration, ref Dictionary obj, BinderOptions? binderOptions) + public static void BindCore(IConfiguration configuration, ref Dictionary instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { - if (obj is null) - { - throw new ArgumentNullException(nameof(obj)); - } - foreach (IConfigurationSection section in configuration.GetChildren()) { if (section.Value is string value) { - obj[section.Key] = value; + instance[section.Key] = value; } } } - public static void BindCore(IConfiguration configuration, ref Program.MyClass obj, BinderOptions? binderOptions) + public static void BindCore(IConfiguration configuration, ref Program.MyClass instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { - if (obj is null) - { - throw new ArgumentNullException(nameof(obj)); - } - ValidateConfigurationKeys(typeof(Program.MyClass), s_configKeys_ProgramMyClass, configuration, binderOptions); - obj.MyString = configuration["MyString"]!; + if (configuration["MyString"] is string value2) + { + instance.MyString = value2; + } - if (configuration["MyInt"] is string value4) + if (configuration["MyInt"] is string value3) { - obj.MyInt = ParseInt(value4, () => configuration.GetSection("MyInt").Path); + instance.MyInt = ParseInt(value3, () => configuration.GetSection("MyInt").Path); + } + else if (defaultValueIfNotFound) + { + instance.MyInt = default; } - if (AsConfigWithChildren(configuration.GetSection("MyList")) is IConfigurationSection section5) + if (AsConfigWithChildren(configuration.GetSection("MyList")) is IConfigurationSection section4) { - List? temp7 = obj.MyList; - temp7 ??= new List(); - BindCore(section5, ref temp7, binderOptions); - obj.MyList = temp7; + List? temp6 = instance.MyList; + temp6 ??= new List(); + BindCore(section4, ref temp6, defaultValueIfNotFound: false, binderOptions); + instance.MyList = temp6; } - if (AsConfigWithChildren(configuration.GetSection("MyArray")) is IConfigurationSection section8) + if (AsConfigWithChildren(configuration.GetSection("MyArray")) is IConfigurationSection section7) { - int[]? temp10 = obj.MyArray; - temp10 ??= new int[0]; - BindCore(section8, ref temp10, binderOptions); - obj.MyArray = temp10; + int[]? temp9 = instance.MyArray; + temp9 ??= new int[0]; + BindCore(section7, ref temp9, defaultValueIfNotFound: false, binderOptions); + instance.MyArray = temp9; } - if (AsConfigWithChildren(configuration.GetSection("MyDictionary")) is IConfigurationSection section11) + if (AsConfigWithChildren(configuration.GetSection("MyDictionary")) is IConfigurationSection section10) { - Dictionary? temp13 = obj.MyDictionary; - temp13 ??= new Dictionary(); - BindCore(section11, ref temp13, binderOptions); - obj.MyDictionary = temp13; + Dictionary? temp12 = instance.MyDictionary; + temp12 ??= new Dictionary(); + BindCore(section10, ref temp12, defaultValueIfNotFound: false, binderOptions); + instance.MyDictionary = temp12; } } + /// If required by the binder options, validates that there are no unknown keys in the input configuration object. public static void ValidateConfigurationKeys(Type type, Lazy> keys, IConfiguration configuration, BinderOptions? binderOptions) { @@ -187,7 +197,7 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration if (binderOptions.BindNonPublicProperties) { - throw new global::System.NotSupportedException($"The configuration binding source generator does not support 'BinderOptions.BindNonPublicProperties'."); + throw new NotSupportedException($"The configuration binding source generator does not support 'BinderOptions.BindNonPublicProperties'."); } return binderOptions; @@ -204,5 +214,6 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration throw new InvalidOperationException($"Failed to convert configuration value at '{getPath()}' to type '{typeof(int)}'.", exception); } } + #endregion Core binding extensions. } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Get_T_BinderOptions.generated.txt b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Get_T_BinderOptions.generated.txt index 34fadacace146d..cd3f237917d4e3 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Get_T_BinderOptions.generated.txt +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Get_T_BinderOptions.generated.txt @@ -2,12 +2,19 @@ #nullable enable #pragma warning disable CS0612, CS0618 // Suppress warnings about [Obsolete] member usage in generated code. -/// Generated helper providing an AOT and linking compatible implementation for configuration binding. -[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] -internal static class GeneratedConfigurationBinder +namespace System.Runtime.CompilerServices { - /// Attempts to bind the configuration instance to a new instance of type T. - public static T? Get(this global::Microsoft.Extensions.Configuration.IConfiguration configuration, global::System.Action? configureOptions) => (T?)(global::Microsoft.Extensions.Configuration.Binder.SourceGeneration.CoreBindingHelper.GetCore(configuration, typeof(T), configureOptions) ?? default(T)); + using System; + using System.CodeDom.Compiler; + + [GeneratedCode("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] + [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] + file sealed class InterceptsLocationAttribute : Attribute + { + public InterceptsLocationAttribute(string filePath, int line, int column) + { + } + } } namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration @@ -17,11 +24,18 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration using System.CodeDom.Compiler; using System.Collections.Generic; using System.Globalization; + using System.Runtime.CompilerServices; - /// Provide core binding logic. [GeneratedCode("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] - file static class CoreBindingHelper + file static class BindingExtensions { + #region IConfiguration extensions. + /// Attempts to bind the configuration instance to a new instance of type T. + [InterceptsLocation(@"src-0.cs", 11, 40)] + public static T? Get(this IConfiguration configuration, Action? configureOptions) => (T?)(GetCore(configuration, typeof(T), configureOptions) ?? default(T)); + #endregion IConfiguration extensions. + + #region Core binding extensions. private readonly static Lazy> s_configKeys_ProgramMyClass = new(() => new HashSet(StringComparer.OrdinalIgnoreCase) { "MyString", "MyInt", "MyList", "MyArray", "MyDictionary" }); public static object? GetCore(this IConfiguration configuration, Type type, Action? configureOptions) @@ -40,101 +54,97 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration if (type == typeof(Program.MyClass)) { - var obj = new Program.MyClass(); - BindCore(configuration, ref obj, binderOptions); - return obj; + var instance = new Program.MyClass(); + BindCore(configuration, ref instance, defaultValueIfNotFound: true, binderOptions); + return instance; } - throw new global::System.NotSupportedException($"Unable to bind to type '{type}': generator did not detect the type as input."); + throw new NotSupportedException($"Unable to bind to type '{type}': generator did not detect the type as input."); } - public static void BindCore(IConfiguration configuration, ref List obj, BinderOptions? binderOptions) + public static void BindCore(IConfiguration configuration, ref List instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { - if (obj is null) - { - throw new ArgumentNullException(nameof(obj)); - } - foreach (IConfigurationSection section in configuration.GetChildren()) { if (section.Value is string value) { - obj.Add(ParseInt(value, () => section.Path)); + instance.Add(ParseInt(value, () => section.Path)); } } } - public static void BindCore(IConfiguration configuration, ref int[] obj, BinderOptions? binderOptions) + public static void BindCore(IConfiguration configuration, ref int[] instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { - if (obj is null) + var temp1 = new List(); + + foreach (IConfigurationSection section in configuration.GetChildren()) { - throw new ArgumentNullException(nameof(obj)); + if (section.Value is string value) + { + temp1.Add(ParseInt(value, () => section.Path)); + } } - var temp1 = new List(); - BindCore(configuration, ref temp1, binderOptions); - int originalCount = obj.Length; - Array.Resize(ref obj, originalCount + temp1.Count); - temp1.CopyTo(obj, originalCount); + int originalCount = instance.Length; + Array.Resize(ref instance, originalCount + temp1.Count); + temp1.CopyTo(instance, originalCount); } - public static void BindCore(IConfiguration configuration, ref Dictionary obj, BinderOptions? binderOptions) + public static void BindCore(IConfiguration configuration, ref Dictionary instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { - if (obj is null) - { - throw new ArgumentNullException(nameof(obj)); - } - foreach (IConfigurationSection section in configuration.GetChildren()) { if (section.Value is string value) { - obj[section.Key] = value; + instance[section.Key] = value; } } } - public static void BindCore(IConfiguration configuration, ref Program.MyClass obj, BinderOptions? binderOptions) + public static void BindCore(IConfiguration configuration, ref Program.MyClass instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { - if (obj is null) - { - throw new ArgumentNullException(nameof(obj)); - } - ValidateConfigurationKeys(typeof(Program.MyClass), s_configKeys_ProgramMyClass, configuration, binderOptions); - obj.MyString = configuration["MyString"]!; + if (configuration["MyString"] is string value2) + { + instance.MyString = value2; + } - if (configuration["MyInt"] is string value4) + if (configuration["MyInt"] is string value3) { - obj.MyInt = ParseInt(value4, () => configuration.GetSection("MyInt").Path); + instance.MyInt = ParseInt(value3, () => configuration.GetSection("MyInt").Path); + } + else if (defaultValueIfNotFound) + { + instance.MyInt = default; } - if (AsConfigWithChildren(configuration.GetSection("MyList")) is IConfigurationSection section5) + if (AsConfigWithChildren(configuration.GetSection("MyList")) is IConfigurationSection section4) { - List? temp7 = obj.MyList; - temp7 ??= new List(); - BindCore(section5, ref temp7, binderOptions); - obj.MyList = temp7; + List? temp6 = instance.MyList; + temp6 ??= new List(); + BindCore(section4, ref temp6, defaultValueIfNotFound: false, binderOptions); + instance.MyList = temp6; } - if (AsConfigWithChildren(configuration.GetSection("MyArray")) is IConfigurationSection section8) + if (AsConfigWithChildren(configuration.GetSection("MyArray")) is IConfigurationSection section7) { - int[]? temp10 = obj.MyArray; - temp10 ??= new int[0]; - BindCore(section8, ref temp10, binderOptions); - obj.MyArray = temp10; + int[]? temp9 = instance.MyArray; + temp9 ??= new int[0]; + BindCore(section7, ref temp9, defaultValueIfNotFound: false, binderOptions); + instance.MyArray = temp9; } - if (AsConfigWithChildren(configuration.GetSection("MyDictionary")) is IConfigurationSection section11) + if (AsConfigWithChildren(configuration.GetSection("MyDictionary")) is IConfigurationSection section10) { - Dictionary? temp13 = obj.MyDictionary; - temp13 ??= new Dictionary(); - BindCore(section11, ref temp13, binderOptions); - obj.MyDictionary = temp13; + Dictionary? temp12 = instance.MyDictionary; + temp12 ??= new Dictionary(); + BindCore(section10, ref temp12, defaultValueIfNotFound: false, binderOptions); + instance.MyDictionary = temp12; } } + /// If required by the binder options, validates that there are no unknown keys in the input configuration object. public static void ValidateConfigurationKeys(Type type, Lazy> keys, IConfiguration configuration, BinderOptions? binderOptions) { @@ -187,7 +197,7 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration if (binderOptions.BindNonPublicProperties) { - throw new global::System.NotSupportedException($"The configuration binding source generator does not support 'BinderOptions.BindNonPublicProperties'."); + throw new NotSupportedException($"The configuration binding source generator does not support 'BinderOptions.BindNonPublicProperties'."); } return binderOptions; @@ -204,5 +214,6 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration throw new InvalidOperationException($"Failed to convert configuration value at '{getPath()}' to type '{typeof(int)}'.", exception); } } + #endregion Core binding extensions. } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Get_TypeOf.generated.txt b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Get_TypeOf.generated.txt index 16a98c931a705f..eca2001123ff01 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Get_TypeOf.generated.txt +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Get_TypeOf.generated.txt @@ -2,12 +2,19 @@ #nullable enable #pragma warning disable CS0612, CS0618 // Suppress warnings about [Obsolete] member usage in generated code. -/// Generated helper providing an AOT and linking compatible implementation for configuration binding. -[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] -internal static class GeneratedConfigurationBinder +namespace System.Runtime.CompilerServices { - /// Attempts to bind the configuration instance to a new instance of type T. - public static object? Get(this global::Microsoft.Extensions.Configuration.IConfiguration configuration, global::System.Type type) => global::Microsoft.Extensions.Configuration.Binder.SourceGeneration.CoreBindingHelper.GetCore(configuration, type, configureOptions: null); + using System; + using System.CodeDom.Compiler; + + [GeneratedCode("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] + [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] + file sealed class InterceptsLocationAttribute : Attribute + { + public InterceptsLocationAttribute(string filePath, int line, int column) + { + } + } } namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration @@ -17,11 +24,18 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration using System.CodeDom.Compiler; using System.Collections.Generic; using System.Globalization; + using System.Runtime.CompilerServices; - /// Provide core binding logic. [GeneratedCode("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] - file static class CoreBindingHelper + file static class BindingExtensions { + #region IConfiguration extensions. + /// Attempts to bind the configuration instance to a new instance of type T. + [InterceptsLocation(@"src-0.cs", 11, 51)] + public static object? Get(this IConfiguration configuration, Type type) => GetCore(configuration, type, configureOptions: null); + #endregion IConfiguration extensions. + + #region Core binding extensions. private readonly static Lazy> s_configKeys_ProgramMyClass2 = new(() => new HashSet(StringComparer.OrdinalIgnoreCase) { "MyInt" }); public static object? GetCore(this IConfiguration configuration, Type type, Action? configureOptions) @@ -40,29 +54,29 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration if (type == typeof(Program.MyClass2)) { - var obj = new Program.MyClass2(); - BindCore(configuration, ref obj, binderOptions); - return obj; + var instance = new Program.MyClass2(); + BindCore(configuration, ref instance, defaultValueIfNotFound: true, binderOptions); + return instance; } - throw new global::System.NotSupportedException($"Unable to bind to type '{type}': generator did not detect the type as input."); + throw new NotSupportedException($"Unable to bind to type '{type}': generator did not detect the type as input."); } - public static void BindCore(IConfiguration configuration, ref Program.MyClass2 obj, BinderOptions? binderOptions) + public static void BindCore(IConfiguration configuration, ref Program.MyClass2 instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { - if (obj is null) - { - throw new ArgumentNullException(nameof(obj)); - } - ValidateConfigurationKeys(typeof(Program.MyClass2), s_configKeys_ProgramMyClass2, configuration, binderOptions); if (configuration["MyInt"] is string value1) { - obj.MyInt = ParseInt(value1, () => configuration.GetSection("MyInt").Path); + instance.MyInt = ParseInt(value1, () => configuration.GetSection("MyInt").Path); + } + else if (defaultValueIfNotFound) + { + instance.MyInt = default; } } + /// If required by the binder options, validates that there are no unknown keys in the input configuration object. public static void ValidateConfigurationKeys(Type type, Lazy> keys, IConfiguration configuration, BinderOptions? binderOptions) { @@ -115,7 +129,7 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration if (binderOptions.BindNonPublicProperties) { - throw new global::System.NotSupportedException($"The configuration binding source generator does not support 'BinderOptions.BindNonPublicProperties'."); + throw new NotSupportedException($"The configuration binding source generator does not support 'BinderOptions.BindNonPublicProperties'."); } return binderOptions; @@ -132,5 +146,6 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration throw new InvalidOperationException($"Failed to convert configuration value at '{getPath()}' to type '{typeof(int)}'.", exception); } } + #endregion Core binding extensions. } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Get_TypeOf_BinderOptions.generated.txt b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Get_TypeOf_BinderOptions.generated.txt index 8d1ee9ed3cd9a3..7883f4a50da0f6 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Get_TypeOf_BinderOptions.generated.txt +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Get_TypeOf_BinderOptions.generated.txt @@ -2,12 +2,19 @@ #nullable enable #pragma warning disable CS0612, CS0618 // Suppress warnings about [Obsolete] member usage in generated code. -/// Generated helper providing an AOT and linking compatible implementation for configuration binding. -[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] -internal static class GeneratedConfigurationBinder +namespace System.Runtime.CompilerServices { - /// Attempts to bind the configuration instance to a new instance of type T. - public static object? Get(this global::Microsoft.Extensions.Configuration.IConfiguration configuration, global::System.Type type, global::System.Action? configureOptions) => global::Microsoft.Extensions.Configuration.Binder.SourceGeneration.CoreBindingHelper.GetCore(configuration, type, configureOptions); + using System; + using System.CodeDom.Compiler; + + [GeneratedCode("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] + [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] + file sealed class InterceptsLocationAttribute : Attribute + { + public InterceptsLocationAttribute(string filePath, int line, int column) + { + } + } } namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration @@ -17,11 +24,18 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration using System.CodeDom.Compiler; using System.Collections.Generic; using System.Globalization; + using System.Runtime.CompilerServices; - /// Provide core binding logic. [GeneratedCode("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] - file static class CoreBindingHelper + file static class BindingExtensions { + #region IConfiguration extensions. + /// Attempts to bind the configuration instance to a new instance of type T. + [InterceptsLocation(@"src-0.cs", 11, 20)] + public static object? Get(this IConfiguration configuration, Type type, Action? configureOptions) => GetCore(configuration, type, configureOptions); + #endregion IConfiguration extensions. + + #region Core binding extensions. private readonly static Lazy> s_configKeys_ProgramMyClass2 = new(() => new HashSet(StringComparer.OrdinalIgnoreCase) { "MyInt" }); public static object? GetCore(this IConfiguration configuration, Type type, Action? configureOptions) @@ -40,29 +54,29 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration if (type == typeof(Program.MyClass2)) { - var obj = new Program.MyClass2(); - BindCore(configuration, ref obj, binderOptions); - return obj; + var instance = new Program.MyClass2(); + BindCore(configuration, ref instance, defaultValueIfNotFound: true, binderOptions); + return instance; } - throw new global::System.NotSupportedException($"Unable to bind to type '{type}': generator did not detect the type as input."); + throw new NotSupportedException($"Unable to bind to type '{type}': generator did not detect the type as input."); } - public static void BindCore(IConfiguration configuration, ref Program.MyClass2 obj, BinderOptions? binderOptions) + public static void BindCore(IConfiguration configuration, ref Program.MyClass2 instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { - if (obj is null) - { - throw new ArgumentNullException(nameof(obj)); - } - ValidateConfigurationKeys(typeof(Program.MyClass2), s_configKeys_ProgramMyClass2, configuration, binderOptions); if (configuration["MyInt"] is string value1) { - obj.MyInt = ParseInt(value1, () => configuration.GetSection("MyInt").Path); + instance.MyInt = ParseInt(value1, () => configuration.GetSection("MyInt").Path); + } + else if (defaultValueIfNotFound) + { + instance.MyInt = default; } } + /// If required by the binder options, validates that there are no unknown keys in the input configuration object. public static void ValidateConfigurationKeys(Type type, Lazy> keys, IConfiguration configuration, BinderOptions? binderOptions) { @@ -115,7 +129,7 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration if (binderOptions.BindNonPublicProperties) { - throw new global::System.NotSupportedException($"The configuration binding source generator does not support 'BinderOptions.BindNonPublicProperties'."); + throw new NotSupportedException($"The configuration binding source generator does not support 'BinderOptions.BindNonPublicProperties'."); } return binderOptions; @@ -132,5 +146,6 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration throw new InvalidOperationException($"Failed to convert configuration value at '{getPath()}' to type '{typeof(int)}'.", exception); } } + #endregion Core binding extensions. } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/EmptyConfigType.generated.txt b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/EmptyConfigType.generated.txt new file mode 100644 index 00000000000000..404ce7561cc47c --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/EmptyConfigType.generated.txt @@ -0,0 +1,104 @@ +// +#nullable enable +#pragma warning disable CS0612, CS0618 // Suppress warnings about [Obsolete] member usage in generated code. + +namespace System.Runtime.CompilerServices +{ + using System; + using System.CodeDom.Compiler; + + [GeneratedCode("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] + [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] + file sealed class InterceptsLocationAttribute : Attribute + { + public InterceptsLocationAttribute(string filePath, int line, int column) + { + } + } +} + +namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration +{ + using Microsoft.Extensions.Configuration; + using System; + using System.CodeDom.Compiler; + using System.Collections.Generic; + using System.Globalization; + using System.Runtime.CompilerServices; + + [GeneratedCode("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] + file static class BindingExtensions + { + #region IConfiguration extensions. + /// Attempts to bind the given object instance to configuration values by matching property names against configuration keys recursively. + [InterceptsLocation(@"src-0.cs", 12, 23)] + public static void Bind_TypeWithNoMembers(this IConfiguration configuration, object? instance) + { + } + + /// Attempts to bind the given object instance to configuration values by matching property names against configuration keys recursively. + [InterceptsLocation(@"src-0.cs", 15, 23)] + public static void Bind_TypeWithNoMembers_Wrapper(this IConfiguration configuration, object? instance) + { + if (configuration is null) + { + throw new ArgumentNullException(nameof(configuration)); + } + + if (instance is null) + { + return; + } + + var typedObj = (TypeWithNoMembers_Wrapper)instance; + BindCore(configuration, ref typedObj, defaultValueIfNotFound: false, binderOptions: null); + } + #endregion IConfiguration extensions. + + #region Core binding extensions. + private readonly static Lazy> s_configKeys_TypeWithNoMembers_Wrapper = new(() => new HashSet(StringComparer.OrdinalIgnoreCase) { "Member" }); + + public static void BindCore(IConfiguration configuration, ref TypeWithNoMembers_Wrapper instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) + { + ValidateConfigurationKeys(typeof(TypeWithNoMembers_Wrapper), s_configKeys_TypeWithNoMembers_Wrapper, configuration, binderOptions); + + if (AsConfigWithChildren(configuration.GetSection("Member")) is IConfigurationSection section0) + { + instance.Member ??= new TypeWithNoMembers(); + } + } + + + /// If required by the binder options, validates that there are no unknown keys in the input configuration object. + public static void ValidateConfigurationKeys(Type type, Lazy> keys, IConfiguration configuration, BinderOptions? binderOptions) + { + if (binderOptions?.ErrorOnUnknownConfiguration is true) + { + List? temp = null; + + foreach (IConfigurationSection section in configuration.GetChildren()) + { + if (!keys.Value.Contains(section.Key)) + { + (temp ??= new List()).Add($"'{section.Key}'"); + } + } + + if (temp is not null) + { + throw new InvalidOperationException($"'ErrorOnUnknownConfiguration' was set on the provided BinderOptions, but the following properties were not found on the instance of {type}: {string.Join(", ", temp)}"); + } + } + } + + public static IConfiguration? AsConfigWithChildren(IConfiguration configuration) + { + foreach (IConfigurationSection _ in configuration.GetChildren()) + { + return configuration; + } + return null; + } + #endregion Core binding extensions. + } +} diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/OptionsBuilder/BindConfiguration.generated.txt b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/OptionsBuilder/BindConfiguration.generated.txt index a74dbfdd04b5b9..44f1df2e78232d 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/OptionsBuilder/BindConfiguration.generated.txt +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/OptionsBuilder/BindConfiguration.generated.txt @@ -2,31 +2,18 @@ #nullable enable #pragma warning disable CS0612, CS0618 // Suppress warnings about [Obsolete] member usage in generated code. -/// Generated helper providing an AOT and linking compatible implementation for configuration binding. -[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] -internal static class GeneratedOptionsBuilderBinder +namespace System.Runtime.CompilerServices { - /// Registers the dependency injection container to bind against the obtained from the DI service provider. - public static global::Microsoft.Extensions.Options.OptionsBuilder BindConfiguration(this global::Microsoft.Extensions.Options.OptionsBuilder optionsBuilder, string configSectionPath, global::System.Action? configureOptions = null) where TOptions : class - { - if (optionsBuilder is null) - { - throw new global::System.ArgumentNullException(nameof(optionsBuilder)); - } + using System; + using System.CodeDom.Compiler; - if (configSectionPath is null) + [GeneratedCode("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] + [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] + file sealed class InterceptsLocationAttribute : Attribute + { + public InterceptsLocationAttribute(string filePath, int line, int column) { - throw new global::System.ArgumentNullException(nameof(configSectionPath)); } - - optionsBuilder.Configure((obj, configuration) => - { - global::Microsoft.Extensions.Configuration.IConfiguration section = string.Equals(string.Empty, configSectionPath, global::System.StringComparison.OrdinalIgnoreCase) ? configuration : configuration.GetSection(configSectionPath); - global::Microsoft.Extensions.Configuration.Binder.SourceGeneration.CoreBindingHelper.BindCoreUntyped(section, obj, typeof(TOptions), configureOptions); - }); - - global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddSingleton, global::Microsoft.Extensions.Options.ConfigurationChangeTokenSource>(optionsBuilder.Services); - return optionsBuilder; } } @@ -34,82 +21,113 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration { using Microsoft.Extensions.Configuration; using Microsoft.Extensions.DependencyInjection; + using Microsoft.Extensions.Options; using System; using System.CodeDom.Compiler; using System.Collections.Generic; using System.Globalization; + using System.Runtime.CompilerServices; - /// Provide core binding logic. [GeneratedCode("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] - file static class CoreBindingHelper + file static class BindingExtensions { + #region OptionsBuilder extensions. + /// Registers the dependency injection container to bind against the obtained from the DI service provider. + [InterceptsLocation(@"src-0.cs", 12, 24)] + public static OptionsBuilder BindConfiguration(this OptionsBuilder optionsBuilder, string configSectionPath, Action? configureBinder = null) where TOptions : class + { + if (optionsBuilder is null) + { + throw new ArgumentNullException(nameof(optionsBuilder)); + } + + if (configSectionPath is null) + { + throw new ArgumentNullException(nameof(configSectionPath)); + } + + optionsBuilder.Configure((instance, config) => + { + if (config is null) + { + throw new ArgumentNullException(nameof(config)); + } + + IConfiguration section = string.Equals(string.Empty, configSectionPath, StringComparison.OrdinalIgnoreCase) ? config : config.GetSection(configSectionPath); + BindCoreMain(section, instance, typeof(TOptions), configureBinder); + }); + + optionsBuilder.Services.AddSingleton, ConfigurationChangeTokenSource>(); + return optionsBuilder; + } + #endregion OptionsBuilder extensions. + + #region Core binding extensions. private readonly static Lazy> s_configKeys_ProgramMyClass = new(() => new HashSet(StringComparer.OrdinalIgnoreCase) { "MyString", "MyInt", "MyList" }); - public static void BindCoreUntyped(this IConfiguration configuration, object obj, Type type, Action? configureOptions) + public static void BindCoreMain(IConfiguration configuration, object instance, Type type, Action? configureOptions) { - if (configuration is null) + if (instance is null) { - throw new ArgumentNullException(nameof(configuration)); + return; } - BinderOptions? binderOptions = GetBinderOptions(configureOptions); - if (!HasValueOrChildren(configuration)) { return; } + BinderOptions? binderOptions = GetBinderOptions(configureOptions); + if (type == typeof(Program.MyClass)) { - var temp = (Program.MyClass)obj; - BindCore(configuration, ref temp, binderOptions); + var temp = (Program.MyClass)instance; + BindCore(configuration, ref temp, defaultValueIfNotFound: false, binderOptions); return; } - throw new global::System.NotSupportedException($"Unable to bind to type '{type}': generator did not detect the type as input."); + throw new NotSupportedException($"Unable to bind to type '{type}': generator did not detect the type as input."); } - public static void BindCore(IConfiguration configuration, ref List obj, BinderOptions? binderOptions) + public static void BindCore(IConfiguration configuration, ref List instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { - if (obj is null) - { - throw new ArgumentNullException(nameof(obj)); - } - foreach (IConfigurationSection section in configuration.GetChildren()) { if (section.Value is string value) { - obj.Add(ParseInt(value, () => section.Path)); + instance.Add(ParseInt(value, () => section.Path)); } } } - public static void BindCore(IConfiguration configuration, ref Program.MyClass obj, BinderOptions? binderOptions) + public static void BindCore(IConfiguration configuration, ref Program.MyClass instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { - if (obj is null) - { - throw new ArgumentNullException(nameof(obj)); - } - ValidateConfigurationKeys(typeof(Program.MyClass), s_configKeys_ProgramMyClass, configuration, binderOptions); - obj.MyString = configuration["MyString"]!; + if (configuration["MyString"] is string value1) + { + instance.MyString = value1; + } if (configuration["MyInt"] is string value2) { - obj.MyInt = ParseInt(value2, () => configuration.GetSection("MyInt").Path); + instance.MyInt = ParseInt(value2, () => configuration.GetSection("MyInt").Path); + } + else if (defaultValueIfNotFound) + { + instance.MyInt = default; } if (AsConfigWithChildren(configuration.GetSection("MyList")) is IConfigurationSection section3) { - List? temp5 = obj.MyList; + List? temp5 = instance.MyList; temp5 ??= new List(); - BindCore(section3, ref temp5, binderOptions); - obj.MyList = temp5; + BindCore(section3, ref temp5, defaultValueIfNotFound: false, binderOptions); + instance.MyList = temp5; } } + /// If required by the binder options, validates that there are no unknown keys in the input configuration object. public static void ValidateConfigurationKeys(Type type, Lazy> keys, IConfiguration configuration, BinderOptions? binderOptions) { @@ -162,7 +180,7 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration if (binderOptions.BindNonPublicProperties) { - throw new global::System.NotSupportedException($"The configuration binding source generator does not support 'BinderOptions.BindNonPublicProperties'."); + throw new NotSupportedException($"The configuration binding source generator does not support 'BinderOptions.BindNonPublicProperties'."); } return binderOptions; @@ -179,5 +197,6 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration throw new InvalidOperationException($"Failed to convert configuration value at '{getPath()}' to type '{typeof(int)}'.", exception); } } + #endregion Core binding extensions. } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/OptionsBuilder/Bind_T.generated.txt b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/OptionsBuilder/Bind_T.generated.txt index ac53b58f24da23..8d7c70a27b3f2e 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/OptionsBuilder/Bind_T.generated.txt +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/OptionsBuilder/Bind_T.generated.txt @@ -2,49 +2,18 @@ #nullable enable #pragma warning disable CS0612, CS0618 // Suppress warnings about [Obsolete] member usage in generated code. -/// Generated helper providing an AOT and linking compatible implementation for configuration binding. -[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] -internal static class GeneratedOptionsBuilderBinder +namespace System.Runtime.CompilerServices { - /// Registers a configuration instance which will bind against. - public static global::Microsoft.Extensions.Options.OptionsBuilder Bind(this global::Microsoft.Extensions.Options.OptionsBuilder optionsBuilder, global::Microsoft.Extensions.Configuration.IConfiguration configuration) where TOptions : class - { - return global::GeneratedOptionsBuilderBinder.Bind(optionsBuilder, configuration, configureOptions: null); - } - - /// Registers a configuration instance which will bind against. - public static global::Microsoft.Extensions.Options.OptionsBuilder Bind(this global::Microsoft.Extensions.Options.OptionsBuilder optionsBuilder, global::Microsoft.Extensions.Configuration.IConfiguration configuration, global::System.Action? configureOptions) where TOptions : class - { - if (optionsBuilder is null) - { - throw new global::System.ArgumentNullException(nameof(optionsBuilder)); - } - - global::GeneratedServiceCollectionBinder.Configure(optionsBuilder.Services, optionsBuilder.Name, configuration, configureOptions); - return optionsBuilder; - } -} + using System; + using System.CodeDom.Compiler; -/// Generated helper providing an AOT and linking compatible implementation for configuration binding. -[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] -internal static class GeneratedServiceCollectionBinder -{ - /// Registers a configuration instance which TOptions will bind against. - public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection Configure(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, string? name, global::Microsoft.Extensions.Configuration.IConfiguration configuration, global::System.Action? configureOptions) where TOptions : class + [GeneratedCode("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] + [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] + file sealed class InterceptsLocationAttribute : Attribute { - if (services is null) + public InterceptsLocationAttribute(string filePath, int line, int column) { - throw new global::System.ArgumentNullException(nameof(services)); } - - if (configuration is null) - { - throw new global::System.ArgumentNullException(nameof(configuration)); - } - - global::Microsoft.Extensions.DependencyInjection.OptionsServiceCollectionExtensions.AddOptions(services); - global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddSingleton>(services, new global::Microsoft.Extensions.Options.ConfigurationChangeTokenSource(name, configuration)); - return global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddSingleton>(services, new global::Microsoft.Extensions.Options.ConfigureNamedOptions(name, obj => global::Microsoft.Extensions.Configuration.Binder.SourceGeneration.CoreBindingHelper.BindCoreUntyped(configuration, obj, typeof(TOptions), configureOptions))); } } @@ -52,82 +21,123 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration { using Microsoft.Extensions.Configuration; using Microsoft.Extensions.DependencyInjection; + using Microsoft.Extensions.Options; using System; using System.CodeDom.Compiler; using System.Collections.Generic; using System.Globalization; + using System.Runtime.CompilerServices; - /// Provide core binding logic. [GeneratedCode("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] - file static class CoreBindingHelper + file static class BindingExtensions { - private readonly static Lazy> s_configKeys_ProgramMyClass = new(() => new HashSet(StringComparer.OrdinalIgnoreCase) { "MyString", "MyInt", "MyList" }); + #region OptionsBuilder extensions. + /// Registers a configuration instance which will bind against. + [InterceptsLocation(@"src-0.cs", 15, 24)] + public static OptionsBuilder Bind(this OptionsBuilder optionsBuilder, IConfiguration config) where TOptions : class + { + return Bind(optionsBuilder, config, configureBinder: null); + } - public static void BindCoreUntyped(this IConfiguration configuration, object obj, Type type, Action? configureOptions) + /// Registers a configuration instance which will bind against. + public static OptionsBuilder Bind(this OptionsBuilder optionsBuilder, IConfiguration config, Action? configureBinder) where TOptions : class { - if (configuration is null) + if (optionsBuilder is null) { - throw new ArgumentNullException(nameof(configuration)); + throw new ArgumentNullException(nameof(optionsBuilder)); } - BinderOptions? binderOptions = GetBinderOptions(configureOptions); + Configure(optionsBuilder.Services, optionsBuilder.Name, config, configureBinder); + return optionsBuilder; + } + #endregion OptionsBuilder extensions. + + #region IServiceCollection extensions. + /// Registers a configuration instance which TOptions will bind against. + public static IServiceCollection Configure(this IServiceCollection services, string? name, IConfiguration config, Action? configureOptions) where TOptions : class + { + if (services is null) + { + throw new ArgumentNullException(nameof(services)); + } + + if (config is null) + { + throw new ArgumentNullException(nameof(config)); + } + + OptionsServiceCollectionExtensions.AddOptions(services); + services.AddSingleton>(new ConfigurationChangeTokenSource(name, config)); + return services.AddSingleton>(new ConfigureNamedOptions(name, instance => BindCoreMain(config, instance, typeof(TOptions), configureOptions))); + } + #endregion IServiceCollection extensions. + + #region Core binding extensions. + private readonly static Lazy> s_configKeys_ProgramMyClass = new(() => new HashSet(StringComparer.OrdinalIgnoreCase) { "MyString", "MyInt", "MyList" }); + + public static void BindCoreMain(IConfiguration configuration, object instance, Type type, Action? configureOptions) + { + if (instance is null) + { + return; + } if (!HasValueOrChildren(configuration)) { return; } + BinderOptions? binderOptions = GetBinderOptions(configureOptions); + if (type == typeof(Program.MyClass)) { - var temp = (Program.MyClass)obj; - BindCore(configuration, ref temp, binderOptions); + var temp = (Program.MyClass)instance; + BindCore(configuration, ref temp, defaultValueIfNotFound: false, binderOptions); return; } - throw new global::System.NotSupportedException($"Unable to bind to type '{type}': generator did not detect the type as input."); + throw new NotSupportedException($"Unable to bind to type '{type}': generator did not detect the type as input."); } - public static void BindCore(IConfiguration configuration, ref List obj, BinderOptions? binderOptions) + public static void BindCore(IConfiguration configuration, ref List instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { - if (obj is null) - { - throw new ArgumentNullException(nameof(obj)); - } - foreach (IConfigurationSection section in configuration.GetChildren()) { if (section.Value is string value) { - obj.Add(ParseInt(value, () => section.Path)); + instance.Add(ParseInt(value, () => section.Path)); } } } - public static void BindCore(IConfiguration configuration, ref Program.MyClass obj, BinderOptions? binderOptions) + public static void BindCore(IConfiguration configuration, ref Program.MyClass instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { - if (obj is null) - { - throw new ArgumentNullException(nameof(obj)); - } - ValidateConfigurationKeys(typeof(Program.MyClass), s_configKeys_ProgramMyClass, configuration, binderOptions); - obj.MyString = configuration["MyString"]!; + if (configuration["MyString"] is string value1) + { + instance.MyString = value1; + } if (configuration["MyInt"] is string value2) { - obj.MyInt = ParseInt(value2, () => configuration.GetSection("MyInt").Path); + instance.MyInt = ParseInt(value2, () => configuration.GetSection("MyInt").Path); + } + else if (defaultValueIfNotFound) + { + instance.MyInt = default; } if (AsConfigWithChildren(configuration.GetSection("MyList")) is IConfigurationSection section3) { - List? temp5 = obj.MyList; + List? temp5 = instance.MyList; temp5 ??= new List(); - BindCore(section3, ref temp5, binderOptions); - obj.MyList = temp5; + BindCore(section3, ref temp5, defaultValueIfNotFound: false, binderOptions); + instance.MyList = temp5; } } + /// If required by the binder options, validates that there are no unknown keys in the input configuration object. public static void ValidateConfigurationKeys(Type type, Lazy> keys, IConfiguration configuration, BinderOptions? binderOptions) { @@ -180,7 +190,7 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration if (binderOptions.BindNonPublicProperties) { - throw new global::System.NotSupportedException($"The configuration binding source generator does not support 'BinderOptions.BindNonPublicProperties'."); + throw new NotSupportedException($"The configuration binding source generator does not support 'BinderOptions.BindNonPublicProperties'."); } return binderOptions; @@ -197,5 +207,6 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration throw new InvalidOperationException($"Failed to convert configuration value at '{getPath()}' to type '{typeof(int)}'.", exception); } } + #endregion Core binding extensions. } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/OptionsBuilder/Bind_T_BinderOptions.generated.txt b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/OptionsBuilder/Bind_T_BinderOptions.generated.txt index fd3ec70a8a328b..385079af1709a4 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/OptionsBuilder/Bind_T_BinderOptions.generated.txt +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/OptionsBuilder/Bind_T_BinderOptions.generated.txt @@ -2,43 +2,18 @@ #nullable enable #pragma warning disable CS0612, CS0618 // Suppress warnings about [Obsolete] member usage in generated code. -/// Generated helper providing an AOT and linking compatible implementation for configuration binding. -[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] -internal static class GeneratedOptionsBuilderBinder +namespace System.Runtime.CompilerServices { - /// Registers a configuration instance which will bind against. - public static global::Microsoft.Extensions.Options.OptionsBuilder Bind(this global::Microsoft.Extensions.Options.OptionsBuilder optionsBuilder, global::Microsoft.Extensions.Configuration.IConfiguration configuration, global::System.Action? configureOptions) where TOptions : class - { - if (optionsBuilder is null) - { - throw new global::System.ArgumentNullException(nameof(optionsBuilder)); - } - - global::GeneratedServiceCollectionBinder.Configure(optionsBuilder.Services, optionsBuilder.Name, configuration, configureOptions); - return optionsBuilder; - } -} + using System; + using System.CodeDom.Compiler; -/// Generated helper providing an AOT and linking compatible implementation for configuration binding. -[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] -internal static class GeneratedServiceCollectionBinder -{ - /// Registers a configuration instance which TOptions will bind against. - public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection Configure(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, string? name, global::Microsoft.Extensions.Configuration.IConfiguration configuration, global::System.Action? configureOptions) where TOptions : class + [GeneratedCode("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] + [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] + file sealed class InterceptsLocationAttribute : Attribute { - if (services is null) + public InterceptsLocationAttribute(string filePath, int line, int column) { - throw new global::System.ArgumentNullException(nameof(services)); } - - if (configuration is null) - { - throw new global::System.ArgumentNullException(nameof(configuration)); - } - - global::Microsoft.Extensions.DependencyInjection.OptionsServiceCollectionExtensions.AddOptions(services); - global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddSingleton>(services, new global::Microsoft.Extensions.Options.ConfigurationChangeTokenSource(name, configuration)); - return global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddSingleton>(services, new global::Microsoft.Extensions.Options.ConfigureNamedOptions(name, obj => global::Microsoft.Extensions.Configuration.Binder.SourceGeneration.CoreBindingHelper.BindCoreUntyped(configuration, obj, typeof(TOptions), configureOptions))); } } @@ -46,82 +21,117 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration { using Microsoft.Extensions.Configuration; using Microsoft.Extensions.DependencyInjection; + using Microsoft.Extensions.Options; using System; using System.CodeDom.Compiler; using System.Collections.Generic; using System.Globalization; + using System.Runtime.CompilerServices; - /// Provide core binding logic. [GeneratedCode("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] - file static class CoreBindingHelper + file static class BindingExtensions { - private readonly static Lazy> s_configKeys_ProgramMyClass = new(() => new HashSet(StringComparer.OrdinalIgnoreCase) { "MyString", "MyInt", "MyList" }); + #region OptionsBuilder extensions. + /// Registers a configuration instance which will bind against. + [InterceptsLocation(@"src-0.cs", 15, 24)] + public static OptionsBuilder Bind(this OptionsBuilder optionsBuilder, IConfiguration config, Action? configureBinder) where TOptions : class + { + if (optionsBuilder is null) + { + throw new ArgumentNullException(nameof(optionsBuilder)); + } - public static void BindCoreUntyped(this IConfiguration configuration, object obj, Type type, Action? configureOptions) + Configure(optionsBuilder.Services, optionsBuilder.Name, config, configureBinder); + return optionsBuilder; + } + #endregion OptionsBuilder extensions. + + #region IServiceCollection extensions. + /// Registers a configuration instance which TOptions will bind against. + public static IServiceCollection Configure(this IServiceCollection services, string? name, IConfiguration config, Action? configureOptions) where TOptions : class { - if (configuration is null) + if (services is null) { - throw new ArgumentNullException(nameof(configuration)); + throw new ArgumentNullException(nameof(services)); } - BinderOptions? binderOptions = GetBinderOptions(configureOptions); + if (config is null) + { + throw new ArgumentNullException(nameof(config)); + } + + OptionsServiceCollectionExtensions.AddOptions(services); + services.AddSingleton>(new ConfigurationChangeTokenSource(name, config)); + return services.AddSingleton>(new ConfigureNamedOptions(name, instance => BindCoreMain(config, instance, typeof(TOptions), configureOptions))); + } + #endregion IServiceCollection extensions. + + #region Core binding extensions. + private readonly static Lazy> s_configKeys_ProgramMyClass = new(() => new HashSet(StringComparer.OrdinalIgnoreCase) { "MyString", "MyInt", "MyList" }); + + public static void BindCoreMain(IConfiguration configuration, object instance, Type type, Action? configureOptions) + { + if (instance is null) + { + return; + } if (!HasValueOrChildren(configuration)) { return; } + BinderOptions? binderOptions = GetBinderOptions(configureOptions); + if (type == typeof(Program.MyClass)) { - var temp = (Program.MyClass)obj; - BindCore(configuration, ref temp, binderOptions); + var temp = (Program.MyClass)instance; + BindCore(configuration, ref temp, defaultValueIfNotFound: false, binderOptions); return; } - throw new global::System.NotSupportedException($"Unable to bind to type '{type}': generator did not detect the type as input."); + throw new NotSupportedException($"Unable to bind to type '{type}': generator did not detect the type as input."); } - public static void BindCore(IConfiguration configuration, ref List obj, BinderOptions? binderOptions) + public static void BindCore(IConfiguration configuration, ref List instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { - if (obj is null) - { - throw new ArgumentNullException(nameof(obj)); - } - foreach (IConfigurationSection section in configuration.GetChildren()) { if (section.Value is string value) { - obj.Add(ParseInt(value, () => section.Path)); + instance.Add(ParseInt(value, () => section.Path)); } } } - public static void BindCore(IConfiguration configuration, ref Program.MyClass obj, BinderOptions? binderOptions) + public static void BindCore(IConfiguration configuration, ref Program.MyClass instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { - if (obj is null) - { - throw new ArgumentNullException(nameof(obj)); - } - ValidateConfigurationKeys(typeof(Program.MyClass), s_configKeys_ProgramMyClass, configuration, binderOptions); - obj.MyString = configuration["MyString"]!; + if (configuration["MyString"] is string value1) + { + instance.MyString = value1; + } if (configuration["MyInt"] is string value2) { - obj.MyInt = ParseInt(value2, () => configuration.GetSection("MyInt").Path); + instance.MyInt = ParseInt(value2, () => configuration.GetSection("MyInt").Path); + } + else if (defaultValueIfNotFound) + { + instance.MyInt = default; } if (AsConfigWithChildren(configuration.GetSection("MyList")) is IConfigurationSection section3) { - List? temp5 = obj.MyList; + List? temp5 = instance.MyList; temp5 ??= new List(); - BindCore(section3, ref temp5, binderOptions); - obj.MyList = temp5; + BindCore(section3, ref temp5, defaultValueIfNotFound: false, binderOptions); + instance.MyList = temp5; } } + /// If required by the binder options, validates that there are no unknown keys in the input configuration object. public static void ValidateConfigurationKeys(Type type, Lazy> keys, IConfiguration configuration, BinderOptions? binderOptions) { @@ -174,7 +184,7 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration if (binderOptions.BindNonPublicProperties) { - throw new global::System.NotSupportedException($"The configuration binding source generator does not support 'BinderOptions.BindNonPublicProperties'."); + throw new NotSupportedException($"The configuration binding source generator does not support 'BinderOptions.BindNonPublicProperties'."); } return binderOptions; @@ -191,5 +201,6 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration throw new InvalidOperationException($"Failed to convert configuration value at '{getPath()}' to type '{typeof(int)}'.", exception); } } + #endregion Core binding extensions. } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/Primitives.generated.txt b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/Primitives.generated.txt index 5d30288a21e785..a8373ca527095a 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/Primitives.generated.txt +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/Primitives.generated.txt @@ -2,12 +2,19 @@ #nullable enable #pragma warning disable CS0612, CS0618 // Suppress warnings about [Obsolete] member usage in generated code. -/// Generated helper providing an AOT and linking compatible implementation for configuration binding. -[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] -internal static class GeneratedConfigurationBinder +namespace System.Runtime.CompilerServices { - /// Attempts to bind the given object instance to configuration values by matching property names against configuration keys recursively. - public static void Bind(this global::Microsoft.Extensions.Configuration.IConfiguration configuration, global::Program.MyClass obj) => global::Microsoft.Extensions.Configuration.Binder.SourceGeneration.CoreBindingHelper.BindCore(configuration, ref obj, binderOptions: null); + using System; + using System.CodeDom.Compiler; + + [GeneratedCode("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] + [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] + file sealed class InterceptsLocationAttribute : Attribute + { + public InterceptsLocationAttribute(string filePath, int line, int column) + { + } + } } namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration @@ -17,157 +24,295 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration using System.CodeDom.Compiler; using System.Collections.Generic; using System.Globalization; + using System.Runtime.CompilerServices; - /// Provide core binding logic. [GeneratedCode("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] - file static class CoreBindingHelper + file static class BindingExtensions { - private readonly static Lazy> s_configKeys_ProgramMyClass = new(() => new HashSet(StringComparer.OrdinalIgnoreCase) { "Prop0", "Prop1", "Prop2", "Prop3", "Prop4", "Prop5", "Prop6", "Prop8", "Prop9", "Prop10", "Prop13", "Prop14", "Prop15", "Prop16", "Prop17", "Prop19", "Prop20", "Prop21", "Prop23", "Prop24", "Prop25", "Prop26", "Prop27", "Prop7", "Prop11", "Prop12", "Prop18", "Prop22" }); - - public static void BindCore(IConfiguration configuration, ref Program.MyClass obj, BinderOptions? binderOptions) + #region IConfiguration extensions. + /// Attempts to bind the given object instance to configuration values by matching property names against configuration keys recursively. + [InterceptsLocation(@"src-0.cs", 13, 16)] + public static void Bind_ProgramMyClass(this IConfiguration configuration, object? instance) { - if (obj is null) + if (configuration is null) + { + throw new ArgumentNullException(nameof(configuration)); + } + + if (instance is null) { - throw new ArgumentNullException(nameof(obj)); + return; } + var typedObj = (Program.MyClass)instance; + BindCore(configuration, ref typedObj, defaultValueIfNotFound: false, binderOptions: null); + } + #endregion IConfiguration extensions. + + #region Core binding extensions. + private readonly static Lazy> s_configKeys_ProgramMyClass = new(() => new HashSet(StringComparer.OrdinalIgnoreCase) { "Prop0", "Prop1", "Prop2", "Prop3", "Prop4", "Prop5", "Prop6", "Prop8", "Prop9", "Prop10", "Prop13", "Prop14", "Prop15", "Prop16", "Prop17", "Prop19", "Prop20", "Prop21", "Prop23", "Prop24", "Prop25", "Prop26", "Prop27", "Prop7", "Prop11", "Prop12", "Prop18", "Prop22", "Prop28", "Prop29", "Prop30" }); + + public static void BindCore(IConfiguration configuration, ref Program.MyClass instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) + { ValidateConfigurationKeys(typeof(Program.MyClass), s_configKeys_ProgramMyClass, configuration, binderOptions); if (configuration["Prop0"] is string value0) { - obj.Prop0 = ParseBool(value0, () => configuration.GetSection("Prop0").Path); + instance.Prop0 = ParseBool(value0, () => configuration.GetSection("Prop0").Path); + } + else if (defaultValueIfNotFound) + { + instance.Prop0 = default; } if (configuration["Prop1"] is string value1) { - obj.Prop1 = ParseByte(value1, () => configuration.GetSection("Prop1").Path); + instance.Prop1 = ParseByte(value1, () => configuration.GetSection("Prop1").Path); + } + else if (defaultValueIfNotFound) + { + instance.Prop1 = default; } if (configuration["Prop2"] is string value2) { - obj.Prop2 = ParseSbyte(value2, () => configuration.GetSection("Prop2").Path); + instance.Prop2 = ParseSbyte(value2, () => configuration.GetSection("Prop2").Path); + } + else if (defaultValueIfNotFound) + { + instance.Prop2 = default; } if (configuration["Prop3"] is string value3) { - obj.Prop3 = ParseChar(value3, () => configuration.GetSection("Prop3").Path); + instance.Prop3 = ParseChar(value3, () => configuration.GetSection("Prop3").Path); + } + else if (defaultValueIfNotFound) + { + instance.Prop3 = default; } if (configuration["Prop4"] is string value4) { - obj.Prop4 = ParseDouble(value4, () => configuration.GetSection("Prop4").Path); + instance.Prop4 = ParseDouble(value4, () => configuration.GetSection("Prop4").Path); + } + else if (defaultValueIfNotFound) + { + instance.Prop4 = default; } - obj.Prop5 = configuration["Prop5"]!; + if (configuration["Prop5"] is string value5) + { + instance.Prop5 = value5; + } if (configuration["Prop6"] is string value6) { - obj.Prop6 = ParseInt(value6, () => configuration.GetSection("Prop6").Path); + instance.Prop6 = ParseInt(value6, () => configuration.GetSection("Prop6").Path); + } + else if (defaultValueIfNotFound) + { + instance.Prop6 = default; } if (configuration["Prop8"] is string value7) { - obj.Prop8 = ParseShort(value7, () => configuration.GetSection("Prop8").Path); + instance.Prop8 = ParseShort(value7, () => configuration.GetSection("Prop8").Path); + } + else if (defaultValueIfNotFound) + { + instance.Prop8 = default; } if (configuration["Prop9"] is string value8) { - obj.Prop9 = ParseLong(value8, () => configuration.GetSection("Prop9").Path); + instance.Prop9 = ParseLong(value8, () => configuration.GetSection("Prop9").Path); + } + else if (defaultValueIfNotFound) + { + instance.Prop9 = default; } if (configuration["Prop10"] is string value9) { - obj.Prop10 = ParseFloat(value9, () => configuration.GetSection("Prop10").Path); + instance.Prop10 = ParseFloat(value9, () => configuration.GetSection("Prop10").Path); + } + else if (defaultValueIfNotFound) + { + instance.Prop10 = default; } if (configuration["Prop13"] is string value10) { - obj.Prop13 = ParseUshort(value10, () => configuration.GetSection("Prop13").Path); + instance.Prop13 = ParseUshort(value10, () => configuration.GetSection("Prop13").Path); + } + else if (defaultValueIfNotFound) + { + instance.Prop13 = default; } if (configuration["Prop14"] is string value11) { - obj.Prop14 = ParseUint(value11, () => configuration.GetSection("Prop14").Path); + instance.Prop14 = ParseUint(value11, () => configuration.GetSection("Prop14").Path); + } + else if (defaultValueIfNotFound) + { + instance.Prop14 = default; } if (configuration["Prop15"] is string value12) { - obj.Prop15 = ParseUlong(value12, () => configuration.GetSection("Prop15").Path); + instance.Prop15 = ParseUlong(value12, () => configuration.GetSection("Prop15").Path); + } + else if (defaultValueIfNotFound) + { + instance.Prop15 = default; } - obj.Prop16 = configuration["Prop16"]!; + if (configuration["Prop16"] is string value13) + { + instance.Prop16 = value13; + } if (configuration["Prop17"] is string value14) { - obj.Prop17 = ParseCultureInfo(value14, () => configuration.GetSection("Prop17").Path); + instance.Prop17 = ParseCultureInfo(value14, () => configuration.GetSection("Prop17").Path); } if (configuration["Prop19"] is string value15) { - obj.Prop19 = ParseDateTime(value15, () => configuration.GetSection("Prop19").Path); + instance.Prop19 = ParseDateTime(value15, () => configuration.GetSection("Prop19").Path); + } + else if (defaultValueIfNotFound) + { + instance.Prop19 = default; } if (configuration["Prop20"] is string value16) { - obj.Prop20 = ParseDateTimeOffset(value16, () => configuration.GetSection("Prop20").Path); + instance.Prop20 = ParseDateTimeOffset(value16, () => configuration.GetSection("Prop20").Path); + } + else if (defaultValueIfNotFound) + { + instance.Prop20 = default; } if (configuration["Prop21"] is string value17) { - obj.Prop21 = ParseDecimal(value17, () => configuration.GetSection("Prop21").Path); + instance.Prop21 = ParseDecimal(value17, () => configuration.GetSection("Prop21").Path); + } + else if (defaultValueIfNotFound) + { + instance.Prop21 = default; } if (configuration["Prop23"] is string value18) { - obj.Prop23 = ParseInt(value18, () => configuration.GetSection("Prop23").Path); + instance.Prop23 = ParseTimeSpan(value18, () => configuration.GetSection("Prop23").Path); + } + else if (defaultValueIfNotFound) + { + instance.Prop23 = default; } if (configuration["Prop24"] is string value19) { - obj.Prop24 = ParseDateTime(value19, () => configuration.GetSection("Prop24").Path); + instance.Prop24 = ParseGuid(value19, () => configuration.GetSection("Prop24").Path); + } + else if (defaultValueIfNotFound) + { + instance.Prop24 = default; } if (configuration["Prop25"] is string value20) { - obj.Prop25 = ParseUri(value20, () => configuration.GetSection("Prop25").Path); + instance.Prop25 = ParseUri(value20, () => configuration.GetSection("Prop25").Path); } if (configuration["Prop26"] is string value21) { - obj.Prop26 = ParseVersion(value21, () => configuration.GetSection("Prop26").Path); + instance.Prop26 = ParseVersion(value21, () => configuration.GetSection("Prop26").Path); } if (configuration["Prop27"] is string value22) { - obj.Prop27 = ParseEnum(value22, () => configuration.GetSection("Prop27").Path); + instance.Prop27 = ParseEnum(value22, () => configuration.GetSection("Prop27").Path); + } + else if (defaultValueIfNotFound) + { + instance.Prop27 = default; } if (configuration["Prop7"] is string value23) { - obj.Prop7 = ParseInt128(value23, () => configuration.GetSection("Prop7").Path); + instance.Prop7 = ParseInt128(value23, () => configuration.GetSection("Prop7").Path); + } + else if (defaultValueIfNotFound) + { + instance.Prop7 = default; } if (configuration["Prop11"] is string value24) { - obj.Prop11 = ParseHalf(value24, () => configuration.GetSection("Prop11").Path); + instance.Prop11 = ParseHalf(value24, () => configuration.GetSection("Prop11").Path); + } + else if (defaultValueIfNotFound) + { + instance.Prop11 = default; } if (configuration["Prop12"] is string value25) { - obj.Prop12 = ParseUInt128(value25, () => configuration.GetSection("Prop12").Path); + instance.Prop12 = ParseUInt128(value25, () => configuration.GetSection("Prop12").Path); + } + else if (defaultValueIfNotFound) + { + instance.Prop12 = default; } if (configuration["Prop18"] is string value26) { - obj.Prop18 = ParseDateOnly(value26, () => configuration.GetSection("Prop18").Path); + instance.Prop18 = ParseDateOnly(value26, () => configuration.GetSection("Prop18").Path); + } + else if (defaultValueIfNotFound) + { + instance.Prop18 = default; } if (configuration["Prop22"] is string value27) { - obj.Prop22 = ParseByteArray(value27, () => configuration.GetSection("Prop22").Path); + instance.Prop22 = ParseTimeOnly(value27, () => configuration.GetSection("Prop22").Path); + } + else if (defaultValueIfNotFound) + { + instance.Prop22 = default; + } + + if (configuration["Prop28"] is string value28) + { + instance.Prop28 = ParseByteArray(value28, () => configuration.GetSection("Prop28").Path); + } + + if (configuration["Prop29"] is string value29) + { + instance.Prop29 = ParseInt(value29, () => configuration.GetSection("Prop29").Path); + } + else if (defaultValueIfNotFound) + { + instance.Prop29 = default; + } + + if (configuration["Prop30"] is string value30) + { + instance.Prop30 = ParseDateTime(value30, () => configuration.GetSection("Prop30").Path); + } + else if (defaultValueIfNotFound) + { + instance.Prop30 = default; } } + /// If required by the binder options, validates that there are no unknown keys in the input configuration object. public static void ValidateConfigurationKeys(Type type, Lazy> keys, IConfiguration configuration, BinderOptions? binderOptions) { @@ -517,5 +662,6 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration throw new InvalidOperationException($"Failed to convert configuration value at '{getPath()}' to type '{typeof(byte[])}'.", exception); } } + #endregion Core binding extensions. } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ServiceCollection/Configure_T.generated.txt b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ServiceCollection/Configure_T.generated.txt index 461f6050e2fdfe..3e38a484c82552 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ServiceCollection/Configure_T.generated.txt +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ServiceCollection/Configure_T.generated.txt @@ -2,177 +2,180 @@ #nullable enable #pragma warning disable CS0612, CS0618 // Suppress warnings about [Obsolete] member usage in generated code. -/// Generated helper providing an AOT and linking compatible implementation for configuration binding. -[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] -internal static class GeneratedServiceCollectionBinder +namespace System.Runtime.CompilerServices { - /// Registers a configuration instance which TOptions will bind against. - public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection Configure(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::Microsoft.Extensions.Configuration.IConfiguration configuration) where TOptions : class - { - return global::GeneratedServiceCollectionBinder.Configure(services, string.Empty, configuration, configureOptions: null); - } + using System; + using System.CodeDom.Compiler; - /// Registers a configuration instance which TOptions will bind against. - public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection Configure(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, string? name, global::Microsoft.Extensions.Configuration.IConfiguration configuration, global::System.Action? configureOptions) where TOptions : class + [GeneratedCode("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] + [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] + file sealed class InterceptsLocationAttribute : Attribute { - if (services is null) - { - throw new global::System.ArgumentNullException(nameof(services)); - } - - if (configuration is null) + public InterceptsLocationAttribute(string filePath, int line, int column) { - throw new global::System.ArgumentNullException(nameof(configuration)); } - - global::Microsoft.Extensions.DependencyInjection.OptionsServiceCollectionExtensions.AddOptions(services); - global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddSingleton>(services, new global::Microsoft.Extensions.Options.ConfigurationChangeTokenSource(name, configuration)); - return global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddSingleton>(services, new global::Microsoft.Extensions.Options.ConfigureNamedOptions(name, obj => global::Microsoft.Extensions.Configuration.Binder.SourceGeneration.CoreBindingHelper.BindCoreUntyped(configuration, obj, typeof(TOptions), configureOptions))); } } namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration { using Microsoft.Extensions.Configuration; + using Microsoft.Extensions.DependencyInjection; + using Microsoft.Extensions.Options; using System; using System.CodeDom.Compiler; using System.Collections.Generic; using System.Globalization; + using System.Runtime.CompilerServices; - /// Provide core binding logic. [GeneratedCode("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] - file static class CoreBindingHelper + file static class BindingExtensions { + #region IServiceCollection extensions. + /// Registers a configuration instance which TOptions will bind against. + [InterceptsLocation(@"src-0.cs", 14, 18)] + public static IServiceCollection Configure(this IServiceCollection services, IConfiguration config) where TOptions : class + { + return Configure(services, string.Empty, config, configureOptions: null); + } + + /// Registers a configuration instance which TOptions will bind against. + public static IServiceCollection Configure(this IServiceCollection services, string? name, IConfiguration config, Action? configureOptions) where TOptions : class + { + if (services is null) + { + throw new ArgumentNullException(nameof(services)); + } + + if (config is null) + { + throw new ArgumentNullException(nameof(config)); + } + + OptionsServiceCollectionExtensions.AddOptions(services); + services.AddSingleton>(new ConfigurationChangeTokenSource(name, config)); + return services.AddSingleton>(new ConfigureNamedOptions(name, instance => BindCoreMain(config, instance, typeof(TOptions), configureOptions))); + } + #endregion IServiceCollection extensions. + + #region Core binding extensions. private readonly static Lazy> s_configKeys_ProgramMyClass2 = new(() => new HashSet(StringComparer.OrdinalIgnoreCase) { "MyInt" }); private readonly static Lazy> s_configKeys_ProgramMyClass = new(() => new HashSet(StringComparer.OrdinalIgnoreCase) { "MyString", "MyInt", "MyList", "MyList2", "MyDictionary" }); - public static void BindCoreUntyped(this IConfiguration configuration, object obj, Type type, Action? configureOptions) + public static void BindCoreMain(IConfiguration configuration, object instance, Type type, Action? configureOptions) { - if (configuration is null) + if (instance is null) { - throw new ArgumentNullException(nameof(configuration)); + return; } - BinderOptions? binderOptions = GetBinderOptions(configureOptions); - if (!HasValueOrChildren(configuration)) { return; } + BinderOptions? binderOptions = GetBinderOptions(configureOptions); + if (type == typeof(Program.MyClass)) { - var temp = (Program.MyClass)obj; - BindCore(configuration, ref temp, binderOptions); + var temp = (Program.MyClass)instance; + BindCore(configuration, ref temp, defaultValueIfNotFound: false, binderOptions); return; } - throw new global::System.NotSupportedException($"Unable to bind to type '{type}': generator did not detect the type as input."); + throw new NotSupportedException($"Unable to bind to type '{type}': generator did not detect the type as input."); } - public static void BindCore(IConfiguration configuration, ref List obj, BinderOptions? binderOptions) + public static void BindCore(IConfiguration configuration, ref List instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { - if (obj is null) - { - throw new ArgumentNullException(nameof(obj)); - } - foreach (IConfigurationSection section in configuration.GetChildren()) { if (section.Value is string value) { - obj.Add(ParseInt(value, () => section.Path)); + instance.Add(ParseInt(value, () => section.Path)); } } } - public static void BindCore(IConfiguration configuration, ref Program.MyClass2 obj, BinderOptions? binderOptions) + public static void BindCore(IConfiguration configuration, ref Program.MyClass2 instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { - if (obj is null) - { - throw new ArgumentNullException(nameof(obj)); - } - ValidateConfigurationKeys(typeof(Program.MyClass2), s_configKeys_ProgramMyClass2, configuration, binderOptions); if (configuration["MyInt"] is string value1) { - obj.MyInt = ParseInt(value1, () => configuration.GetSection("MyInt").Path); + instance.MyInt = ParseInt(value1, () => configuration.GetSection("MyInt").Path); } - } - - public static void BindCore(IConfiguration configuration, ref List obj, BinderOptions? binderOptions) - { - if (obj is null) + else if (defaultValueIfNotFound) { - throw new ArgumentNullException(nameof(obj)); + instance.MyInt = default; } + } + public static void BindCore(IConfiguration configuration, ref List instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) + { foreach (IConfigurationSection section in configuration.GetChildren()) { var value = new Program.MyClass2(); - BindCore(section, ref value, binderOptions); - obj.Add(value); + BindCore(section, ref value, defaultValueIfNotFound: false, binderOptions); + instance.Add(value); } } - public static void BindCore(IConfiguration configuration, ref Dictionary obj, BinderOptions? binderOptions) + public static void BindCore(IConfiguration configuration, ref Dictionary instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { - if (obj is null) - { - throw new ArgumentNullException(nameof(obj)); - } - foreach (IConfigurationSection section in configuration.GetChildren()) { if (section.Value is string value) { - obj[section.Key] = value; + instance[section.Key] = value; } } } - public static void BindCore(IConfiguration configuration, ref Program.MyClass obj, BinderOptions? binderOptions) + public static void BindCore(IConfiguration configuration, ref Program.MyClass instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { - if (obj is null) - { - throw new ArgumentNullException(nameof(obj)); - } - ValidateConfigurationKeys(typeof(Program.MyClass), s_configKeys_ProgramMyClass, configuration, binderOptions); - obj.MyString = configuration["MyString"]!; + if (configuration["MyString"] is string value3) + { + instance.MyString = value3; + } if (configuration["MyInt"] is string value4) { - obj.MyInt = ParseInt(value4, () => configuration.GetSection("MyInt").Path); + instance.MyInt = ParseInt(value4, () => configuration.GetSection("MyInt").Path); + } + else if (defaultValueIfNotFound) + { + instance.MyInt = default; } if (AsConfigWithChildren(configuration.GetSection("MyList")) is IConfigurationSection section5) { - List? temp7 = obj.MyList; + List? temp7 = instance.MyList; temp7 ??= new List(); - BindCore(section5, ref temp7, binderOptions); - obj.MyList = temp7; + BindCore(section5, ref temp7, defaultValueIfNotFound: false, binderOptions); + instance.MyList = temp7; } if (AsConfigWithChildren(configuration.GetSection("MyList2")) is IConfigurationSection section8) { - List? temp10 = obj.MyList2; + List? temp10 = instance.MyList2; temp10 ??= new List(); - BindCore(section8, ref temp10, binderOptions); - obj.MyList2 = temp10; + BindCore(section8, ref temp10, defaultValueIfNotFound: false, binderOptions); + instance.MyList2 = temp10; } if (AsConfigWithChildren(configuration.GetSection("MyDictionary")) is IConfigurationSection section11) { - Dictionary? temp13 = obj.MyDictionary; + Dictionary? temp13 = instance.MyDictionary; temp13 ??= new Dictionary(); - BindCore(section11, ref temp13, binderOptions); - obj.MyDictionary = temp13; + BindCore(section11, ref temp13, defaultValueIfNotFound: false, binderOptions); + instance.MyDictionary = temp13; } } + /// If required by the binder options, validates that there are no unknown keys in the input configuration object. public static void ValidateConfigurationKeys(Type type, Lazy> keys, IConfiguration configuration, BinderOptions? binderOptions) { @@ -225,7 +228,7 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration if (binderOptions.BindNonPublicProperties) { - throw new global::System.NotSupportedException($"The configuration binding source generator does not support 'BinderOptions.BindNonPublicProperties'."); + throw new NotSupportedException($"The configuration binding source generator does not support 'BinderOptions.BindNonPublicProperties'."); } return binderOptions; @@ -242,5 +245,6 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration throw new InvalidOperationException($"Failed to convert configuration value at '{getPath()}' to type '{typeof(int)}'.", exception); } } + #endregion Core binding extensions. } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ServiceCollection/Configure_T_BinderOptions.generated.txt b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ServiceCollection/Configure_T_BinderOptions.generated.txt index a57f652720bc8c..186e93a49207b2 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ServiceCollection/Configure_T_BinderOptions.generated.txt +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ServiceCollection/Configure_T_BinderOptions.generated.txt @@ -2,177 +2,180 @@ #nullable enable #pragma warning disable CS0612, CS0618 // Suppress warnings about [Obsolete] member usage in generated code. -/// Generated helper providing an AOT and linking compatible implementation for configuration binding. -[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] -internal static class GeneratedServiceCollectionBinder +namespace System.Runtime.CompilerServices { - /// Registers a configuration instance which TOptions will bind against. - public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection Configure(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::Microsoft.Extensions.Configuration.IConfiguration configuration, global::System.Action? configureOptions) where TOptions : class - { - return global::GeneratedServiceCollectionBinder.Configure(services, string.Empty, configuration, configureOptions); - } + using System; + using System.CodeDom.Compiler; - /// Registers a configuration instance which TOptions will bind against. - public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection Configure(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, string? name, global::Microsoft.Extensions.Configuration.IConfiguration configuration, global::System.Action? configureOptions) where TOptions : class + [GeneratedCode("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] + [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] + file sealed class InterceptsLocationAttribute : Attribute { - if (services is null) - { - throw new global::System.ArgumentNullException(nameof(services)); - } - - if (configuration is null) + public InterceptsLocationAttribute(string filePath, int line, int column) { - throw new global::System.ArgumentNullException(nameof(configuration)); } - - global::Microsoft.Extensions.DependencyInjection.OptionsServiceCollectionExtensions.AddOptions(services); - global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddSingleton>(services, new global::Microsoft.Extensions.Options.ConfigurationChangeTokenSource(name, configuration)); - return global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddSingleton>(services, new global::Microsoft.Extensions.Options.ConfigureNamedOptions(name, obj => global::Microsoft.Extensions.Configuration.Binder.SourceGeneration.CoreBindingHelper.BindCoreUntyped(configuration, obj, typeof(TOptions), configureOptions))); } } namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration { using Microsoft.Extensions.Configuration; + using Microsoft.Extensions.DependencyInjection; + using Microsoft.Extensions.Options; using System; using System.CodeDom.Compiler; using System.Collections.Generic; using System.Globalization; + using System.Runtime.CompilerServices; - /// Provide core binding logic. [GeneratedCode("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] - file static class CoreBindingHelper + file static class BindingExtensions { + #region IServiceCollection extensions. + /// Registers a configuration instance which TOptions will bind against. + [InterceptsLocation(@"src-0.cs", 14, 18)] + public static IServiceCollection Configure(this IServiceCollection services, IConfiguration config, Action? configureOptions) where TOptions : class + { + return Configure(services, string.Empty, config, configureOptions); + } + + /// Registers a configuration instance which TOptions will bind against. + public static IServiceCollection Configure(this IServiceCollection services, string? name, IConfiguration config, Action? configureOptions) where TOptions : class + { + if (services is null) + { + throw new ArgumentNullException(nameof(services)); + } + + if (config is null) + { + throw new ArgumentNullException(nameof(config)); + } + + OptionsServiceCollectionExtensions.AddOptions(services); + services.AddSingleton>(new ConfigurationChangeTokenSource(name, config)); + return services.AddSingleton>(new ConfigureNamedOptions(name, instance => BindCoreMain(config, instance, typeof(TOptions), configureOptions))); + } + #endregion IServiceCollection extensions. + + #region Core binding extensions. private readonly static Lazy> s_configKeys_ProgramMyClass2 = new(() => new HashSet(StringComparer.OrdinalIgnoreCase) { "MyInt" }); private readonly static Lazy> s_configKeys_ProgramMyClass = new(() => new HashSet(StringComparer.OrdinalIgnoreCase) { "MyString", "MyInt", "MyList", "MyList2", "MyDictionary" }); - public static void BindCoreUntyped(this IConfiguration configuration, object obj, Type type, Action? configureOptions) + public static void BindCoreMain(IConfiguration configuration, object instance, Type type, Action? configureOptions) { - if (configuration is null) + if (instance is null) { - throw new ArgumentNullException(nameof(configuration)); + return; } - BinderOptions? binderOptions = GetBinderOptions(configureOptions); - if (!HasValueOrChildren(configuration)) { return; } + BinderOptions? binderOptions = GetBinderOptions(configureOptions); + if (type == typeof(Program.MyClass)) { - var temp = (Program.MyClass)obj; - BindCore(configuration, ref temp, binderOptions); + var temp = (Program.MyClass)instance; + BindCore(configuration, ref temp, defaultValueIfNotFound: false, binderOptions); return; } - throw new global::System.NotSupportedException($"Unable to bind to type '{type}': generator did not detect the type as input."); + throw new NotSupportedException($"Unable to bind to type '{type}': generator did not detect the type as input."); } - public static void BindCore(IConfiguration configuration, ref List obj, BinderOptions? binderOptions) + public static void BindCore(IConfiguration configuration, ref List instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { - if (obj is null) - { - throw new ArgumentNullException(nameof(obj)); - } - foreach (IConfigurationSection section in configuration.GetChildren()) { if (section.Value is string value) { - obj.Add(ParseInt(value, () => section.Path)); + instance.Add(ParseInt(value, () => section.Path)); } } } - public static void BindCore(IConfiguration configuration, ref Program.MyClass2 obj, BinderOptions? binderOptions) + public static void BindCore(IConfiguration configuration, ref Program.MyClass2 instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { - if (obj is null) - { - throw new ArgumentNullException(nameof(obj)); - } - ValidateConfigurationKeys(typeof(Program.MyClass2), s_configKeys_ProgramMyClass2, configuration, binderOptions); if (configuration["MyInt"] is string value1) { - obj.MyInt = ParseInt(value1, () => configuration.GetSection("MyInt").Path); + instance.MyInt = ParseInt(value1, () => configuration.GetSection("MyInt").Path); } - } - - public static void BindCore(IConfiguration configuration, ref List obj, BinderOptions? binderOptions) - { - if (obj is null) + else if (defaultValueIfNotFound) { - throw new ArgumentNullException(nameof(obj)); + instance.MyInt = default; } + } + public static void BindCore(IConfiguration configuration, ref List instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) + { foreach (IConfigurationSection section in configuration.GetChildren()) { var value = new Program.MyClass2(); - BindCore(section, ref value, binderOptions); - obj.Add(value); + BindCore(section, ref value, defaultValueIfNotFound: false, binderOptions); + instance.Add(value); } } - public static void BindCore(IConfiguration configuration, ref Dictionary obj, BinderOptions? binderOptions) + public static void BindCore(IConfiguration configuration, ref Dictionary instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { - if (obj is null) - { - throw new ArgumentNullException(nameof(obj)); - } - foreach (IConfigurationSection section in configuration.GetChildren()) { if (section.Value is string value) { - obj[section.Key] = value; + instance[section.Key] = value; } } } - public static void BindCore(IConfiguration configuration, ref Program.MyClass obj, BinderOptions? binderOptions) + public static void BindCore(IConfiguration configuration, ref Program.MyClass instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { - if (obj is null) - { - throw new ArgumentNullException(nameof(obj)); - } - ValidateConfigurationKeys(typeof(Program.MyClass), s_configKeys_ProgramMyClass, configuration, binderOptions); - obj.MyString = configuration["MyString"]!; + if (configuration["MyString"] is string value3) + { + instance.MyString = value3; + } if (configuration["MyInt"] is string value4) { - obj.MyInt = ParseInt(value4, () => configuration.GetSection("MyInt").Path); + instance.MyInt = ParseInt(value4, () => configuration.GetSection("MyInt").Path); + } + else if (defaultValueIfNotFound) + { + instance.MyInt = default; } if (AsConfigWithChildren(configuration.GetSection("MyList")) is IConfigurationSection section5) { - List? temp7 = obj.MyList; + List? temp7 = instance.MyList; temp7 ??= new List(); - BindCore(section5, ref temp7, binderOptions); - obj.MyList = temp7; + BindCore(section5, ref temp7, defaultValueIfNotFound: false, binderOptions); + instance.MyList = temp7; } if (AsConfigWithChildren(configuration.GetSection("MyList2")) is IConfigurationSection section8) { - List? temp10 = obj.MyList2; + List? temp10 = instance.MyList2; temp10 ??= new List(); - BindCore(section8, ref temp10, binderOptions); - obj.MyList2 = temp10; + BindCore(section8, ref temp10, defaultValueIfNotFound: false, binderOptions); + instance.MyList2 = temp10; } if (AsConfigWithChildren(configuration.GetSection("MyDictionary")) is IConfigurationSection section11) { - Dictionary? temp13 = obj.MyDictionary; + Dictionary? temp13 = instance.MyDictionary; temp13 ??= new Dictionary(); - BindCore(section11, ref temp13, binderOptions); - obj.MyDictionary = temp13; + BindCore(section11, ref temp13, defaultValueIfNotFound: false, binderOptions); + instance.MyDictionary = temp13; } } + /// If required by the binder options, validates that there are no unknown keys in the input configuration object. public static void ValidateConfigurationKeys(Type type, Lazy> keys, IConfiguration configuration, BinderOptions? binderOptions) { @@ -225,7 +228,7 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration if (binderOptions.BindNonPublicProperties) { - throw new global::System.NotSupportedException($"The configuration binding source generator does not support 'BinderOptions.BindNonPublicProperties'."); + throw new NotSupportedException($"The configuration binding source generator does not support 'BinderOptions.BindNonPublicProperties'."); } return binderOptions; @@ -242,5 +245,6 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration throw new InvalidOperationException($"Failed to convert configuration value at '{getPath()}' to type '{typeof(int)}'.", exception); } } + #endregion Core binding extensions. } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ServiceCollection/Configure_T_name.generated.txt b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ServiceCollection/Configure_T_name.generated.txt index 66975c3164d745..7958adb1125338 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ServiceCollection/Configure_T_name.generated.txt +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ServiceCollection/Configure_T_name.generated.txt @@ -2,177 +2,180 @@ #nullable enable #pragma warning disable CS0612, CS0618 // Suppress warnings about [Obsolete] member usage in generated code. -/// Generated helper providing an AOT and linking compatible implementation for configuration binding. -[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] -internal static class GeneratedServiceCollectionBinder +namespace System.Runtime.CompilerServices { - /// Registers a configuration instance which TOptions will bind against. - public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection Configure(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, string? name, global::Microsoft.Extensions.Configuration.IConfiguration configuration) where TOptions : class - { - return global::GeneratedServiceCollectionBinder.Configure(services, name, configuration, configureOptions: null); - } + using System; + using System.CodeDom.Compiler; - /// Registers a configuration instance which TOptions will bind against. - public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection Configure(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, string? name, global::Microsoft.Extensions.Configuration.IConfiguration configuration, global::System.Action? configureOptions) where TOptions : class + [GeneratedCode("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] + [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] + file sealed class InterceptsLocationAttribute : Attribute { - if (services is null) - { - throw new global::System.ArgumentNullException(nameof(services)); - } - - if (configuration is null) + public InterceptsLocationAttribute(string filePath, int line, int column) { - throw new global::System.ArgumentNullException(nameof(configuration)); } - - global::Microsoft.Extensions.DependencyInjection.OptionsServiceCollectionExtensions.AddOptions(services); - global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddSingleton>(services, new global::Microsoft.Extensions.Options.ConfigurationChangeTokenSource(name, configuration)); - return global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddSingleton>(services, new global::Microsoft.Extensions.Options.ConfigureNamedOptions(name, obj => global::Microsoft.Extensions.Configuration.Binder.SourceGeneration.CoreBindingHelper.BindCoreUntyped(configuration, obj, typeof(TOptions), configureOptions))); } } namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration { using Microsoft.Extensions.Configuration; + using Microsoft.Extensions.DependencyInjection; + using Microsoft.Extensions.Options; using System; using System.CodeDom.Compiler; using System.Collections.Generic; using System.Globalization; + using System.Runtime.CompilerServices; - /// Provide core binding logic. [GeneratedCode("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] - file static class CoreBindingHelper + file static class BindingExtensions { + #region IServiceCollection extensions. + /// Registers a configuration instance which TOptions will bind against. + [InterceptsLocation(@"src-0.cs", 14, 18)] + public static IServiceCollection Configure(this IServiceCollection services, string? name, IConfiguration config) where TOptions : class + { + return Configure(services, name, config, configureOptions: null); + } + + /// Registers a configuration instance which TOptions will bind against. + public static IServiceCollection Configure(this IServiceCollection services, string? name, IConfiguration config, Action? configureOptions) where TOptions : class + { + if (services is null) + { + throw new ArgumentNullException(nameof(services)); + } + + if (config is null) + { + throw new ArgumentNullException(nameof(config)); + } + + OptionsServiceCollectionExtensions.AddOptions(services); + services.AddSingleton>(new ConfigurationChangeTokenSource(name, config)); + return services.AddSingleton>(new ConfigureNamedOptions(name, instance => BindCoreMain(config, instance, typeof(TOptions), configureOptions))); + } + #endregion IServiceCollection extensions. + + #region Core binding extensions. private readonly static Lazy> s_configKeys_ProgramMyClass2 = new(() => new HashSet(StringComparer.OrdinalIgnoreCase) { "MyInt" }); private readonly static Lazy> s_configKeys_ProgramMyClass = new(() => new HashSet(StringComparer.OrdinalIgnoreCase) { "MyString", "MyInt", "MyList", "MyList2", "MyDictionary" }); - public static void BindCoreUntyped(this IConfiguration configuration, object obj, Type type, Action? configureOptions) + public static void BindCoreMain(IConfiguration configuration, object instance, Type type, Action? configureOptions) { - if (configuration is null) + if (instance is null) { - throw new ArgumentNullException(nameof(configuration)); + return; } - BinderOptions? binderOptions = GetBinderOptions(configureOptions); - if (!HasValueOrChildren(configuration)) { return; } + BinderOptions? binderOptions = GetBinderOptions(configureOptions); + if (type == typeof(Program.MyClass)) { - var temp = (Program.MyClass)obj; - BindCore(configuration, ref temp, binderOptions); + var temp = (Program.MyClass)instance; + BindCore(configuration, ref temp, defaultValueIfNotFound: false, binderOptions); return; } - throw new global::System.NotSupportedException($"Unable to bind to type '{type}': generator did not detect the type as input."); + throw new NotSupportedException($"Unable to bind to type '{type}': generator did not detect the type as input."); } - public static void BindCore(IConfiguration configuration, ref List obj, BinderOptions? binderOptions) + public static void BindCore(IConfiguration configuration, ref List instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { - if (obj is null) - { - throw new ArgumentNullException(nameof(obj)); - } - foreach (IConfigurationSection section in configuration.GetChildren()) { if (section.Value is string value) { - obj.Add(ParseInt(value, () => section.Path)); + instance.Add(ParseInt(value, () => section.Path)); } } } - public static void BindCore(IConfiguration configuration, ref Program.MyClass2 obj, BinderOptions? binderOptions) + public static void BindCore(IConfiguration configuration, ref Program.MyClass2 instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { - if (obj is null) - { - throw new ArgumentNullException(nameof(obj)); - } - ValidateConfigurationKeys(typeof(Program.MyClass2), s_configKeys_ProgramMyClass2, configuration, binderOptions); if (configuration["MyInt"] is string value1) { - obj.MyInt = ParseInt(value1, () => configuration.GetSection("MyInt").Path); + instance.MyInt = ParseInt(value1, () => configuration.GetSection("MyInt").Path); } - } - - public static void BindCore(IConfiguration configuration, ref List obj, BinderOptions? binderOptions) - { - if (obj is null) + else if (defaultValueIfNotFound) { - throw new ArgumentNullException(nameof(obj)); + instance.MyInt = default; } + } + public static void BindCore(IConfiguration configuration, ref List instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) + { foreach (IConfigurationSection section in configuration.GetChildren()) { var value = new Program.MyClass2(); - BindCore(section, ref value, binderOptions); - obj.Add(value); + BindCore(section, ref value, defaultValueIfNotFound: false, binderOptions); + instance.Add(value); } } - public static void BindCore(IConfiguration configuration, ref Dictionary obj, BinderOptions? binderOptions) + public static void BindCore(IConfiguration configuration, ref Dictionary instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { - if (obj is null) - { - throw new ArgumentNullException(nameof(obj)); - } - foreach (IConfigurationSection section in configuration.GetChildren()) { if (section.Value is string value) { - obj[section.Key] = value; + instance[section.Key] = value; } } } - public static void BindCore(IConfiguration configuration, ref Program.MyClass obj, BinderOptions? binderOptions) + public static void BindCore(IConfiguration configuration, ref Program.MyClass instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { - if (obj is null) - { - throw new ArgumentNullException(nameof(obj)); - } - ValidateConfigurationKeys(typeof(Program.MyClass), s_configKeys_ProgramMyClass, configuration, binderOptions); - obj.MyString = configuration["MyString"]!; + if (configuration["MyString"] is string value3) + { + instance.MyString = value3; + } if (configuration["MyInt"] is string value4) { - obj.MyInt = ParseInt(value4, () => configuration.GetSection("MyInt").Path); + instance.MyInt = ParseInt(value4, () => configuration.GetSection("MyInt").Path); + } + else if (defaultValueIfNotFound) + { + instance.MyInt = default; } if (AsConfigWithChildren(configuration.GetSection("MyList")) is IConfigurationSection section5) { - List? temp7 = obj.MyList; + List? temp7 = instance.MyList; temp7 ??= new List(); - BindCore(section5, ref temp7, binderOptions); - obj.MyList = temp7; + BindCore(section5, ref temp7, defaultValueIfNotFound: false, binderOptions); + instance.MyList = temp7; } if (AsConfigWithChildren(configuration.GetSection("MyList2")) is IConfigurationSection section8) { - List? temp10 = obj.MyList2; + List? temp10 = instance.MyList2; temp10 ??= new List(); - BindCore(section8, ref temp10, binderOptions); - obj.MyList2 = temp10; + BindCore(section8, ref temp10, defaultValueIfNotFound: false, binderOptions); + instance.MyList2 = temp10; } if (AsConfigWithChildren(configuration.GetSection("MyDictionary")) is IConfigurationSection section11) { - Dictionary? temp13 = obj.MyDictionary; + Dictionary? temp13 = instance.MyDictionary; temp13 ??= new Dictionary(); - BindCore(section11, ref temp13, binderOptions); - obj.MyDictionary = temp13; + BindCore(section11, ref temp13, defaultValueIfNotFound: false, binderOptions); + instance.MyDictionary = temp13; } } + /// If required by the binder options, validates that there are no unknown keys in the input configuration object. public static void ValidateConfigurationKeys(Type type, Lazy> keys, IConfiguration configuration, BinderOptions? binderOptions) { @@ -225,7 +228,7 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration if (binderOptions.BindNonPublicProperties) { - throw new global::System.NotSupportedException($"The configuration binding source generator does not support 'BinderOptions.BindNonPublicProperties'."); + throw new NotSupportedException($"The configuration binding source generator does not support 'BinderOptions.BindNonPublicProperties'."); } return binderOptions; @@ -242,5 +245,6 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration throw new InvalidOperationException($"Failed to convert configuration value at '{getPath()}' to type '{typeof(int)}'.", exception); } } + #endregion Core binding extensions. } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ServiceCollection/Configure_T_name_BinderOptions.generated.txt b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ServiceCollection/Configure_T_name_BinderOptions.generated.txt index 0263ef12179401..b87d0c0a259cc8 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ServiceCollection/Configure_T_name_BinderOptions.generated.txt +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ServiceCollection/Configure_T_name_BinderOptions.generated.txt @@ -2,171 +2,174 @@ #nullable enable #pragma warning disable CS0612, CS0618 // Suppress warnings about [Obsolete] member usage in generated code. -/// Generated helper providing an AOT and linking compatible implementation for configuration binding. -[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] -internal static class GeneratedServiceCollectionBinder +namespace System.Runtime.CompilerServices { - /// Registers a configuration instance which TOptions will bind against. - public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection Configure(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, string? name, global::Microsoft.Extensions.Configuration.IConfiguration configuration, global::System.Action? configureOptions) where TOptions : class - { - if (services is null) - { - throw new global::System.ArgumentNullException(nameof(services)); - } + using System; + using System.CodeDom.Compiler; - if (configuration is null) + [GeneratedCode("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] + [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] + file sealed class InterceptsLocationAttribute : Attribute + { + public InterceptsLocationAttribute(string filePath, int line, int column) { - throw new global::System.ArgumentNullException(nameof(configuration)); } - - global::Microsoft.Extensions.DependencyInjection.OptionsServiceCollectionExtensions.AddOptions(services); - global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddSingleton>(services, new global::Microsoft.Extensions.Options.ConfigurationChangeTokenSource(name, configuration)); - return global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddSingleton>(services, new global::Microsoft.Extensions.Options.ConfigureNamedOptions(name, obj => global::Microsoft.Extensions.Configuration.Binder.SourceGeneration.CoreBindingHelper.BindCoreUntyped(configuration, obj, typeof(TOptions), configureOptions))); } } namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration { using Microsoft.Extensions.Configuration; + using Microsoft.Extensions.DependencyInjection; + using Microsoft.Extensions.Options; using System; using System.CodeDom.Compiler; using System.Collections.Generic; using System.Globalization; + using System.Runtime.CompilerServices; - /// Provide core binding logic. [GeneratedCode("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] - file static class CoreBindingHelper + file static class BindingExtensions { + #region IServiceCollection extensions. + /// Registers a configuration instance which TOptions will bind against. + [InterceptsLocation(@"src-0.cs", 14, 18)] + public static IServiceCollection Configure(this IServiceCollection services, string? name, IConfiguration config, Action? configureOptions) where TOptions : class + { + if (services is null) + { + throw new ArgumentNullException(nameof(services)); + } + + if (config is null) + { + throw new ArgumentNullException(nameof(config)); + } + + OptionsServiceCollectionExtensions.AddOptions(services); + services.AddSingleton>(new ConfigurationChangeTokenSource(name, config)); + return services.AddSingleton>(new ConfigureNamedOptions(name, instance => BindCoreMain(config, instance, typeof(TOptions), configureOptions))); + } + #endregion IServiceCollection extensions. + + #region Core binding extensions. private readonly static Lazy> s_configKeys_ProgramMyClass2 = new(() => new HashSet(StringComparer.OrdinalIgnoreCase) { "MyInt" }); private readonly static Lazy> s_configKeys_ProgramMyClass = new(() => new HashSet(StringComparer.OrdinalIgnoreCase) { "MyString", "MyInt", "MyList", "MyList2", "MyDictionary" }); - public static void BindCoreUntyped(this IConfiguration configuration, object obj, Type type, Action? configureOptions) + public static void BindCoreMain(IConfiguration configuration, object instance, Type type, Action? configureOptions) { - if (configuration is null) + if (instance is null) { - throw new ArgumentNullException(nameof(configuration)); + return; } - BinderOptions? binderOptions = GetBinderOptions(configureOptions); - if (!HasValueOrChildren(configuration)) { return; } + BinderOptions? binderOptions = GetBinderOptions(configureOptions); + if (type == typeof(Program.MyClass)) { - var temp = (Program.MyClass)obj; - BindCore(configuration, ref temp, binderOptions); + var temp = (Program.MyClass)instance; + BindCore(configuration, ref temp, defaultValueIfNotFound: false, binderOptions); return; } - throw new global::System.NotSupportedException($"Unable to bind to type '{type}': generator did not detect the type as input."); + throw new NotSupportedException($"Unable to bind to type '{type}': generator did not detect the type as input."); } - public static void BindCore(IConfiguration configuration, ref List obj, BinderOptions? binderOptions) + public static void BindCore(IConfiguration configuration, ref List instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { - if (obj is null) - { - throw new ArgumentNullException(nameof(obj)); - } - foreach (IConfigurationSection section in configuration.GetChildren()) { if (section.Value is string value) { - obj.Add(ParseInt(value, () => section.Path)); + instance.Add(ParseInt(value, () => section.Path)); } } } - public static void BindCore(IConfiguration configuration, ref Program.MyClass2 obj, BinderOptions? binderOptions) + public static void BindCore(IConfiguration configuration, ref Program.MyClass2 instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { - if (obj is null) - { - throw new ArgumentNullException(nameof(obj)); - } - ValidateConfigurationKeys(typeof(Program.MyClass2), s_configKeys_ProgramMyClass2, configuration, binderOptions); if (configuration["MyInt"] is string value1) { - obj.MyInt = ParseInt(value1, () => configuration.GetSection("MyInt").Path); + instance.MyInt = ParseInt(value1, () => configuration.GetSection("MyInt").Path); } - } - - public static void BindCore(IConfiguration configuration, ref List obj, BinderOptions? binderOptions) - { - if (obj is null) + else if (defaultValueIfNotFound) { - throw new ArgumentNullException(nameof(obj)); + instance.MyInt = default; } + } + public static void BindCore(IConfiguration configuration, ref List instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) + { foreach (IConfigurationSection section in configuration.GetChildren()) { var value = new Program.MyClass2(); - BindCore(section, ref value, binderOptions); - obj.Add(value); + BindCore(section, ref value, defaultValueIfNotFound: false, binderOptions); + instance.Add(value); } } - public static void BindCore(IConfiguration configuration, ref Dictionary obj, BinderOptions? binderOptions) + public static void BindCore(IConfiguration configuration, ref Dictionary instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { - if (obj is null) - { - throw new ArgumentNullException(nameof(obj)); - } - foreach (IConfigurationSection section in configuration.GetChildren()) { if (section.Value is string value) { - obj[section.Key] = value; + instance[section.Key] = value; } } } - public static void BindCore(IConfiguration configuration, ref Program.MyClass obj, BinderOptions? binderOptions) + public static void BindCore(IConfiguration configuration, ref Program.MyClass instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { - if (obj is null) - { - throw new ArgumentNullException(nameof(obj)); - } - ValidateConfigurationKeys(typeof(Program.MyClass), s_configKeys_ProgramMyClass, configuration, binderOptions); - obj.MyString = configuration["MyString"]!; + if (configuration["MyString"] is string value3) + { + instance.MyString = value3; + } if (configuration["MyInt"] is string value4) { - obj.MyInt = ParseInt(value4, () => configuration.GetSection("MyInt").Path); + instance.MyInt = ParseInt(value4, () => configuration.GetSection("MyInt").Path); + } + else if (defaultValueIfNotFound) + { + instance.MyInt = default; } if (AsConfigWithChildren(configuration.GetSection("MyList")) is IConfigurationSection section5) { - List? temp7 = obj.MyList; + List? temp7 = instance.MyList; temp7 ??= new List(); - BindCore(section5, ref temp7, binderOptions); - obj.MyList = temp7; + BindCore(section5, ref temp7, defaultValueIfNotFound: false, binderOptions); + instance.MyList = temp7; } if (AsConfigWithChildren(configuration.GetSection("MyList2")) is IConfigurationSection section8) { - List? temp10 = obj.MyList2; + List? temp10 = instance.MyList2; temp10 ??= new List(); - BindCore(section8, ref temp10, binderOptions); - obj.MyList2 = temp10; + BindCore(section8, ref temp10, defaultValueIfNotFound: false, binderOptions); + instance.MyList2 = temp10; } if (AsConfigWithChildren(configuration.GetSection("MyDictionary")) is IConfigurationSection section11) { - Dictionary? temp13 = obj.MyDictionary; + Dictionary? temp13 = instance.MyDictionary; temp13 ??= new Dictionary(); - BindCore(section11, ref temp13, binderOptions); - obj.MyDictionary = temp13; + BindCore(section11, ref temp13, defaultValueIfNotFound: false, binderOptions); + instance.MyDictionary = temp13; } } + /// If required by the binder options, validates that there are no unknown keys in the input configuration object. public static void ValidateConfigurationKeys(Type type, Lazy> keys, IConfiguration configuration, BinderOptions? binderOptions) { @@ -219,7 +222,7 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration if (binderOptions.BindNonPublicProperties) { - throw new global::System.NotSupportedException($"The configuration binding source generator does not support 'BinderOptions.BindNonPublicProperties'."); + throw new NotSupportedException($"The configuration binding source generator does not support 'BinderOptions.BindNonPublicProperties'."); } return binderOptions; @@ -236,5 +239,6 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration throw new InvalidOperationException($"Failed to convert configuration value at '{getPath()}' to type '{typeof(int)}'.", exception); } } + #endregion Core binding extensions. } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/ConfigBindingGenTestDriver.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/ConfigBindingGenTestDriver.cs new file mode 100644 index 00000000000000..4373b404fc67f0 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/ConfigBindingGenTestDriver.cs @@ -0,0 +1,156 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Diagnostics; +using System.Globalization; +using System.Linq; +using System.Reflection; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.Extensions.Configuration.Binder.SourceGeneration; +using SourceGenerators.Tests; +using Xunit; + +namespace Microsoft.Extensions.SourceGeneration.Configuration.Binder.Tests +{ + [ActiveIssue("https://github.com/dotnet/runtime/issues/52062", TestPlatforms.Browser)] + public partial class ConfigurationBindingGeneratorTests : ConfigurationBinderTestsBase + { + internal sealed class ConfigBindingGenTestDriver + { + private readonly CSharpParseOptions _parseOptions; + private GeneratorDriver _generatorDriver; + private SourceGenerationSpec? _genSpec; + + private readonly LanguageVersion _langVersion; + private readonly IEnumerable? _assemblyReferences; + private Compilation _compilation = null; + + public ConfigBindingGenTestDriver( + LanguageVersion langVersion = LanguageVersion.LatestMajor, + IEnumerable? assemblyReferences = null) + { + _langVersion = langVersion; + + _assemblyReferences = assemblyReferences ?? s_compilationAssemblyRefs; + + _parseOptions = new CSharpParseOptions(langVersion).WithFeatures(new[] { + new KeyValuePair("InterceptorsPreview", "") , + new KeyValuePair("InterceptorsPreviewNamespaces", "Microsoft.Extensions.Configuration.Binder.SourceGeneration") + }); + + ConfigurationBindingGenerator generator = new() { OnSourceEmitting = spec => _genSpec = spec }; + _generatorDriver = CSharpGeneratorDriver.Create( + new ISourceGenerator[] { generator.AsSourceGenerator() }, + parseOptions: _parseOptions, + driverOptions: new GeneratorDriverOptions( + disabledOutputs: IncrementalGeneratorOutputKind.None, + trackIncrementalGeneratorSteps: true)); + } + + public async Task RunGeneratorAndUpdateCompilation(string? source = null) + { + await UpdateCompilationWithSource(source); + Assert.NotNull(_compilation); + + _generatorDriver = _generatorDriver.RunGeneratorsAndUpdateCompilation(_compilation, out Compilation outputCompilation, out _, CancellationToken.None); + GeneratorDriverRunResult runResult = _generatorDriver.GetRunResult(); + + return new ConfigBindingGenRunResult + { + OutputCompilation = outputCompilation, + Diagnostics = runResult.Diagnostics, + GeneratedSource = runResult.Results[0].GeneratedSources is { Length: not 0 } sources ? sources[0] : null, + TrackedSteps = runResult.Results[0].TrackedSteps[ConfigurationBindingGenerator.GenSpecTrackingName], + GenerationSpec = _genSpec + }; + } + + private async Task UpdateCompilationWithSource(string? source = null) + { + if (_compilation is not null && source is not null) + { + SyntaxTree newTree = CSharpSyntaxTree.ParseText(source, _parseOptions); + _compilation = _compilation.ReplaceSyntaxTree(_compilation.SyntaxTrees.First(), newTree); + } + else if (_compilation is null) + { + Assert.True(source is not null, "Generator test requires input source."); + using AdhocWorkspace workspace = RoslynTestUtils.CreateTestWorkspace(); + + Project project = RoslynTestUtils.CreateTestProject(workspace, _assemblyReferences, langVersion: _langVersion) + .WithCompilationOptions(new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary).WithNullableContextOptions(NullableContextOptions.Annotations)) + .WithParseOptions(_parseOptions) + .WithDocuments(new string[] { source }); + Assert.True(project.Solution.Workspace.TryApplyChanges(project.Solution)); + + _compilation = (await project.GetCompilationAsync(CancellationToken.None).ConfigureAwait(false))!; + } + } + } + } + + internal struct ConfigBindingGenRunResult + { + public required Compilation OutputCompilation { get; init; } + + public required GeneratedSourceResult? GeneratedSource { get; init; } + + /// + /// Diagnostics produced by the generator alone. Doesn't include any from other build participants. + /// + public required ImmutableArray Diagnostics { get; init; } + + public required ImmutableArray TrackedSteps { get; init; } + + public required SourceGenerationSpec? GenerationSpec { get; init; } + } + + internal enum ExpectedDiagnostics + { + None, + FromGeneratorOnly, + } + + internal static class ConfigBindingGenTestDriverExtensions + { + public static void ValidateIncrementalResult(this ConfigBindingGenRunResult result, + IncrementalStepRunReason inputReason, + IncrementalStepRunReason outputReason) + { + Assert.Collection(result.TrackedSteps, step => + { + Assert.Collection(step.Inputs, source => Assert.Equal(inputReason, source.Source.Outputs[source.OutputIndex].Reason)); + Assert.Collection(step.Outputs, output => Assert.Equal(outputReason, output.Reason)); + }); + } + + public static void ValidateDiagnostics(this ConfigBindingGenRunResult result, ExpectedDiagnostics expectedDiags) + { + ImmutableArray outputDiagnostics = result.OutputCompilation.GetDiagnostics(); + + if (expectedDiags is ExpectedDiagnostics.None) + { + foreach (Diagnostic diagnostic in outputDiagnostics) + { + Assert.True( + IsPermitted(diagnostic), + $"Generator caused dagnostic in output compilation: {diagnostic.GetMessage(CultureInfo.InvariantCulture)}."); + } + } + else + { + Debug.Assert(expectedDiags is ExpectedDiagnostics.FromGeneratorOnly); + + Assert.NotEmpty(result.Diagnostics); + Assert.False(outputDiagnostics.Any(diag => !IsPermitted(diag))); + } + + static bool IsPermitted(Diagnostic diagnostic) => diagnostic.Severity <= DiagnosticSeverity.Info; + } + } +} diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/ConfigurationBinderTests.Generator.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/ConfigurationBinderTests.Generator.cs new file mode 100644 index 00000000000000..5b6d824e87dbe9 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/ConfigurationBinderTests.Generator.cs @@ -0,0 +1,419 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using Microsoft.Extensions.Configuration; +using Xunit; + +namespace Microsoft.Extensions.SourceGeneration.Configuration.Binder.Tests +{ + public partial class ConfigurationBinderTests : ConfigurationBinderTestsBase + { + // These are regression tests for https://github.com/dotnet/runtime/issues/90851 + // Source Generator Interceptors rely on identifying an accurate invocation + // source location (line and character positions). These tests cover newline + // and whitespace scenarios to ensure the interceptors get wired up correctly. + + [Fact] + public void TestBindingInvocationsWithNewlines_GetMethodTypeArg() + { + IConfiguration configuration = TestHelpers.GetConfigurationFromJsonString(@"{""Longitude"":1,""Latitude"":2}"); + + // Newline between the configuration instance and the binding invocation (with the dot on the first line) + GeolocationRecord record1 = (GeolocationRecord)configuration. + Get(typeof(GeolocationRecord), _ => { }); + + AssertRecordIsBound(record1, 1, 2); + + // Newline between the configuration instance and the binding invocation (with the dot on the second line) + GeolocationRecord record2 = (GeolocationRecord)configuration + .Get(typeof(GeolocationRecord), _ => { }); + + AssertRecordIsBound(record2, 1, 2); + + // Newlines between the instance, the invocation, and the arguments + GeolocationRecord record3 = (GeolocationRecord)configuration + .Get( + typeof(GeolocationRecord), + _ => { } + ); + + AssertRecordIsBound(record3, 1, 2); + + // Newlines before and after the instance (with the dot on the first line) + GeolocationRecord record4 = (GeolocationRecord) + configuration. + Get(typeof(GeolocationRecord), _ => { }); + + AssertRecordIsBound(record4, 1, 2); + + // Newlines before and after the instance (with the dot on the second line) + GeolocationRecord record5 = (GeolocationRecord) + configuration + .Get(typeof(GeolocationRecord), _ => { }); + + AssertRecordIsBound(record5, 1, 2); + + // Newlines in every place possible + GeolocationRecord + record6 + = + ( + GeolocationRecord + ) + configuration + . + Get + ( + typeof + ( + GeolocationRecord + ) + , + _ + => + { + } + ) + ; + + AssertRecordIsBound(record6, 1, 2); + } + + [Fact] + public void TestBindingInvocationsWithNewlines_GetMethodGeneric() + { + IConfiguration configuration = TestHelpers.GetConfigurationFromJsonString(@"{""Longitude"":1,""Latitude"":2}"); + + // Newline between the invocation method name and the generic type argument + GeolocationRecord record1 = configuration.Get + (); + + AssertRecordIsBound(record1, 1, 2); + + // Newlines on either side of the generic type argument + GeolocationRecord record2 = configuration.Get< + GeolocationRecord + >(); + + AssertRecordIsBound(record2, 1, 2); + + // Newlines in every place possible + GeolocationRecord + record3 + = + configuration + . + Get + < + GeolocationRecord + > + () + ; + + AssertRecordIsBound(record3, 1, 2); + } + + [Fact] + public void TestBindingInvocationsWithNewlines_BindExtensionMethod() + { + // Newline between the configuration instance and the extension method invocation + GeolocationRecord record1 = new GeolocationRecord(); + TestHelpers.GetConfigurationFromJsonString(@"{""Longitude"":1,""Latitude"":2}") + .Bind(record1); + + AssertRecordIsBound(record1, 1, 2); + + // Newlines between the method that returns the instance and the extension method invocation + GeolocationRecord record2 = new GeolocationRecord(); + TestHelpers + .GetConfigurationFromJsonString(@"{""Longitude"":1,""Latitude"":2}") + .Bind(record2); + + AssertRecordIsBound(record2, 1, 2); + + // Newlines within the argument to the method returning the configuration and around the extension method argument + GeolocationRecord record3 = new GeolocationRecord(); + TestHelpers + .GetConfigurationFromJsonString(@"{""Longitude"":1, + ""Latitude"":2} + ") + .Bind( + record3 + ); + + AssertRecordIsBound(record3, 1, 2); + + // Newlines in every place possible + GeolocationRecord record4 = new GeolocationRecord(); + TestHelpers + . + GetConfigurationFromJsonString + ( + @"{""Longitude"":1, ""Latitude"":2}" + ) + . + Bind + ( + record4 + ) + ; + + AssertRecordIsBound(record4, 1, 2); + } + + [Fact] + public void TestBindingInvocationsWithNewlines_BindStaticMethod() + { + IConfiguration configuration = TestHelpers.GetConfigurationFromJsonString(@"{""Longitude"":1,""Latitude"":2}"); + + // Newline between the class and the static method invocation (with the dot on the first line) + GeolocationRecord record1 = new GeolocationRecord(); + ConfigurationBinder. + Bind(configuration, record1); + + // Newline between the class and the static method invocation (with the dot on the second line) + GeolocationRecord record2 = new GeolocationRecord(); + ConfigurationBinder + .Bind(configuration, record2); + + AssertRecordIsBound(record2, 1, 2); + + // Newline before the arguments + GeolocationRecord record3 = new GeolocationRecord(); + ConfigurationBinder.Bind( + configuration, record3); + + AssertRecordIsBound(record3, 1, 2); + + // Newlines in every place possible + GeolocationRecord record4 = new GeolocationRecord(); + ConfigurationBinder + . + Bind + ( + configuration + , + record4 + ) + ; + + AssertRecordIsBound(record4, 1, 2); + } + + [Fact] + public void TestBindingInvocationsWithNewlines_GetValueMethod() + { + IConfiguration configuration = TestHelpers.GetConfigurationFromJsonString(@"{""Longitude"":1,""Latitude"":2}"); + + // Newline between the configuration instance and the binding invocation (with the dot on the first line) + int lat1 = configuration. + GetValue("Latitude"); + + Assert.Equal(2, lat1); + + // Newline between the configuration instance and the binding invocation (with the dot on the second line) + int lat2 = configuration + .GetValue("Latitude"); + + Assert.Equal(2, lat2); + + // Newlines in every place possible + long + lat3 + = + configuration + . + GetValue + < + int + > + ( + "Latitude" + ) + ; + Assert.Equal(2, lat3); + + // Newlines and pragmas wrapped around the generic type argument + long lat4 = configuration.GetValue< +#if DEBUG + int +#else + long +#endif + >("Latitude"); + + Assert.Equal(2, lat4); + } + + private static void AssertRecordIsBound(GeolocationRecord record, int longitude, int latitude) + { + Assert.Equal((longitude, latitude), (record.Longitude, record.Latitude)); + } + + // These are regression tests for https://github.com/dotnet/runtime/issues/90976 + // Ensure that every emitted identifier name is unique, otherwise name clashes + // will occur and cause compilation to fail. + + [Fact] + public void NameClashTests_NamingPatternsThatCouldCauseClashes() + { + // Potential class between type with closed generic & non generic type. + // Both types start with same substring. The generic arg type's name is + // the same as the suffix of the non generic type's name. + // Manifested in https://github.com/dotnet/runtime/issues/90976. + + IConfiguration configuration = TestHelpers.GetConfigurationFromJsonString(@"{""Value"":1}"); + + var c1 = new Cint(); + var c2 = new C(); + + configuration.Bind(c1); + configuration.Bind(c2); + Assert.Equal(1, c1.Value); + Assert.Equal(1, c2.Value); + } + + internal class C + { + public int Value { get; set; } + } + + internal class Cint + { + public int Value { get; set; } + } + + [Fact] + public void NameClashTests_SameTypeName() + { + // Both types have the same name, but one is a nested type. + + IConfiguration configuration = TestHelpers.GetConfigurationFromJsonString(@"{""Value"":1}"); + + var c1 = new ClassWithThisIdentifier(); + var c2 = new ClassWithThisIdentifier_Wrapper.ClassWithThisIdentifier(); + + configuration.Bind(c1); + configuration.Bind(c2); + Assert.Equal(1, c1.Value); + Assert.Equal(1, c2.Value); + } + + internal class ClassWithThisIdentifier + { + public int Value { get; set; } + } + + internal class ClassWithThisIdentifier_Wrapper + { + internal class ClassWithThisIdentifier + { + public int Value { get; set; } + } + } + + /// + /// These are regression tests for https://github.com/dotnet/runtime/issues/90909. + /// Ensure that we don't emit root interceptors to handle types/members that + /// are inaccessible to the generated helpers. Tests for inaccessible transitive members + /// are covered in the shared (reflection/src-gen) , + /// e.g. . + /// + /// + /// In these cases, binding calls will fallback to reflection, as with all cases where + /// we can't correctly resolve the type, such as generic call patterns and boxed objects. + /// + [Fact] + public void MemberAccessibility_InaccessibleNestedTypeAsRootConfig() + { + IConfiguration configuration = TestHelpers.GetConfigurationFromJsonString(@"{""Value"":1}"); + + // Ensure no compilation errors; types are skipped. + +#pragma warning disable SYSLIB1104 // Binding logic was not generated for a binder call. + var c1 = new InaccessibleClass_1(); + configuration.Bind(c1); + var c2 = configuration.Get(); + var c3 = configuration.Get(); + + // Generic collections. + + configuration = TestHelpers + .GetConfigurationFromJsonString(@"{""Array"": [{""Value"":1}]}") + .GetSection("Array"); + var c4 = configuration.Get(); + var c5 = configuration.Get>(); + + // Generic types. + + Action? configureOptions = options => options.BindNonPublicProperties = true; + string GetNestedObjectPayload(string propName) => $$""" + { + "{{propName}}": { + "Value": 1 + } + } + """; + + configuration = TestHelpers.GetConfigurationFromJsonString(GetNestedObjectPayload("item1")); + var c6 = configuration.Get>(configureOptions); + + configuration = TestHelpers.GetConfigurationFromJsonString(GetNestedObjectPayload("protectedMember")); + var c7 = configuration.Get>(configureOptions); + var c8 = configuration.Get>(configureOptions); + + configuration = TestHelpers.GetConfigurationFromJsonString(GetNestedObjectPayload("publicMember")); + var c9 = configuration.Get>(configureOptions); + var c10 = configuration.Get>(configureOptions); +#pragma warning disable SYSLIB1104 + + // Reflection fallback. + + Assert.Equal(1, c1.Value); + Assert.Equal(1, c2.Value); + Assert.Equal(1, c3.Value); + + Assert.Equal(1, c4[0].Value); + Assert.Equal(1, c5[0].Value); + Assert.Equal(1, c6["item1"].Value); + + Assert.Equal(1, c7.GetProtectedMember.Value); + Assert.Equal(1, c8.GetProtectedMember.Value); + Assert.Equal(1, c9.PublicMember.Value); + Assert.Equal(1, c10.PublicMember.Value); + } + + private class InaccessibleClass_1() + { + public int Value { get; set; } + } + + protected record InaccessibleClass_2(int Value); + + protected internal class InaccessibleClass_3 + { + public InaccessibleClass_3(int value) => Value = value; + + public int Value { get; } + } + + internal class AccessibleGenericClass + { + protected T ProtectedMember { get; set; } + + public T GetProtectedMember => ProtectedMember; + } + + private class InaccessibleGenericClass + { + public T PublicMember { get; set; } + } + + public class AccessibleClass() + { + public int Value { get; set; } + } + } +} diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/ConfigurationBindingGeneratorTests.Baselines.Options.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/GeneratorTests.Baselines.Options.cs similarity index 55% rename from src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/ConfigurationBindingGeneratorTests.Baselines.Options.cs rename to src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/GeneratorTests.Baselines.Options.cs index 8807dfb0962206..4480ab4066882c 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/ConfigurationBindingGeneratorTests.Baselines.Options.cs +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/GeneratorTests.Baselines.Options.cs @@ -4,10 +4,11 @@ using System.Threading.Tasks; using Xunit; -namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration.Tests +namespace Microsoft.Extensions.SourceGeneration.Configuration.Binder.Tests { public partial class ConfigurationBindingGeneratorTests { + #region IServiceCollection extensions. private string GetConfigureSource(string paramList) => $$""" using System.Collections.Generic; using Microsoft.Extensions.Configuration; @@ -19,7 +20,7 @@ public static void Main() { ConfigurationBuilder configurationBuilder = new(); IConfiguration config = configurationBuilder.Build(); - IConfigurationSection section = config.GetSection(""MySection""); + IConfigurationSection section = config.GetSection("MySection"); ServiceCollection services = new(); services.Configure({{paramList}}); @@ -40,7 +41,9 @@ public class MyClass2 } } """; + #endregion IServiceCollection extensions. + #region OptionsBuilder extensions. [Fact] public async Task Configure_T() => await VerifyAgainstBaselineUsingFile("Configure_T.generated.txt", GetConfigureSource("section"), extType: ExtensionClassType.ServiceCollection); @@ -57,6 +60,81 @@ public async Task Configure_T_BinderOptions() => public async Task Configure_T_name_BinderOptions() => await VerifyAgainstBaselineUsingFile("Configure_T_name_BinderOptions.generated.txt", GetConfigureSource(@""""", section, _ => { }"), extType: ExtensionClassType.ServiceCollection); + [Theory] + [InlineData("OptionsConfigurationServiceCollectionExtensions.Configure(config: section, services: services);")] + [InlineData("""OptionsConfigurationServiceCollectionExtensions.Configure(name: "", config: section, services: services);""")] + [InlineData("OptionsConfigurationServiceCollectionExtensions.Configure(configureBinder: _ => { }, config: section, services: services);")] + [InlineData("""OptionsConfigurationServiceCollectionExtensions.Configure(configureBinder: _ => { }, config: section, name: "", services: services);""")] + [InlineData("""OptionsConfigurationServiceCollectionExtensions.Configure(name: "", services: services, configureBinder: _ => { }, config: section);""")] + public async Task Configure_T_NamedParameters_OutOfOrder(string row) + { + string source = $$""" + using System.Collections.Generic; + using Microsoft.Extensions.Configuration; + using Microsoft.Extensions.DependencyInjection; + + public class Program + { + public static void Main() + { + ConfigurationBuilder configurationBuilder = new(); + IConfiguration config = configurationBuilder.Build(); + IConfigurationSection section = config.GetSection("MySection"); + ServiceCollection services = new(); + + {{row}} + } + + public class MyClass + { + public string MyString { get; set; } + public int MyInt { get; set; } + public List MyList { get; set; } + public Dictionary MyDictionary { get; set; } + } + } + """; + + await VerifyThatSourceIsGenerated(source); + } + + [Theory] + [InlineData("OptionsBuilderConfigurationExtensions.Bind(config: config, optionsBuilder: optionsBuilder);")] + [InlineData("OptionsBuilderConfigurationExtensions.Bind(configureBinder: _ => { }, config: config, optionsBuilder: optionsBuilder);")] + [InlineData("OptionsBuilderConfigurationExtensions.Bind(config: config, configureBinder: _ => { }, optionsBuilder: optionsBuilder);")] + public async Task Bind_T_NamedParameters_OutOfOrder(string row) + { + string source = $$""" + using System.Collections.Generic; + using Microsoft.Extensions.Configuration; + using Microsoft.Extensions.DependencyInjection; + using Microsoft.Extensions.Options; + + public class Program + { + public static void Main() + { + ConfigurationBuilder configurationBuilder = new(); + IConfiguration config = configurationBuilder.Build(); + var services = new ServiceCollection(); + OptionsBuilder optionsBuilder = new(services, ""); + + {{row}} + } + + public class MyClass + { + public string MyString { get; set; } + public int MyInt { get; set; } + public List MyList { get; set; } + public Dictionary MyDictionary { get; set; } + } + } + """; + + await VerifyThatSourceIsGenerated(source); + } + private string GetBindSource(string? configureActions = null) => $$""" using System.Collections.Generic; using Microsoft.Extensions.Configuration; @@ -126,5 +204,6 @@ public class MyClass await VerifyAgainstBaselineUsingFile("BindConfiguration.generated.txt", GetSource(), extType: ExtensionClassType.OptionsBuilder); await VerifyAgainstBaselineUsingFile("BindConfiguration.generated.txt", GetSource(@", _ => { }"), extType: ExtensionClassType.OptionsBuilder); } + #endregion OptionsBuilder extensions. } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/ConfigurationBindingGeneratorTests.Baselines.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/GeneratorTests.Baselines.cs similarity index 67% rename from src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/ConfigurationBindingGeneratorTests.Baselines.cs rename to src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/GeneratorTests.Baselines.cs index aba2a9f6184f2c..e05a7737137128 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/ConfigurationBindingGeneratorTests.Baselines.cs +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/GeneratorTests.Baselines.cs @@ -2,51 +2,125 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; -using System.Globalization; +using System.Collections.Immutable; using System.Linq; using System.Threading.Tasks; using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp; using Xunit; -namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration.Tests +namespace Microsoft.Extensions.SourceGeneration.Configuration.Binder.Tests { public partial class ConfigurationBindingGeneratorTests { - private const string BindCallSampleCode = @" - using System.Collections.Generic; - using Microsoft.Extensions.Configuration; + [Fact] + public async Task Bind() => + await VerifyAgainstBaselineUsingFile("Bind.generated.txt", BindCallSampleCode, extType: ExtensionClassType.ConfigurationBinder); - public class Program + [Theory] + [InlineData("ConfigurationBinder.Bind(instance: configObj, configuration: config);")] + [InlineData("""ConfigurationBinder.Bind(key: "", instance: configObj, configuration: config);""")] + [InlineData("""ConfigurationBinder.Bind(instance: configObj, key: "", configuration: config);""")] + [InlineData("ConfigurationBinder.Bind(configureOptions: _ => { }, configuration: config, instance: configObj);")] + [InlineData("ConfigurationBinder.Bind(configuration: config, configureOptions: _ => { }, instance: configObj);")] + public async Task Bind_NamedParameters_OutOfOrder(string row) { - public static void Main() - { - ConfigurationBuilder configurationBuilder = new(); - IConfigurationRoot config = configurationBuilder.Build(); + string source = $$""" + using System.Collections.Generic; + using Microsoft.Extensions.Configuration; - MyClass configObj = new(); - config.Bind(configObj); - config.Bind(configObj, options => { }); - config.Bind(""key"", configObj); - } + public class Program + { + public static void Main() + { + ConfigurationBuilder configurationBuilder = new(); + IConfigurationRoot config = configurationBuilder.Build(); - public class MyClass - { - public string MyString { get; set; } - public int MyInt { get; set; } - public List MyList { get; set; } - public Dictionary MyDictionary { get; set; } - public Dictionary MyComplexDictionary { get; set; } - } + MyClass configObj = new(); + {{row}} + } - public class MyClass2 { } - }"; + public class MyClass + { + public string MyString { get; set; } + public int MyInt { get; set; } + public List MyList { get; set; } + public Dictionary MyDictionary { get; set; } + } + } + """; + + await VerifyThatSourceIsGenerated(source); + } + + [Theory] + [InlineData("var obj = ConfigurationBinder.Get(type: typeof(MyClass), configuration: config);")] + [InlineData("var obj = ConfigurationBinder.Get(configureOptions: _ => { }, configuration: config);")] + [InlineData("var obj = ConfigurationBinder.Get(configureOptions: _ => { }, type: typeof(MyClass), configuration: config);")] + [InlineData("var obj = ConfigurationBinder.Get(type: typeof(MyClass), configureOptions: _ => { }, configuration: config);")] + public async Task Get_TypeOf_NamedParametersOutOfOrder(string row) + { + string source = $$""" + using System.Collections.Generic; + using Microsoft.Extensions.Configuration; + + public class Program + { + public static void Main() + { + ConfigurationBuilder configurationBuilder = new(); + IConfigurationRoot config = configurationBuilder.Build(); + + MyClass configObj = new(); + {{row}} + } + + public class MyClass + { + public string MyString { get; set; } + public int MyInt { get; set; } + public List MyList { get; set; } + public Dictionary MyDictionary { get; set; } + } + } + """; + + await VerifyThatSourceIsGenerated(source); + } [Theory] - [InlineData(LanguageVersion.Preview)] - [InlineData(LanguageVersion.CSharp11)] - public async Task Bind(LanguageVersion langVersion) => - await VerifyAgainstBaselineUsingFile("Bind.generated.txt", BindCallSampleCode, langVersion, extType: ExtensionClassType.ConfigurationBinder); + [InlineData("""var str = ConfigurationBinder.GetValue(key: "key", configuration: config, type: typeof(string));""")] + [InlineData("""var str = ConfigurationBinder.GetValue(key: "key", configuration: config);""")] + [InlineData("""var str = ConfigurationBinder.GetValue(key: "key", defaultValue: "default", configuration: config);""")] + [InlineData("""var str = ConfigurationBinder.GetValue(configuration: config, key: "key", defaultValue: "default");""")] + [InlineData("""var str = ConfigurationBinder.GetValue(defaultValue: "default", key: "key", configuration: config, type: typeof(string));""")] + [InlineData("""var str = ConfigurationBinder.GetValue(defaultValue: "default", type: typeof(string), key: "key", configuration: config);""")] + public async Task GetValue_NamedParametersOutOfOrder(string row) + { + string source = $$""" + using System.Collections.Generic; + using Microsoft.Extensions.Configuration; + + public class Program + { + public static void Main() + { + ConfigurationBuilder configurationBuilder = new(); + IConfigurationRoot config = configurationBuilder.Build(); + {{row}} + } + + public class MyClass + { + public string MyString { get; set; } + public int MyInt { get; set; } + public List MyList { get; set; } + public Dictionary MyDictionary { get; set; } + } + } + """; + + await VerifyThatSourceIsGenerated(source); + } [Fact] public async Task Bind_Instance() @@ -68,7 +142,7 @@ public static void Main() public class MyClass { - public string MyString { get; set; } + public string? MyString { get; set; } public int MyInt { get; set; } public List MyList { get; set; } public Dictionary MyDictionary { get; set; } @@ -150,6 +224,49 @@ public class MyClass2 { } await VerifyAgainstBaselineUsingFile("Bind_Key_Instance.generated.txt", source, extType: ExtensionClassType.ConfigurationBinder); } + [Fact] + public async Task Bind_CanParseTargetConfigType_FromMethodParam() + { + string source = """ + using System; + using Microsoft.Extensions.Configuration; + + public class Program + { + public static void Main() + { + ConfigurationBuilder configurationBuilder = new(); + IConfiguration config = configurationBuilder.Build(); + + BindOptions(config, new MyClass0()); + BindOptions(config, new MyClass1(), _ => { }); + BindOptions(config, "", new MyClass2()); + } + + private static void BindOptions(IConfiguration config, MyClass0 instance) + { + config.Bind(instance); + } + + private static void BindOptions(IConfiguration config, MyClass1 instance, Action? configureOptions) + { + config.Bind(instance, configureOptions); + } + + private static void BindOptions(IConfiguration config, string path, MyClass2 instance) + { + config.Bind(path, instance); + } + + public class MyClass0 { } + public class MyClass1 { } + public class MyClass2 { } + } + """; + + await VerifyAgainstBaselineUsingFile("Bind_ParseTypeFromMethodParam.generated.txt", source, extType: ExtensionClassType.ConfigurationBinder); + } + [Fact] public async Task Get() { @@ -165,9 +282,9 @@ public static void Main() IConfigurationRoot config = configurationBuilder.Build(); MyClass configObj = config.Get(); - configObj = config.Get(typeof(MyClass2)); + MyClass2 configObj2 = (MyClass2)config.Get(typeof(MyClass2)); configObj = config.Get(binderOptions => { }); - configObj = config.Get(typeof(MyClass2), binderOptions => { }); + configObj2 = (MyClass2)config.Get(typeof(MyClass2), binderOptions => { }); } public class MyClass @@ -198,6 +315,30 @@ public class MyClass4 await VerifyAgainstBaselineUsingFile("Get.generated.txt", source, extType: ExtensionClassType.ConfigurationBinder); } + [Fact] + public async Task Get_PrimitivesOnly() + { + string source = """ + using Microsoft.Extensions.Configuration; + + public class Program + { + public static void Main() + { + ConfigurationBuilder configurationBuilder = new(); + IConfigurationRoot config = configurationBuilder.Build(); + + config.Get(); + config.Get(typeof(string)); + config.Get(binderOptions => { }); + config.Get(typeof(double), binderOptions => { }); + } + } + """; + + await VerifyAgainstBaselineUsingFile("Get_PrimitivesOnly.generated.txt", source, extType: ExtensionClassType.ConfigurationBinder); + } + [Fact] public async Task Get_T() { @@ -304,7 +445,7 @@ public static void Main() ConfigurationBuilder configurationBuilder = new(); IConfigurationRoot config = configurationBuilder.Build(); - MyClass configObj = config.Get(typeof(MyClass2)); + MyClass2 configObj = (MyClass2)config.Get(typeof(MyClass2)); } public class MyClass @@ -538,9 +679,9 @@ public class MyClass2 }" ; - var (d, r) = await RunGenerator(source); - Assert.Empty(r); - Assert.Empty(d); + ConfigBindingGenRunResult result = await RunGeneratorAndUpdateCompilation(source); + Assert.False(result.GeneratedSource.HasValue); + Assert.Empty(result.Diagnostics); } [Fact] @@ -592,9 +733,9 @@ public class MyClass public UInt128 Prop12 { get; set; } public DateOnly Prop18 { get; set; } public TimeOnly Prop22 { get; set; } - public byte[] Prop22 { get; set; } - public int Prop23 { get; set; } - public DateTime Prop24 { get; set; } + public byte[] Prop28 { get; set; } + public int Prop29 { get; set; } + public DateTime Prop30 { get; set; } } } """; @@ -615,11 +756,12 @@ public static void Main() { ConfigurationBuilder configurationBuilder = new(); IConfiguration config = configurationBuilder.Build(); - IConfigurationSection section = config.GetSection(""MySection""); + IConfigurationSection section = config.GetSection("MySection"); section.Get(); } + // Diagnostic warning because we don't know how to instantiate two properties on this type. public class MyClassWithCustomCollections { public CustomDictionary CustomDictionary { get; set; } @@ -627,6 +769,7 @@ public class MyClassWithCustomCollections public ICustomDictionary ICustomDictionary { get; set; } public ICustomSet ICustomCollection { get; set; } public IReadOnlyList IReadOnlyList { get; set; } + // Diagnostic warning because we don't know how to instantiate the property type. public IReadOnlyDictionary UnsupportedIReadOnlyDictionaryUnsupported { get; set; } public IReadOnlyDictionary IReadOnlyDictionary { get; set; } } @@ -639,22 +782,73 @@ public class CustomList : List { } + // Diagnostic warning because we don't know how to instantiate this type. public interface ICustomDictionary : IDictionary { } + // Diagnostic warning because we don't know how to instantiate this type. public interface ICustomSet : ISet { } } """; - await VerifyAgainstBaselineUsingFile("Collections.generated.txt", source, assessDiagnostics: (d) => - { - Console.WriteLine((d.Where(diag => diag.Id == Diagnostics.TypeNotSupported.Id).Count() , d.Where(diag => diag.Id == Diagnostics.PropertyNotSupported.Id).Count())); - Assert.Equal(3, d.Where(diag => diag.Id == Diagnostics.TypeNotSupported.Id).Count()); - Assert.Equal(6, d.Where(diag => diag.Id == Diagnostics.PropertyNotSupported.Id).Count()); - }); + ConfigBindingGenRunResult result = await VerifyAgainstBaselineUsingFile( + "Collections.generated.txt", + source, + expectedDiags: ExpectedDiagnostics.FromGeneratorOnly); + + ImmutableArray diagnostics = result.Diagnostics; + Assert.Equal(3, diagnostics.Where(diag => diag.Id == Diagnostics.TypeNotSupported.Id).Count()); + Assert.Equal(3, diagnostics.Where(diag => diag.Id == Diagnostics.PropertyNotSupported.Id).Count()); + } + + [Fact] + public async Task MinimalGenerationIfNoBindableMembers() + { + string source = """ + using System.Collections.Generic; + using Microsoft.Extensions.Configuration; + + public class Program + { + public static void Main() + { + ConfigurationBuilder configurationBuilder = new(); + IConfiguration configuration = configurationBuilder.Build(); + + TypeWithNoMembers obj = new(); + configuration.Bind(obj); + + TypeWithNoMembers_Wrapper obj2 = new(); + configuration.Bind(obj2); + + List obj3 = new(); + configuration.Bind(obj3); + } + } + + public class TypeWithNoMembers + { + } + + public class TypeWithNoMembers_Wrapper + { + public TypeWithNoMembers Member { get; set; } + } + + public abstract class AbstractType_CannotInit + { + } + """; + + ConfigBindingGenRunResult result = await VerifyAgainstBaselineUsingFile( + "EmptyConfigType.generated.txt", + source, + expectedDiags: ExpectedDiagnostics.FromGeneratorOnly); + + Assert.Equal(2, result.Diagnostics.Where(diag => diag.Id == Diagnostics.TypeNotSupported.Id).Count()); } } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/GeneratorTests.Helpers.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/GeneratorTests.Helpers.cs new file mode 100644 index 00000000000000..cbbd34e7fc41da --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/GeneratorTests.Helpers.cs @@ -0,0 +1,170 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Globalization; +using System.IO; +using System.Linq; +using System.Reflection; +using System.Threading.Tasks; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.Text; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.Configuration.Binder.SourceGeneration; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; +using SourceGenerators.Tests; +using Xunit; + +namespace Microsoft.Extensions.SourceGeneration.Configuration.Binder.Tests +{ + public partial class ConfigurationBindingGeneratorTests + { + /// + /// Keep in sync with variants, e.g. . + /// + private const string BindCallSampleCode = """ + using System.Collections.Generic; + using Microsoft.Extensions.Configuration; + + public class Program + { + public static void Main() + { + ConfigurationBuilder configurationBuilder = new(); + IConfigurationRoot config = configurationBuilder.Build(); + + MyClass configObj = new(); + config.Bind(configObj); + config.Bind(configObj, options => { }); + config.Bind("key", configObj); + } + + public class MyClass + { + public string MyString { get; set; } + public int MyInt { get; set; } + public List MyList { get; set; } + public Dictionary MyDictionary { get; set; } + public Dictionary MyComplexDictionary { get; set; } + } + + public class MyClass2 { } + } + """; + + private static class Diagnostics + { + public static (string Id, string Title) TypeNotSupported = ("SYSLIB1100", "Did not generate binding logic for a type"); + public static (string Id, string Title) PropertyNotSupported = ("SYSLIB1101", "Did not generate binding logic for a property on a type"); + public static (string Id, string Title) ValueTypesInvalidForBind = ("SYSLIB1103", "Value types are invalid inputs to configuration 'Bind' methods"); + public static (string Id, string Title) CouldNotDetermineTypeInfo = ("SYSLIB1104", "The target type for a binder call could not be determined"); + } + + private static readonly Assembly[] s_compilationAssemblyRefs = new[] { + typeof(BitArray).Assembly, + typeof(ConfigurationBinder).Assembly, + typeof(ConfigurationBuilder).Assembly, + typeof(CultureInfo).Assembly, + typeof(Dictionary<,>).Assembly, + typeof(Enumerable).Assembly, + typeof(IConfiguration).Assembly, + typeof(IServiceCollection).Assembly, + typeof(IServiceProvider).Assembly, + typeof(IDictionary).Assembly, + typeof(OptionsBuilder<>).Assembly, + typeof(OptionsConfigurationServiceCollectionExtensions).Assembly, + typeof(Uri).Assembly, + }; + + private enum ExtensionClassType + { + None, + ConfigurationBinder, + OptionsBuilder, + ServiceCollection, + } + + private static async Task VerifyThatSourceIsGenerated(string testSourceCode) + { + ConfigBindingGenRunResult result = await RunGeneratorAndUpdateCompilation(testSourceCode); + GeneratedSourceResult? source = result.GeneratedSource; + + Assert.NotNull(source); + Assert.Empty(result.Diagnostics); + Assert.True(source.Value.SourceText.Lines.Count > 10); + } + + private static async Task VerifyAgainstBaselineUsingFile( + string filename, + string testSourceCode, + ExtensionClassType extType = ExtensionClassType.None, + ExpectedDiagnostics expectedDiags = ExpectedDiagnostics.None) + { + string path = extType is ExtensionClassType.None + ? Path.Combine("Baselines", filename) + : Path.Combine("Baselines", extType.ToString(), filename); + string baseline = LineEndingsHelper.Normalize(await File.ReadAllTextAsync(path).ConfigureAwait(false)); + string[] expectedLines = baseline.Replace("%VERSION%", typeof(ConfigurationBindingGenerator).Assembly.GetName().Version?.ToString()) + .Split(Environment.NewLine); + + ConfigBindingGenRunResult result = await RunGeneratorAndUpdateCompilation(testSourceCode); + result.ValidateDiagnostics(expectedDiags); + + SourceText resultSourceText = result.GeneratedSource.Value.SourceText; + bool resultEqualsBaseline = RoslynTestUtils.CompareLines(expectedLines, resultSourceText, out string errorMessage); + +#if UPDATE_BASELINES + if (!resultEqualsBaseline) + { + const string envVarName = "RepoRootDir"; + string errMessage = $"To update baselines, specify a '{envVarName}' environment variable. See this assembly's README.md doc for more details."; + + string? repoRootDir = Environment.GetEnvironmentVariable(envVarName); + Assert.True(repoRootDir is not null, errMessage); + + IEnumerable lines = resultSourceText.Lines.Select(l => l.ToString()); + string source = string.Join(Environment.NewLine, lines).TrimEnd(Environment.NewLine.ToCharArray()) + Environment.NewLine; + path = Path.Combine($"{repoRootDir}\\src\\libraries\\Microsoft.Extensions.Configuration.Binder\\tests\\SourceGenerationTests\\", path); + + await File.WriteAllTextAsync(path, source).ConfigureAwait(false); + resultEqualsBaseline = true; + } +#endif + + Assert.True(resultEqualsBaseline, errorMessage); + + return result; + } + + private static async Task RunGeneratorAndUpdateCompilation( + string source, + LanguageVersion langVersion = LanguageVersion.CSharp12, + IEnumerable? assemblyReferences = null) + { + ConfigBindingGenTestDriver driver = new ConfigBindingGenTestDriver(langVersion, assemblyReferences); + return await driver.RunGeneratorAndUpdateCompilation(source); + } + + private static List GetAssemblyRefsWithAdditional(params Type[] additional) + { + List assemblies = new(s_compilationAssemblyRefs); + assemblies.AddRange(additional.Select(t => t.Assembly)); + return assemblies; + } + + private static HashSet GetFilteredAssemblyRefs(IEnumerable exclusions) + { + HashSet assemblies = new(s_compilationAssemblyRefs); + foreach (Type exclusion in exclusions) + { + assemblies.Remove(exclusion.Assembly); + } + return assemblies; + } + } +} diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/GeneratorTests.Incremental.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/GeneratorTests.Incremental.cs new file mode 100644 index 00000000000000..aff9a0c20364ca --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/GeneratorTests.Incremental.cs @@ -0,0 +1,362 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Threading.Tasks; +using Microsoft.CodeAnalysis; +using Microsoft.Extensions.Configuration.Binder.SourceGeneration; +using SourceGenerators.Tests; +using Xunit; + +namespace Microsoft.Extensions.SourceGeneration.Configuration.Binder.Tests +{ + public partial class ConfigurationBindingGeneratorTests : ConfigurationBinderTestsBase + { + [ActiveIssue("https://github.com/dotnet/runtime/issues/52062", TestPlatforms.Browser)] + public sealed class IncrementalTests + { + [Fact] + public async Task CompilingTheSameSourceResultsInEqualModels() + { + SourceGenerationSpec spec1 = (await new ConfigBindingGenTestDriver().RunGeneratorAndUpdateCompilation(BindCallSampleCode)).GenerationSpec; + SourceGenerationSpec spec2 = (await new ConfigBindingGenTestDriver().RunGeneratorAndUpdateCompilation(BindCallSampleCode)).GenerationSpec; + + Assert.NotSame(spec1, spec2); + GeneratorTestHelpers.AssertStructurallyEqual(spec1, spec2); + + Assert.Equal(spec1, spec2); + Assert.Equal(spec1.GetHashCode(), spec2.GetHashCode()); + } + + [Fact] + public async Task RunWithNoDiags_Then_NoEdit() + { + ConfigBindingGenTestDriver driver = new ConfigBindingGenTestDriver(); + + ConfigBindingGenRunResult result = await driver.RunGeneratorAndUpdateCompilation(BindCallSampleCode); + result.ValidateIncrementalResult(IncrementalStepRunReason.New, IncrementalStepRunReason.New); + + result = await driver.RunGeneratorAndUpdateCompilation(); + result.ValidateIncrementalResult(IncrementalStepRunReason.Modified, IncrementalStepRunReason.Unchanged); + } + + [Fact] + public async Task RunWithNoDiags_Then_ChangeInputOrder() + { + ConfigBindingGenTestDriver driver = new ConfigBindingGenTestDriver(); + + ConfigBindingGenRunResult result = await driver.RunGeneratorAndUpdateCompilation(BindCallSampleCode); + result.ValidateIncrementalResult(IncrementalStepRunReason.New, IncrementalStepRunReason.New); + + // We expect different spec because diag locations are different. + result = await driver.RunGeneratorAndUpdateCompilation(BindCallSampleCodeVariant_ReorderedInvocations); + result.ValidateIncrementalResult(IncrementalStepRunReason.Modified, IncrementalStepRunReason.Modified); + + // We expect different spec because members are reordered. + result = await driver.RunGeneratorAndUpdateCompilation(BindCallSampleCodeVariant_ReorderedConfigTypeMembers); + result.ValidateIncrementalResult(IncrementalStepRunReason.Modified, IncrementalStepRunReason.Modified); + } + + [Fact] + public async Task RunWithNoDiags_Then_EditWithNoDiags() + { + ConfigBindingGenTestDriver driver = new ConfigBindingGenTestDriver(); + + ConfigBindingGenRunResult result = await driver.RunGeneratorAndUpdateCompilation(BindCallSampleCode); + result.ValidateIncrementalResult(IncrementalStepRunReason.New, IncrementalStepRunReason.New); + + result = await driver.RunGeneratorAndUpdateCompilation(BindCallSampleCodeVariant_WithDifferentConfigTypeName); + result.ValidateIncrementalResult(IncrementalStepRunReason.Modified, IncrementalStepRunReason.Modified); + } + + [Fact] + public async Task RunWithNoDiags_Then_EditWithDiags() + { + ConfigBindingGenTestDriver driver = new ConfigBindingGenTestDriver(); + + ConfigBindingGenRunResult result = await driver.RunGeneratorAndUpdateCompilation(BindCallSampleCode); + result.ValidateIncrementalResult(IncrementalStepRunReason.New, IncrementalStepRunReason.New); + + result = await driver.RunGeneratorAndUpdateCompilation(BindCallSampleCodeVariant_WithUnsupportedMember); + result.ValidateIncrementalResult(IncrementalStepRunReason.Modified, IncrementalStepRunReason.Modified); + } + + [Fact] + public async Task RunWithDiags_Then_NoEdit() + { + ConfigBindingGenTestDriver driver = new ConfigBindingGenTestDriver(); + + ConfigBindingGenRunResult result = await driver.RunGeneratorAndUpdateCompilation(BindCallSampleCodeVariant_WithUnsupportedMember); + result.ValidateIncrementalResult(IncrementalStepRunReason.New, IncrementalStepRunReason.New); + + result = await driver.RunGeneratorAndUpdateCompilation(); + result.ValidateIncrementalResult(IncrementalStepRunReason.Modified, IncrementalStepRunReason.Unchanged); + } + + [Fact] + public async Task RunWithDiags_Then_ChangeInputOrder() + { + ConfigBindingGenTestDriver driver = new ConfigBindingGenTestDriver(); + + ConfigBindingGenRunResult result = await driver.RunGeneratorAndUpdateCompilation(BindCallSampleCodeVariant_WithUnsupportedMember); + result.ValidateIncrementalResult(IncrementalStepRunReason.New, IncrementalStepRunReason.New); + + // We expect different spec because diag locations are different. + result = await driver.RunGeneratorAndUpdateCompilation(BindCallSampleCodeVariant_WithUnsupportedMember_ReorderedInvocations); + result.ValidateIncrementalResult(IncrementalStepRunReason.Modified, IncrementalStepRunReason.Modified); + + // We expect different spec because members are reordered. + result = await driver.RunGeneratorAndUpdateCompilation(BindCallSampleCodeVariant_WithUnsupportedMember_ReorderedConfigTypeMembers); + result.ValidateIncrementalResult(IncrementalStepRunReason.Modified, IncrementalStepRunReason.Modified); + } + + [Fact] + public async Task RunWithDiags_Then_EditWithNoDiags() + { + ConfigBindingGenTestDriver driver = new ConfigBindingGenTestDriver(); + + ConfigBindingGenRunResult result = await driver.RunGeneratorAndUpdateCompilation(BindCallSampleCodeVariant_WithUnsupportedMember); + result.ValidateIncrementalResult(IncrementalStepRunReason.New, IncrementalStepRunReason.New); + + result = await driver.RunGeneratorAndUpdateCompilation(BindCallSampleCode); + result.ValidateIncrementalResult(IncrementalStepRunReason.Modified, IncrementalStepRunReason.Modified); + } + + [Fact] + public async Task RunWithDiags_Then_EditWithDiags() + { + ConfigBindingGenTestDriver driver = new ConfigBindingGenTestDriver(); + + ConfigBindingGenRunResult result = await driver.RunGeneratorAndUpdateCompilation(BindCallSampleCodeVariant_WithUnsupportedMember); + result.ValidateIncrementalResult(IncrementalStepRunReason.New, IncrementalStepRunReason.New); + + result = await driver.RunGeneratorAndUpdateCompilation(BindCallSampleCodeVariant_WithUnsupportedMember_WithDiffMemberName); + result.ValidateIncrementalResult(IncrementalStepRunReason.Modified, IncrementalStepRunReason.Modified); + } + } + + #region Incremental test sources. + /// + /// Keep in sync with . + /// + private const string BindCallSampleCodeVariant_ReorderedInvocations = """ + using System.Collections.Generic; + using Microsoft.Extensions.Configuration; + + public class Program + { + public static void Main() + { + ConfigurationBuilder configurationBuilder = new(); + IConfigurationRoot config = configurationBuilder.Build(); + + MyClass configObj = new(); + config.Bind(configObj, options => { }); + config.Bind("key", configObj); + config.Bind(configObj); + } + + public class MyClass + { + public string MyString { get; set; } + public int MyInt { get; set; } + public List MyList { get; set; } + public Dictionary MyDictionary { get; set; } + public Dictionary MyComplexDictionary { get; set; } + } + + public class MyClass2 { } + } + """; + + /// + /// Keep in sync with . + /// + private const string BindCallSampleCodeVariant_ReorderedConfigTypeMembers = """ + using System.Collections.Generic; + using Microsoft.Extensions.Configuration; + + public class Program + { + public static void Main() + { + ConfigurationBuilder configurationBuilder = new(); + IConfigurationRoot config = configurationBuilder.Build(); + + MyClass configObj = new(); + config.Bind(configObj, options => { }); + config.Bind("key", configObj); + config.Bind(configObj); + } + + public class MyClass + { + public List MyList { get; set; } + public Dictionary MyDictionary { get; set; } + public string MyString { get; set; } + public int MyInt { get; set; } + public Dictionary MyComplexDictionary { get; set; } + } + + public class MyClass2 { } + } + """; + + /// + /// Keep in sync with . + /// + private const string BindCallSampleCodeVariant_WithDifferentConfigTypeName = """ + using System.Collections.Generic; + using Microsoft.Extensions.Configuration; + + public class Program + { + public static void Main() + { + ConfigurationBuilder configurationBuilder = new(); + IConfigurationRoot config = configurationBuilder.Build(); + + MyClass0 configObj = new(); + config.Bind(configObj, options => { }); + config.Bind("key", configObj); + config.Bind(configObj); + } + + public class MyClass0 + { + public List MyList { get; set; } + public Dictionary MyDictionary { get; set; } + public string MyString { get; set; } + public int MyInt { get; set; } + public Dictionary MyComplexDictionary { get; set; } + } + + public class MyClass2 { } + } + """; + + private const string BindCallSampleCodeVariant_WithUnsupportedMember = """ + using System.Collections.Generic; + using Microsoft.Extensions.Configuration; + + public class Program + { + public static void Main() + { + ConfigurationBuilder configurationBuilder = new(); + IConfigurationRoot config = configurationBuilder.Build(); + + MyClass configObj = new(); + config.Bind(configObj); + config.Bind(configObj, options => { }); + config.Bind("key", configObj); + } + + public class MyClass + { + public string MyString { get; set; } + public int MyInt { get; set; } + public List MyList { get; set; } + public Dictionary MyDictionary { get; set; } + public Dictionary MyComplexDictionary { get; set; } + public int[,] UnsupportedMember { get; set; } + } + + public class MyClass2 { } + } + """; + + private const string BindCallSampleCodeVariant_WithUnsupportedMember_ReorderedInvocations = """ + using System.Collections.Generic; + using Microsoft.Extensions.Configuration; + + public class Program + { + public static void Main() + { + ConfigurationBuilder configurationBuilder = new(); + IConfigurationRoot config = configurationBuilder.Build(); + + MyClass configObj = new(); + config.Bind("key", configObj); + config.Bind(configObj); + config.Bind(configObj, options => { }); + } + + public class MyClass + { + public string MyString { get; set; } + public int MyInt { get; set; } + public List MyList { get; set; } + public Dictionary MyDictionary { get; set; } + public Dictionary MyComplexDictionary { get; set; } + public int[,] UnsupportedMember { get; set; } + } + + public class MyClass2 { } + } + """; + + private const string BindCallSampleCodeVariant_WithUnsupportedMember_ReorderedConfigTypeMembers = """ + using System.Collections.Generic; + using Microsoft.Extensions.Configuration; + + public class Program + { + public static void Main() + { + ConfigurationBuilder configurationBuilder = new(); + IConfigurationRoot config = configurationBuilder.Build(); + + MyClass configObj = new(); + config.Bind("key", configObj); + config.Bind(configObj); + config.Bind(configObj, options => { }); + } + + public class MyClass + { + public string MyString { get; set; } + public int MyInt { get; set; } + public int[,] UnsupportedMember { get; set; } + public Dictionary MyDictionary { get; set; } + public Dictionary MyComplexDictionary { get; set; } + public List MyList { get; set; } + } + + public class MyClass2 { } + } + """; + + private const string BindCallSampleCodeVariant_WithUnsupportedMember_WithDiffMemberName = """ + using System.Collections.Generic; + using Microsoft.Extensions.Configuration; + + public class Program + { + public static void Main() + { + ConfigurationBuilder configurationBuilder = new(); + IConfigurationRoot config = configurationBuilder.Build(); + + MyClass configObj = new(); + config.Bind(configObj); + config.Bind(configObj, options => { }); + config.Bind("key", configObj); + } + + public class MyClass + { + public string MyString { get; set; } + public int MyInt { get; set; } + public List MyList { get; set; } + public Dictionary MyDictionary { get; set; } + public Dictionary MyComplexDictionary { get; set; } + public int[,] UnsupportedMember_DiffMemberName { get; set; } + } + + public class MyClass2 { } + } + """; + #endregion Incremental test sources. + } +} diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/ConfigurationBindingGeneratorTests.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/GeneratorTests.cs similarity index 53% rename from src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/ConfigurationBindingGeneratorTests.cs rename to src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/GeneratorTests.cs index 5bc5145739daac..d93607d3763996 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/ConfigurationBindingGeneratorTests.cs +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/GeneratorTests.cs @@ -6,61 +6,33 @@ using System.Collections.Generic; using System.Collections.Immutable; using System.Globalization; -using System.IO; using System.Linq; -using System.Reflection; using System.Text; using System.Text.Json; using System.Threading.Tasks; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; +using Microsoft.Extensions.Configuration; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Options; -using Microsoft.Extensions.SourceGeneration.Configuration.Binder.Tests; -using SourceGenerators.Tests; using Xunit; -namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration.Tests +namespace Microsoft.Extensions.SourceGeneration.Configuration.Binder.Tests { [ActiveIssue("https://github.com/dotnet/runtime/issues/52062", TestPlatforms.Browser)] public partial class ConfigurationBindingGeneratorTests : ConfigurationBinderTestsBase { - private static class Diagnostics + [Theory] + [InlineData(LanguageVersion.CSharp11)] + [InlineData(LanguageVersion.CSharp10)] + public async Task LangVersionMustBeCharp12OrHigher(LanguageVersion langVersion) { - public static (string Id, string Title) TypeNotSupported = ("SYSLIB1100", "Did not generate binding logic for a type"); - public static (string Id, string Title) PropertyNotSupported = ("SYSLIB1101", "Did not generate binding logic for a property on a type"); - public static (string Id, string Title) ValueTypesInvalidForBind = ("SYSLIB1103", "Value types are invalid inputs to configuration 'Bind' methods"); - public static (string Id, string Title) CouldNotDetermineTypeInfo = ("SYSLIB1104", "The target type for a binder call could not be determined"); - } - - private static readonly Assembly[] s_compilationAssemblyRefs = new[] { - typeof(ConfigurationBinder).Assembly, - typeof(CultureInfo).Assembly, - typeof(IConfiguration).Assembly, - typeof(IServiceCollection).Assembly, - typeof(IDictionary).Assembly, - typeof(OptionsBuilder<>).Assembly, - typeof(OptionsConfigurationServiceCollectionExtensions).Assembly, - typeof(Uri).Assembly, - }; - - private enum ExtensionClassType - { - None, - ConfigurationBinder, - OptionsBuilder, - ServiceCollection, - } + ConfigBindingGenRunResult result = await RunGeneratorAndUpdateCompilation(BindCallSampleCode, langVersion: langVersion); + Assert.False(result.GeneratedSource.HasValue); - [Fact] - public async Task LangVersionMustBeCharp11OrHigher() - { - var (d, r) = await RunGenerator(BindCallSampleCode, LanguageVersion.CSharp10); - Assert.Empty(r); - - Diagnostic diagnostic = Assert.Single(d); + Diagnostic diagnostic = Assert.Single(result.Diagnostics); Assert.True(diagnostic.Id == "SYSLIB1102"); - Assert.Contains("C# 11", diagnostic.Descriptor.Title.ToString(CultureInfo.InvariantCulture)); + Assert.Contains("C# 12", diagnostic.Descriptor.Title.ToString(CultureInfo.InvariantCulture)); Assert.Equal(DiagnosticSeverity.Error, diagnostic.Severity); } @@ -103,11 +75,11 @@ public record struct MyRecordStruct { } } """; - var (d, r) = await RunGenerator(source); - Assert.Empty(r); - Assert.Equal(7, d.Count()); + ConfigBindingGenRunResult result = await RunGeneratorAndUpdateCompilation(source); + Assert.False(result.GeneratedSource.HasValue); + Assert.Equal(7, result.Diagnostics.Count()); - foreach (Diagnostic diagnostic in d) + foreach (Diagnostic diagnostic in result.Diagnostics) { Assert.True(diagnostic.Id == Diagnostics.ValueTypesInvalidForBind.Id); Assert.Contains(Diagnostics.ValueTypesInvalidForBind.Title, diagnostic.Descriptor.Title.ToString(CultureInfo.InvariantCulture)); @@ -139,11 +111,11 @@ public record struct MyRecordStruct { } } """; - var (d, r) = await RunGenerator(source); - Assert.Empty(r); - Assert.Equal(2, d.Count()); + ConfigBindingGenRunResult result = await RunGeneratorAndUpdateCompilation(source); + Assert.False(result.GeneratedSource.HasValue); + Assert.Equal(2, result.Diagnostics.Count()); - foreach (Diagnostic diagnostic in d) + foreach (Diagnostic diagnostic in result.Diagnostics) { Assert.True(diagnostic.Id == Diagnostics.CouldNotDetermineTypeInfo.Id); Assert.Contains(Diagnostics.CouldNotDetermineTypeInfo.Title, diagnostic.Descriptor.Title.ToString(CultureInfo.InvariantCulture)); @@ -191,11 +163,11 @@ public class MyClass { } } """; - var (d, r) = await RunGenerator(source); - Assert.Empty(r); - Assert.Equal(6, d.Count()); + ConfigBindingGenRunResult result = await RunGeneratorAndUpdateCompilation(source); + Assert.False(result.GeneratedSource.HasValue); + Assert.Equal(6, result.Diagnostics.Count()); - foreach (Diagnostic diagnostic in d) + foreach (Diagnostic diagnostic in result.Diagnostics) { Assert.True(diagnostic.Id == Diagnostics.CouldNotDetermineTypeInfo.Id); Assert.Contains(Diagnostics.CouldNotDetermineTypeInfo.Title, diagnostic.Descriptor.Title.ToString(CultureInfo.InvariantCulture)); @@ -205,60 +177,7 @@ public class MyClass { } } [Fact] - public async Task BindCanParseMethodParam() - { - string source = """ - using System; - using Microsoft.AspNetCore.Builder; - using Microsoft.Extensions.Configuration; - using Microsoft.Extensions.DependencyInjection; - - public class Program - { - public static void Main() - { - ConfigurationBuilder configurationBuilder = new(); - IConfiguration config = configurationBuilder.Build(); - - BindOptions(config, new MyClass0()); - BindOptions(config, new MyClass1(), (_) => { }); - BindOptions(config, "", new MyClass2()); - } - - private void BindOptions(IConfiguration config, MyClass0 instance) - { - config.Bind(instance); - } - - private void BindOptions(IConfiguration config, MyClass1 instance, Action? configureOptions) - { - config.Bind(instance, configureOptions); - } - - private void BindOptions(IConfiguration config, string path, MyClass2 instance) - { - config.Bind(path, instance); - } - - public class MyClass0 { } - public class MyClass1 { } - public class MyClass2 { } - } - """; - - var (d, r) = await RunGenerator(source); - Assert.Single(r); - - string generatedSource = string.Join('\n', r[0].SourceText.Lines.Select(x => x.ToString())); - Assert.Contains($"public static void Bind(this global::Microsoft.Extensions.Configuration.IConfiguration configuration, global::Program.MyClass0 obj) => {{ }};", generatedSource); - Assert.Contains($"public static void Bind(this global::Microsoft.Extensions.Configuration.IConfiguration configuration, global::Program.MyClass1 obj, global::System.Action? configureOptions) => {{ }};", generatedSource); - Assert.Contains($"public static void Bind(this global::Microsoft.Extensions.Configuration.IConfiguration configuration, string key, global::Program.MyClass2 obj) => {{ }};", generatedSource); - - Assert.Empty(d); - } - - [Fact] - public async Task SucceedForMinimalInput() + public async Task SucceedWhenGivenMinimumRequiredReferences() { string source = """ using System; @@ -299,22 +218,15 @@ public class MyClass0 { } async Task Test(bool expectOutput) { - var (d, r) = await RunGenerator(source, references: GetFilteredAssemblyRefs(exclusions)); - - Assert.Empty(d); - - if (expectOutput) - { - Assert.Single(r); - } - else - { - Assert.Empty(r); - } + ConfigBindingGenRunResult result = await RunGeneratorAndUpdateCompilation(source, assemblyReferences: GetFilteredAssemblyRefs(exclusions)); + Assert.Empty(result.Diagnostics); + Action ValidateSourceResult = expectOutput ? () => Assert.NotNull(result.GeneratedSource) : () => Assert.False(result.GeneratedSource.HasValue); + ValidateSourceResult(); } } [Fact] + [ActiveIssue("Work out why we aren't getting all the expected diagnostics.")] public async Task IssueDiagnosticsForAllOffendingCallsites() { string source = """ @@ -363,61 +275,10 @@ public class AnotherGraphWithUnsupportedMembers } """; - var (d, r) = await RunGenerator(source, references: GetAssemblyRefsWithAdditional(typeof(ImmutableArray<>), typeof(Encoding), typeof(JsonSerializer))); - Assert.Single(r); - Assert.Equal(12, d.Where(diag => diag.Id == Diagnostics.TypeNotSupported.Id).Count()); - Assert.Equal(10, d.Where(diag => diag.Id == Diagnostics.PropertyNotSupported.Id).Count()); - } - - private static async Task VerifyAgainstBaselineUsingFile( - string filename, - string testSourceCode, - LanguageVersion languageVersion = LanguageVersion.Preview, - Action>? assessDiagnostics = null, - ExtensionClassType extType = ExtensionClassType.None) - { - string path = extType is ExtensionClassType.None - ? Path.Combine("Baselines", filename) - : Path.Combine("Baselines", extType.ToString(), filename); - string baseline = LineEndingsHelper.Normalize(await File.ReadAllTextAsync(path).ConfigureAwait(false)); - string[] expectedLines = baseline.Replace("%VERSION%", typeof(ConfigurationBindingGenerator).Assembly.GetName().Version?.ToString()) - .Split(Environment.NewLine); - - var (d, r) = await RunGenerator(testSourceCode, languageVersion); - bool success = RoslynTestUtils.CompareLines(expectedLines, r[0].SourceText, out string errorMessage); - -#if !SKIP_BASELINES - Assert.Single(r); - (assessDiagnostics ?? ((d) => Assert.Empty(d))).Invoke(d); - Assert.True(success, errorMessage); -#endif - } - - private static async Task<(ImmutableArray, ImmutableArray)> RunGenerator( - string testSourceCode, - LanguageVersion langVersion = LanguageVersion.CSharp11, - IEnumerable? references = null) => - await RoslynTestUtils.RunGenerator( - new ConfigurationBindingGenerator(), - references ?? s_compilationAssemblyRefs, - new[] { testSourceCode }, - langVersion: langVersion).ConfigureAwait(false); - - public static List GetAssemblyRefsWithAdditional(params Type[] additional) - { - List assemblies = new(s_compilationAssemblyRefs); - assemblies.AddRange(additional.Select(t => t.Assembly)); - return assemblies; - } - - public static HashSet GetFilteredAssemblyRefs(IEnumerable exclusions) - { - HashSet assemblies = new(s_compilationAssemblyRefs); - foreach (Type exclusion in exclusions) - { - assemblies.Remove(exclusion.Assembly); - } - return assemblies; + ConfigBindingGenRunResult result = await RunGeneratorAndUpdateCompilation(source, assemblyReferences: GetAssemblyRefsWithAdditional(typeof(ImmutableArray<>), typeof(Encoding), typeof(JsonSerializer))); + Assert.NotNull(result.GeneratedSource); + Assert.True(result.Diagnostics.Any(diag => diag.Id == Diagnostics.TypeNotSupported.Id)); + Assert.True(result.Diagnostics.Any(diag => diag.Id == Diagnostics.PropertyNotSupported.Id)); } } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Microsoft.Extensions.Configuration.Binder.SourceGeneration.Tests.csproj b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Microsoft.Extensions.Configuration.Binder.SourceGeneration.Tests.csproj index 2108bc2574ed2c..848d93b32a475a 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Microsoft.Extensions.Configuration.Binder.SourceGeneration.Tests.csproj +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Microsoft.Extensions.Configuration.Binder.SourceGeneration.Tests.csproj @@ -2,16 +2,20 @@ $(NetCoreAppCurrent);$(NetFrameworkMinimum) true - - SYSLIB1100,SYSLIB1101 - + + $(NoWarn);SYSLIB1100,SYSLIB1101 + + $(NoWarn);SYSLIB1103,SYSLIB1104 + $(Features);InterceptorsPreview + + $(InterceptorsPreviewNamespaces);Microsoft.Extensions.Configuration.Binder.SourceGeneration true $(DefineConstants);BUILDING_SOURCE_GENERATOR_TESTS;ROSLYN4_0_OR_GREATER;ROSLYN4_4_OR_GREATER $(DefineConstants);LAUNCH_DEBUGGER - $(DefineConstants);SKIP_BASELINES + $(DefineConstants);UPDATE_BASELINES @@ -20,6 +24,7 @@ + @@ -28,6 +33,7 @@ + @@ -43,16 +49,16 @@ - + PreserveNewest - - - + + + + + + diff --git a/src/libraries/Microsoft.Extensions.Configuration.CommandLine/src/PACKAGE.md b/src/libraries/Microsoft.Extensions.Configuration.CommandLine/src/PACKAGE.md index 39daac6e4ec6c0..3714573d5d449b 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.CommandLine/src/PACKAGE.md +++ b/src/libraries/Microsoft.Extensions.Configuration.CommandLine/src/PACKAGE.md @@ -1,14 +1,16 @@ ## About + + Command line configuration provider implementation for [Microsoft.Extensions.Configuration](https://www.nuget.org/packages/Microsoft.Extensions.Configuration/). This package enables you to read configuration parameters from the command line arguments of your application. You can use [CommandLineConfigurationExtensions.AddCommandLine](https://learn.microsoft.com/dotnet/api/microsoft.extensions.configuration.commandlineconfigurationextensions.addcommandline) extension method on `IConfigurationBuilder` to add the command line configuration provider to the configuration builder. -For more information, see the documentation: [Command-line configuration provider](https://learn.microsoft.com/dotnet/core/extensions/configuration-providers#command-line-configuration-provider). +## How to Use -## Example + The following example shows how to read application configuration from the command line. You can use a command like `dotnet run --InputPath "c:\fizz" --OutputPath "c:\buzz"` to run it. -```cs +```C# using System; using Microsoft.Extensions.Configuration; @@ -20,10 +22,23 @@ class Program IConfiguration config = new ConfigurationBuilder() .AddCommandLine(args) .Build(); - + // Read configuration values Console.WriteLine($"InputPath: {config["InputPath"]}"); Console.WriteLine($"OutputPath: {config["OutputPath"]}"); } } ``` + +## Additional Documentation + + + +* [Command-line configuration provider](https://learn.microsoft.com/dotnet/core/extensions/configuration-providers#command-line-configuration-provider) +* [API documentation](https://learn.microsoft.com/dotnet/api/microsoft.extensions.configuration.commandline) + +## Feedback & Contributing + + + +Microsoft.Extensions.Configuration.CommandLine is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). \ No newline at end of file diff --git a/src/libraries/Microsoft.Extensions.Configuration.EnvironmentVariables/src/PACKAGE.md b/src/libraries/Microsoft.Extensions.Configuration.EnvironmentVariables/src/PACKAGE.md index 84d2d9412cce76..eb9a67bfbfda25 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.EnvironmentVariables/src/PACKAGE.md +++ b/src/libraries/Microsoft.Extensions.Configuration.EnvironmentVariables/src/PACKAGE.md @@ -1,10 +1,13 @@ ## About + + Environment variables configuration provider implementation for [Microsoft.Extensions.Configuration](https://www.nuget.org/packages/Microsoft.Extensions.Configuration/). This package enables you to read configuration parameters from environment variables. You can use [EnvironmentVariablesExtensions.AddEnvironmentVariables](https://learn.microsoft.com/dotnet/api/microsoft.extensions.configuration.environmentvariablesextensions.addenvironmentvariables) extension method on `IConfigurationBuilder` to add the environment variables configuration provider to the configuration builder. -For more information, see the documentation: [Environment variable configuration provider](https://learn.microsoft.com/dotnet/core/extensions/configuration-providers#environment-variable-configuration-provider). +## How to Use + + -## Example The following example shows how to read application configuration from environment variables. ```cs @@ -26,3 +29,16 @@ class Program } } ``` + +## Additional Documentation + + + +* [Environment variable configuration provider](https://learn.microsoft.com/dotnet/core/extensions/configuration-providers#environment-variable-configuration-provider) +* [API documentation](https://learn.microsoft.com/dotnet/api/microsoft.extensions.configuration.environmentvariables) + +## Feedback & Contributing + + + +Microsoft.Extensions.Configuration.EnvironmentVariables is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). \ No newline at end of file diff --git a/src/libraries/Microsoft.Extensions.Configuration.FileExtensions/src/PACKAGE.md b/src/libraries/Microsoft.Extensions.Configuration.FileExtensions/src/PACKAGE.md index e43c909d83225c..4a4f404102c89a 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.FileExtensions/src/PACKAGE.md +++ b/src/libraries/Microsoft.Extensions.Configuration.FileExtensions/src/PACKAGE.md @@ -1,9 +1,19 @@ ## About + + Provides a base class for file-based configuration providers used with [Microsoft.Extensions.Configuration](https://www.nuget.org/packages/Microsoft.Extensions.Configuration/) and extension methods for configuring them. -For more information, see the documentation: +## Additional Documentation + + + +* [Configuration in .NET](https://learn.microsoft.com/dotnet/core/extensions/configuration) +* [Microsoft.Extensions.Configuration.FileConfigurationProvider](https://learn.microsoft.com/dotnet/api/microsoft.extensions.configuration.fileconfigurationprovider) +* [Microsoft.Extensions.Configuration.FileConfigurationExtensions](https://learn.microsoft.com/dotnet/api/microsoft.extensions.configuration.fileconfigurationextensions) + +## Feedback & Contributing + + -- [Configuration in .NET](https://learn.microsoft.com/dotnet/core/extensions/configuration) -- [Microsoft.Extensions.Configuration.FileConfigurationProvider](https://learn.microsoft.com/dotnet/api/microsoft.extensions.configuration.fileconfigurationprovider) -- [Microsoft.Extensions.Configuration.FileConfigurationExtensions](https://learn.microsoft.com/dotnet/api/microsoft.extensions.configuration.fileconfigurationextensions) +Microsoft.Extensions.Configuration.FileExtensions is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). \ No newline at end of file diff --git a/src/libraries/Microsoft.Extensions.Configuration.Ini/src/PACKAGE.md b/src/libraries/Microsoft.Extensions.Configuration.Ini/src/PACKAGE.md index 7322364392c600..a987b88924d64d 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Ini/src/PACKAGE.md +++ b/src/libraries/Microsoft.Extensions.Configuration.Ini/src/PACKAGE.md @@ -1,13 +1,14 @@ ## About + + INI configuration provider implementation for [Microsoft.Extensions.Configuration](https://www.nuget.org/packages/Microsoft.Extensions.Configuration/). This package enables you to read configuration parameters from [INI files](https://en.wikipedia.org/wiki/INI_file). You can use [IniConfigurationExtensions.AddIniFile](https://learn.microsoft.com/dotnet/api/microsoft.extensions.configuration.iniconfigurationextensions.addinifile) extension method on `IConfigurationBuilder` to add INI configuration provider to the configuration builder. -For more information, see the documentation: [INI configuration provider](https://learn.microsoft.com/dotnet/core/extensions/configuration-providers#ini-configuration-provider). +## How to Use -## Example -The following example shows how to read the application configuration from INI file. + -```cs +```C# using System; using Microsoft.Extensions.Configuration; @@ -47,3 +48,16 @@ You can include a configuration file using a code like this in your `.csproj` fi ``` + +## Additional Documentation + + + +* [INI configuration provider](https://learn.microsoft.com/dotnet/core/extensions/configuration-providers#ini-configuration-provider) +* [API documentation](https://learn.microsoft.com/dotnet/api/microsoft.extensions.configuration.ini) + +## Feedback & Contributing + + + +Microsoft.Extensions.Configuration.Ini is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). \ No newline at end of file diff --git a/src/libraries/Microsoft.Extensions.Configuration.Json/src/PACKAGE.md b/src/libraries/Microsoft.Extensions.Configuration.Json/src/PACKAGE.md index ae1c6355100f3a..825b3dbe412d4d 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Json/src/PACKAGE.md +++ b/src/libraries/Microsoft.Extensions.Configuration.Json/src/PACKAGE.md @@ -1,10 +1,13 @@ ## About + + JSON configuration provider implementation for [Microsoft.Extensions.Configuration](https://www.nuget.org/packages/Microsoft.Extensions.Configuration/). This package enables you to read your application's settings from a JSON file. You can use [JsonConfigurationExtensions.AddJsonFile](https://docs.microsoft.com/dotnet/api/microsoft.extensions.configuration.jsonconfigurationextensions.addjsonfile) extension method on `IConfigurationBuilder` to add the JSON configuration provider to the configuration builder. -For more information, see the documentation: [JSON configuration provider](https://learn.microsoft.com/dotnet/core/extensions/configuration-providers#json-configuration-provider). +## How to Use + + -## Example The following example shows how to read application settings from the JSON configuration file. ```cs @@ -60,3 +63,16 @@ You can include a configuration file using a code like this in your `.csproj` fi ``` + +## Additional Documentation + + + +* [JSON configuration provider](https://learn.microsoft.com/dotnet/core/extensions/configuration-providers#json-configuration-provider) +* [API documentation](https://learn.microsoft.com/dotnet/api/microsoft.extensions.configuration.json) + +## Feedback & Contributing + + + +Microsoft.Extensions.Configuration.Json is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). \ No newline at end of file diff --git a/src/libraries/Microsoft.Extensions.Configuration.UserSecrets/src/PACKAGE.md b/src/libraries/Microsoft.Extensions.Configuration.UserSecrets/src/PACKAGE.md index 4dba7ccd6bbb5d..1fe9ee6c98939e 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.UserSecrets/src/PACKAGE.md +++ b/src/libraries/Microsoft.Extensions.Configuration.UserSecrets/src/PACKAGE.md @@ -1,8 +1,19 @@ ## About + + User secrets configuration provider implementation for [Microsoft.Extensions.Configuration](https://www.nuget.org/packages/Microsoft.Extensions.Configuration/). User secrets mechanism enables you to override application configuration settings with values stored in the local secrets file. You can use [UserSecretsConfigurationExtensions.AddUserSecrets](https://learn.microsoft.com/dotnet/api/microsoft.extensions.configuration.usersecretsconfigurationextensions.addusersecrets) extension method on `IConfigurationBuilder` to add user secrets provider to the configuration builder. -For more information, see the documentation: +## Additional Documentation + + + +* [Configuration in .NET](https://learn.microsoft.com/dotnet/core/extensions/configuration) +* [Safe storage of app secrets in development in ASP.NET Core](https://learn.microsoft.com/aspnet/core/security/app-secrets) +* [API documentation](https://learn.microsoft.com/dotnet/api/microsoft.extensions.configuration.usersecrets) + +## Feedback & Contributing + + -- [Configuration in .NET](https://learn.microsoft.com/dotnet/core/extensions/configuration) -- [Safe storage of app secrets in development in ASP.NET Core](https://learn.microsoft.com/aspnet/core/security/app-secrets) +Microsoft.Extensions.Configuration.UserSecrets is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). \ No newline at end of file diff --git a/src/libraries/Microsoft.Extensions.Configuration.Xml/src/PACKAGE.md b/src/libraries/Microsoft.Extensions.Configuration.Xml/src/PACKAGE.md index d47be06c8dea3b..209645bc5a4732 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Xml/src/PACKAGE.md +++ b/src/libraries/Microsoft.Extensions.Configuration.Xml/src/PACKAGE.md @@ -1,10 +1,13 @@ ## About + + XML configuration provider implementation for [Microsoft.Extensions.Configuration](https://www.nuget.org/packages/Microsoft.Extensions.Configuration/). This package enables you to read configuration parameters from XML files. You can use [XmlConfigurationExtensions.AddXmlFile](https://learn.microsoft.com/dotnet/api/microsoft.extensions.configuration.xmlconfigurationextensions.addxmlfile) extension method on `IConfigurationBuilder` to add XML configuration provider to the configuration builder. -For more information, see the documentation: [XML configuration provider](https://learn.microsoft.com/dotnet/core/extensions/configuration-providers#xml-configuration-provider). +## How to Use + + -## Example The following example shows how to read the application configuration from XML file. ```cs @@ -59,3 +62,16 @@ You can include a configuration file using a code like this in your `.csproj` fi ``` + +## Additional Documentation + + + +* [XML configuration provider](https://learn.microsoft.com/dotnet/core/extensions/configuration-providers#xml-configuration-provider) +* [API documentation](https://learn.microsoft.com/dotnet/api/microsoft.extensions.configuration.xml) + +## Feedback & Contributing + + + +Microsoft.Extensions.Configuration.Xml is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). \ No newline at end of file diff --git a/src/libraries/Microsoft.Extensions.Configuration/src/PACKAGE.md b/src/libraries/Microsoft.Extensions.Configuration/src/PACKAGE.md index 92e93652f7a609..1d193afbda8667 100644 --- a/src/libraries/Microsoft.Extensions.Configuration/src/PACKAGE.md +++ b/src/libraries/Microsoft.Extensions.Configuration/src/PACKAGE.md @@ -1,8 +1,82 @@ ## About + + `Microsoft.Extensions.Configuration` is combined with a core configuration abstraction under `Microsoft.Extensions.Configuration.Abstractions` that allows for building different kinds of configuration providers to retrieve key/value pair configuration values from in the form of `IConfiguration`. There are a number of built-in configuration provider implementations to read from environment variables, in-memory collections, JSON, INI or XML files. Aside from the built-in variations, there are more shipped libraries shipped by community for integration with various configuration service and other data sources. -For more information, see the documentation: +## Key Features + + + +* In-memory configuration provider +* Chained configuration provider for chaining multiple confiugration providers together. +* Base types that implement configuration abstraction interfaces that can be used when implementing other configuration providers. + +## How to Use + + + +```C# +using Microsoft.Extensions.Configuration; + +var configurationBuilder = new ConfigurationBuilder(); + +configurationBuilder.AddInMemoryCollection( + new Dictionary + { + ["Setting1"] = "value", + ["MyOptions:Enabled"] = bool.TrueString, + }); + +configurationBuilder.AddInMemoryCollection( + new Dictionary + { + ["Setting2"] = "value2", + ["MyOptions:Enabled"] = bool.FalseString, + }); + +var config = configurationBuilder.Build(); + +// note case-insensitive +Console.WriteLine(config["setting1"]); +Console.WriteLine(config["setting2"]); + +// note last in wins +Console.WriteLine(config["MyOptions:Enabled"]); +``` + +## Main Types + + + +The main types provided by this library are: + +* `Microsoft.Extensions.Configuration.ConfigurationBuilder` +* `Microsoft.Extensions.Configuration.ConfigurationManager` +* `Microsoft.Extensions.Configuration.ConfigurationRoot` +* `Microsoft.Extensions.Configuration.ConfigurationSection` + +## Additional Documentation + + - [Configuration in .NET](https://learn.microsoft.com/dotnet/core/extensions/configuration) - [Microsoft.Extensions.Configuration namespace](https://learn.microsoft.com/dotnet/api/microsoft.extensions.configuration) + +## Related Packages + + +* [Microsoft.Extensions.Configuration.Binder](https://www.nuget.org/packages/Microsoft.Extensions.Configuration.Binder) +* [Microsoft.Extensions.Configuration.CommandLine](https://www.nuget.org/packages/Microsoft.Extensions.Configuration.CommandLine) +* [Microsoft.Extensions.Configuration.EnvironmentVariables](https://www.nuget.org/packages/Microsoft.Extensions.Configuration.EnvironmentVariables) +* [Microsoft.Extensions.Configuration.FileExtensions](https://www.nuget.org/packages/Microsoft.Extensions.Configuration.FileExtensions) +* [Microsoft.Extensions.Configuration.Ini](https://www.nuget.org/packages/Microsoft.Extensions.Configuration.Ini) +* [Microsoft.Extensions.Configuration.Json](https://www.nuget.org/packages/Microsoft.Extensions.Configuration.Json) +* [Microsoft.Extensions.Configuration.UserSecrets](https://www.nuget.org/packages/Microsoft.Extensions.Configuration.UserSecrets) +* [Microsoft.Extensions.Configuration.Xml](https://www.nuget.org/packages/Microsoft.Extensions.Configuration.Xml) + +## Feedback & Contributing + + + +Microsoft.Extensions.Configuration is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection.Abstractions/src/ActivatorUtilities.cs b/src/libraries/Microsoft.Extensions.DependencyInjection.Abstractions/src/ActivatorUtilities.cs index eb7931489b557e..3ba30724d97f2b 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection.Abstractions/src/ActivatorUtilities.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection.Abstractions/src/ActivatorUtilities.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.Collections.Concurrent; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Linq.Expressions; @@ -10,6 +11,10 @@ using System.Runtime.ExceptionServices; using Microsoft.Extensions.Internal; +#if NETCOREAPP +[assembly: System.Reflection.Metadata.MetadataUpdateHandler(typeof(Microsoft.Extensions.DependencyInjection.ActivatorUtilities.ActivatorUtilitiesUpdateHandler))] +#endif + namespace Microsoft.Extensions.DependencyInjection { /// @@ -17,13 +22,21 @@ namespace Microsoft.Extensions.DependencyInjection /// public static class ActivatorUtilities { +#if NETCOREAPP + // Support caching of constructor metadata for the common case of types in non-collectible assemblies. + private static readonly ConcurrentDictionary s_constructorInfos = new(); + + // Support caching of constructor metadata for types in collectible assemblies. + private static readonly Lazy> s_collectibleConstructorInfos = new(); +#endif + #if NET8_0_OR_GREATER // Maximum number of fixed arguments for ConstructorInvoker.Invoke(arg1, etc). private const int FixedArgumentThreshold = 4; #endif private static readonly MethodInfo GetServiceInfo = - GetMethodInfo>((sp, t, r, c) => GetService(sp, t, r, c)); + GetMethodInfo>((sp, t, r, c, k) => GetService(sp, t, r, c, k)); /// /// Instantiate a type with constructor arguments provided directly and/or from an . @@ -47,6 +60,17 @@ public static object CreateInstance( throw new InvalidOperationException(SR.CannotCreateAbstractClasses); } + ConstructorInfoEx[]? constructors; +#if NETCOREAPP + if (!s_constructorInfos.TryGetValue(instanceType, out constructors)) + { + constructors = GetOrAddConstructors(instanceType); + } +#else + constructors = CreateConstructorInfoExs(instanceType); +#endif + + ConstructorInfoEx? constructor; IServiceProviderIsService? serviceProviderIsService = provider.GetService(); // if container supports using IServiceProviderIsService, we try to find the longest ctor that // (a) matches all parameters given to CreateInstance @@ -61,10 +85,11 @@ public static object CreateInstance( ConstructorMatcher bestMatcher = default; bool multipleBestLengthFound = false; - foreach (ConstructorInfo? constructor in instanceType.GetConstructors()) + for (int i = 0; i < constructors.Length; i++) { - var matcher = new ConstructorMatcher(constructor); - bool isPreferred = constructor.IsDefined(typeof(ActivatorUtilitiesConstructorAttribute), false); + constructor = constructors[i]; + ConstructorMatcher matcher = new(constructor); + bool isPreferred = constructor.IsPreferred; int length = matcher.Match(parameters, serviceProviderIsService); if (isPreferred) @@ -105,18 +130,79 @@ public static object CreateInstance( } } - Type?[] argumentTypes = new Type[parameters.Length]; - for (int i = 0; i < argumentTypes.Length; i++) + Type?[] argumentTypes; + if (parameters.Length == 0) { - argumentTypes[i] = parameters[i]?.GetType(); + argumentTypes = Type.EmptyTypes; + } + else + { + argumentTypes = new Type[parameters.Length]; + for (int i = 0; i < argumentTypes.Length; i++) + { + argumentTypes[i] = parameters[i]?.GetType(); + } } FindApplicableConstructor(instanceType, argumentTypes, out ConstructorInfo constructorInfo, out int?[] parameterMap); - var constructorMatcher = new ConstructorMatcher(constructorInfo); + + // Find the ConstructorInfoEx from the given constructorInfo. + constructor = null; + foreach (ConstructorInfoEx ctor in constructors) + { + if (ReferenceEquals(ctor.Info, constructorInfo)) + { + constructor = ctor; + break; + } + } + + Debug.Assert(constructor != null); + + var constructorMatcher = new ConstructorMatcher(constructor); constructorMatcher.MapParameters(parameterMap, parameters); return constructorMatcher.CreateInstance(provider); } +#if NETCOREAPP + private static ConstructorInfoEx[] GetOrAddConstructors( + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] Type type) + { + // Not found. Do the slower work of checking for the value in the correct cache. + // Null and non-collectible load contexts use the default cache. + if (!type.Assembly.IsCollectible) + { + return s_constructorInfos.GetOrAdd(type, CreateConstructorInfoExs(type)); + } + + // Collectible load contexts should use the ConditionalWeakTable so they can be unloaded. + if (s_collectibleConstructorInfos.Value.TryGetValue(type, out ConstructorInfoEx[]? value)) + { + return value; + } + + value = CreateConstructorInfoExs(type); + + // ConditionalWeakTable doesn't support GetOrAdd() so use AddOrUpdate(). This means threads + // can have different instances for the same type, but that is OK since they are equivalent. + s_collectibleConstructorInfos.Value.AddOrUpdate(type, value); + return value; + } +#endif // NETCOREAPP + + private static ConstructorInfoEx[] CreateConstructorInfoExs( + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] Type type) + { + ConstructorInfo[] constructors = type.GetConstructors(); + ConstructorInfoEx[]? value = new ConstructorInfoEx[constructors.Length]; + for (int i = 0; i < constructors.Length; i++) + { + value[i] = new ConstructorInfoEx(constructors[i]); + } + + return value; + } + /// /// Create a delegate that will instantiate a type with constructor arguments provided directly /// and/or from an . @@ -238,9 +324,9 @@ private static MethodInfo GetMethodInfo(Expression expr) return mc.Method; } - private static object? GetService(IServiceProvider sp, Type type, Type requiredBy, bool hasDefaultValue) + private static object? GetService(IServiceProvider sp, Type type, Type requiredBy, bool hasDefaultValue, object? key) { - object? service = sp.GetService(type); + object? service = key == null ? sp.GetService(type) : GetKeyedService(sp, type, key); if (service is null && !hasDefaultValue) { ThrowHelperUnableToResolveService(type, requiredBy); @@ -275,10 +361,12 @@ private static BlockExpression BuildFactoryExpression( } else { + var keyAttribute = (FromKeyedServicesAttribute?) Attribute.GetCustomAttribute(constructorParameter, typeof(FromKeyedServicesAttribute), inherit: false); var parameterTypeExpression = new Expression[] { serviceProvider, Expression.Constant(parameterType, typeof(Type)), Expression.Constant(constructor.DeclaringType, typeof(Type)), - Expression.Constant(hasDefaultValue) }; + Expression.Constant(hasDefaultValue), + Expression.Constant(keyAttribute?.Key) }; constructorArguments[i] = Expression.Call(GetServiceInfo, parameterTypeExpression); } @@ -349,10 +437,10 @@ private static ObjectFactory CreateFactoryReflection( if (matchedArgCount == 0) { // All injected; use a fast path. - Type[] types = GetParameterTypes(); + FactoryParameterContext[] parameters = GetFactoryParameterContext(); return useFixedValues ? - (serviceProvider, arguments) => ReflectionFactoryServiceOnlyFixed(invoker, types, declaringType, serviceProvider) : - (serviceProvider, arguments) => ReflectionFactoryServiceOnlySpan(invoker, types, declaringType, serviceProvider); + (serviceProvider, arguments) => ReflectionFactoryServiceOnlyFixed(invoker, parameters, declaringType, serviceProvider) : + (serviceProvider, arguments) => ReflectionFactoryServiceOnlySpan(invoker, parameters, declaringType, serviceProvider); } if (matchedArgCount == constructorParameters.Length) @@ -370,16 +458,6 @@ ObjectFactory InvokeCanonical() (serviceProvider, arguments) => ReflectionFactoryCanonicalFixed(invoker, parameters, declaringType, serviceProvider, arguments) : (serviceProvider, arguments) => ReflectionFactoryCanonicalSpan(invoker, parameters, declaringType, serviceProvider, arguments); } - - Type[] GetParameterTypes() - { - Type[] types = new Type[constructorParameters.Length]; - for (int i = 0; i < constructorParameters.Length; i++) - { - types[i] = constructorParameters[i].ParameterType; - } - return types; - } #else ParameterInfo[] constructorParameters = constructor.GetParameters(); if (constructorParameters.Length == 0) @@ -398,8 +476,15 @@ FactoryParameterContext[] GetFactoryParameterContext() for (int i = 0; i < constructorParameters.Length; i++) { ParameterInfo constructorParameter = constructorParameters[i]; + FromKeyedServicesAttribute? attr = (FromKeyedServicesAttribute?) + Attribute.GetCustomAttribute(constructorParameter, typeof(FromKeyedServicesAttribute), inherit: false); bool hasDefaultValue = ParameterDefaultValue.TryGetDefaultValue(constructorParameter, out object? defaultValue); - parameters[i] = new FactoryParameterContext(constructorParameter.ParameterType, hasDefaultValue, defaultValue, parameterMap[i] ?? -1); + parameters[i] = new FactoryParameterContext( + constructorParameter.ParameterType, + hasDefaultValue, + defaultValue, + parameterMap[i] ?? -1, + attr?.Key); } return parameters; @@ -409,18 +494,20 @@ FactoryParameterContext[] GetFactoryParameterContext() private readonly struct FactoryParameterContext { - public FactoryParameterContext(Type parameterType, bool hasDefaultValue, object? defaultValue, int argumentIndex) + public FactoryParameterContext(Type parameterType, bool hasDefaultValue, object? defaultValue, int argumentIndex, object? serviceKey) { ParameterType = parameterType; HasDefaultValue = hasDefaultValue; DefaultValue = defaultValue; ArgumentIndex = argumentIndex; + ServiceKey = serviceKey; } public Type ParameterType { get; } public bool HasDefaultValue { get; } public object? DefaultValue { get; } public int ArgumentIndex { get; } + public object? ServiceKey { get; } } private static void FindApplicableConstructor( @@ -551,58 +638,82 @@ private static bool TryCreateParameterMap(ParameterInfo[] constructorParameters, return true; } - private static object? GetService(IServiceProvider serviceProvider, ParameterInfo parameterInfo) + private sealed class ConstructorInfoEx { - // Handle keyed service - if (TryGetServiceKey(parameterInfo, out object? key)) + public readonly ConstructorInfo Info; + public readonly ParameterInfo[] Parameters; + public readonly bool IsPreferred; + private readonly object?[]? _parameterKeys; + + public ConstructorInfoEx(ConstructorInfo constructor) { - if (serviceProvider is IKeyedServiceProvider keyedServiceProvider) + Info = constructor; + Parameters = constructor.GetParameters(); + IsPreferred = constructor.IsDefined(typeof(ActivatorUtilitiesConstructorAttribute), inherit: false); + + for (int i = 0; i < Parameters.Length; i++) { - return keyedServiceProvider.GetKeyedService(parameterInfo.ParameterType, key); + FromKeyedServicesAttribute? attr = (FromKeyedServicesAttribute?) + Attribute.GetCustomAttribute(Parameters[i], typeof(FromKeyedServicesAttribute), inherit: false); + + if (attr is not null) + { + _parameterKeys ??= new object?[Parameters.Length]; + _parameterKeys[i] = attr.Key; + } } - throw new InvalidOperationException(SR.KeyedServicesNotSupported); } - // Try non keyed service - return serviceProvider.GetService(parameterInfo.ParameterType); - } - private static bool IsService(IServiceProviderIsService serviceProviderIsService, ParameterInfo parameterInfo) - { - // Handle keyed service - if (TryGetServiceKey(parameterInfo, out object? key)) + public bool IsService(IServiceProviderIsService serviceProviderIsService, int parameterIndex) { - if (serviceProviderIsService is IServiceProviderIsKeyedService serviceProviderIsKeyedService) + ParameterInfo parameterInfo = Parameters[parameterIndex]; + + // Handle keyed service + object? key = _parameterKeys?[parameterIndex]; + if (key is not null) { - return serviceProviderIsKeyedService.IsKeyedService(parameterInfo.ParameterType, key); + if (serviceProviderIsService is IServiceProviderIsKeyedService serviceProviderIsKeyedService) + { + return serviceProviderIsKeyedService.IsKeyedService(parameterInfo.ParameterType, key); + } + + throw new InvalidOperationException(SR.KeyedServicesNotSupported); } - throw new InvalidOperationException(SR.KeyedServicesNotSupported); + + // Use non-keyed service + return serviceProviderIsService.IsService(parameterInfo.ParameterType); } - // Try non keyed service - return serviceProviderIsService.IsService(parameterInfo.ParameterType); - } - private static bool TryGetServiceKey(ParameterInfo parameterInfo, out object? key) - { - foreach (var attribute in parameterInfo.GetCustomAttributes(false)) + public object? GetService(IServiceProvider serviceProvider, int parameterIndex) { - key = attribute.Key; - return true; + ParameterInfo parameterInfo = Parameters[parameterIndex]; + + // Handle keyed service + object? key = _parameterKeys?[parameterIndex]; + if (key is not null) + { + if (serviceProvider is IKeyedServiceProvider keyedServiceProvider) + { + return keyedServiceProvider.GetKeyedService(parameterInfo.ParameterType, key); + } + + throw new InvalidOperationException(SR.KeyedServicesNotSupported); + } + + // Use non-keyed service + return serviceProvider.GetService(parameterInfo.ParameterType); } - key = null; - return false; } private readonly struct ConstructorMatcher { - private readonly ConstructorInfo _constructor; - private readonly ParameterInfo[] _parameters; + private readonly ConstructorInfoEx _constructor; private readonly object?[] _parameterValues; - public ConstructorMatcher(ConstructorInfo constructor) + public ConstructorMatcher(ConstructorInfoEx constructor) { _constructor = constructor; - _parameters = _constructor.GetParameters(); - _parameterValues = new object?[_parameters.Length]; + _parameterValues = new object[constructor.Parameters.Length]; } public int Match(object[] givenParameters, IServiceProviderIsService serviceProviderIsService) @@ -612,10 +723,10 @@ public int Match(object[] givenParameters, IServiceProviderIsService serviceProv Type? givenType = givenParameters[givenIndex]?.GetType(); bool givenMatched = false; - for (int applyIndex = 0; applyIndex < _parameters.Length; applyIndex++) + for (int applyIndex = 0; applyIndex < _constructor.Parameters.Length; applyIndex++) { if (_parameterValues[applyIndex] == null && - _parameters[applyIndex].ParameterType.IsAssignableFrom(givenType)) + _constructor.Parameters[applyIndex].ParameterType.IsAssignableFrom(givenType)) { givenMatched = true; _parameterValues[applyIndex] = givenParameters[givenIndex]; @@ -630,12 +741,12 @@ public int Match(object[] givenParameters, IServiceProviderIsService serviceProv } // confirms the rest of ctor arguments match either as a parameter with a default value or as a service registered - for (int i = 0; i < _parameters.Length; i++) + for (int i = 0; i < _constructor.Parameters.Length; i++) { if (_parameterValues[i] == null && - !IsService(serviceProviderIsService, _parameters[i])) + !_constructor.IsService(serviceProviderIsService, i)) { - if (ParameterDefaultValue.TryGetDefaultValue(_parameters[i], out object? defaultValue)) + if (ParameterDefaultValue.TryGetDefaultValue(_constructor.Parameters[i], out object? defaultValue)) { _parameterValues[i] = defaultValue; } @@ -646,21 +757,21 @@ public int Match(object[] givenParameters, IServiceProviderIsService serviceProv } } - return _parameters.Length; + return _constructor.Parameters.Length; } public object CreateInstance(IServiceProvider provider) { - for (int index = 0; index < _parameters.Length; index++) + for (int index = 0; index < _constructor.Parameters.Length; index++) { if (_parameterValues[index] == null) { - object? value = GetService(provider, _parameters[index]); + object? value = _constructor.GetService(provider, index); if (value == null) { - if (!ParameterDefaultValue.TryGetDefaultValue(_parameters[index], out object? defaultValue)) + if (!ParameterDefaultValue.TryGetDefaultValue(_constructor.Parameters[index], out object? defaultValue)) { - throw new InvalidOperationException(SR.Format(SR.UnableToResolveService, _parameters[index].ParameterType, _constructor.DeclaringType)); + throw new InvalidOperationException(SR.Format(SR.UnableToResolveService, _constructor.Parameters[index].ParameterType, _constructor.Info.DeclaringType)); } else { @@ -677,7 +788,7 @@ public object CreateInstance(IServiceProvider provider) #if NETFRAMEWORK || NETSTANDARD2_0 try { - return _constructor.Invoke(_parameterValues); + return _constructor.Info.Invoke(_parameterValues); } catch (TargetInvocationException ex) when (ex.InnerException != null) { @@ -686,13 +797,13 @@ public object CreateInstance(IServiceProvider provider) throw; } #else - return _constructor.Invoke(BindingFlags.DoNotWrapExceptions, binder: null, parameters: _parameterValues, culture: null); + return _constructor.Info.Invoke(BindingFlags.DoNotWrapExceptions, binder: null, parameters: _parameterValues, culture: null); #endif } public void MapParameters(int?[] parameterMap, object[] givenParameters) { - for (int i = 0; i < _parameters.Length; i++) + for (int i = 0; i < _constructor.Parameters.Length; i++) { if (parameterMap[i] != null) { @@ -715,39 +826,39 @@ private static void ThrowMarkedCtorDoesNotTakeAllProvidedArguments() #if NET8_0_OR_GREATER // Use the faster ConstructorInvoker which also has alloc-free APIs when <= 4 parameters. private static object ReflectionFactoryServiceOnlyFixed( ConstructorInvoker invoker, - Type[] parameterTypes, + FactoryParameterContext[] parameters, Type declaringType, IServiceProvider serviceProvider) { - Debug.Assert(parameterTypes.Length >= 1 && parameterTypes.Length <= FixedArgumentThreshold); + Debug.Assert(parameters.Length >= 1 && parameters.Length <= FixedArgumentThreshold); Debug.Assert(FixedArgumentThreshold == 4); if (serviceProvider is null) ThrowHelperArgumentNullExceptionServiceProvider(); - switch (parameterTypes.Length) + switch (parameters.Length) { case 1: return invoker.Invoke( - GetService(serviceProvider, parameterTypes[0], declaringType, false)); + GetService(serviceProvider, parameters[0].ParameterType, declaringType, false, parameters[0].ServiceKey)); case 2: return invoker.Invoke( - GetService(serviceProvider, parameterTypes[0], declaringType, false), - GetService(serviceProvider, parameterTypes[1], declaringType, false)); + GetService(serviceProvider, parameters[0].ParameterType, declaringType, false, parameters[0].ServiceKey), + GetService(serviceProvider, parameters[1].ParameterType, declaringType, false, parameters[1].ServiceKey)); case 3: return invoker.Invoke( - GetService(serviceProvider, parameterTypes[0], declaringType, false), - GetService(serviceProvider, parameterTypes[1], declaringType, false), - GetService(serviceProvider, parameterTypes[2], declaringType, false)); + GetService(serviceProvider, parameters[0].ParameterType, declaringType, false, parameters[0].ServiceKey), + GetService(serviceProvider, parameters[1].ParameterType, declaringType, false, parameters[1].ServiceKey), + GetService(serviceProvider, parameters[2].ParameterType, declaringType, false, parameters[2].ServiceKey)); case 4: return invoker.Invoke( - GetService(serviceProvider, parameterTypes[0], declaringType, false), - GetService(serviceProvider, parameterTypes[1], declaringType, false), - GetService(serviceProvider, parameterTypes[2], declaringType, false), - GetService(serviceProvider, parameterTypes[3], declaringType, false)); + GetService(serviceProvider, parameters[0].ParameterType, declaringType, false, parameters[0].ServiceKey), + GetService(serviceProvider, parameters[1].ParameterType, declaringType, false, parameters[1].ServiceKey), + GetService(serviceProvider, parameters[2].ParameterType, declaringType, false, parameters[2].ServiceKey), + GetService(serviceProvider, parameters[3].ParameterType, declaringType, false, parameters[3].ServiceKey)); } return null!; @@ -755,17 +866,17 @@ private static object ReflectionFactoryServiceOnlyFixed( private static object ReflectionFactoryServiceOnlySpan( ConstructorInvoker invoker, - Type[] parameterTypes, + FactoryParameterContext[] parameters, Type declaringType, IServiceProvider serviceProvider) { if (serviceProvider is null) ThrowHelperArgumentNullExceptionServiceProvider(); - object?[] arguments = new object?[parameterTypes.Length]; - for (int i = 0; i < parameterTypes.Length; i++) + object?[] arguments = new object?[parameters.Length]; + for (int i = 0; i < parameters.Length; i++) { - arguments[i] = GetService(serviceProvider, parameterTypes[i], declaringType, false); + arguments[i] = GetService(serviceProvider, parameters[i].ParameterType, declaringType, false, parameters[i].ServiceKey); } return invoker.Invoke(arguments.AsSpan()); @@ -797,7 +908,8 @@ private static object ReflectionFactoryCanonicalFixed( serviceProvider, parameter1.ParameterType, declaringType, - parameter1.HasDefaultValue)) ?? parameter1.DefaultValue); + parameter1.HasDefaultValue, + parameter1.ServiceKey)) ?? parameter1.DefaultValue); case 2: { ref FactoryParameterContext parameter2 = ref parameters[1]; @@ -810,7 +922,8 @@ private static object ReflectionFactoryCanonicalFixed( serviceProvider, parameter1.ParameterType, declaringType, - parameter1.HasDefaultValue)) ?? parameter1.DefaultValue, + parameter1.HasDefaultValue, + parameter1.ServiceKey)) ?? parameter1.DefaultValue, ((parameter2.ArgumentIndex != -1) // Throws a NullReferenceException if arguments is null. Consistent with expression-based factory. ? arguments![parameter2.ArgumentIndex] @@ -818,7 +931,8 @@ private static object ReflectionFactoryCanonicalFixed( serviceProvider, parameter2.ParameterType, declaringType, - parameter2.HasDefaultValue)) ?? parameter2.DefaultValue); + parameter2.HasDefaultValue, + parameter2.ServiceKey)) ?? parameter2.DefaultValue); } case 3: { @@ -833,7 +947,8 @@ private static object ReflectionFactoryCanonicalFixed( serviceProvider, parameter1.ParameterType, declaringType, - parameter1.HasDefaultValue)) ?? parameter1.DefaultValue, + parameter1.HasDefaultValue, + parameter1.ServiceKey)) ?? parameter1.DefaultValue, ((parameter2.ArgumentIndex != -1) // Throws a NullReferenceException if arguments is null. Consistent with expression-based factory. ? arguments![parameter2.ArgumentIndex] @@ -841,7 +956,8 @@ private static object ReflectionFactoryCanonicalFixed( serviceProvider, parameter2.ParameterType, declaringType, - parameter2.HasDefaultValue)) ?? parameter2.DefaultValue, + parameter2.HasDefaultValue, + parameter2.ServiceKey)) ?? parameter2.DefaultValue, ((parameter3.ArgumentIndex != -1) // Throws a NullReferenceException if arguments is null. Consistent with expression-based factory. ? arguments![parameter3.ArgumentIndex] @@ -849,7 +965,8 @@ private static object ReflectionFactoryCanonicalFixed( serviceProvider, parameter3.ParameterType, declaringType, - parameter3.HasDefaultValue)) ?? parameter3.DefaultValue); + parameter3.HasDefaultValue, + parameter3.ServiceKey)) ?? parameter3.DefaultValue); } case 4: { @@ -865,7 +982,8 @@ private static object ReflectionFactoryCanonicalFixed( serviceProvider, parameter1.ParameterType, declaringType, - parameter1.HasDefaultValue)) ?? parameter1.DefaultValue, + parameter1.HasDefaultValue, + parameter1.ServiceKey)) ?? parameter1.DefaultValue, ((parameter2.ArgumentIndex != -1) // Throws a NullReferenceException if arguments is null. Consistent with expression-based factory. ? arguments![parameter2.ArgumentIndex] @@ -873,7 +991,8 @@ private static object ReflectionFactoryCanonicalFixed( serviceProvider, parameter2.ParameterType, declaringType, - parameter2.HasDefaultValue)) ?? parameter2.DefaultValue, + parameter2.HasDefaultValue, + parameter2.ServiceKey)) ?? parameter2.DefaultValue, ((parameter3.ArgumentIndex != -1) // Throws a NullReferenceException if arguments is null. Consistent with expression-based factory. ? arguments![parameter3.ArgumentIndex] @@ -881,7 +1000,8 @@ private static object ReflectionFactoryCanonicalFixed( serviceProvider, parameter3.ParameterType, declaringType, - parameter3.HasDefaultValue)) ?? parameter3.DefaultValue, + parameter3.HasDefaultValue, + parameter3.ServiceKey)) ?? parameter3.DefaultValue, ((parameter4.ArgumentIndex != -1) // Throws a NullReferenceException if arguments is null. Consistent with expression-based factory. ? arguments![parameter4.ArgumentIndex] @@ -889,7 +1009,8 @@ private static object ReflectionFactoryCanonicalFixed( serviceProvider, parameter4.ParameterType, declaringType, - parameter4.HasDefaultValue)) ?? parameter4.DefaultValue); + parameter4.HasDefaultValue, + parameter4.ServiceKey)) ?? parameter4.DefaultValue); } } @@ -918,7 +1039,8 @@ private static object ReflectionFactoryCanonicalSpan( serviceProvider, parameter.ParameterType, declaringType, - parameter.HasDefaultValue)) ?? parameter.DefaultValue; + parameter.HasDefaultValue, + parameter.ServiceKey)) ?? parameter.DefaultValue; } return invoker.Invoke(constructorArguments.AsSpan()); @@ -968,11 +1090,39 @@ private static object ReflectionFactoryCanonical( serviceProvider, parameter.ParameterType, declaringType, - parameter.HasDefaultValue)) ?? parameter.DefaultValue; + parameter.HasDefaultValue, + parameter.ServiceKey)) ?? parameter.DefaultValue; } return constructor.Invoke(BindingFlags.DoNotWrapExceptions, binder: null, constructorArguments, culture: null); } #endif // NET8_0_OR_GREATER + +#if NETCOREAPP + internal static class ActivatorUtilitiesUpdateHandler + { + public static void ClearCache(Type[]? _) + { + // Ignore the Type[] argument; just clear the caches. + s_constructorInfos.Clear(); + if (s_collectibleConstructorInfos.IsValueCreated) + { + s_collectibleConstructorInfos.Value.Clear(); + } + } + } +#endif + + private static object? GetKeyedService(IServiceProvider provider, Type type, object? serviceKey) + { + ThrowHelper.ThrowIfNull(provider); + + if (provider is IKeyedServiceProvider keyedServiceProvider) + { + return keyedServiceProvider.GetKeyedService(type, serviceKey); + } + + throw new InvalidOperationException(SR.KeyedServicesNotSupported); + } } } diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection.Abstractions/src/PACKAGE.md b/src/libraries/Microsoft.Extensions.DependencyInjection.Abstractions/src/PACKAGE.md new file mode 100644 index 00000000000000..6c8a654b7f4ff1 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.DependencyInjection.Abstractions/src/PACKAGE.md @@ -0,0 +1,34 @@ +## About +Supports the lower-level abstractions for the dependency injection (DI) software design pattern which is a technique for achieving Inversion of Control (IoC) between classes and their dependencies. + +## Key Features +- Interfaces for DI implementations which are provided in other packages including `Microsoft.Extensions.DependencyInjection`. +- An implementation of a service collection, which is used to add services to and later retrieve them either directly or through constructor injection. +- Interfaces, attributes and extensions methods to support various DI concepts including specifying a service's lifetime and supporting keyed services. + +## How to Use +This package is typically used with an implementation of the DI abstractions, such as `Microsoft.Extensions.DependencyInjection`. + +## Main Types +The main types provided by this library are: +* `Microsoft.Extensions.DependencyInjection.ActivatorUtilities` +* `Microsoft.Extensions.DependencyInjection.IServiceCollection` +* `Microsoft.Extensions.DependencyInjection.ServiceCollection` +* `Microsoft.Extensions.DependencyInjection.ServiceCollectionDescriptorExtensions` +* `Microsoft.Extensions.DependencyInjection.ServiceDescriptor` +* `Microsoft.Extensions.DependencyInjection.IServiceProviderFactory` + +## Additional Documentation +* [Conceptual documentation](https://learn.microsoft.com/dotnet/core/extensions/dependency-injection) +* API documentation + - [ActivatorUtilities](https://learn.microsoft.com/dotnet/api/microsoft.extensions.dependencyinjection.defaultserviceproviderfactory) + - [ServiceCollection](https://learn.microsoft.com/dotnet/api/microsoft.extensions.dependencyinjection.servicecollection) + - [ServiceDescriptor](https://learn.microsoft.com/dotnet/api/microsoft.extensions.dependencyinjection.servicedescriptor) + +## Related Packages +- `Microsoft.Extensions.DependencyInjection` +- `Microsoft.Extensions.Hosting` +- `Microsoft.Extensions.Options` + +## Feedback & Contributing +Microsoft.Extensions.DependencyInjection.Abstractions is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection.Specification.Tests/src/KeyedDependencyInjectionSpecificationTests.cs b/src/libraries/Microsoft.Extensions.DependencyInjection.Specification.Tests/src/KeyedDependencyInjectionSpecificationTests.cs index 32112c3e5f7691..a0dc73d58f820b 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection.Specification.Tests/src/KeyedDependencyInjectionSpecificationTests.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection.Specification.Tests/src/KeyedDependencyInjectionSpecificationTests.cs @@ -476,5 +476,41 @@ public ServiceProviderAccessor(IServiceProvider serviceProvider) public IServiceProvider ServiceProvider { get; } } + + [Fact] + public void SimpleServiceKeyedResolution() + { + // Arrange + var services = new ServiceCollection(); + services.AddKeyedTransient("simple"); + services.AddKeyedTransient("another"); + services.AddTransient(); + var provider = CreateServiceProvider(services); + var sut = provider.GetService(); + + // Act + var result = sut!.GetService("simple"); + + // Assert + Assert.True(result.GetType() == typeof(SimpleService)); + } + + public class SimpleParentWithDynamicKeyedService + { + private readonly IServiceProvider _serviceProvider; + + public SimpleParentWithDynamicKeyedService(IServiceProvider serviceProvider) + { + _serviceProvider = serviceProvider; + } + + public ISimpleService GetService(string name) => _serviceProvider.GetKeyedService(name)!; + } + + public interface ISimpleService { } + + public class SimpleService : ISimpleService { } + + public class AnotherSimpleService : ISimpleService { } } } diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ILLink/ILLink.Substitutions.xml b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ILLink/ILLink.Substitutions.xml index eb381de19d6153..6aa354ee23683c 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ILLink/ILLink.Substitutions.xml +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ILLink/ILLink.Substitutions.xml @@ -3,5 +3,8 @@ + + + diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/src/PACKAGE.md b/src/libraries/Microsoft.Extensions.DependencyInjection/src/PACKAGE.md new file mode 100644 index 00000000000000..91474fd46a2f29 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/src/PACKAGE.md @@ -0,0 +1,50 @@ +## About +Supports the dependency injection (DI) software design pattern which is a technique for achieving Inversion of Control (IoC) between classes and their dependencies. + +## Key Features +Provides an implementation of the DI interfaces found in the `Microsoft.Extensions.DependencyInjection.Abstractions` package. + +## How to Use +```cs +ServiceCollection services = new (); +services.AddSingleton(); +using ServiceProvider provider = services.BuildServiceProvider(); + +// The code below, following the IoC pattern, is typically only aware of the IMessageWriter interface, not the implementation. +IMessageWriter messageWriter = provider.GetService()!; +messageWriter.Write("Hello"); + +public interface IMessageWriter +{ + void Write(string message); +} + +internal class MessageWriter : IMessageWriter +{ + public void Write(string message) + { + Console.WriteLine($"MessageWriter.Write(message: \"{message}\")"); + } +} +``` + +## Main Types +The main types provided by this library are: +* `Microsoft.Extensions.DependencyInjection.DefaultServiceProviderFactory` +* `Microsoft.Extensions.DependencyInjection.ServiceCollectionContainerBuilderExtensions` +* `Microsoft.Extensions.DependencyInjection.ServiceProvider` + +## Additional Documentation +* [Conceptual documentation](https://learn.microsoft.com/dotnet/core/extensions/dependency-injection) +* API documentation + - [DefaultServiceProviderFactory](https://learn.microsoft.com/dotnet/api/microsoft.extensions.dependencyinjection.defaultserviceproviderfactory) + - [ServiceCollectionContainerBuilderExtensions](https://learn.microsoft.com/dotnet/api/microsoft.extensions.dependencyinjection.servicecollectioncontainerbuilderextensions) + - [ServiceProvider](https://learn.microsoft.com/dotnet/api/microsoft.extensions.dependencyinjection.serviceprovider) + +## Related Packages +- `Microsoft.Extensions.DependencyInjection.Abstractions` +- `Microsoft.Extensions.Hosting` +- `Microsoft.Extensions.Options` + +## Feedback & Contributing +Microsoft.Extensions.DependencyInjection is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/Expressions/ExpressionResolverBuilder.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/Expressions/ExpressionResolverBuilder.cs index c28b9f00bb4e7d..670c62ee2f69a8 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/Expressions/ExpressionResolverBuilder.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/Expressions/ExpressionResolverBuilder.cs @@ -123,6 +123,15 @@ static MethodInfo GetArrayEmptyMethodInfo(Type elementType) return ServiceLookupHelpers.GetArrayEmptyMethodInfo(elementType); } + [UnconditionalSuppressMessage("AotAnalysis", "IL3050:RequiresDynamicCode", + Justification = "VerifyAotCompatibility ensures elementType is not a ValueType")] + static NewArrayExpression NewArrayInit(Type elementType, IEnumerable expr) + { + Debug.Assert(!ServiceProvider.VerifyAotCompatibility || !elementType.IsValueType, "VerifyAotCompatibility=true will throw during building the IEnumerableCallSite if elementType is a ValueType."); + + return Expression.NewArrayInit(elementType, expr); + } + if (callSite.ServiceCallSites.Length == 0) { return Expression.Constant( @@ -130,7 +139,7 @@ static MethodInfo GetArrayEmptyMethodInfo(Type elementType) .Invoke(obj: null, parameters: Array.Empty())); } - return Expression.NewArrayInit( + return NewArrayInit( callSite.ItemType, callSite.ServiceCallSites.Select(cs => Convert( diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceProvider.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceProvider.cs index aa178741fc4516..ff5efbe98cf334 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceProvider.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceProvider.cs @@ -37,6 +37,9 @@ public sealed class ServiceProvider : IServiceProvider, IKeyedServiceProvider, I internal static bool VerifyOpenGenericServiceTrimmability { get; } = AppContext.TryGetSwitch("Microsoft.Extensions.DependencyInjection.VerifyOpenGenericServiceTrimmability", out bool verifyOpenGenerics) ? verifyOpenGenerics : false; + internal static bool DisableDynamicEngine { get; } = + AppContext.TryGetSwitch("Microsoft.Extensions.DependencyInjection.DisableDynamicEngine", out bool disableDynamicEngine) ? disableDynamicEngine : false; + internal static bool VerifyAotCompatibility => #if NETFRAMEWORK || NETSTANDARD2_0 false; @@ -246,7 +249,7 @@ private ServiceProviderEngine GetEngine() #if NETFRAMEWORK || NETSTANDARD2_0 engine = CreateDynamicEngine(); #else - if (RuntimeFeature.IsDynamicCodeCompiled) + if (RuntimeFeature.IsDynamicCodeCompiled && !DisableDynamicEngine) { engine = CreateDynamicEngine(); } diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/CollectibleAssembly/CollectableClasses.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/CollectibleAssembly/CollectableClasses.cs new file mode 100644 index 00000000000000..cc9e925ae0ad8e --- /dev/null +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/CollectibleAssembly/CollectableClasses.cs @@ -0,0 +1,25 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.Extensions.DependencyInjection; + +namespace CollectibleAssembly +{ + public class ClassToCreate + { + public object ClassAsCtorArgument { get; set; } + + public ClassToCreate(ClassAsCtorArgument obj) { ClassAsCtorArgument = obj; } + + public static object Create(ServiceProvider provider) + { + // Both the type to create (ClassToCreate) and the ctor's arg type (ClassAsCtorArgument) are + // located in this assembly, so both types need to be GC'd for this assembly to be collected. + return ActivatorUtilities.CreateInstance(provider, new ClassAsCtorArgument()); + } + } + + public class ClassAsCtorArgument + { + } +} diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/CollectibleAssembly/CollectibleAssembly.csproj b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/CollectibleAssembly/CollectibleAssembly.csproj new file mode 100644 index 00000000000000..82159cece28227 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/CollectibleAssembly/CollectibleAssembly.csproj @@ -0,0 +1,11 @@ + + + $(NetCoreAppCurrent);$(NetFrameworkMinimum) + true + + + + + + + diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ActivatorUtilitiesTests.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ActivatorUtilitiesTests.cs index 7572e6977a4c49..f6e7c2f3a8eb7c 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ActivatorUtilitiesTests.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ActivatorUtilitiesTests.cs @@ -2,8 +2,15 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.IO; +using System.Reflection; using Microsoft.DotNet.RemoteExecutor; using Xunit; +using System.Runtime.CompilerServices; + +#if NETCOREAPP +using System.Runtime.Loader; +#endif namespace Microsoft.Extensions.DependencyInjection.Tests { @@ -233,6 +240,100 @@ public void CreateFactory_CreatesFactoryMethod_5Types_5Injected() Assert.NotNull(item.Z); } + [ConditionalTheory(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))] + [InlineData(true)] +#if NETCOREAPP + [InlineData(false)] +#endif + public void CreateFactory_CreatesFactoryMethod_KeyedParams(bool useDynamicCode) + { + var options = new RemoteInvokeOptions(); + if (!useDynamicCode) + { + DisableDynamicCode(options); + } + + using var remoteHandle = RemoteExecutor.Invoke(static () => + { + var factory = ActivatorUtilities.CreateFactory(Type.EmptyTypes); + + var services = new ServiceCollection(); + services.AddSingleton(new A()); + services.AddKeyedSingleton("b", new B()); + services.AddKeyedSingleton("c", new C()); + using var provider = services.BuildServiceProvider(); + ClassWithAKeyedBKeyedC item = factory(provider, null); + + Assert.IsType>(factory); + Assert.NotNull(item.A); + Assert.NotNull(item.B); + Assert.NotNull(item.C); + }, options); + } + + [ConditionalTheory(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))] + [InlineData(true)] +#if NETCOREAPP + [InlineData(false)] +#endif + public void CreateFactory_CreatesFactoryMethod_KeyedParams_5Types(bool useDynamicCode) + { + var options = new RemoteInvokeOptions(); + if (!useDynamicCode) + { + DisableDynamicCode(options); + } + + using var remoteHandle = RemoteExecutor.Invoke(static () => + { + var factory = ActivatorUtilities.CreateFactory(Type.EmptyTypes); + + var services = new ServiceCollection(); + services.AddSingleton(new A()); + services.AddKeyedSingleton("b", new B()); + services.AddKeyedSingleton("c", new C()); + services.AddSingleton(new S()); + services.AddSingleton(new Z()); + using var provider = services.BuildServiceProvider(); + ClassWithAKeyedBKeyedCSZ item = factory(provider, null); + + Assert.IsType>(factory); + Assert.NotNull(item.A); + Assert.NotNull(item.B); + Assert.NotNull(item.C); + }, options); + } + + [ConditionalTheory(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))] + [InlineData(true)] +#if NETCOREAPP + [InlineData(false)] +#endif + public void CreateFactory_CreatesFactoryMethod_KeyedParams_1Injected(bool useDynamicCode) + { + var options = new RemoteInvokeOptions(); + if (!useDynamicCode) + { + DisableDynamicCode(options); + } + + using var remoteHandle = RemoteExecutor.Invoke(static () => + { + var factory = ActivatorUtilities.CreateFactory(new Type[] { typeof(A) }); + + var services = new ServiceCollection(); + services.AddKeyedSingleton("b", new B()); + services.AddKeyedSingleton("c", new C()); + using var provider = services.BuildServiceProvider(); + ClassWithAKeyedBKeyedC item = factory(provider, new object?[] { new A() }); + + Assert.IsType>(factory); + Assert.NotNull(item.A); + Assert.NotNull(item.B); + Assert.NotNull(item.C); + }, options); + } + [ConditionalTheory(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))] [InlineData(true)] #if NETCOREAPP @@ -386,6 +487,125 @@ public void CreateFactory_RemoteExecutor_NoParameters_Success(bool useDynamicCod }, options); } +#if NETCOREAPP + [ActiveIssue("https://github.com/dotnet/runtime/issues/34072", TestRuntimes.Mono)] + [ConditionalTheory(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))] + [InlineData(true)] + [InlineData(false)] + public void CreateInstance_CollectibleAssembly(bool useDynamicCode) + { + if (PlatformDetection.IsNonBundledAssemblyLoadingSupported) + { + RemoteInvokeOptions options = new(); + if (!useDynamicCode) + { + DisableDynamicCode(options); + } + + using var remoteHandle = RemoteExecutor.Invoke(static () => + { + Assert.False(Collectible_IsAssemblyLoaded()); + Collectible_LoadAndCreate(useCollectibleAssembly : true, out WeakReference asmWeakRef, out WeakReference typeWeakRef); + + for (int i = 0; (typeWeakRef.IsAlive || asmWeakRef.IsAlive) && (i < 10); i++) + { + GC.Collect(); + GC.WaitForPendingFinalizers(); + } + + // These should be GC'd. + Assert.False(asmWeakRef.IsAlive, "asmWeakRef.IsAlive"); + Assert.False(typeWeakRef.IsAlive, "typeWeakRef.IsAlive"); + Assert.False(Collectible_IsAssemblyLoaded()); + }, options); + } + } + + [ConditionalTheory(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))] + [InlineData(true)] + [InlineData(false)] + public void CreateInstance_NormalAssembly(bool useDynamicCode) + { + RemoteInvokeOptions options = new(); + if (!useDynamicCode) + { + DisableDynamicCode(options); + } + + using var remoteHandle = RemoteExecutor.Invoke(static () => + { + Assert.False(Collectible_IsAssemblyLoaded()); + Collectible_LoadAndCreate(useCollectibleAssembly: false, out WeakReference asmWeakRef, out WeakReference typeWeakRef); + + for (int i = 0; (typeWeakRef.IsAlive || asmWeakRef.IsAlive) && (i < 10); i++) + { + GC.Collect(); + GC.WaitForPendingFinalizers(); + } + + // These will not be GC'd. + Assert.True(asmWeakRef.IsAlive, "alcWeakRef.IsAlive"); + Assert.True(typeWeakRef.IsAlive, "typeWeakRef.IsAlive"); + Assert.True(Collectible_IsAssemblyLoaded()); + }, options); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + static void Collectible_LoadAndCreate(bool useCollectibleAssembly, out WeakReference asmWeakRef, out WeakReference typeWeakRef) + { + Assembly asm; + object obj; + + if (useCollectibleAssembly) + { + asm = MyLoadContext.LoadAsCollectable(); + obj = CreateWithActivator(asm); + Assert.True(obj.GetType().Assembly.IsCollectible); + } + else + { + asm = MyLoadContext.LoadNormal(); + obj = CreateWithActivator(asm); + Assert.False(obj.GetType().Assembly.IsCollectible); + } + + Assert.True(Collectible_IsAssemblyLoaded()); + asmWeakRef = new WeakReference(asm); + typeWeakRef = new WeakReference(obj.GetType()); + + static object CreateWithActivator(Assembly asm) + { + Type t = asm.GetType("CollectibleAssembly.ClassToCreate"); + MethodInfo mi = t.GetMethod("Create", BindingFlags.Static | BindingFlags.Public, new Type[] { typeof(ServiceProvider) }); + + object instance; + ServiceCollection services = new(); + using (ServiceProvider provider = services.BuildServiceProvider()) + { + instance = mi.Invoke(null, new object[] { provider }); + } + + return instance; + } + } + + static bool Collectible_IsAssemblyLoaded() + { + Assembly[] assemblies = AppDomain.CurrentDomain.GetAssemblies(); + for (int i = 0; i < assemblies.Length; i++) + { + Assembly asm = assemblies[i]; + string asmName = Path.GetFileName(asm.Location); + if (asmName == "CollectibleAssembly.dll") + { + return true; + } + } + + return false; + } +#endif + private static void DisableDynamicCode(RemoteInvokeOptions options) { // We probably only need to set 'IsDynamicCodeCompiled' since only that is checked, @@ -401,6 +621,13 @@ internal class C { } internal class S { } internal class Z { } + internal class ClassWithAKeyedBKeyedC : ClassWithABC + { + public ClassWithAKeyedBKeyedC(A a, [FromKeyedServices("b")] B b, [FromKeyedServices("c")] C c) + : base(a, b, c) + { } + } + internal class ClassWithABCS : ClassWithABC { public S S { get; } @@ -414,6 +641,13 @@ internal class ClassWithABCSZ : ClassWithABCS public ClassWithABCSZ(A a, B b, C c, S s, Z z) : base(a, b, c, s) { Z = z; } } + internal class ClassWithAKeyedBKeyedCSZ : ClassWithABCSZ + { + public ClassWithAKeyedBKeyedCSZ(A a, [FromKeyedServices("b")] B b, [FromKeyedServices("c")] C c, S s, Z z) + : base(a, b, c, s, z) + { } + } + internal class ClassWithABC_FirstConstructorWithAttribute : ClassWithABC { [ActivatorUtilitiesConstructor] @@ -581,5 +815,36 @@ public ClassWithStringDefaultValue(string text = "DEFAULT") Text = text; } } -} +#if NETCOREAPP + internal class MyLoadContext : AssemblyLoadContext + { + private MyLoadContext() : base(isCollectible: true) + { + } + + public Assembly LoadAssembly() + { + Assembly asm = LoadFromAssemblyPath(GetPath()); + Assert.Equal(GetLoadContext(asm), this); + return asm; + } + + public static Assembly LoadAsCollectable() + { + MyLoadContext alc = new MyLoadContext(); + return alc.LoadAssembly(); + } + + public static Assembly LoadNormal() + { + return Assembly.LoadFrom(GetPath()); + } + + private static string GetPath() + { + return Path.Combine(Directory.GetCurrentDirectory(), "CollectibleAssembly.dll"); + } + } +#endif +} diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/Microsoft.Extensions.DependencyInjection.Tests.csproj b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/Microsoft.Extensions.DependencyInjection.Tests.csproj index 4ac3c02d7157a3..067508506a82fe 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/Microsoft.Extensions.DependencyInjection.Tests.csproj +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/Microsoft.Extensions.DependencyInjection.Tests.csproj @@ -1,4 +1,4 @@ - + $(NetCoreAppCurrent);$(NetFrameworkMinimum) @@ -24,6 +24,7 @@ + diff --git a/src/libraries/Microsoft.Extensions.DependencyModel/src/PACKAGE.md b/src/libraries/Microsoft.Extensions.DependencyModel/src/PACKAGE.md index f4d4a485899ebb..19b13040a8aea9 100644 --- a/src/libraries/Microsoft.Extensions.DependencyModel/src/PACKAGE.md +++ b/src/libraries/Microsoft.Extensions.DependencyModel/src/PACKAGE.md @@ -1,16 +1,14 @@ ## About + + Provides abstractions for reading `.deps` files. When a .NET application is compiled, the SDK generates a JSON manifest file (`.deps.json`) that contains information about application dependencies. You can use `Microsoft.Extensions.DependencyModel` to read information from this manifest at run time. This is useful when you want to dynamically compile code (for example, using Roslyn Emit API) referencing the same dependencies as your main application. By default, the dependency manifest contains information about the application's target framework and runtime dependencies. Set the [PreserveCompilationContext](https://docs.microsoft.com/dotnet/core/project-sdk/msbuild-props#preservecompilationcontext) project property to `true` to additionally include information about reference assemblies used during compilation. -For more information, see the documentation: - -- [.deps.json file format](https://github.com/dotnet/sdk/blob/main/documentation/specs/runtime-configuration-file.md#appnamedepsjson) -- [Microsoft.Extensions.DependencyModel namespace](https://docs.microsoft.com/dotnet/api/microsoft.extensions.dependencymodel) -- [Microsoft.Extensions.DependencyModel.DependencyContext](https://docs.microsoft.com/dotnet/api/microsoft.extensions.dependencymodel.dependencycontext) +## How to Use -## Example + The following example shows how to display the list of assemblies used when compiling the current application. Include `true` in your project file to run this example. @@ -35,3 +33,17 @@ class Program } } ``` + +## Additional Documentation + + + +* [.deps.json file format](https://github.com/dotnet/sdk/blob/main/documentation/specs/runtime-configuration-file.md#appnamedepsjson) +* [Microsoft.Extensions.DependencyModel namespace](https://docs.microsoft.com/dotnet/api/microsoft.extensions.dependencymodel) +* [Microsoft.Extensions.DependencyModel.DependencyContext](https://docs.microsoft.com/dotnet/api/microsoft.extensions.dependencymodel.dependencycontext) + +## Feedback & Contributing + + + +Microsoft.Extensions.DependencyModel is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). \ No newline at end of file diff --git a/src/libraries/Microsoft.Extensions.Diagnostics/Microsoft.Extensions.Diagnostics.sln b/src/libraries/Microsoft.Extensions.Diagnostics/Microsoft.Extensions.Diagnostics.sln index 5e4931ef73fb1d..5b249c373b11c6 100644 --- a/src/libraries/Microsoft.Extensions.Diagnostics/Microsoft.Extensions.Diagnostics.sln +++ b/src/libraries/Microsoft.Extensions.Diagnostics/Microsoft.Extensions.Diagnostics.sln @@ -59,6 +59,20 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.Extensions.Config EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.Extensions.Options", "..\Microsoft.Extensions.Options\ref\Microsoft.Extensions.Options.csproj", "{DBAB1C82-A3A0-4ADC-95BC-B87557C61C42}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.Extensions.Primitives", "..\Microsoft.Extensions.Primitives\ref\Microsoft.Extensions.Primitives.csproj", "{6BB43905-3DBD-47E4-A38F-2BE319300B15}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.Extensions.Primitives", "..\Microsoft.Extensions.Primitives\src\Microsoft.Extensions.Primitives.csproj", "{711B2905-FDC7-4D67-B40B-9DEFF042CB01}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.Extensions.Options", "..\Microsoft.Extensions.Options\src\Microsoft.Extensions.Options.csproj", "{B233AB55-788C-48B6-9557-098B8D0DDBFF}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.Extensions.Configuration.Binder", "..\Microsoft.Extensions.Configuration.Binder\src\Microsoft.Extensions.Configuration.Binder.csproj", "{ECF14067-8633-4DDA-8EAE-124989F8E09E}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.Extensions.Configuration.Binder", "..\Microsoft.Extensions.Configuration.Binder\ref\Microsoft.Extensions.Configuration.Binder.csproj", "{D835A0A8-C213-461F-8B41-6F2715DBEC43}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.Extensions.Options.ConfigurationExtensions", "..\Microsoft.Extensions.Options.ConfigurationExtensions\ref\Microsoft.Extensions.Options.ConfigurationExtensions.csproj", "{F2C0D619-8CAF-4F81-B681-3F75AF79661F}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.Extensions.Configuration", "..\Microsoft.Extensions.Configuration\ref\Microsoft.Extensions.Configuration.csproj", "{57AF678A-3671-4B9E-9608-053E2197D0A4}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -161,6 +175,34 @@ Global {DBAB1C82-A3A0-4ADC-95BC-B87557C61C42}.Debug|Any CPU.Build.0 = Debug|Any CPU {DBAB1C82-A3A0-4ADC-95BC-B87557C61C42}.Release|Any CPU.ActiveCfg = Release|Any CPU {DBAB1C82-A3A0-4ADC-95BC-B87557C61C42}.Release|Any CPU.Build.0 = Release|Any CPU + {6BB43905-3DBD-47E4-A38F-2BE319300B15}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {6BB43905-3DBD-47E4-A38F-2BE319300B15}.Debug|Any CPU.Build.0 = Debug|Any CPU + {6BB43905-3DBD-47E4-A38F-2BE319300B15}.Release|Any CPU.ActiveCfg = Release|Any CPU + {6BB43905-3DBD-47E4-A38F-2BE319300B15}.Release|Any CPU.Build.0 = Release|Any CPU + {711B2905-FDC7-4D67-B40B-9DEFF042CB01}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {711B2905-FDC7-4D67-B40B-9DEFF042CB01}.Debug|Any CPU.Build.0 = Debug|Any CPU + {711B2905-FDC7-4D67-B40B-9DEFF042CB01}.Release|Any CPU.ActiveCfg = Release|Any CPU + {711B2905-FDC7-4D67-B40B-9DEFF042CB01}.Release|Any CPU.Build.0 = Release|Any CPU + {B233AB55-788C-48B6-9557-098B8D0DDBFF}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {B233AB55-788C-48B6-9557-098B8D0DDBFF}.Debug|Any CPU.Build.0 = Debug|Any CPU + {B233AB55-788C-48B6-9557-098B8D0DDBFF}.Release|Any CPU.ActiveCfg = Release|Any CPU + {B233AB55-788C-48B6-9557-098B8D0DDBFF}.Release|Any CPU.Build.0 = Release|Any CPU + {ECF14067-8633-4DDA-8EAE-124989F8E09E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {ECF14067-8633-4DDA-8EAE-124989F8E09E}.Debug|Any CPU.Build.0 = Debug|Any CPU + {ECF14067-8633-4DDA-8EAE-124989F8E09E}.Release|Any CPU.ActiveCfg = Release|Any CPU + {ECF14067-8633-4DDA-8EAE-124989F8E09E}.Release|Any CPU.Build.0 = Release|Any CPU + {D835A0A8-C213-461F-8B41-6F2715DBEC43}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {D835A0A8-C213-461F-8B41-6F2715DBEC43}.Debug|Any CPU.Build.0 = Debug|Any CPU + {D835A0A8-C213-461F-8B41-6F2715DBEC43}.Release|Any CPU.ActiveCfg = Release|Any CPU + {D835A0A8-C213-461F-8B41-6F2715DBEC43}.Release|Any CPU.Build.0 = Release|Any CPU + {F2C0D619-8CAF-4F81-B681-3F75AF79661F}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {F2C0D619-8CAF-4F81-B681-3F75AF79661F}.Debug|Any CPU.Build.0 = Debug|Any CPU + {F2C0D619-8CAF-4F81-B681-3F75AF79661F}.Release|Any CPU.ActiveCfg = Release|Any CPU + {F2C0D619-8CAF-4F81-B681-3F75AF79661F}.Release|Any CPU.Build.0 = Release|Any CPU + {57AF678A-3671-4B9E-9608-053E2197D0A4}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {57AF678A-3671-4B9E-9608-053E2197D0A4}.Debug|Any CPU.Build.0 = Debug|Any CPU + {57AF678A-3671-4B9E-9608-053E2197D0A4}.Release|Any CPU.ActiveCfg = Release|Any CPU + {57AF678A-3671-4B9E-9608-053E2197D0A4}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -190,6 +232,13 @@ Global {A2853038-B04A-4BAA-B0B4-0481457003B8} = {A447D0CB-601B-479E-A2B2-76E48F5D4D61} {A77E804D-4576-4962-A248-92E538ED997C} = {A447D0CB-601B-479E-A2B2-76E48F5D4D61} {DBAB1C82-A3A0-4ADC-95BC-B87557C61C42} = {9BF048D0-411D-4C2A-8C32-3A3255501D27} + {6BB43905-3DBD-47E4-A38F-2BE319300B15} = {9BF048D0-411D-4C2A-8C32-3A3255501D27} + {711B2905-FDC7-4D67-B40B-9DEFF042CB01} = {A447D0CB-601B-479E-A2B2-76E48F5D4D61} + {B233AB55-788C-48B6-9557-098B8D0DDBFF} = {A447D0CB-601B-479E-A2B2-76E48F5D4D61} + {ECF14067-8633-4DDA-8EAE-124989F8E09E} = {A447D0CB-601B-479E-A2B2-76E48F5D4D61} + {D835A0A8-C213-461F-8B41-6F2715DBEC43} = {9BF048D0-411D-4C2A-8C32-3A3255501D27} + {F2C0D619-8CAF-4F81-B681-3F75AF79661F} = {9BF048D0-411D-4C2A-8C32-3A3255501D27} + {57AF678A-3671-4B9E-9608-053E2197D0A4} = {9BF048D0-411D-4C2A-8C32-3A3255501D27} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {7D279EE5-E38F-4125-AE82-6ADE52D72F26} diff --git a/src/libraries/Microsoft.Extensions.Diagnostics/src/Metrics/ListenerSubscription.cs b/src/libraries/Microsoft.Extensions.Diagnostics/src/Metrics/ListenerSubscription.cs index 24c6d0b7b7f297..c6eb978af5441a 100644 --- a/src/libraries/Microsoft.Extensions.Diagnostics/src/Metrics/ListenerSubscription.cs +++ b/src/libraries/Microsoft.Extensions.Diagnostics/src/Metrics/ListenerSubscription.cs @@ -165,36 +165,39 @@ internal static bool RuleMatches(InstrumentRule rule, Instrument instrument, str // Meter - var ruleMeterName = rule.MeterName.AsSpan(); - // Don't allow "*" anywhere except at the end. - var starIndex = ruleMeterName.IndexOf('*'); - if (starIndex != -1 && starIndex != ruleMeterName.Length - 1) + // The same logic as Microsoft.Extensions.Logging.LoggerRuleSelector.IsBetter for category names + var meterName = rule.MeterName; + if (meterName != null) { - return false; - } - // Rule "System.Net.*" matches meter "System.Net" and "System.Net.Http" - if (ruleMeterName.EndsWith(".*".AsSpan(), StringComparison.Ordinal)) - { - ruleMeterName = ruleMeterName.Slice(0, ruleMeterName.Length - 2); - } - // System.Net* matches System.Net and System.Net.Http - else if (starIndex != -1) - { - ruleMeterName = ruleMeterName.Slice(0, ruleMeterName.Length - 1); - } + const char WildcardChar = '*'; - // Rule "" matches everything - if (ruleMeterName.IsEmpty) - { - return true; + int wildcardIndex = meterName.IndexOf(WildcardChar); + if (wildcardIndex >= 0 && + meterName.IndexOf(WildcardChar, wildcardIndex + 1) >= 0) + { + throw new InvalidOperationException(SR.MoreThanOneWildcard); + } + + ReadOnlySpan prefix, suffix; + if (wildcardIndex < 0) + { + prefix = meterName.AsSpan(); + suffix = default; + } + else + { + prefix = meterName.AsSpan(0, wildcardIndex); + suffix = meterName.AsSpan(wildcardIndex + 1); + } + + if (!instrument.Meter.Name.AsSpan().StartsWith(prefix, StringComparison.OrdinalIgnoreCase) || + !instrument.Meter.Name.AsSpan().EndsWith(suffix, StringComparison.OrdinalIgnoreCase)) + { + return false; + } } - // "System.Net" matches "System.Net" and "System.Net.Http" - return instrument.Meter.Name.AsSpan().StartsWith(ruleMeterName, StringComparison.OrdinalIgnoreCase) - // Exact match +/- ".*" - && (ruleMeterName.Length == instrument.Meter.Name.Length - // Only allow StartsWith on segment boundaries - || instrument.Meter.Name[ruleMeterName.Length] == '.'); + return true; } // Everything must already match the Instrument and listener, or be blank. diff --git a/src/libraries/Microsoft.Extensions.Diagnostics/src/Metrics/MetricsServiceExtensions.cs b/src/libraries/Microsoft.Extensions.Diagnostics/src/Metrics/MetricsServiceExtensions.cs index 79a08c78b82e7c..47fa1e0dee8193 100644 --- a/src/libraries/Microsoft.Extensions.Diagnostics/src/Metrics/MetricsServiceExtensions.cs +++ b/src/libraries/Microsoft.Extensions.Diagnostics/src/Metrics/MetricsServiceExtensions.cs @@ -4,6 +4,7 @@ using Microsoft.Extensions.DependencyInjection.Extensions; using Microsoft.Extensions.Diagnostics.Metrics; using Microsoft.Extensions.Diagnostics.Metrics.Configuration; +using Microsoft.Extensions.Options; using System; using System.Diagnostics.Metrics; @@ -32,7 +33,9 @@ public static IServiceCollection AddMetrics(this IServiceCollection services) services.TryAddSingleton(); // Make sure the subscription manager is started when the host starts. // The host will trigger options validation. - services.AddOptions().Configure((_, manager) => manager.Initialize()).ValidateOnStart(); + services.AddOptions().ValidateOnStart(); + // Make sure this is only registered/run once. + services.TryAddSingleton, SubscriptionActivator>(); services.TryAddSingleton(); @@ -66,5 +69,10 @@ private sealed class MetricsBuilder(IServiceCollection services) : IMetricsBuild } private sealed class NoOpOptions { } + + private sealed class SubscriptionActivator(MetricsSubscriptionManager manager) : IConfigureOptions + { + public void Configure(NoOpOptions options) => manager.Initialize(); + } } } diff --git a/src/libraries/Microsoft.Extensions.Diagnostics/src/Resources/Strings.resx b/src/libraries/Microsoft.Extensions.Diagnostics/src/Resources/Strings.resx index ab0d5ca5a461b3..0589f84ba89e7d 100644 --- a/src/libraries/Microsoft.Extensions.Diagnostics/src/Resources/Strings.resx +++ b/src/libraries/Microsoft.Extensions.Diagnostics/src/Resources/Strings.resx @@ -1,4 +1,64 @@ - + + + @@ -60,4 +120,7 @@ The meter factory does not allow a custom scope value when creating a meter. - + + Only one wildcard character is allowed in category name. + + \ No newline at end of file diff --git a/src/libraries/Microsoft.Extensions.Diagnostics/tests/ListenerSubscriptionTests.cs b/src/libraries/Microsoft.Extensions.Diagnostics/tests/ListenerSubscriptionTests.cs index da59cffc28d6a6..a0c5d281bfaa0f 100644 --- a/src/libraries/Microsoft.Extensions.Diagnostics/tests/ListenerSubscriptionTests.cs +++ b/src/libraries/Microsoft.Extensions.Diagnostics/tests/ListenerSubscriptionTests.cs @@ -214,13 +214,20 @@ public void RuleCanBeTurnedOffAndOnAgain() [InlineData("", "", "")] [InlineData("*", "", "")] [InlineData("lonG", "", "")] + [InlineData("lonG.", "", "")] [InlineData("lonG*", "", "")] [InlineData("lonG.*", "", "")] + [InlineData("lonG.sil", "", "")] + [InlineData("lonG.sil*", "", "")] [InlineData("lonG.sillY.meteR", "", "")] [InlineData("lonG.sillY.meteR*", "", "")] [InlineData("lonG.sillY.meteR.*", "", "")] + [InlineData("*namE", "", "")] + [InlineData("*.namE", "", "")] + [InlineData("*.sillY.meteR.Name", "", "")] + [InlineData("long*Name", "", "")] + [InlineData("lonG.sillY.meter*MeteR.namE", "", "")] // Shouldn't match, but does, left for compatibility with Logging. [InlineData("lonG.sillY.meteR.namE", "", "")] - [InlineData("lonG.sillY.meteR.namE.*", "", "")] [InlineData("", "instrumenTnamE", "")] [InlineData("lonG.sillY.meteR.namE", "instrumenTnamE", "")] [InlineData("", "", "listeneRnamE")] @@ -237,15 +244,13 @@ public void RuleMatchesTest(string meterName, string instrumentName, string list } [ConditionalTheory(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))] - [InlineData("*.*", "", "")] [InlineData("", "*", "")] [InlineData("", "", "*")] - [InlineData("lonG.", "", "")] - [InlineData("lonG.sil", "", "")] - [InlineData("lonG.sil*", "", "")] [InlineData("sillY.meteR.namE", "", "")] + [InlineData(".*", "", "")] + [InlineData("*.", "", "")] + [InlineData("lonG.sillY.meteR.namE.*", "", "")] [InlineData("namE", "", "")] - [InlineData("*.namE", "", "")] [InlineData("wrongMeter", "", "")] [InlineData("wrongMeter", "InstrumentName", "")] [InlineData("wrongMeter", "", "ListenerName")] @@ -261,6 +266,17 @@ public void RuleMatchesNegativeTest(string meterName, string instrumentName, str }, meterName, instrumentName, listenerName).Dispose(); } + [ConditionalFact(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))] + public void MultipleWildcardsThrows() + { + RemoteExecutor.Invoke(() => { + var rule = new InstrumentRule("*.*", null, null, MeterScope.Global, enable: true); + var meter = new Meter("Long.Silly.Meter.Name"); + var instrument = meter.CreateCounter("InstrumentName"); + Assert.Throws< InvalidOperationException>(() => ListenerSubscription.RuleMatches(rule, instrument, "ListenerName", new FakeMeterFactory())); + }).Dispose(); + } + [Theory] [MemberData(nameof(IsMoreSpecificTestData))] public void IsMoreSpecificTest(InstrumentRule rule, InstrumentRule? best, bool isLocalScope) @@ -388,3 +404,8 @@ private class FakeMeterFactory : IMeterFactory } } } + +internal class SR +{ + public static string MoreThanOneWildcard => "More than one wildcard is not allowed in a rule."; +} diff --git a/src/libraries/Microsoft.Extensions.Diagnostics/tests/MetricsSubscriptionManagerTests.cs b/src/libraries/Microsoft.Extensions.Diagnostics/tests/MetricsSubscriptionManagerTests.cs new file mode 100644 index 00000000000000..60049f6820e963 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Diagnostics/tests/MetricsSubscriptionManagerTests.cs @@ -0,0 +1,44 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics.Metrics; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Diagnostics.Metrics; +using Microsoft.Extensions.Options; +using Xunit; + +namespace Microsoft.Extensions.Diagnostics.Tests +{ + public class MetricsSubscriptionManagerTests + { + [Fact] + public void AddMetrics_InitializesListeners() + { + var serviceCollection = new ServiceCollection(); + serviceCollection.AddMetrics(); // Duplicate call, should not add things twice. + serviceCollection.AddMetrics(l => l.AddListener()); + var serviceProvider = serviceCollection.BuildServiceProvider(); + + // Make sure the subscription manager is started. + serviceProvider.GetRequiredService().Validate(); + + var listeners = serviceProvider.GetRequiredService>(); + + var listener = Assert.Single(listeners); + var fakeListener = Assert.IsType(listener); + Assert.Equal(1, fakeListener.InitializeCount); + } + + private class FakeListener : IMetricsListener + { + public string Name => "Fake"; + public int InitializeCount { get; private set; } + public MeasurementHandlers GetMeasurementHandlers() => new MeasurementHandlers(); + public void Initialize(IObservableInstrumentsSource source) => InitializeCount++; + public bool InstrumentPublished(Instrument instrument, out object? userState) => throw new NotImplementedException(); + public void MeasurementsCompleted(Instrument instrument, object? userState) => throw new NotImplementedException(); + } + } +} diff --git a/src/libraries/Microsoft.Extensions.FileProviders.Abstractions/src/PACKAGE.md b/src/libraries/Microsoft.Extensions.FileProviders.Abstractions/src/PACKAGE.md new file mode 100644 index 00000000000000..bbb9b68a06e4a6 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.FileProviders.Abstractions/src/PACKAGE.md @@ -0,0 +1,51 @@ +## About + + + +Serves as the foundation for creating file providers in .NET, offering core abstractions to develop custom file providers capable of fetching files from various sources. + +## Key Features + + + +* Core abstractions for creating and managing file providers. +* Flexibility to develop custom file providers for fetching files from distinct sources. + +## How to Use + + + +This package is typically used with an implementation of the file provider abstractions, such as `Microsoft.Extensions.FileProviders.Composite` or `Microsoft.Extensions.FileProviders.Physical`. + +## Main Types + + + +The main types provided by this library are: + +* `Microsoft.Extensions.FileProviders.IFileProvider` +* `Microsoft.Extensions.FileProviders.IDirectoryContents` +* `Microsoft.Extensions.FileProviders.IFileInfo` +* `Microsoft.Extensions.FileProviders.NullFileProvider` + +## Additional Documentation + + + +* [Conceptual documentation](https://learn.microsoft.com/aspnet/core/fundamentals/file-providers) +* [Detect changes with change tokens](https://learn.microsoft.com/aspnet/core/fundamentals/change-tokens) +* [API documentation](https://learn.microsoft.com/dotnet/api/microsoft.extensions.fileproviders) + +## Related Packages + + + +* File provider for physical files: [Microsoft.Extensions.FileProviders.Physical](https://www.nuget.org/packages/Microsoft.Extensions.FileProviders.Physical/) +* File provider for files in embedded resources: [Microsoft.Extensions.FileProviders.Embedded](https://www.nuget.org/packages/Microsoft.Extensions.FileProviders.Embedded/) +* Composite file and directory providers: [Microsoft.Extensions.FileProviders.Composite](https://www.nuget.org/packages/Microsoft.Extensions.FileProviders.Composite/) + +## Feedback & Contributing + + + +Microsoft.Extensions.FileProviders.Abstractions is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). diff --git a/src/libraries/Microsoft.Extensions.FileProviders.Physical/src/PACKAGE.md b/src/libraries/Microsoft.Extensions.FileProviders.Physical/src/PACKAGE.md new file mode 100644 index 00000000000000..6ffcd733120209 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.FileProviders.Physical/src/PACKAGE.md @@ -0,0 +1,70 @@ +## About + + + +Provides an implementation of a physical file provider, facilitating file access and monitoring on the disk. The primary type, [`PhysicalFileProvider`](https://learn.microsoft.com/dotnet/api/microsoft.extensions.fileproviders.physicalfileprovider), enables the lookup of files on disk and can watch for changes either via `FileSystemWatcher` or polling mechanisms. + + +## Key Features + + + +* Easy access and monitoring of files on the disk. +* Ability to watch for file changes either by using `FileSystemWatcher` or through polling. + +## How to Use + + + +This library can be used to look up files on disk and monitor file changes effectively. +Below is an example of how to use the `PhysicalFileProvider` to access files on disk and monitor changes: + +```c# +using Microsoft.Extensions.FileProviders; +using Microsoft.Extensions.FileProviders.Physical; + +using var provider = new PhysicalFileProvider(AppContext.BaseDirectory); + +Environment.SetEnvironmentVariable("DOTNET_USE_POLLING_FILE_WATCHER", "1"); + +var contents = provider.GetDirectoryContents(string.Empty); +foreach (PhysicalFileInfo fileInfo in contents) +{ + Console.WriteLine(fileInfo.PhysicalPath); +} + +var changeToken = provider.Watch("*.txt"); +changeToken.RegisterChangeCallback(_ => Console.WriteLine("Text file changed"), null); + +Console.ReadLine(); +``` + +## Main Types + + + +The main types provided by this library are: + +* `Microsoft.Extensions.FileProviders.PhysicalFileProvider` +* `Microsoft.Extensions.FileProviders.PhysicalDirectoryInfo` +* `Microsoft.Extensions.FileProviders.PhysicalFileInfo` + +## Additional Documentation + + + +* [Conceptual documentation](https://learn.microsoft.com/aspnet/core/fundamentals/file-providers#physical-file-provider) +* [API documentation](https://learn.microsoft.com/dotnet/api/microsoft.extensions.fileproviders.physical) + +## Related Packages + + + +* Abstractions of files and directories: [Microsoft.Extensions.FileProviders.Abstractions](https://www.nuget.org/packages/Microsoft.Extensions.FileProviders.Abstractions/) +* File system globbing to find files matching a specified pattern: [Microsoft.Extensions.FileSystemGlobbing](https://www.nuget.org/packages/Microsoft.Extensions.FileSystemGlobbing/) + +## Feedback & Contributing + + + +Microsoft.Extensions.FileProviders.Physical is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). diff --git a/src/libraries/Microsoft.Extensions.FileSystemGlobbing/src/PACKAGE.md b/src/libraries/Microsoft.Extensions.FileSystemGlobbing/src/PACKAGE.md new file mode 100644 index 00000000000000..25bd9129c3968b --- /dev/null +++ b/src/libraries/Microsoft.Extensions.FileSystemGlobbing/src/PACKAGE.md @@ -0,0 +1,52 @@ +## About + + + +Provides support for matching file system names/paths using [glob patterns](https://en.wikipedia.org/wiki/Glob_(programming)). + +## Key Features + + + +* Contains the `Matcher` type, which can be used to match files in the file system based on user-defined patterns. + +## How to Use + + + +Get all matching files: + +```c# +using Microsoft.Extensions.FileSystemGlobbing; + +Matcher matcher = new(); +matcher.AddIncludePatterns(new[] { "*.txt", "*.asciidoc", "*.md" }); + +string searchDirectory = "../starting-folder/"; + +IEnumerable matchingFiles = matcher.GetResultsInFullPath(searchDirectory); + +// Use matchingFiles if there are any found. +// The files in this collection are fully qualified file system paths. +``` + +## Main Types + + + +The main types provided by this library are: + +* `Microsoft.Extensions.FileSystemGlobbing.Matcher` + +## Additional Documentation + + + +* [Conceptual documentation](https://learn.microsoft.com/dotnet/core/extensions/file-globbing) +* [API documentation](https://learn.microsoft.com/dotnet/api/microsoft.extensions.filesystemglobbing) + +## Feedback & Contributing + + + +Microsoft.Extensions.FileSystemGlobbing is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). \ No newline at end of file diff --git a/src/libraries/Microsoft.Extensions.Hosting.Abstractions/src/PACKAGE.md b/src/libraries/Microsoft.Extensions.Hosting.Abstractions/src/PACKAGE.md new file mode 100644 index 00000000000000..f017c0b7193271 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Hosting.Abstractions/src/PACKAGE.md @@ -0,0 +1,43 @@ +## About +Contains abstractions to host user code in an application by encapsulating an application's resources and lifetime functionality including: +- Dependency injection (DI). +- Logging. +- Configuration. +- Starting, stopping and obtaining notifications. + +This package is also used to wire up specific application models like ASP.NET Core that are built on top of hosting. + +## Key Features +* Provides the `BackgroundService` base class and the `IHostedService` interface for implementing worker services. +* Provides interfaces used to configure and start\stop a host. +* Provides types to obtain environment settings such as an application name and paths. + +## How to Use +See the Conceptual documentation below for using `BackgroundService` and `IHostedService` to host worker services. + +## Main Types +The main types provided by this library are: + +* `Microsoft.Extensions.Hosting.BackgroundService` +* `Microsoft.Extensions.Hosting.IHostBuilder` +* `Microsoft.Extensions.Hosting.IHostedService` + +## Additional Documentation +* Conceptual documentation + - [Worker services in .NET](https://learn.microsoft.com/dotnet/core/extensions/workers) + - [Implement the IHostedService interface](https://learn.microsoft.com/dotnet/core/extensions/timer-service) +* API documentation + - [BackgroundService](https://learn.microsoft.com/dotnet/api/microsoft.extensions.hosting.backgroundservice) + - [IHostBuilder](https://learn.microsoft.com/dotnet/api/microsoft.extensions.hosting.ihostbuilder) + - [IHostedService](https://learn.microsoft.com/dotnet/api/microsoft.extensions.hosting.ihostedservice) + +## Related Packages +- `Microsoft.Extensions.Hosting` +- `Microsoft.Extensions.Configuration.Abstractions` +- `Microsoft.Extensions.DependencyInjection.Abstractions` +- `Microsoft.Extensions.Diagnostics.Abstractions` +- `Microsoft.Extensions.FileProviders.Abstractions` +- `Microsoft.Extensions.Logging.Abstractions` + +## Feedback & Contributing +Microsoft.Extensions.Hosting.Abstractions is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). diff --git a/src/libraries/Microsoft.Extensions.Hosting.WindowsServices/src/PACKAGE.md b/src/libraries/Microsoft.Extensions.Hosting.WindowsServices/src/PACKAGE.md new file mode 100644 index 00000000000000..65908c6aaa7b63 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Hosting.WindowsServices/src/PACKAGE.md @@ -0,0 +1,44 @@ +## About +Supports using Windows Services with the hosting infrastructure. + +## Key Features +* Can configure a host to be a Windows Service. + +## How to Use +From a Worker Service app created using the Visual Studio template: +```cs +IHost host = Host.CreateDefaultBuilder(args) + .ConfigureServices(services => + { + services.AddHostedService(); + }) + // Configure as a Windows Service + .UseWindowsService(options => + { + options.ServiceName = "My Service"; + }) + .Build(); + +host.Run(); +``` + +## Main Types +The main types provided by this library are: +* `Microsoft.Extensions.Hosting.WindowsServiceLifetimeHostBuilderExtensions` +* `Microsoft.Extensions.Hosting.WindowsServices.WindowsServiceLifetime` + +## Additional Documentation +* [WindowsServiceLifetime](https://learn.microsoft.com/dotnet/api/microsoft.extensions.hosting.windowsservices.windowsservicelifetime) +* [WindowsServiceLifetimeHostBuilderExtensions](https://learn.microsoft.com/dotnet/api/microsoft.extensions.hosting.windowsservicelifetimehostbuilderextensions) +* [Create Windows Service using BackgroundService](https://learn.microsoft.com/dotnet/core/extensions/windows-service) +* [Host ASP.NET Core in a Windows Service](https://learn.microsoft.com/aspnet/core/host-and-deploy/windows-service?tabs=visual-studio) + +## Related Packages +- `Microsoft.Extensions.Hosting` +- `System.ServiceProcess.ServiceController` + +## Feedback & Contributing + + + +Microsoft.Extensions.Hosting.WindowsServices is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). diff --git a/src/libraries/Microsoft.Extensions.Hosting/src/PACKAGE.md b/src/libraries/Microsoft.Extensions.Hosting/src/PACKAGE.md new file mode 100644 index 00000000000000..836433b1a4bba5 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Hosting/src/PACKAGE.md @@ -0,0 +1,81 @@ +## About + + + +Contains the .NET Generic Host `HostBuilder` which layers on the `Microsoft.Extensions.Hosting.Abstractions` package. + +## Key Features + + + +* Contains the .NET Generic Host `HostBuilder`. + +## How to Use + + + +For a console app project: +```C# + using (IHost host = new HostBuilder().Build()) + { + var lifetime = host.Services.GetRequiredService(); + + lifetime.ApplicationStarted.Register(() => + { + Console.WriteLine("Started"); + }); + lifetime.ApplicationStopping.Register(() => + { + Console.WriteLine("Stopping firing"); + Console.WriteLine("Stopping end"); + }); + lifetime.ApplicationStopped.Register(() => + { + Console.WriteLine("Stopped firing"); + Console.WriteLine("Stopped end"); + }); + + host.Start(); + + // Listens for Ctrl+C. + host.WaitForShutdown(); + } +``` + +## Main Types + + + +The main types provided by this library are: + +* `Microsoft.Extensions.Host`. +* `Microsoft.Extensions.Hosting.HostApplicationBuilder` +* `Microsoft.Extensions.Hosting.HostBuilder` +* `Microsoft.Extensions.Hosting.IHostedService` +* `Microsoft.Extensions.Hosting.IHostedLifecycleService` + +## Additional Documentation + + + +* [Generic host](https://learn.microsoft.com/dotnet/core/extensions/generic-host) +* API documentation + - [Host](https://learn.microsoft.com/dotnet/api/microsoft.extensions.hosting.host) + - [HostApplicationBuilder](https://learn.microsoft.com/dotnet/api/microsoft.extensions.hosting.hostapplicationbuilder) + - [HostBuilder](https://learn.microsoft.com/dotnet/api/microsoft.extensions.hosting.hostbuilder) + +## Related Packages + + + +- `Microsoft.Extensions.Configuration` +- `Microsoft.Extensions.DependencyInjection` +- `Microsoft.Extensions.Hosting.Abstractions` +- `Microsoft.Extensions.Logging` +- `Microsoft.Extensions.Options` + +## Feedback & Contributing + + + +Microsoft.Extensions.Hosting is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). diff --git a/src/libraries/Microsoft.Extensions.Http/ref/Microsoft.Extensions.Http.cs b/src/libraries/Microsoft.Extensions.Http/ref/Microsoft.Extensions.Http.cs index 9071751734f946..51891e8ebcf754 100644 --- a/src/libraries/Microsoft.Extensions.Http/ref/Microsoft.Extensions.Http.cs +++ b/src/libraries/Microsoft.Extensions.Http/ref/Microsoft.Extensions.Http.cs @@ -57,7 +57,7 @@ public static partial class HttpClientFactoryServiceCollectionExtensions public static Microsoft.Extensions.DependencyInjection.IHttpClientBuilder AddHttpClient(this Microsoft.Extensions.DependencyInjection.IServiceCollection services, string name, System.Action configureClient) where TClient : class where TImplementation : class, TClient { throw null; } public static Microsoft.Extensions.DependencyInjection.IHttpClientBuilder AddHttpClient(this Microsoft.Extensions.DependencyInjection.IServiceCollection services, string name, System.Func factory) where TClient : class where TImplementation : class, TClient { throw null; } public static Microsoft.Extensions.DependencyInjection.IHttpClientBuilder AddHttpClient(this Microsoft.Extensions.DependencyInjection.IServiceCollection services, string name, System.Func factory) where TClient : class where TImplementation : class, TClient { throw null; } - public static Microsoft.Extensions.DependencyInjection.IServiceCollection ConfigureHttpClientDefaults(Microsoft.Extensions.DependencyInjection.IServiceCollection services, System.Action configure) { throw null; } + public static Microsoft.Extensions.DependencyInjection.IServiceCollection ConfigureHttpClientDefaults(this Microsoft.Extensions.DependencyInjection.IServiceCollection services, System.Action configure) { throw null; } } public partial interface IHttpClientBuilder { diff --git a/src/libraries/Microsoft.Extensions.Http/src/PACKAGE.md b/src/libraries/Microsoft.Extensions.Http/src/PACKAGE.md new file mode 100644 index 00000000000000..294cb308cc7a9b --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Http/src/PACKAGE.md @@ -0,0 +1,81 @@ +## About + + + +[Microsoft.Extensions.Http](https://www.nuget.org/packages/Microsoft.Extensions.Http) package provides `AddHttpClient` extension methods for `IServiceCollection`, `IHttpClientFactory` interface and its default implementation. This provides the ability to set up named `HttpClient` configurations in a DI container and later retrieve them via an injected `IHttpClientFactory` instance. + +## Key Features + + + +* The package allows to fluently set up multiple `HttpClient` configurations for applications that use DI via `AddHttpClient` extension method. +* `HttpClientFactory` caches `HttpMessageHandler` instances per configuration name, which allows to reuse resources between `HttpClient` instances to avoid port exhaustion. +* `HttpClientFactory` manages lifetime of `HttpMessageHandler` instances and recycles connections to track DNS changes. + +## How to Use + + + +Note that lifetime management of `HttpClient` instances created by `HttpClientFactory` is completely different from instances created manually. The strategies are to use either short-lived clients created by `HttpClientFactory` or long-lived clients with `PooledConnectionLifetime` set up. For more information, see the [HttpClient lifetime management section](https://learn.microsoft.com/dotnet/core/extensions/httpclient-factory#httpclient-lifetime-management) in the conceptual docs and [Guidelines for using HTTP clients](https://learn.microsoft.com/dotnet/fundamentals/networking/http/httpclient-guidelines). + +### Configuring HttpClient + +```c# +builder.Services.AddHttpClient("foo"); // adding an HttpClient named "foo" with a default configuration + +builder.Services.AddHttpClient("example", c => c.BaseAddress = new Uri("https://www.example.com")) // configuring HttpClient itself + .AddHttpMessageHandler() // adding additional delegating handlers to form a message handler chain + .ConfigurePrimaryHttpMessageHandler(b => new HttpClientHandler() { AllowAutoRedirect = false }) // configuring primary handler + .SetHandlerLifetime(TimeSpan.FromMinutes(30)); // changing the handler recycling interval +``` + +### Using the configured HttpClient + +```c# +public class MyService +{ + public MyService(IHttpClientFactory httpClientFactory) + { + _httpClientFactory = httpClientFactory; // injecting the factory + } + + private Task GetExampleAsync(Uri uri, CancellationToken ct) + { + HttpClient exampleClient = _httpClientFactory.CreateClient("example"); // creating the client for the specified name + return exampleClient.GetStringAsync(uri, ct); // using the client + } +} +``` + +## Main Types + + + +The main types provided by this library are: + +* `IHttpClientFactory` +* `IHttpMessageHandlerFactory` +* `HttpClientFactoryServiceCollectionExtensions` + +## Additional Documentation + + + +* [Conceptual documentation](https://learn.microsoft.com/dotnet/core/extensions/httpclient-factory) + * Also see [HttpClient guidelines](https://learn.microsoft.com/dotnet/fundamentals/networking/http/httpclient-guidelines) conceptual doc +* [API documentation](https://learn.microsoft.com/dotnet/api/system.net.http?view=dotnet-plat-ext-7.0) + * Also see [`AddHttpClient` extension method](https://learn.microsoft.com/dotnet/api/microsoft.extensions.dependencyinjection.httpclientfactoryservicecollectionextensions?view=dotnet-plat-ext-7.0) API doc + +## Related Packages + + + +* [Microsoft.Extensions.DependencyInjection](https://www.nuget.org/packages/Microsoft.Extensions.DependencyInjection/) +* [Microsoft.Extensions.Http.Polly](https://www.nuget.org/packages/Microsoft.Extensions.Http.Polly) +* [Microsoft.Extensions.Http.Telemetry](https://www.nuget.org/packages/Microsoft.Extensions.Http.Telemetry) + +## Feedback & Contributing + + + +Microsoft.Extensions.Http is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). diff --git a/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/DiagnosticDescriptors.cs b/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/DiagnosticDescriptors.cs index ed51ed6164188d..409ab6de3f51c9 100644 --- a/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/DiagnosticDescriptors.cs +++ b/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/DiagnosticDescriptors.cs @@ -200,5 +200,13 @@ public static class DiagnosticDescriptors category: "LoggingGenerator", DiagnosticSeverity.Warning, isEnabledByDefault: true); + + public static DiagnosticDescriptor LoggingUnsupportedLanguageVersion { get; } = DiagnosticDescriptorHelper.Create( + id: "SYSLIB1026", + title: new LocalizableResourceString(nameof(SR.LoggingUnsupportedLanguageVersionTitle), SR.ResourceManager, typeof(FxResources.Microsoft.Extensions.Logging.Generators.SR)), + messageFormat: new LocalizableResourceString(nameof(SR.LoggingUnsupportedLanguageVersionMessageFormat), SR.ResourceManager, typeof(FxResources.Microsoft.Extensions.Logging.Generators.SR)), + category: "LoggingGenerator", + defaultSeverity: DiagnosticSeverity.Error, + isEnabledByDefault: true); } } diff --git a/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/LoggerMessageGenerator.Parser.cs b/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/LoggerMessageGenerator.Parser.cs index 2b5b8ca6a1a498..038b71bef0759b 100644 --- a/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/LoggerMessageGenerator.Parser.cs +++ b/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/LoggerMessageGenerator.Parser.cs @@ -591,6 +591,13 @@ static bool IsAllowedKind(SyntaxKind kind) => } } + if (results.Count > 0 && _compilation is CSharpCompilation { LanguageVersion : LanguageVersion version and < LanguageVersion.CSharp8 }) + { + // we only support C# 8.0 and above + Diag(DiagnosticDescriptors.LoggingUnsupportedLanguageVersion, null, version.ToDisplayString(), LanguageVersion.CSharp8.ToDisplayString()); + return Array.Empty(); + } + return results; } diff --git a/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/Strings.resx b/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/Strings.resx index 529263fda731fb..602e6f4730d216 100644 --- a/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/Strings.resx +++ b/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/Strings.resx @@ -231,4 +231,10 @@ Logging method contains malformed format strings + + C# language version not supported by the source generator. + + + The Logging source generator is not available in C# {0}. Please use language version {1} or greater. + \ No newline at end of file diff --git a/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/xlf/Strings.cs.xlf b/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/xlf/Strings.cs.xlf index ac5ddc07ceb05f..6f3218ac6dd5b7 100644 --- a/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/xlf/Strings.cs.xlf +++ b/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/xlf/Strings.cs.xlf @@ -67,6 +67,16 @@ Metody protokolování musí být statické. + + The Logging source generator is not available in C# {0}. Please use language version {1} or greater. + Zdrojový generátor protokolování není k dispozici v jazyce C# {0}. Použijte prosím jazykovou verzi {1} nebo vyšší. + + + + C# language version not supported by the source generator. + Zdrojový generátor nepodporuje jazykovou verzi jazyka C#. + + Logging method '{0}' contains malformed format strings Metoda protokolování {0} obsahuje řetězce s poškozeným formátem. diff --git a/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/xlf/Strings.de.xlf b/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/xlf/Strings.de.xlf index c57008e516a5be..b838765022a4ea 100644 --- a/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/xlf/Strings.de.xlf +++ b/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/xlf/Strings.de.xlf @@ -67,6 +67,16 @@ Protokollierungsmethoden müssen statisch sein. + + The Logging source generator is not available in C# {0}. Please use language version {1} or greater. + Der Protokollierungsquellgenerator ist in C# {0} nicht verfügbar. Verwenden Sie die Sprachversion {1} oder höher. + + + + C# language version not supported by the source generator. + Die C#-Sprachversion wird vom Quellgenerator nicht unterstützt. + + Logging method '{0}' contains malformed format strings Die Protokollierungsmethode „{0}“ enthält nicht wohlgeformte Formatzeichenfolgen. diff --git a/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/xlf/Strings.es.xlf b/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/xlf/Strings.es.xlf index 29cce5c24f3f77..30000484775d0a 100644 --- a/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/xlf/Strings.es.xlf +++ b/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/xlf/Strings.es.xlf @@ -67,6 +67,16 @@ Los métodos de registro deben ser estáticos + + The Logging source generator is not available in C# {0}. Please use language version {1} or greater. + El generador de origen de registro no está disponible en C# {0}. Use la versión de idioma {1} o superior. + + + + C# language version not supported by the source generator. + La versión del idioma C# no es compatible con el generador de origen. + + Logging method '{0}' contains malformed format strings El método de registro “{0}” contiene cadenas con formato incorrecto diff --git a/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/xlf/Strings.fr.xlf b/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/xlf/Strings.fr.xlf index afd2c184fc42fa..04ccccc8536aa0 100644 --- a/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/xlf/Strings.fr.xlf +++ b/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/xlf/Strings.fr.xlf @@ -67,6 +67,16 @@ Les méthodes de journalisation doivent être statiques + + The Logging source generator is not available in C# {0}. Please use language version {1} or greater. + Le générateur de source de connexion n'est pas disponible en C# « {0} ». Veuillez utiliser la version linguistique {1} ou supérieure. + + + + C# language version not supported by the source generator. + Version du langage C# non prise en charge par le générateur de source. + + Logging method '{0}' contains malformed format strings La méthode de journalisation « {0} »contient des chaînes de format incorrectes diff --git a/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/xlf/Strings.it.xlf b/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/xlf/Strings.it.xlf index 99b2e817511ae8..0c007fdce54828 100644 --- a/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/xlf/Strings.it.xlf +++ b/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/xlf/Strings.it.xlf @@ -67,6 +67,16 @@ I metodi di registrazione devono essere statici + + The Logging source generator is not available in C# {0}. Please use language version {1} or greater. + Il generatore dell'origine di registrazione non è disponibile in C# {0}. Usare la versione del linguaggio {1} o successiva. + + + + C# language version not supported by the source generator. + Versione del linguaggio C# non supportata dal generatore di origine. + + Logging method '{0}' contains malformed format strings Il metodo di registrazione '{0}' contiene stringhe in formato non valido diff --git a/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/xlf/Strings.ja.xlf b/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/xlf/Strings.ja.xlf index a8c79393a42f23..36f5b1baad528f 100644 --- a/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/xlf/Strings.ja.xlf +++ b/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/xlf/Strings.ja.xlf @@ -67,6 +67,16 @@ ログ メソッドは静的である必要があります + + The Logging source generator is not available in C# {0}. Please use language version {1} or greater. + ログ ソース ジェネレーターは、C# {0} では使用できません。言語バージョン {1} 以降を使用してください。 + + + + C# language version not supported by the source generator. + ソース ジェネレーターでサポートされていない C# 言語バージョン。 + + Logging method '{0}' contains malformed format strings ログ メソッド '{0}' に、形式の正しくない文字列が含まれています diff --git a/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/xlf/Strings.ko.xlf b/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/xlf/Strings.ko.xlf index 0b77d9c65adb5f..547c49b3e33dc9 100644 --- a/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/xlf/Strings.ko.xlf +++ b/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/xlf/Strings.ko.xlf @@ -67,6 +67,16 @@ 로깅 메서드는 정적이어야 함 + + The Logging source generator is not available in C# {0}. Please use language version {1} or greater. + 로깅 원본 생성기는 C# {0}에서 사용할 수 없습니다. {1} 이상의 언어 버전을 사용하세요. + + + + C# language version not supported by the source generator. + 원본 생성기에서 지원되지 않는 C# 언어 버전입니다. + + Logging method '{0}' contains malformed format strings 로깅 메서드 '{0}'에 잘못된 형식의 문자열이 포함되어 있습니다. diff --git a/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/xlf/Strings.pl.xlf b/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/xlf/Strings.pl.xlf index 2784dfa601576b..a40b72d24ffcea 100644 --- a/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/xlf/Strings.pl.xlf +++ b/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/xlf/Strings.pl.xlf @@ -67,6 +67,16 @@ Metody rejestrowania muszą być statyczne + + The Logging source generator is not available in C# {0}. Please use language version {1} or greater. + Generator źródła rejestrowania nie jest dostępny w języku C# {0}. Użyj wersji językowej {1} lub nowszej. + + + + C# language version not supported by the source generator. + Wersja języka C# nie jest obsługiwana przez generator źródła. + + Logging method '{0}' contains malformed format strings Metoda rejestrowania „{0}” zawiera źle sformułowane ciągi formatu diff --git a/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/xlf/Strings.pt-BR.xlf b/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/xlf/Strings.pt-BR.xlf index 22b9fabb3b2965..0560fb6b5907c8 100644 --- a/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/xlf/Strings.pt-BR.xlf +++ b/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/xlf/Strings.pt-BR.xlf @@ -67,6 +67,16 @@ Os métodos de registro em log devem ser estáticos + + The Logging source generator is not available in C# {0}. Please use language version {1} or greater. + O gerador de fonte de log não está disponível em C# {0}. Use a versão do idioma {1} ou superior. + + + + C# language version not supported by the source generator. + Versão da linguagem C# não suportada pelo gerador de origem. + + Logging method '{0}' contains malformed format strings O método de registro '{0}' contém strings de formato malformado diff --git a/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/xlf/Strings.ru.xlf b/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/xlf/Strings.ru.xlf index 4446666caf4828..589635d5721bb9 100644 --- a/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/xlf/Strings.ru.xlf +++ b/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/xlf/Strings.ru.xlf @@ -67,6 +67,16 @@ Методы ведения журнала должны быть статическими + + The Logging source generator is not available in C# {0}. Please use language version {1} or greater. + Генератор исходного кода ведения журнала недоступен в C# {0}. Используйте языковую версию {1} или более позднюю. + + + + C# language version not supported by the source generator. + Версия языка C# не поддерживается генератором исходного кода. + + Logging method '{0}' contains malformed format strings Метод ведения журнала событий "{0}" содержит строки неправильного формата diff --git a/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/xlf/Strings.tr.xlf b/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/xlf/Strings.tr.xlf index fb21d7e0bd62ed..c40d99aa4602c5 100644 --- a/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/xlf/Strings.tr.xlf +++ b/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/xlf/Strings.tr.xlf @@ -67,6 +67,16 @@ Günlüğe kaydetme yöntemleri statik olmalıdır + + The Logging source generator is not available in C# {0}. Please use language version {1} or greater. + Günlüğe kaydetme kaynak oluşturucusu C# {0} sürümünde kullanılamıyor. Lütfen dil sürümü {1} veya üstü bir sürümü kullanın. + + + + C# language version not supported by the source generator. + C# dil sürümü kaynak oluşturucu tarafından desteklenmiyor. + + Logging method '{0}' contains malformed format strings '{0}' günlüğe kaydetme yöntemi hatalı biçimlendirilmiş biçim dizeleri içeriyor diff --git a/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/xlf/Strings.zh-Hans.xlf b/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/xlf/Strings.zh-Hans.xlf index 021f321ddf4b9f..04c39c8843a1e1 100644 --- a/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/xlf/Strings.zh-Hans.xlf +++ b/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/xlf/Strings.zh-Hans.xlf @@ -67,6 +67,16 @@ 日志记录方法必须为静态方法 + + The Logging source generator is not available in C# {0}. Please use language version {1} or greater. + 记录源生成器在 C#“{0}”中不可用。请使用{1}或更高版本的语言版本。 + + + + C# language version not supported by the source generator. + 源生成器不支持 C# 语言版本。 + + Logging method '{0}' contains malformed format strings 日志记录方法“{0}”包含格式错误的字符串 diff --git a/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/xlf/Strings.zh-Hant.xlf b/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/xlf/Strings.zh-Hant.xlf index c50c7054e6ee2a..af008cf098ff8a 100644 --- a/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/xlf/Strings.zh-Hant.xlf +++ b/src/libraries/Microsoft.Extensions.Logging.Abstractions/gen/Resources/xlf/Strings.zh-Hant.xlf @@ -67,6 +67,16 @@ 記錄方法必須是靜態 + + The Logging source generator is not available in C# {0}. Please use language version {1} or greater. + 記錄來源產生器在 C# {0} 中不可用。請使用語言版本 {1} 或更新版本。 + + + + C# language version not supported by the source generator. + 來源產生器不支援 C# 語言版本。 + + Logging method '{0}' contains malformed format strings 記錄方法 '{0}' 包含格式錯誤的格式字串 diff --git a/src/libraries/Microsoft.Extensions.Logging.Abstractions/src/Microsoft.Extensions.Logging.Abstractions.csproj b/src/libraries/Microsoft.Extensions.Logging.Abstractions/src/Microsoft.Extensions.Logging.Abstractions.csproj index db0c78fb49df05..085cade3966b96 100644 --- a/src/libraries/Microsoft.Extensions.Logging.Abstractions/src/Microsoft.Extensions.Logging.Abstractions.csproj +++ b/src/libraries/Microsoft.Extensions.Logging.Abstractions/src/Microsoft.Extensions.Logging.Abstractions.csproj @@ -4,8 +4,6 @@ $(NetCoreAppCurrent);$(NetCoreAppPrevious);$(NetCoreAppMinimum);netstandard2.0;$(NetFrameworkMinimum) true true - - $(NoWarn);AD0001 true Logging abstractions for Microsoft.Extensions.Logging. diff --git a/src/libraries/Microsoft.Extensions.Logging.Abstractions/src/PACKAGE.md b/src/libraries/Microsoft.Extensions.Logging.Abstractions/src/PACKAGE.md new file mode 100644 index 00000000000000..400958a0a194a4 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Logging.Abstractions/src/PACKAGE.md @@ -0,0 +1,164 @@ +## About + + + +`Microsoft.Extensions.Logging.Abstractions` provides abstractions of logging. Interfaces defined in this package are implemented by classes in [Microsoft.Extensions.Logging](https://www.nuget.org/packages/Microsoft.Extensions.Logging/) and other logging packages. + +This package includes a logging source generator that produces highly efficient and optimized code for logging message methods. + +## Key Features + + + +* Define main logging abstraction interfaces like ILogger, ILoggerFactory, ILoggerProvider, etc. + +## How to Use + + + +#### Custom logger provider implementation example + +```C# +using Microsoft.Extensions.Logging; + +public sealed class ColorConsoleLogger : ILogger +{ + private readonly string _name; + private readonly Func _getCurrentConfig; + + public ColorConsoleLogger( + string name, + Func getCurrentConfig) => + (_name, _getCurrentConfig) = (name, getCurrentConfig); + + public IDisposable? BeginScope(TState state) where TState : notnull => default!; + + public bool IsEnabled(LogLevel logLevel) => + _getCurrentConfig().LogLevelToColorMap.ContainsKey(logLevel); + + public void Log( + LogLevel logLevel, + EventId eventId, + TState state, + Exception? exception, + Func formatter) + { + if (!IsEnabled(logLevel)) + { + return; + } + + ColorConsoleLoggerConfiguration config = _getCurrentConfig(); + if (config.EventId == 0 || config.EventId == eventId.Id) + { + ConsoleColor originalColor = Console.ForegroundColor; + + Console.ForegroundColor = config.LogLevelToColorMap[logLevel]; + Console.WriteLine($"[{eventId.Id,2}: {logLevel,-12}]"); + + Console.ForegroundColor = originalColor; + Console.Write($" {_name} - "); + + Console.ForegroundColor = config.LogLevelToColorMap[logLevel]; + Console.Write($"{formatter(state, exception)}"); + + Console.ForegroundColor = originalColor; + Console.WriteLine(); + } + } +} + +``` + +#### Create logs + +```csharp + +// Worker class that uses logger implementation of teh interface ILogger + +public sealed class Worker : BackgroundService +{ + private readonly ILogger _logger; + + public Worker(ILogger logger) => + _logger = logger; + + protected override async Task ExecuteAsync(CancellationToken stoppingToken) + { + while (!stoppingToken.IsCancellationRequested) + { + _logger.LogInformation("Worker running at: {time}", DateTimeOffset.UtcNow); + await Task.Delay(1_000, stoppingToken); + } + } +} + +``` + +#### Use source generator + +```csharp +public static partial class Log +{ + [LoggerMessage( + EventId = 0, + Level = LogLevel.Critical, + Message = "Could not open socket to `{hostName}`")] + public static partial void CouldNotOpenSocket(this ILogger logger, string hostName); +} + +public partial class InstanceLoggingExample +{ + private readonly ILogger _logger; + + public InstanceLoggingExample(ILogger logger) + { + _logger = logger; + } + + [LoggerMessage( + EventId = 0, + Level = LogLevel.Critical, + Message = "Could not open socket to `{hostName}`")] + public partial void CouldNotOpenSocket(string hostName); +} + +``` + +## Main Types + + + +The main types provided by this library are: + +* `Microsoft.Extensions.Logging.ILogger` +* `Microsoft.Extensions.Logging.ILoggerProvider` +* `Microsoft.Extensions.Logging.ILoggerFactory` +* `Microsoft.Extensions.Logging.ILogger` +* `Microsoft.Extensions.Logging.LogLevel` +* `Microsoft.Extensions.Logging.Logger` +* `Microsoft.Extensions.Logging.LoggerMessage` +* `Microsoft.Extensions.Logging.Abstractions.NullLogger` + +## Additional Documentation + + + +* [Conceptual documentation](https://learn.microsoft.com/dotnet/core/extensions/logging) +* [API documentation](https://learn.microsoft.com/dotnet/api/microsoft.extensions.logging) + +## Related Packages + + +[Microsoft.Extensions.Logging](https://www.nuget.org/packages/Microsoft.Extensions.Logging) +[Microsoft.Extensions.Logging.Console](https://www.nuget.org/packages/Microsoft.Extensions.Logging.Console) +[Microsoft.Extensions.Logging.Debug](https://www.nuget.org/packages/Microsoft.Extensions.Logging.Debug) +[Microsoft.Extensions.Logging.EventSource](https://www.nuget.org/packages/Microsoft.Extensions.Logging.EventSource) +[Microsoft.Extensions.Logging.EventLog](https://www.nuget.org/packages/Microsoft.Extensions.Logging.EventLog) +[Microsoft.Extensions.Logging.TraceSource](https://www.nuget.org/packages/Microsoft.Extensions.Logging.TraceSource) + +## Feedback & Contributing + + + +Microsoft.Extensions.Logging.Abstractions is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). \ No newline at end of file diff --git a/src/libraries/Microsoft.Extensions.Logging.Abstractions/tests/Microsoft.Extensions.Logging.Generators.Tests/LoggerMessageGeneratorParserTests.cs b/src/libraries/Microsoft.Extensions.Logging.Abstractions/tests/Microsoft.Extensions.Logging.Generators.Tests/LoggerMessageGeneratorParserTests.cs index 462111ad00c372..dc677e81cfc4a0 100644 --- a/src/libraries/Microsoft.Extensions.Logging.Abstractions/tests/Microsoft.Extensions.Logging.Generators.Tests/LoggerMessageGeneratorParserTests.cs +++ b/src/libraries/Microsoft.Extensions.Logging.Abstractions/tests/Microsoft.Extensions.Logging.Generators.Tests/LoggerMessageGeneratorParserTests.cs @@ -4,10 +4,12 @@ using System; using System.Collections.Generic; using System.Collections.Immutable; +using System.Linq; using System.Reflection; using System.Threading; using System.Threading.Tasks; using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; using SourceGenerators.Tests; using Xunit; @@ -108,7 +110,7 @@ partial class C {{ [LoggerMessage({argumentList})] static partial void M1(ILogger logger, string foo); - + [LoggerMessage({argumentList})] static partial void M2(ILogger logger, LogLevel level, string foo); }} @@ -911,6 +913,56 @@ static partial void M1(ILogger logger) Assert.Equal(DiagnosticDescriptors.LoggingMethodHasBody.Id, diagnostics[0].Id); } + [Fact] + public async Task LanguageVersionTest() + { + string source = """ + using Microsoft.Extensions.Logging; + + internal partial class Program + { + static void Main() { } + + [LoggerMessage( + EventId = 0, + Level = LogLevel.Critical, + Message = "Could not open socket to `{hostName}`")] + static partial void CouldNotOpenSocket(ILogger logger, string hostName); + } + """; + + Assembly[]? refs = new[] { typeof(ILogger).Assembly, typeof(LoggerMessageAttribute).Assembly }; + + // Run the generator with C# 7.0 and verify that it fails. + var (diagnostics, generatedSources) = await RoslynTestUtils.RunGenerator( + new LoggerMessageGenerator(), refs, new[] { source }, includeBaseReferences: true, LanguageVersion.CSharp7).ConfigureAwait(false); + + Assert.NotEmpty(diagnostics); + Assert.Equal("SYSLIB1026", diagnostics[0].Id); + Assert.Empty(generatedSources); + + // Run the generator with C# 8.0 and verify that it succeeds. + (diagnostics, generatedSources) = await RoslynTestUtils.RunGenerator( + new LoggerMessageGenerator(), refs, new[] { source }, includeBaseReferences: true, LanguageVersion.CSharp8).ConfigureAwait(false); + + Assert.Empty(diagnostics); + Assert.Single(generatedSources); + + // Compile the generated code with C# 7.0 and verify that it fails. + CSharpParseOptions parseOptions = new CSharpParseOptions(LanguageVersion.CSharp7); + SyntaxTree syntaxTree = SyntaxFactory.ParseSyntaxTree(generatedSources[0].SourceText.ToString(), parseOptions); + var diags = syntaxTree.GetDiagnostics().ToArray(); + Assert.Equal(1, diags.Length); + // error CS8107: Feature 'nullable reference types' is not available in C# 7.0. Please use language version 8.0 or greater. + Assert.Equal("CS8107", diags[0].Id); + + // Compile the generated code with C# 8.0 and verify that it succeeds. + parseOptions = new CSharpParseOptions(LanguageVersion.CSharp8); + syntaxTree = SyntaxFactory.ParseSyntaxTree(generatedSources[0].SourceText.ToString(), parseOptions); + diags = syntaxTree.GetDiagnostics().ToArray(); + Assert.Equal(0, diags.Length); + } + private static async Task> RunGenerator( string code, bool wrap = true, diff --git a/src/libraries/Microsoft.Extensions.Logging.Configuration/src/Microsoft.Extensions.Logging.Configuration.csproj b/src/libraries/Microsoft.Extensions.Logging.Configuration/src/Microsoft.Extensions.Logging.Configuration.csproj index 545e2867e4de6c..820eb7fa062e72 100644 --- a/src/libraries/Microsoft.Extensions.Logging.Configuration/src/Microsoft.Extensions.Logging.Configuration.csproj +++ b/src/libraries/Microsoft.Extensions.Logging.Configuration/src/Microsoft.Extensions.Logging.Configuration.csproj @@ -3,6 +3,9 @@ $(NetCoreAppCurrent);$(NetCoreAppPrevious);$(NetCoreAppMinimum);netstandard2.0;$(NetFrameworkMinimum) true + $(Features);InterceptorsPreview + + $(InterceptorsPreviewNamespaces);Microsoft.Extensions.Configuration.Binder.SourceGeneration true true Configuration support for Microsoft.Extensions.Logging. diff --git a/src/libraries/Microsoft.Extensions.Logging.Console/src/Microsoft.Extensions.Logging.Console.csproj b/src/libraries/Microsoft.Extensions.Logging.Console/src/Microsoft.Extensions.Logging.Console.csproj index abc5c9d9792eef..e83340eb0eae55 100644 --- a/src/libraries/Microsoft.Extensions.Logging.Console/src/Microsoft.Extensions.Logging.Console.csproj +++ b/src/libraries/Microsoft.Extensions.Logging.Console/src/Microsoft.Extensions.Logging.Console.csproj @@ -7,7 +7,12 @@ $(DefineConstants);NO_SUPPRESS_GC_TRANSITION true true + $(Features);InterceptorsPreview + + $(InterceptorsPreviewNamespaces);Microsoft.Extensions.Configuration.Binder.SourceGeneration true + + $(NoWarn);SYSLIB1100;SYSLIB1101 Console logger provider implementation for Microsoft.Extensions.Logging. diff --git a/src/libraries/Microsoft.Extensions.Logging.Console/src/PACKAGE.md b/src/libraries/Microsoft.Extensions.Logging.Console/src/PACKAGE.md new file mode 100644 index 00000000000000..69823b6d4d3faa --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Logging.Console/src/PACKAGE.md @@ -0,0 +1,106 @@ +## About + + +`Microsoft.Extensions.Logging.Console` provides a Console logger provider implementation for Microsoft.Extensions.Logging. It provides extension methods for the [ILoggingBuilder](https://learn.microsoft.com/dotnet/api/microsoft.extensions.logging.iloggingbuilder) and [ILoggerProviderConfiguration](https://learn.microsoft.com/dotnet/api/microsoft.extensions.logging.configuration.iloggerproviderconfiguration-1) classes. + +## Key Features + + + +* Allow logging to the console using the [Microsoft.Extensions.Logging](https://www.nuget.org/packages/Microsoft.Extensions.Logging/) package. +* Provide extension methods for the [ILoggingBuilder](https://learn.microsoft.com/dotnet/api/microsoft.extensions.logging.iloggingbuilder) and [ILoggerProviderConfiguration](https://learn.microsoft.com/dotnet/api/microsoft.extensions.logging.configuration.iloggerproviderconfiguration-1) classes. + +## How to Use + + +```csharp +using System; +using Microsoft.Extensions.Logging; + +namespace ConsoleLoggerSample +{ + class Program + { + static void Main(string[] args) + { + // Create a logger factory with a console provider + using ILoggerFactory loggerFactory = LoggerFactory.Create(builder => builder.AddConsole()); + + // Create a logger with the category name of the current class + ILogger logger = loggerFactory.CreateLogger(); + + // Log some messages with different log levels and message templates + logger.LogTrace("This is a trace message."); + logger.LogDebug("This is a debug message."); + logger.LogInformation("Hello {Name}!", "World"); + logger.LogWarning("This is a warning message."); + logger.LogError("This is an error message."); + logger.LogCritical("This is a critical message."); + + // Use structured logging to capture complex data + var person = new Person { Name = "Alice", Age = 25 }; + logger.LogInformation("Created a new person: {@Person}", person); + + // Use exception logging to capture the details of an exception + try + { + throw new Exception("Something went wrong."); + } + catch (Exception ex) + { + logger.LogError(ex, "An exception occurred."); + } + + Console.WriteLine("Press any key to exit."); + Console.ReadKey(); + } + } + + // A simple class to demonstrate structured logging + class Person + { + public string Name { get; set; } + public int Age { get; set; } + } +} + +``` + +## Main Types + + + +The main types provided by this library are: + +* `ConsoleLoggerProvider` +* `ConsoleLoggerSettings` +* `ConsoleLoggerOptions` +* `ConsoleLoggerExtensions` +* `ConsoleFormatter` +* `ConsoleFormatterOptions` +* `JsonConsoleFormatterOptions` +* `SimpleConsoleFormatterOptions` + +## Additional Documentation + + + +* [Conceptual documentation](https://learn.microsoft.com/dotnet/core/extensions/logging) +* [Console log formatter](https://learn.microsoft.com/dotnet/core/extensions/console-log-formatter) +* [API documentation](https://learn.microsoft.com/dotnet/api/microsoft.extensions.logging) + +## Related Packages + + +[Microsoft.Extensions.Logging.Abstractions](https://www.nuget.org/packages/Microsoft.Extensions.Logging.Abstractions) +[Microsoft.Extensions.Logging](https://www.nuget.org/packages/Microsoft.Extensions.Logging) +[Microsoft.Extensions.Logging.Debug](https://www.nuget.org/packages/Microsoft.Extensions.Logging.Debug) +[Microsoft.Extensions.Logging.EventSource](https://www.nuget.org/packages/Microsoft.Extensions.Logging.EventSource) +[Microsoft.Extensions.Logging.EventLog](https://www.nuget.org/packages/Microsoft.Extensions.Logging.EventLog) +[Microsoft.Extensions.Logging.TraceSource](https://www.nuget.org/packages/Microsoft.Extensions.Logging.TraceSource) + +## Feedback & Contributing + + + +Microsoft.Extensions.Logging.Console is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). \ No newline at end of file diff --git a/src/libraries/Microsoft.Extensions.Logging.Debug/src/PACKAGE.md b/src/libraries/Microsoft.Extensions.Logging.Debug/src/PACKAGE.md new file mode 100644 index 00000000000000..97b378f4350c57 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Logging.Debug/src/PACKAGE.md @@ -0,0 +1,100 @@ +## About + + +`Microsoft.Extensions.Logging.Debug` provides a Debug output logger provider implementation for Microsoft.Extensions.Logging. This logger logs messages to a debugger monitor by writing messages with `System.Diagnostics.Debug.WriteLine()`. + +## Key Features + + + +* Allow logging to the debugger output. +* Provide extensions method for the [ILoggingBuilder](https://docs.microsoft.com/dotnet/api/microsoft.extensions.logging.iloggingbuilder) class to easily enable this Debug logger. + +## How to Use + + +```csharp +using System; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Debug; + +namespace DebugLoggerSample +{ + class Program + { + static void Main(string[] args) + { + // Create a logger factory with a debug provider + using ILoggerFactory loggerFactory = LoggerFactory.Create(builder => builder.AddDebug()); + + // Create a logger with the category name of the current class + ILogger logger = loggerFactory.CreateLogger(); + + // Log some messages with different log levels and message templates + logger.LogTrace("This is a trace message."); + logger.LogDebug("This is a debug message."); + logger.LogInformation("Hello {Name}!", "World"); + logger.LogWarning("This is a warning message."); + logger.LogError("This is an error message."); + logger.LogCritical("This is a critical message."); + + // Use structured logging to capture complex data + var person = new Person { Name = "Alice", Age = 25 }; + logger.LogInformation("Created a new person: {@Person}", person); + + // Use exception logging to capture the details of an exception + try + { + throw new Exception("Something went wrong."); + } + catch (Exception ex) + { + logger.LogError(ex, "An exception occurred."); + } + + Console.WriteLine("Press any key to exit."); + Console.ReadKey(); + } + } + + // A simple class to demonstrate structured logging + class Person + { + public string Name { get; set; } + public int Age { get; set; } + } +} + +``` + +## Main Types + + + +The main types provided by this library are: + +* `DebugLoggerProvider` +* `DebugLoggerFactoryExtensions` + +## Additional Documentation + + + +* [Conceptual documentation](https://learn.microsoft.com/dotnet/core/extensions/logging) +* [API documentation](https://learn.microsoft.com/dotnet/api/microsoft.extensions.logging) + +## Related Packages +[Microsoft.Extensions.Logging.Abstractions](https://www.nuget.org/packages/Microsoft.Extensions.Logging.Abstractions) +[Microsoft.Extensions.Logging](https://www.nuget.org/packages/Microsoft.Extensions.Logging) +[Microsoft.Extensions.Logging.Console](https://www.nuget.org/packages/Microsoft.Extensions.Logging.Console) +[Microsoft.Extensions.Logging.EventSource](https://www.nuget.org/packages/Microsoft.Extensions.Logging.EventSource) +[Microsoft.Extensions.Logging.EventLog](https://www.nuget.org/packages/Microsoft.Extensions.Logging.EventLog) +[Microsoft.Extensions.Logging.TraceSource](https://www.nuget.org/packages/Microsoft.Extensions.Logging.TraceSource) + + + +## Feedback & Contributing + + + +Microsoft.Extensions.Logging.Debug is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). \ No newline at end of file diff --git a/src/libraries/Microsoft.Extensions.Logging.TraceSource/src/PACKAGE.md b/src/libraries/Microsoft.Extensions.Logging.TraceSource/src/PACKAGE.md new file mode 100644 index 00000000000000..a58e190ec552b9 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Logging.TraceSource/src/PACKAGE.md @@ -0,0 +1,75 @@ +## About + + + +Implements a trace logger provider for the .NET logging infrastructre facilitating enhanced logging capabilities and trace-level diagnostics in application by writing messages to a trace listener using System.Diagnostic.TraceSource. + +## Key Features + + + +* Seamless integration with .NET logging infrastructure. +* Fine-grained control over trace messages using SourceSwitch. +* A set of builder methods to configure logging infrastructure. + +## How to Use + + + +The Microsoft.Extensions.Logging.TraceSource library provides extension methods to the logger factory and the logger builder to add a trace source with trace listeners. + +```csharp +using System.Diagnostics; +using Microsoft.Extensions.Logging; + +using var consoleTraceListener = new ConsoleTraceListener(); +using var textWriterTraceListener = new TextWriterTraceListener("/traces.txt"); +using var loggerFactory = LoggerFactory.Create(builder => +{ + builder + .AddTraceSource(new SourceSwitch("Something") { Level = SourceLevels.All }, consoleTraceListener) + .AddTraceSource(new SourceSwitch("HouseKeeping") { Level = SourceLevels.All }, textWriterTraceListener); +}); + +var logger = loggerFactory.CreateLogger(); + +logger.LogInformation("Information message."); +// Program Information: 0 : Information message. +logger.LogWarning("Warning message."); +// Program Warning: 0 : Warning message. + +var traceSource = new TraceSource("HouseKeeping", SourceLevels.All); +traceSource.Listeners.Add(consoleTraceListener); +traceSource.Listeners.Add(textWriterTraceListener); + +traceSource.TraceEvent(TraceEventType.Error, 0, "Error message."); +//HouseKeeping Error: 0 : Error message. +``` + +## Main Types + + + +The main types provided by this library are: + +* `Microsoft.Extensions.Logging.TraceSource.TraceSourceLoggerProvider` + +## Additional Documentation + + + +* [API documentation](https://learn.microsoft.com/dotnet/api/microsoft.extensions.logging.tracesource) + +## Related Packages + + + +* Abstractions for dependency injection: [Microsoft.Extensions.DependencyInjection.Abstractions](https://www.nuget.org/packages/Microsoft.Extensions.DependencyInjection.Abstractions/) +* Default implementation of logging infrastructure: [Microsoft.Extensions.Logging](https://www.nuget.org/packages/Microsoft.Extensions.Logging/) +* Abstractions for logging: [Microsoft.Extensions.Logging.Abstractions](https://www.nuget.org/packages/Microsoft.Extensions.Logging.Abstractions/) + +## Feedback & Contributing + + + +Microsoft.Extensions.Logging.TraceSource is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). \ No newline at end of file diff --git a/src/libraries/Microsoft.Extensions.Logging/src/PACKAGE.md b/src/libraries/Microsoft.Extensions.Logging/src/PACKAGE.md new file mode 100644 index 00000000000000..87c2d2b79ee29d --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Logging/src/PACKAGE.md @@ -0,0 +1,146 @@ +## About + + + +`Microsoft.Extensions.Logging` is combined with a core logging abstraction under `Microsoft.Extensions.Logging.Abstractions`. This abstraction is available in our basic built-in implementations like console, event log, and debug (Debug.WriteLine) logging. + +## Key Features + + + +* Provide concrete implementations of ILoggerFactory +* Provide extension methods for service collections, logger builder, and activity tracking +* Provide logging filtering extension methods for logger builder + +## How to Use + + +Prior to .NET 6, we only had two forms possible for doing logging, using `Microsoft.Extensions.Logging`: + +```cs +public class LoggingSample1 +{ + private ILogger _logger; + + public LoggingSample1(ILogger logger) + { + _logger = logger; + } + + public void LogMethod(string name) + { + _logger.LogInformation("Hello {name}", name); + } +} +``` + +Here are some problems with the LoggingSample1 sample using `LogInformation`, `LogWarning`, etc.: + +1. We can provide event ID through these APIs, but they are not required today. Which leads to bad usages in real systems that want to react or detect specific event issues being logged. +2. Parameters passed are processed before LogLevel checks; this leads to unnecessary code paths getting triggered even when logging is disabled for a log level. +3. It requires parsing of message string on every use to find templates to substitute. + +Because of these problems, the more efficient runtime approach recommended as best practices is to use LoggerMessage.Define APIs instead, illustrated below with LoggingSample2: + +```cs +public class LoggingSample2 +{ + private ILogger _logger; + + public LoggingSample2(ILogger logger) + { + _logger = logger; + } + + public void LogMethod(string name) + { + Log.LogName(_logger, name); + } + + private static class Log + { + private static readonly Action _logName = LoggerMessage.Define(LogLevel.Information, 0, @"Hello {name}"); + + public static void LogName(ILogger logger, string name) + { + _logName(logger, name, null!); + } + } +} +``` + +To reach a balance between performance and usability we added the compile-time logging source generator feature in .NET 6, to learn more about it and learn how to use a source generator to create log messages check out [this documentation](https://learn.microsoft.com/dotnet/core/extensions/logger-message-generator). + +```csharp + +public partial class InstanceLoggingExample +{ + private readonly ILogger _logger; + + public InstanceLoggingExample(ILogger logger) + { + _logger = logger; + } + + [LoggerMessage( + EventId = 0, + Level = LogLevel.Critical, + Message = "Could not open socket to `{hostName}`")] + public partial void CouldNotOpenSocket(string hostName); +} +``` + +#### Baggage and Tags for `ActivityTrackingOptions` + +.NET 5.0 exposed a new feature that allows configuring the logger builder with the `ActivityTrackingOption` to add the tracing context Span Id, Trace Id, Parent Id, Trace state, and Trace flags to the logging scope. The tracing context usually carried in `Activity.Current`. + +.NET 6.0 Preview 1 extended this feature to include more tracing context properties which are the Baggage and the Tags: + +```cs + var loggerFactory = LoggerFactory.Create(logging => + { + logging.Configure(options => + { + options.ActivityTrackingOptions = ActivityTrackingOptions.Tags | ActivityTrackingOptions.Baggage; + }).AddSimpleConsole(options => + { + options.IncludeScopes = true; + }); + }); +``` + +## Main Types + + + +The main types provided by this library are: + +* LoggingServiceCollectionExtensions +* LoggerFactory +* LoggerFactoryOptions +* LoggingBuilderExtensions +* ActivityTrackingOptions +* FilterLoggingBuilderExtensions + +## Additional Documentation + + + +* [Conceptual documentation](https://learn.microsoft.com/dotnet/core/extensions/logging) +* [API documentation](https://learn.microsoft.com/dotnet/api/microsoft.extensions.logging) + +## Related Packages + + +[Microsoft.Extensions.Logging.Abstractions](https://www.nuget.org/packages/Microsoft.Extensions.Logging.Abstractions) +[Microsoft.Extensions.Logging.Console](https://www.nuget.org/packages/Microsoft.Extensions.Logging.Console) +[Microsoft.Extensions.Logging.Debug](https://www.nuget.org/packages/Microsoft.Extensions.Logging.Debug) +[Microsoft.Extensions.Logging.EventSource](https://www.nuget.org/packages/Microsoft.Extensions.Logging.EventSource) +[Microsoft.Extensions.Logging.EventLog](https://www.nuget.org/packages/Microsoft.Extensions.Logging.EventLog) +[Microsoft.Extensions.Logging.TraceSource](https://www.nuget.org/packages/Microsoft.Extensions.Logging.TraceSource) + +## Feedback & Contributing + + + +Microsoft.Extensions.Logging is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). \ No newline at end of file diff --git a/src/libraries/Microsoft.Extensions.Options.ConfigurationExtensions/src/PACKAGE.md b/src/libraries/Microsoft.Extensions.Options.ConfigurationExtensions/src/PACKAGE.md new file mode 100644 index 00000000000000..c893fd4f0e3204 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Options.ConfigurationExtensions/src/PACKAGE.md @@ -0,0 +1,145 @@ +## About + +`Microsoft.Extensions.Options.ConfigurationExtensions` provides additional configuration-specific functionality related to Options. + +## Key Features + +* Extension methods for OptionsBuilder for configuration binding +* Extension methods for IServiceCollection for Options configuration +* ConfigurationChangeTokenSource for monitoring configuration changes + +## How to Use + +#### Options Configuration binding + +```csharp +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; + +class Program +{ + // appsettings.json contents: + // { + // "MyOptions": { + // "Setting1": "Value1", + // "Setting2": "Value2" + // } + // } + + static void Main(string[] args) + { + IConfiguration configuration = new ConfigurationBuilder() + .SetBasePath(Environment.CurrentDirectory) + .AddJsonFile("appsettings.json") + .Build(); + + IServiceCollection services = new ServiceCollection(); + + // Bind the configuration to MyOptions + services.Configure(configuration.GetSection("MyOptions")); + + IServiceProvider serviceProvider = services.BuildServiceProvider(); + + // Retrieve MyOptions using dependency injection + var myOptions = serviceProvider.GetRequiredService>().Value; + + // Access the bound configuration values + Console.WriteLine($"Setting1: {myOptions.Setting1}"); + Console.WriteLine($"Setting2: {myOptions.Setting2}"); + } +} + +public class MyOptions +{ + public string Setting1 { get; set; } + public string Setting2 { get; set; } +} + +``` + +#### Monitoring options configuration changes + +```csharp +// Assume we have a class that represents some options +public class MyOptions +{ + public string Name { get; set; } + public int Age { get; set; } +} + +// appsettings.json contents: +// { +// "MyOptions": { +// "Name": "Alice", +// "Age": 25 +// } +// } + +// Assume we have a configuration object that contains some settings +var config = new ConfigurationBuilder() + .AddJsonFile("appsettings.json") + .Build(); + +// We can use the ConfigurationChangeTokenSource to create a change token source for the options +var changeTokenSource = new ConfigurationChangeTokenSource(config.GetSection("MyOptions")); + +// We can register the change token source with the options monitor +services.AddOptions() + .Configure(options => + { + // Configure the options with the configuration values + config.GetSection("MyOptions").Bind(options); + }) + .AddChangeTokenSource(changeTokenSource); + +// Now we can inject the options monitor into any class that needs them +public class MyClass +{ + private readonly IOptionsMonitor _optionsMonitor; + + public MyClass(IOptionsMonitor optionsMonitor) + { + _optionsMonitor = optionsMonitor; + } + + public void DoSomething() + { + // Can access the current options value like this + var options = _optionsMonitor.CurrentValue; + var name = options.Name; + var age = options.Age; + // Do something with name and age + + // Can also register a callback to be notified when the options change + _optionsMonitor.OnChange(newOptions => + { + // Do something when the options change + }); + } +} + +``` + +## Main Types + +The main types provided by this library are: + +* `ConfigurationChangeTokenSource` +* `OptionsBuilderConfigurationExtensions` +* `OptionsConfigurationServiceCollectionExtensions` + +## Additional Documentation + +* [Conceptual documentation](https://learn.microsoft.com/aspnet/core/fundamentals/configuration/options) +* [API documentation](https://learn.microsoft.com/dotnet/api/microsoft.extensions.options) + +## Related Packages + +* [Microsoft.Extensions.Options](https://www.nuget.org/packages/Microsoft.Extensions.Options) +* [Microsoft.Extensions.Configuration](https://www.nuget.org/packages/Microsoft.Extensions.Configuration) +* [Microsoft.Extensions.DependencyInjection](https://www.nuget.org/packages/Microsoft.Extensions.DependencyInjection) + +## Feedback & Contributing + +Microsoft.Extensions.Options.ConfigurationExtensions is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). \ No newline at end of file diff --git a/src/libraries/Microsoft.Extensions.Options.ConfigurationExtensions/tests/Common/ConfigurationExtensionsTests.cs b/src/libraries/Microsoft.Extensions.Options.ConfigurationExtensions/tests/Common/ConfigurationExtensionsTests.cs new file mode 100644 index 00000000000000..8c76472a2c339b --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Options.ConfigurationExtensions/tests/Common/ConfigurationExtensionsTests.cs @@ -0,0 +1,76 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; +using Xunit; + +namespace Microsoft.Extensions.Options.ConfigurationExtensions.Tests +{ + public partial class ConfigurationExtensionsTests + { + private static IConfiguration s_emptyConfig { get; } = new ConfigurationBuilder().Build(); + + [Fact] + public void TestNullHandling_OptionsBuilderExt_Bind() + { + // Null options builder. + OptionsBuilder? optionsBuilder = null; + Assert.Throws(() => optionsBuilder!.Bind(s_emptyConfig)); + Assert.Throws(() => optionsBuilder!.Bind(s_emptyConfig, _ => { })); + + // Null configuration. + optionsBuilder = CreateOptionsBuilder(); + Assert.Throws(() => optionsBuilder.Bind(config: null!)); + Assert.Throws(() => optionsBuilder.Bind(config: null!, _ => { })); + + // Null configureBinder. + optionsBuilder.Bind(s_emptyConfig, configureBinder: null); + } + + [Fact] + public void TestNullHandling_OptionsBuilderExt_BindConfiguration() + { + // Null options builder. + string configSectionPath = "FakeSectionPath"; + OptionsBuilder? optionsBuilder = null; + Assert.Throws(() => optionsBuilder!.BindConfiguration(configSectionPath)); + + // Null config section path. + optionsBuilder = CreateOptionsBuilder(); + Assert.Throws(() => optionsBuilder.BindConfiguration(configSectionPath: null!)); + + // Null configureBinder. + optionsBuilder.BindConfiguration(configSectionPath, configureBinder: null); + } + + [Fact] + public void TestNullHandling_IServiceCollectionExt_Configure() + { + // Null services + IServiceCollection? services = null; + string name = "Name"; + Assert.Throws(() => services!.Configure(s_emptyConfig)); + Assert.Throws(() => services!.Configure(name, s_emptyConfig)); + + // Null config. + services = new ServiceCollection(); + Assert.Throws(() => services.Configure(config: null!)); + Assert.Throws(() => services.Configure(name, config: null!)); + + // Null name. + services.Configure(name: null!, s_emptyConfig); + + // Null configureBinder. + services.Configure(s_emptyConfig, configureBinder: null); + services.Configure(name, s_emptyConfig, configureBinder: null); + } + + private static OptionsBuilder CreateOptionsBuilder() + { + var services = new ServiceCollection(); + return new OptionsBuilder(services, Options.DefaultName); + } + } +} diff --git a/src/libraries/Microsoft.Extensions.Options.ConfigurationExtensions/tests/Common/OptionsBuilderConfigurationExtensionsTests.cs b/src/libraries/Microsoft.Extensions.Options.ConfigurationExtensions/tests/Common/OptionsBuilderConfigurationExtensionsTests.cs index 6d1d4a018f097e..d90122ef74d8bc 100644 --- a/src/libraries/Microsoft.Extensions.Options.ConfigurationExtensions/tests/Common/OptionsBuilderConfigurationExtensionsTests.cs +++ b/src/libraries/Microsoft.Extensions.Options.ConfigurationExtensions/tests/Common/OptionsBuilderConfigurationExtensionsTests.cs @@ -31,7 +31,8 @@ public static void BindConfiguration_ThrowsForNullConfigurationSectionPath() Assert.Throws("configSectionPath", () => { - optionsBuilder.BindConfiguration(configSectionPath); + optionsBuilder + .BindConfiguration(configSectionPath); }); } @@ -170,8 +171,8 @@ public static void BindConfiguration_UpdatesOptionOnConfigurationUpdate() services.AddSingleton(new ConfigurationBuilder() .Add(configSource) .Build()); - OptionsBuilder optionsBuilder = services.AddOptions(); - _ = optionsBuilder.BindConfiguration(configSectionName); + _ = services.AddOptions() + .BindConfiguration(configSectionName); using ServiceProvider serviceProvider = services.BuildServiceProvider(); var optionsMonitor = serviceProvider.GetRequiredService>(); bool updateHasRun = false; diff --git a/src/libraries/Microsoft.Extensions.Options.ConfigurationExtensions/tests/SourceGenerationTests/ConfigurationExtensionsTest.Generator.cs b/src/libraries/Microsoft.Extensions.Options.ConfigurationExtensions/tests/SourceGenerationTests/ConfigurationExtensionsTest.Generator.cs new file mode 100644 index 00000000000000..9b3aff8c495a77 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Options.ConfigurationExtensions/tests/SourceGenerationTests/ConfigurationExtensionsTest.Generator.cs @@ -0,0 +1,177 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; +using Xunit; + +namespace Microsoft.Extensions.Options.ConfigurationExtensions.Tests +{ + public partial class ConfigurationExtensionsTests + { + // These are regression tests for https://github.com/dotnet/runtime/issues/90851 + // Source Generator Interceptors rely on identifying an accurate invocation + // source location (line and character positions). These tests cover newline + // and whitespace scenarios to ensure the interceptors get wired up correctly. + + [Fact] + public void TestBindingInvocationsWithNewlines_BindExtension() + { + OptionsBuilder? optionsBuilder = CreateOptionsBuilder(); + + // Newline between instance and invocation using configureBinder argument (with the dot on the first line) + optionsBuilder. + Bind(s_emptyConfig, configureBinder: null); + + // Newline between instance and invocation using configureBinder argument (with the dot on the second line) + optionsBuilder + .Bind(s_emptyConfig, configureBinder: null); + + // Newline between instance and invocation (with the dot on the first line) + optionsBuilder. + Bind(s_emptyConfig); + + // Newline between instance and invocation (with the dot on the second line) + optionsBuilder + .Bind(s_emptyConfig); + + // Newlines in every place possible + optionsBuilder + . + Bind + ( + s_emptyConfig + , + configureBinder + : + null + ) + ; + } + + [Fact] + public void TestBindingInvocationsWithNewlines_BindConfigurationExtension() + { + OptionsBuilder? optionsBuilder = CreateOptionsBuilder(); + + // Newline between instance and invocation using configureBinder argument (with the dot on the first line) + optionsBuilder. + BindConfiguration(configSectionPath: "path", + _ => { }); + + // Newline between instance and invocation using configureBinder argument (with the dot on the second line) + optionsBuilder + .BindConfiguration(configSectionPath: "path", + _ => { }); + + // Newlines between the instance and invocation and within the arguments. No indentation before invocation. + optionsBuilder. + BindConfiguration( + configSectionPath: "path", + _ => { }); + + // Newlines in every place possible + optionsBuilder + . + BindConfiguration + ( + configSectionPath + : + "path" + , + _ + => + { + } + ) + ; + } + + [Fact] + public void TestBindingInvocationsWithNewlines_ConfigureExtension() + { + OptionsBuilder? optionsBuilder = CreateOptionsBuilder(); + IServiceCollection services = new ServiceCollection(); + + // Newlines between each method call + services + .Configure(s_emptyConfig) + .AddOptions(); + + // Newlines in every place possible + services + . + Configure + < + FakeOptions + > + ( + name + : + null! + , + s_emptyConfig + ) + ; + } + + [Fact] + public void TestBindingInvocationsWithNewlines_StaticCalls() + { + OptionsBuilder? optionsBuilder = CreateOptionsBuilder(); + IServiceCollection services = new ServiceCollection(); + + // Bind: Newlines in every place possible + OptionsBuilderConfigurationExtensions + . + Bind + ( + optionsBuilder + , + s_emptyConfig + ) + ; + + // // BindConfiguration: Newlines in every place possible + OptionsBuilderConfigurationExtensions + . + BindConfiguration + ( + optionsBuilder + , + "path" + ); + + // Configure: Newlines in every place possible + OptionsConfigurationServiceCollectionExtensions + . + Configure + < + FakeOptions + > + ( + services + , + s_emptyConfig + ) + ; + } + + [Fact] + public void TestBindAndConfigureWithNamedParameters() + { + OptionsBuilder? optionsBuilder = CreateOptionsBuilder(); + IServiceCollection services = new ServiceCollection(); + + OptionsBuilderConfigurationExtensions.Bind(config: s_emptyConfig, optionsBuilder: optionsBuilder); + OptionsBuilderConfigurationExtensions.Bind(configureBinder: _ => { }, config: s_emptyConfig, optionsBuilder: optionsBuilder); + + OptionsBuilderConfigurationExtensions.BindConfiguration(configureBinder: _ => { }, configSectionPath: "path", optionsBuilder: optionsBuilder); + + OptionsConfigurationServiceCollectionExtensions.Configure(config: s_emptyConfig, services: services); + OptionsConfigurationServiceCollectionExtensions.Configure(name: "", config: s_emptyConfig, services: services); + OptionsConfigurationServiceCollectionExtensions.Configure(configureBinder: _ => { }, config: s_emptyConfig, services: services); + OptionsConfigurationServiceCollectionExtensions.Configure(name: "", configureBinder: _ => { }, config: s_emptyConfig, services: services); + } + } +} diff --git a/src/libraries/Microsoft.Extensions.Options.ConfigurationExtensions/tests/SourceGenerationTests/Microsoft.Extensions.Options.ConfigurationExtensions.SourceGeneration.Tests.csproj b/src/libraries/Microsoft.Extensions.Options.ConfigurationExtensions/tests/SourceGenerationTests/Microsoft.Extensions.Options.ConfigurationExtensions.SourceGeneration.Tests.csproj index f1843ebff94a2a..2bdacc95ff39e6 100644 --- a/src/libraries/Microsoft.Extensions.Options.ConfigurationExtensions/tests/SourceGenerationTests/Microsoft.Extensions.Options.ConfigurationExtensions.SourceGeneration.Tests.csproj +++ b/src/libraries/Microsoft.Extensions.Options.ConfigurationExtensions/tests/SourceGenerationTests/Microsoft.Extensions.Options.ConfigurationExtensions.SourceGeneration.Tests.csproj @@ -2,9 +2,11 @@ enable $(NetCoreAppCurrent);$(NetFrameworkMinimum) - true $(DefineConstants);BUILDING_SOURCE_GENERATOR_TESTS;ROSLYN4_0_OR_GREATER;ROSLYN4_4_OR_GREATER - + $(Features);InterceptorsPreview + + $(InterceptorsPreviewNamespaces);Microsoft.Extensions.Configuration.Binder.SourceGeneration + true true @@ -20,10 +22,12 @@ + - + + diff --git a/src/libraries/Microsoft.Extensions.Options.ConfigurationExtensions/tests/UnitTests/Microsoft.Extensions.Options.ConfigurationExtensions.UnitTests.csproj b/src/libraries/Microsoft.Extensions.Options.ConfigurationExtensions/tests/UnitTests/Microsoft.Extensions.Options.ConfigurationExtensions.UnitTests.csproj index 3a5db72bf30f7d..512e0018a400c0 100644 --- a/src/libraries/Microsoft.Extensions.Options.ConfigurationExtensions/tests/UnitTests/Microsoft.Extensions.Options.ConfigurationExtensions.UnitTests.csproj +++ b/src/libraries/Microsoft.Extensions.Options.ConfigurationExtensions/tests/UnitTests/Microsoft.Extensions.Options.ConfigurationExtensions.UnitTests.csproj @@ -16,10 +16,11 @@ + - + \ No newline at end of file diff --git a/src/libraries/Microsoft.Extensions.Options.DataAnnotations/src/DataAnnotationValidateOptions.cs b/src/libraries/Microsoft.Extensions.Options.DataAnnotations/src/DataAnnotationValidateOptions.cs index e8f4b606882cf9..5b734edf32738c 100644 --- a/src/libraries/Microsoft.Extensions.Options.DataAnnotations/src/DataAnnotationValidateOptions.cs +++ b/src/libraries/Microsoft.Extensions.Options.DataAnnotations/src/DataAnnotationValidateOptions.cs @@ -95,7 +95,8 @@ private static bool TryValidateOptions(object options, string qualifiedName, Lis foreach (PropertyInfo propertyInfo in options.GetType().GetProperties(BindingFlags.Instance | BindingFlags.Public)) { - if (propertyInfo.GetMethod is null) + // Indexers are properties which take parameters. Ignore them. + if (propertyInfo.GetMethod is null || propertyInfo.GetMethod.GetParameters().Length > 0) { continue; } diff --git a/src/libraries/Microsoft.Extensions.Options.DataAnnotations/src/PACKAGE.md b/src/libraries/Microsoft.Extensions.Options.DataAnnotations/src/PACKAGE.md new file mode 100644 index 00000000000000..368518bbedb313 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Options.DataAnnotations/src/PACKAGE.md @@ -0,0 +1,75 @@ +## About + + + +Microsoft.Extensions.Options.DataAnnotations is a library that adds extra validation functionality to configuration options using data annotations. + +It allows to apply validation rules to configuration classes to ensure they are correctly configured before the application starts running. + +This way, misconfiguration issues are catched early during the application startup rather than facing them later in production. + +## Key Features + + + +* Enables validation of configuration options using data annotations. +* Early detection of misconfiguration issues during application startup. + +## How to Use + + + +While configuring services, chain the `ValidateDataAnnotations()` and `ValidateOnStart()` methods to the `AddOptions` method for your configuration class. + +Here is a simple example demonstrating how to validate options on application startup: + +```csharp +services + .AddOptions() + .ValidateDataAnnotations() + .ValidateOnStart(); +``` + +In the configuration class, use data annotations to specify the validation rules. + +For instance, in the following `MyOptions` class, the `Name` property is marked as required: + +```csharp +using System.ComponentModel.DataAnnotations; + +public class MyOptions +{ + [Required(AllowEmptyStrings = false)] + public string Name { get; set; } +} +``` + +With this setup, an error indicating that the `Name` field is required will be thrown upon startup if it hasn't been configured. + +## Main Types + + + +The main types provided by this library are: + +* `Microsoft.Extensions.Options.DataAnnotationsValidateOptions` +* `Microsoft.Extensions.DependencyInjection.OptionsBuilderDataAnnotationsExtensions` + +## Additional Documentation + + + +* [Conceptual documentation](https://learn.microsoft.com/dotnet/core/extensions/options) +* [API documentation](https://learn.microsoft.com/dotnet/api/microsoft.extensions.options.dataannotationvalidateoptions-1) + +## Related Packages + + + +Core options: [Microsoft.Extensions.Options](https://www.nuget.org/packages/Microsoft.Extensions.Options/) + +## Feedback & Contributing + + + +Microsoft.Extensions.Options.DataAnnotations is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). diff --git a/src/libraries/Microsoft.Extensions.Options/gen/DiagDescriptors.cs b/src/libraries/Microsoft.Extensions.Options/gen/DiagDescriptors.cs index 474686189fb126..49562a0a128c2b 100644 --- a/src/libraries/Microsoft.Extensions.Options/gen/DiagDescriptors.cs +++ b/src/libraries/Microsoft.Extensions.Options/gen/DiagDescriptors.cs @@ -105,5 +105,19 @@ internal sealed class DiagDescriptors : DiagDescriptorsBase messageFormat: SR.InaccessibleValidationAttributeMessage, category: Category, defaultSeverity: DiagnosticSeverity.Info); + + public static DiagnosticDescriptor OptionsUnsupportedLanguageVersion { get; } = Make( + id: "SYSLIB1216", + title: SR.OptionsUnsupportedLanguageVersionTitle, + messageFormat: SR.OptionsUnsupportedLanguageVersionMessage, + category: Category, + defaultSeverity: DiagnosticSeverity.Error); + + public static DiagnosticDescriptor IncompatibleWithTypeForValidationAttribute { get; } = Make( + id: "SYSLIB1217", + title: SR.TypeCannotBeUsedWithTheValidationAttributeTitle, + messageFormat: SR.TypeCannotBeUsedWithTheValidationAttributeMessage, + category: Category, + defaultSeverity: DiagnosticSeverity.Warning); } } diff --git a/src/libraries/Microsoft.Extensions.Options/gen/Emitter.cs b/src/libraries/Microsoft.Extensions.Options/gen/Emitter.cs index adbaa874ad7f37..41609ad4b2010a 100644 --- a/src/libraries/Microsoft.Extensions.Options/gen/Emitter.cs +++ b/src/libraries/Microsoft.Extensions.Options/gen/Emitter.cs @@ -3,10 +3,13 @@ using System; using System.Collections.Generic; +using System.Collections.Immutable; using System.Linq; +using System.Text; using System.Threading; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.DotnetRuntime.Extensions; namespace Microsoft.Extensions.Options.Generators { @@ -16,34 +19,39 @@ namespace Microsoft.Extensions.Options.Generators internal sealed class Emitter : EmitterBase { private const string StaticFieldHolderClassesNamespace = "__OptionValidationStaticInstances"; + internal const string StaticGeneratedValidationAttributesClassesNamespace = "__OptionValidationGeneratedAttributes"; + internal const string StaticAttributeClassNamePrefix = "__SourceGen_"; + internal const string StaticGeneratedMaxLengthAttributeClassesName = "__SourceGen_MaxLengthAttribute"; private const string StaticListType = "global::System.Collections.Generic.List"; private const string StaticValidationResultType = "global::System.ComponentModel.DataAnnotations.ValidationResult"; private const string StaticValidationAttributeType = "global::System.ComponentModel.DataAnnotations.ValidationAttribute"; - + private const string StaticValidationContextType = "global::System.ComponentModel.DataAnnotations.ValidationContext"; private string _staticValidationAttributeHolderClassName = "__Attributes"; private string _staticValidatorHolderClassName = "__Validators"; private string _staticValidationAttributeHolderClassFQN; private string _staticValidatorHolderClassFQN; - private string _modifier; + private string _TryGetValueNullableAnnotation; + private readonly SymbolHolder _symbolHolder; + private readonly OptionsSourceGenContext _optionsSourceGenContext; + private sealed record StaticFieldInfo(string FieldTypeFQN, int FieldOrder, string FieldName, IList InstantiationLines); - public Emitter(Compilation compilation, bool emitPreamble = true) : base(emitPreamble) + public Emitter(Compilation compilation, SymbolHolder symbolHolder, OptionsSourceGenContext optionsSourceGenContext, bool emitPreamble = true) : base(emitPreamble) { - if (((CSharpCompilation)compilation).LanguageVersion >= Microsoft.CodeAnalysis.CSharp.LanguageVersion.CSharp11) - { - _modifier = "file"; - } - else + _optionsSourceGenContext = optionsSourceGenContext; + + if (!_optionsSourceGenContext.IsLangVersion11AndAbove) { - _modifier = "internal"; - string suffix = $"_{new Random().Next():X8}"; - _staticValidationAttributeHolderClassName += suffix; - _staticValidatorHolderClassName += suffix; + _staticValidationAttributeHolderClassName += _optionsSourceGenContext.Suffix; + _staticValidatorHolderClassName += _optionsSourceGenContext.Suffix; } _staticValidationAttributeHolderClassFQN = $"global::{StaticFieldHolderClassesNamespace}.{_staticValidationAttributeHolderClassName}"; _staticValidatorHolderClassFQN = $"global::{StaticFieldHolderClassesNamespace}.{_staticValidatorHolderClassName}"; + _TryGetValueNullableAnnotation = GetNullableAnnotationStringForTryValidateValueToUseInGeneratedCode(compilation); + + _symbolHolder = symbolHolder; } public string Emit( @@ -61,10 +69,36 @@ public string Emit( GenStaticClassWithStaticReadonlyFields(staticValidationAttributesDict.Values, StaticFieldHolderClassesNamespace, _staticValidationAttributeHolderClassName); GenStaticClassWithStaticReadonlyFields(staticValidatorsDict.Values, StaticFieldHolderClassesNamespace, _staticValidatorHolderClassName); + GenValidationAttributesClasses(); return Capture(); } + /// + /// Returns the nullable annotation string to use in the code generation according to the first parameter of + /// is nullable annotated. + /// + /// The to consider for analysis. + /// "!" if the first parameter is not nullable annotated, otherwise an empty string. + /// + /// In .NET 8.0 we have changed the nullable annotation on first parameter of the method cref="System.ComponentModel.DataAnnotations.Validator.TryValidateValue(object, ValidationContext, ICollection{ValidationResult}, IEnumerable{ValidationAttribute})"/> + /// The source generator need to detect if we need to append "!" to the first parameter of the method call when running on down-level versions. + /// + private static string GetNullableAnnotationStringForTryValidateValueToUseInGeneratedCode(Compilation compilation) + { + INamedTypeSymbol? validatorTypeSymbol = compilation.GetBestTypeByMetadataName("System.ComponentModel.DataAnnotations.Validator"); + if (validatorTypeSymbol is not null) + { + ImmutableArray members = validatorTypeSymbol.GetMembers("TryValidateValue"); + if (members.Length == 1 && members[0] is IMethodSymbol tryValidateValueMethod) + { + return tryValidateValueMethod.Parameters[0].NullableAnnotation == NullableAnnotation.NotAnnotated ? "!" : string.Empty; + } + } + + return "!"; + } + private void GenValidatorType(ValidatorType vt, ref Dictionary staticValidationAttributesDict, ref Dictionary staticValidatorsDict) { if (vt.Namespace.Length > 0) @@ -117,7 +151,7 @@ private void GenStaticClassWithStaticReadonlyFields(IEnumerable OutOpenBrace(); OutGeneratedCodeAttribute(); - OutLn($"{_modifier} static class {className}"); + OutLn($"{_optionsSourceGenContext.ClassModifier} static class {className}"); OutOpenBrace(); var staticValidationAttributes = staticFields @@ -157,11 +191,401 @@ private void GenStaticClassWithStaticReadonlyFields(IEnumerable OutCloseBrace(); } + public void EmitMaxLengthAttribute(string modifier, string prefix, string className, string linesToInsert, string suffix) + { + OutGeneratedCodeAttribute(); + + string qualifiedClassName = $"{prefix}{suffix}_{className}"; + + OutLn($$""" +[global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + {{modifier}} class {{qualifiedClassName}} : {{StaticValidationAttributeType}} + { + private const int MaxAllowableLength = -1; + private static string DefaultErrorMessageString => "The field {0} must be a string or array type with a maximum length of '{1}'."; + public {{qualifiedClassName}}(int length) : base(() => DefaultErrorMessageString) { Length = length; } + public {{qualifiedClassName}}(): base(() => DefaultErrorMessageString) { Length = MaxAllowableLength; } + public int Length { get; } + public override string FormatErrorMessage(string name) => string.Format(global::System.Globalization.CultureInfo.CurrentCulture, ErrorMessageString, name, Length); + public override bool IsValid(object? value) + { + if (Length == 0 || Length < -1) + { + throw new global::System.InvalidOperationException("MaxLengthAttribute must have a Length value that is greater than zero. Use MaxLength() without parameters to indicate that the string or array can have the maximum allowable length."); + } + if (value == null || MaxAllowableLength == Length) + { + return true; + } + + int length; + if (value is string stringValue) + { + length = stringValue.Length; + } + else if (value is System.Collections.ICollection collectionValue) + { + length = collectionValue.Count; + } + {{linesToInsert}}else + { + throw new global::System.InvalidCastException($"The field of type {value.GetType()} must be a string, array, or ICollection type."); + } + + return length <= Length; + } + } +"""); + } + + public void EmitMinLengthAttribute(string modifier, string prefix, string className, string linesToInsert, string suffix) + { + OutGeneratedCodeAttribute(); + + string qualifiedClassName = $"{prefix}{suffix}_{className}"; + + OutLn($$""" +[global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + {{modifier}} class {{qualifiedClassName}} : {{StaticValidationAttributeType}} + { + private static string DefaultErrorMessageString => "The field {0} must be a string or array type with a minimum length of '{1}'."; + + public {{qualifiedClassName}}(int length) : base(() => DefaultErrorMessageString) { Length = length; } + public int Length { get; } + public override bool IsValid(object? value) + { + if (Length < -1) + { + throw new global::System.InvalidOperationException("MinLengthAttribute must have a Length value that is zero or greater."); + } + if (value == null) + { + return true; + } + + int length; + if (value is string stringValue) + { + length = stringValue.Length; + } + else if (value is System.Collections.ICollection collectionValue) + { + length = collectionValue.Count; + } + {{linesToInsert}}else + { + throw new global::System.InvalidCastException($"The field of type {value.GetType()} must be a string, array, or ICollection type."); + } + + return length >= Length; + } + public override string FormatErrorMessage(string name) => string.Format(global::System.Globalization.CultureInfo.CurrentCulture, ErrorMessageString, name, Length); + } +"""); + } + + public void EmitLengthAttribute(string modifier, string prefix, string className, string linesToInsert, string suffix) + { + OutGeneratedCodeAttribute(); + + string qualifiedClassName = $"{prefix}{suffix}_{className}"; + + OutLn($$""" +[global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + {{modifier}} class {{qualifiedClassName}} : {{StaticValidationAttributeType}} + { + private static string DefaultErrorMessageString => "The field {0} must be a string or collection type with a minimum length of '{1}' and maximum length of '{2}'."; + public {{qualifiedClassName}}(int minimumLength, int maximumLength) : base(() => DefaultErrorMessageString) { MinimumLength = minimumLength; MaximumLength = maximumLength; } + public int MinimumLength { get; } + public int MaximumLength { get; } + public override bool IsValid(object? value) + { + if (MinimumLength < 0) + { + throw new global::System.InvalidOperationException("LengthAttribute must have a MinimumLength value that is zero or greater."); + } + if (MaximumLength < MinimumLength) + { + throw new global::System.InvalidOperationException("LengthAttribute must have a MaximumLength value that is greater than or equal to MinimumLength."); + } + if (value == null) + { + return true; + } + + int length; + if (value is string stringValue) + { + length = stringValue.Length; + } + else if (value is System.Collections.ICollection collectionValue) + { + length = collectionValue.Count; + } + {{linesToInsert}}else + { + throw new global::System.InvalidCastException($"The field of type {value.GetType()} must be a string, array, or ICollection type."); + } + + return (uint)(length - MinimumLength) <= (uint)(MaximumLength - MinimumLength); + } + public override string FormatErrorMessage(string name) => string.Format(global::System.Globalization.CultureInfo.CurrentCulture, ErrorMessageString, name, MinimumLength, MaximumLength); + } +"""); + } + + public void EmitCompareAttribute(string modifier, string prefix, string className, string linesToInsert, string suffix) + { + OutGeneratedCodeAttribute(); + + string qualifiedClassName = $"{prefix}{suffix}_{className}"; + + OutLn($$""" +[global::System.AttributeUsage(global::System.AttributeTargets.Property, AllowMultiple = false)] + {{modifier}} class {{qualifiedClassName}} : {{StaticValidationAttributeType}} + { + private static string DefaultErrorMessageString => "'{0}' and '{1}' do not match."; + public {{qualifiedClassName}}(string otherProperty) : base(() => DefaultErrorMessageString) + { + if (otherProperty == null) + { + throw new global::System.ArgumentNullException(nameof(otherProperty)); + } + OtherProperty = otherProperty; + } + public string OtherProperty { get; } + public override bool RequiresValidationContext => true; + + protected override {{StaticValidationResultType}}? IsValid(object? value, {{StaticValidationContextType}} validationContext) + { + bool result = true; + + {{linesToInsert}} + if (!result) + { + string[]? memberNames = validationContext.MemberName is null ? null : new string[] { validationContext.MemberName }; + return new {{StaticValidationResultType}}(FormatErrorMessage(validationContext.DisplayName), memberNames); + } + + return null; + } + public override string FormatErrorMessage(string name) => string.Format(global::System.Globalization.CultureInfo.CurrentCulture, ErrorMessageString, name, OtherProperty); + } +"""); + } + + public void EmitRangeAttribute(string modifier, string prefix, string className, string suffix) + { + OutGeneratedCodeAttribute(); + + string qualifiedClassName = $"{prefix}{suffix}_{className}"; + + OutLn($$""" +[global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + {{modifier}} class {{qualifiedClassName}} : {{StaticValidationAttributeType}} + { + public {{qualifiedClassName}}(int minimum, int maximum) : base() + { + Minimum = minimum; + Maximum = maximum; + OperandType = typeof(int); + } + public {{qualifiedClassName}}(double minimum, double maximum) : base() + { + Minimum = minimum; + Maximum = maximum; + OperandType = typeof(double); + } + public {{qualifiedClassName}}(global::System.Type type, string minimum, string maximum) : base() + { + OperandType = type; + NeedToConvertMinMax = true; + Minimum = minimum; + Maximum = maximum; + } + public object Minimum { get; private set; } + public object Maximum { get; private set; } + public bool MinimumIsExclusive { get; set; } + public bool MaximumIsExclusive { get; set; } + public global::System.Type OperandType { get; } + public bool ParseLimitsInInvariantCulture { get; set; } + public bool ConvertValueInInvariantCulture { get; set; } + public override string FormatErrorMessage(string name) => + string.Format(global::System.Globalization.CultureInfo.CurrentCulture, GetValidationErrorMessage(), name, Minimum, Maximum); + private bool NeedToConvertMinMax { get; } + private bool Initialized { get; set; } + public override bool IsValid(object? value) + { + if (!Initialized) + { + if (Minimum is null || Maximum is null) + { + throw new global::System.InvalidOperationException("The minimum and maximum values must be set to valid values."); + } + if (NeedToConvertMinMax) + { + System.Globalization.CultureInfo culture = ParseLimitsInInvariantCulture ? global::System.Globalization.CultureInfo.InvariantCulture : global::System.Globalization.CultureInfo.CurrentCulture; + Minimum = ConvertValue(Minimum, culture) ?? throw new global::System.InvalidOperationException("The minimum and maximum values must be set to valid values."); + Maximum = ConvertValue(Maximum, culture) ?? throw new global::System.InvalidOperationException("The minimum and maximum values must be set to valid values."); + } + int cmp = ((global::System.IComparable)Minimum).CompareTo((global::System.IComparable)Maximum); + if (cmp > 0) + { + throw new global::System.InvalidOperationException("The maximum value '{Maximum}' must be greater than or equal to the minimum value '{Minimum}'."); + } + else if (cmp == 0 && (MinimumIsExclusive || MaximumIsExclusive)) + { + throw new global::System.InvalidOperationException("Cannot use exclusive bounds when the maximum value is equal to the minimum value."); + } + Initialized = true; + } + + if (value is null or string { Length: 0 }) + { + return true; + } + + System.Globalization.CultureInfo formatProvider = ConvertValueInInvariantCulture ? global::System.Globalization.CultureInfo.InvariantCulture : global::System.Globalization.CultureInfo.CurrentCulture; + object? convertedValue; + + try + { + convertedValue = ConvertValue(value, formatProvider); + } + catch (global::System.Exception e) when (e is global::System.FormatException or global::System.InvalidCastException or global::System.NotSupportedException) + { + return false; + } + + var min = (global::System.IComparable)Minimum; + var max = (global::System.IComparable)Maximum; + + return + (MinimumIsExclusive ? min.CompareTo(convertedValue) < 0 : min.CompareTo(convertedValue) <= 0) && + (MaximumIsExclusive ? max.CompareTo(convertedValue) > 0 : max.CompareTo(convertedValue) >= 0); + } + private string GetValidationErrorMessage() + { + return (MinimumIsExclusive, MaximumIsExclusive) switch + { + (false, false) => "The field {0} must be between {1} and {2}.", + (true, false) => "The field {0} must be between {1} exclusive and {2}.", + (false, true) => "The field {0} must be between {1} and {2} exclusive.", + (true, true) => "The field {0} must be between {1} exclusive and {2} exclusive.", + }; + } + private object? ConvertValue(object? value, System.Globalization.CultureInfo formatProvider) + { + if (value is string stringValue) + { + value = global::System.Convert.ChangeType(stringValue, OperandType, formatProvider); + } + else + { + value = global::System.Convert.ChangeType(value, OperandType, formatProvider); + } + return value; + } + } +"""); + } + + private string GenerateStronglyTypedCodeForLengthAttributes(HashSet data) + { + if (data.Count == 0) + { + return string.Empty; + } + + StringBuilder sb = new(); + string padding = GetPaddingString(3); + + foreach (var type in data) + { + string typeName = (string)type; + sb.AppendLine($"else if (value is {typeName})"); + sb.AppendLine($"{padding}{{"); + sb.AppendLine($"{padding} length = (({typeName})value).Count;"); + sb.AppendLine($"{padding}}}"); + sb.Append($"{padding}"); + } + + return sb.ToString(); + } + + private string GenerateStronglyTypedCodeForCompareAttribute(HashSet? data) + { + if (data is null || data.Count == 0) + { + return string.Empty; + } + + StringBuilder sb = new(); + string padding = GetPaddingString(3); + bool first = true; + + foreach (var obj in data) + { + (string type, string property) = ((string, string))obj; + sb.Append(first ? $"if " : $"{padding}else if "); + sb.AppendLine($"(validationContext.ObjectInstance is {type} && OtherProperty == \"{property}\")"); + sb.AppendLine($"{padding}{{"); + sb.AppendLine($"{padding} result = Equals(value, (({type})validationContext.ObjectInstance).{property});"); + sb.AppendLine($"{padding}}}"); + first = false; + } + + return sb.ToString(); + } + + private void GenValidationAttributesClasses() + { + if (_optionsSourceGenContext.AttributesToGenerate.Count == 0) + { + return; + } + + var attributesData = _optionsSourceGenContext.AttributesToGenerate.OrderBy(static kvp => kvp.Key, StringComparer.Ordinal).ToArray(); + + OutLn($"namespace {StaticGeneratedValidationAttributesClassesNamespace}"); + OutOpenBrace(); + + foreach (var attributeData in attributesData) + { + if (attributeData.Key == _symbolHolder.MaxLengthAttributeSymbol.Name) + { + string linesToInsert = attributeData.Value is not null ? GenerateStronglyTypedCodeForLengthAttributes((HashSet)attributeData.Value) : string.Empty; + EmitMaxLengthAttribute(_optionsSourceGenContext.ClassModifier, Emitter.StaticAttributeClassNamePrefix, attributeData.Key, linesToInsert, _optionsSourceGenContext.Suffix); + } + else if (attributeData.Key == _symbolHolder.MinLengthAttributeSymbol.Name) + { + string linesToInsert = attributeData.Value is not null ? GenerateStronglyTypedCodeForLengthAttributes((HashSet)attributeData.Value) : string.Empty; + EmitMinLengthAttribute(_optionsSourceGenContext.ClassModifier, Emitter.StaticAttributeClassNamePrefix, attributeData.Key, linesToInsert, _optionsSourceGenContext.Suffix); + } + else if (_symbolHolder.LengthAttributeSymbol is not null && attributeData.Key == _symbolHolder.LengthAttributeSymbol.Name) + { + string linesToInsert = attributeData.Value is not null ? GenerateStronglyTypedCodeForLengthAttributes((HashSet)attributeData.Value) : string.Empty; + EmitLengthAttribute(_optionsSourceGenContext.ClassModifier, Emitter.StaticAttributeClassNamePrefix, attributeData.Key, linesToInsert, _optionsSourceGenContext.Suffix); + } + else if (attributeData.Key == _symbolHolder.CompareAttributeSymbol.Name && attributeData.Value is not null) + { + string linesToInsert = GenerateStronglyTypedCodeForCompareAttribute((HashSet)attributeData.Value); + EmitCompareAttribute(_optionsSourceGenContext.ClassModifier, Emitter.StaticAttributeClassNamePrefix, attributeData.Key, linesToInsert: linesToInsert, _optionsSourceGenContext.Suffix); + } + else if (attributeData.Key == _symbolHolder.RangeAttributeSymbol.Name) + { + EmitRangeAttribute(_optionsSourceGenContext.ClassModifier, Emitter.StaticAttributeClassNamePrefix, attributeData.Key, _optionsSourceGenContext.Suffix); + } + } + + OutCloseBrace(); + } + private void GenModelSelfValidationIfNecessary(ValidatedModel modelToValidate) { if (modelToValidate.SelfValidates) { - OutLn($"builder.AddResults(((global::System.ComponentModel.DataAnnotations.IValidatableObject)options).Validate(context));"); + OutLn($"(builder ??= new()).AddResults(((global::System.ComponentModel.DataAnnotations.IValidatableObject)options).Validate(context));"); OutLn(); } } @@ -180,11 +604,18 @@ private void GenModelValidationMethod( OutLn($"/// Validation result."); OutGeneratedCodeAttribute(); + if (_symbolHolder.UnconditionalSuppressMessageAttributeSymbol is not null) + { + // We disable the warning on `new ValidationContext(object)` usage as we use it in a safe way that not require executing the reflection code. + // This is done by initializing the DisplayName in the context which is the part trigger reflection if it is not initialized. + OutLn($"[System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage(\"Trimming\", \"IL2026:RequiresUnreferencedCode\","); + OutLn($" Justification = \"The created ValidationContext object is used in a way that never call reflection\")]"); + } + OutLn($"public {(makeStatic ? "static " : string.Empty)}global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, {modelToValidate.Name} options)"); OutOpenBrace(); - OutLn($"var baseName = (string.IsNullOrEmpty(name) ? \"{modelToValidate.SimpleName}\" : name) + \".\";"); - OutLn($"var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder();"); - OutLn($"var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options);"); + OutLn($"global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null;"); + OutLn($"var context = new {StaticValidationContextType}(options);"); int capacity = modelToValidate.MembersToValidate.Max(static vm => vm.ValidationAttributes.Count); if (capacity > 0) @@ -199,33 +630,33 @@ private void GenModelValidationMethod( { if (vm.ValidationAttributes.Count > 0) { - GenMemberValidation(vm, ref staticValidationAttributesDict, cleanListsBeforeUse); + GenMemberValidation(vm, modelToValidate.SimpleName, ref staticValidationAttributesDict, cleanListsBeforeUse); cleanListsBeforeUse = true; OutLn(); } if (vm.TransValidatorType is not null) { - GenTransitiveValidation(vm, ref staticValidatorsDict); + GenTransitiveValidation(vm, modelToValidate.SimpleName, ref staticValidatorsDict); OutLn(); } if (vm.EnumerationValidatorType is not null) { - GenEnumerationValidation(vm, ref staticValidatorsDict); + GenEnumerationValidation(vm, modelToValidate.SimpleName, ref staticValidatorsDict); OutLn(); } } GenModelSelfValidationIfNecessary(modelToValidate); - OutLn($"return builder.Build();"); + OutLn($"return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build();"); OutCloseBrace(); } - private void GenMemberValidation(ValidatedMember vm, ref Dictionary staticValidationAttributesDict, bool cleanListsBeforeUse) + private void GenMemberValidation(ValidatedMember vm, string modelName, ref Dictionary staticValidationAttributesDict, bool cleanListsBeforeUse) { OutLn($"context.MemberName = \"{vm.Name}\";"); - OutLn($"context.DisplayName = baseName + \"{vm.Name}\";"); + OutLn($"context.DisplayName = string.IsNullOrEmpty(name) ? \"{modelName}.{vm.Name}\" : $\"{{name}}.{vm.Name}\";"); if (cleanListsBeforeUse) { @@ -239,9 +670,9 @@ private void GenMemberValidation(ValidatedMember vm, ref Dictionary staticValidatorsDict) + private void GenTransitiveValidation(ValidatedMember vm, string modelName, ref Dictionary staticValidatorsDict) { string callSequence; if (vm.TransValidateTypeIsSynthetic) @@ -321,20 +752,22 @@ private void GenTransitiveValidation(ValidatedMember vm, ref Dictionary staticValidatorsDict) + private void GenEnumerationValidation(ValidatedMember vm, string modelName, ref Dictionary staticValidatorsDict) { var valueAccess = (vm.IsValueType && vm.IsNullable) ? ".Value" : string.Empty; var enumeratedValueAccess = (vm.EnumeratedIsNullable && vm.EnumeratedIsValueType) ? ".Value" : string.Empty; @@ -365,14 +798,16 @@ private void GenEnumerationValidation(ValidatedMember vm, ref Dictionary types, SourceProductionContext context) { + if (types.Length == 0) + { + return; + } + if (!SymbolLoader.TryLoad(compilation, out var symbolHolder)) { // Not eligible compilation return; } - var parser = new Parser(compilation, context.ReportDiagnostic, symbolHolder!, context.CancellationToken); + OptionsSourceGenContext optionsSourceGenContext = new(compilation); + + var parser = new Parser(compilation, context.ReportDiagnostic, symbolHolder!, optionsSourceGenContext, context.CancellationToken); var validatorTypes = parser.GetValidatorTypes(types); if (validatorTypes.Count > 0) { - var emitter = new Emitter(compilation); + var emitter = new Emitter(compilation, symbolHolder!, optionsSourceGenContext); var result = emitter.Emit(validatorTypes, context.CancellationToken); context.AddSource("Validators.g.cs", SourceText.From(result, Encoding.UTF8)); diff --git a/src/libraries/Microsoft.Extensions.Options/gen/Microsoft.Extensions.Options.SourceGeneration.csproj b/src/libraries/Microsoft.Extensions.Options/gen/Microsoft.Extensions.Options.SourceGeneration.csproj index 41ceaf6739c333..f5bad279371755 100644 --- a/src/libraries/Microsoft.Extensions.Options/gen/Microsoft.Extensions.Options.SourceGeneration.csproj +++ b/src/libraries/Microsoft.Extensions.Options/gen/Microsoft.Extensions.Options.SourceGeneration.csproj @@ -20,6 +20,7 @@ + @@ -29,6 +30,7 @@ + diff --git a/src/libraries/Microsoft.Extensions.Options/gen/OptionsSourceGenContext.cs b/src/libraries/Microsoft.Extensions.Options/gen/OptionsSourceGenContext.cs new file mode 100644 index 00000000000000..8da3e317769627 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Options/gen/OptionsSourceGenContext.cs @@ -0,0 +1,83 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Runtime.Versioning; + +namespace Microsoft.Extensions.Options.Generators +{ + internal sealed class OptionsSourceGenContext + { + public OptionsSourceGenContext(Compilation compilation) + { + IsLangVersion11AndAbove = ((CSharpCompilation)compilation).LanguageVersion >= Microsoft.CodeAnalysis.CSharp.LanguageVersion.CSharp11; + ClassModifier = IsLangVersion11AndAbove ? "file" : "internal"; + Suffix = IsLangVersion11AndAbove ? "" : $"_{GetNonRandomizedHashCode(compilation.SourceModule.Name):X8}"; + } + + internal string Suffix { get; } + internal string ClassModifier { get; } + internal bool IsLangVersion11AndAbove { get; } + internal Dictionary?> AttributesToGenerate { get; set; } = new Dictionary?>(); + + internal void EnsureTrackingAttribute(string attributeName, bool createValue, out HashSet? value) + { + bool exist = AttributesToGenerate.TryGetValue(attributeName, out value); + if (value is null) + { + if (createValue) + { + value = new HashSet(); + } + + if (!exist || createValue) + { + AttributesToGenerate[attributeName] = value; + } + } + } + + internal static bool IsConvertibleBasicType(ITypeSymbol typeSymbol) + { + return typeSymbol.SpecialType switch + { + SpecialType.System_Boolean => true, + SpecialType.System_Byte => true, + SpecialType.System_Char => true, + SpecialType.System_DateTime => true, + SpecialType.System_Decimal => true, + SpecialType.System_Double => true, + SpecialType.System_Int16 => true, + SpecialType.System_Int32 => true, + SpecialType.System_Int64 => true, + SpecialType.System_SByte => true, + SpecialType.System_Single => true, + SpecialType.System_UInt16 => true, + SpecialType.System_UInt32 => true, + SpecialType.System_UInt64 => true, + SpecialType.System_String => true, + _ => false, + }; + } + + /// + /// Returns a non-randomized hash code for the given string. + /// We always return a positive value. + /// + internal static int GetNonRandomizedHashCode(string s) + { + uint result = 2166136261u; + foreach (char c in s) + { + result = (c ^ result) * 16777619; + } + + return Math.Abs((int)result); + } + } +} diff --git a/src/libraries/Microsoft.Extensions.Options/gen/Parser.cs b/src/libraries/Microsoft.Extensions.Options/gen/Parser.cs index 07ea7de51bf9a4..47cb71c3411cde 100644 --- a/src/libraries/Microsoft.Extensions.Options/gen/Parser.cs +++ b/src/libraries/Microsoft.Extensions.Options/gen/Parser.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Collections.Immutable; using System.Globalization; using System.Linq; using System.Text; @@ -24,6 +25,7 @@ internal sealed class Parser private readonly Compilation _compilation; private readonly Action _reportDiagnostic; private readonly SymbolHolder _symbolHolder; + private readonly OptionsSourceGenContext _optionsSourceGenContext; private readonly Dictionary _synthesizedValidators = new(SymbolEqualityComparer.Default); private readonly HashSet _visitedModelTypes = new(SymbolEqualityComparer.Default); @@ -31,12 +33,14 @@ public Parser( Compilation compilation, Action reportDiagnostic, SymbolHolder symbolHolder, + OptionsSourceGenContext optionsSourceGenContext, CancellationToken cancellationToken) { _compilation = compilation; _cancellationToken = cancellationToken; _reportDiagnostic = reportDiagnostic; _symbolHolder = symbolHolder; + _optionsSourceGenContext = optionsSourceGenContext; } public IReadOnlyList GetValidatorTypes(IEnumerable<(TypeDeclarationSyntax TypeSyntax, SemanticModel SemanticModel)> classes) @@ -143,6 +147,13 @@ public IReadOnlyList GetValidatorTypes(IEnumerable<(TypeDeclarati results.AddRange(_synthesizedValidators.Values); _synthesizedValidators.Clear(); + if (results.Count > 0 && _compilation is CSharpCompilation { LanguageVersion : LanguageVersion version and < LanguageVersion.CSharp8 }) + { + // we only support C# 8.0 and above + Diag(DiagDescriptors.OptionsUnsupportedLanguageVersion, null, version.ToDisplayString(), LanguageVersion.CSharp8.ToDisplayString()); + return new List(); + } + return results; } @@ -233,6 +244,13 @@ private static bool HasOpenGenerics(ITypeSymbol type, out string genericType) type = ((INamedTypeSymbol)type).TypeArguments[0]; } + // Check first if the type is IEnumerable interface + if (SymbolEqualityComparer.Default.Equals(type.OriginalDefinition, _symbolHolder.GenericIEnumerableSymbol)) + { + return ((INamedTypeSymbol)type).TypeArguments[0]; + } + + // Check first if the type implement IEnumerable interface foreach (var implementingInterface in type.AllInterfaces) { if (SymbolEqualityComparer.Default.Equals(implementingInterface.OriginalDefinition, _compilation.GetSpecialType(SpecialType.System_Collections_Generic_IEnumerable_T))) @@ -273,7 +291,7 @@ private List GetMembersToValidate(ITypeSymbol modelType, bool s ? memberLocation : lowerLocationInCompilation; - var memberInfo = GetMemberInfo(member, speculate, location, validatorType); + var memberInfo = GetMemberInfo(member, speculate, location, modelType, validatorType); if (memberInfo is not null) { if (member.DeclaredAccessibility != Accessibility.Public) @@ -289,7 +307,7 @@ private List GetMembersToValidate(ITypeSymbol modelType, bool s return membersToValidate; } - private ValidatedMember? GetMemberInfo(ISymbol member, bool speculate, Location location, ITypeSymbol validatorType) + private ValidatedMember? GetMemberInfo(ISymbol member, bool speculate, Location location, ITypeSymbol modelType, ITypeSymbol validatorType) { ITypeSymbol memberType; switch (member) @@ -310,7 +328,7 @@ private List GetMembersToValidate(ITypeSymbol modelType, bool s break; */ default: - // we only care about properties and fields + // we only care about properties return null; } @@ -452,17 +470,58 @@ private List GetMembersToValidate(ITypeSymbol modelType, bool s continue; } - var validationAttr = new ValidationAttributeInfo(attributeType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)); + string attributeFullQualifiedName = attributeType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + if (SymbolEqualityComparer.Default.Equals(attributeType, _symbolHolder.MaxLengthAttributeSymbol) || + SymbolEqualityComparer.Default.Equals(attributeType, _symbolHolder.MinLengthAttributeSymbol) || + (_symbolHolder.LengthAttributeSymbol is not null && SymbolEqualityComparer.Default.Equals(attributeType, _symbolHolder.LengthAttributeSymbol))) + { + if (!LengthBasedAttributeIsTrackedForSubstitution(memberType, location, attributeType, ref attributeFullQualifiedName)) + { + continue; + } + } + else if (SymbolEqualityComparer.Default.Equals(attributeType, _symbolHolder.CompareAttributeSymbol)) + { + TrackCompareAttributeForSubstitution(attribute, modelType, ref attributeFullQualifiedName); + } + else if (SymbolEqualityComparer.Default.Equals(attributeType, _symbolHolder.RangeAttributeSymbol)) + { + TrackRangeAttributeForSubstitution(attribute, memberType, ref attributeFullQualifiedName); + } + + var validationAttr = new ValidationAttributeInfo(attributeFullQualifiedName); validationAttrs.Add(validationAttr); - foreach (var constructorArgument in attribute.ConstructorArguments) + ImmutableArray parameters = attribute.AttributeConstructor?.Parameters ?? ImmutableArray.Empty; + bool lastParameterDeclaredWithParamsKeyword = parameters.Length > 0 && parameters[parameters.Length - 1].IsParams; + + ImmutableArray arguments = attribute.ConstructorArguments; + + for (int i = 0; i < arguments.Length; i++) { - validationAttr.ConstructorArguments.Add(GetArgumentExpression(constructorArgument.Type!, constructorArgument.Value)); + TypedConstant argument = arguments[i]; + if (argument.Kind == TypedConstantKind.Array) + { + bool isParams = lastParameterDeclaredWithParamsKeyword && i == arguments.Length - 1; + validationAttr.ConstructorArguments.Add(GetArrayArgumentExpression(argument.Values, isParams)); + } + else + { + validationAttr.ConstructorArguments.Add(GetArgumentExpression(argument.Type!, argument.Value)); + } } foreach (var namedArgument in attribute.NamedArguments) { - validationAttr.Properties.Add(namedArgument.Key, GetArgumentExpression(namedArgument.Value.Type!, namedArgument.Value.Value)); + if (namedArgument.Value.Kind == TypedConstantKind.Array) + { + bool isParams = lastParameterDeclaredWithParamsKeyword && namedArgument.Key == parameters[parameters.Length - 1].Name; + validationAttr.Properties.Add(namedArgument.Key, GetArrayArgumentExpression(namedArgument.Value.Values, isParams)); + } + else + { + validationAttr.Properties.Add(namedArgument.Key, GetArgumentExpression(namedArgument.Value.Type!, namedArgument.Value.Value)); + } } } } @@ -530,6 +589,79 @@ private List GetMembersToValidate(ITypeSymbol modelType, bool s return null; } + private bool LengthBasedAttributeIsTrackedForSubstitution(ITypeSymbol memberType, Location location, ITypeSymbol attributeType, ref string attributeFullQualifiedName) + { + if (memberType.SpecialType == SpecialType.System_String || ConvertTo(memberType, _symbolHolder.ICollectionSymbol)) + { + _optionsSourceGenContext.EnsureTrackingAttribute(attributeType.Name, createValue: false, out _); + } + else if (ParserUtilities.TypeHasProperty(memberType, "Count", SpecialType.System_Int32)) + { + _optionsSourceGenContext.EnsureTrackingAttribute(attributeType.Name, createValue: true, out HashSet? trackedTypeList); + trackedTypeList!.Add(memberType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)); + } + else + { + Diag(DiagDescriptors.IncompatibleWithTypeForValidationAttribute, location, attributeType.Name, memberType.Name); + return false; + } + + attributeFullQualifiedName = $"{Emitter.StaticGeneratedValidationAttributesClassesNamespace}.{Emitter.StaticAttributeClassNamePrefix}{_optionsSourceGenContext.Suffix}_{attributeType.Name}"; + return true; + } + + private void TrackCompareAttributeForSubstitution(AttributeData attribute, ITypeSymbol modelType, ref string attributeFullQualifiedName) + { + ImmutableArray constructorParameters = attribute.AttributeConstructor?.Parameters ?? ImmutableArray.Empty; + if (constructorParameters.Length == 1 && constructorParameters[0].Name == "otherProperty" && constructorParameters[0].Type.SpecialType == SpecialType.System_String) + { + _optionsSourceGenContext.EnsureTrackingAttribute(attribute.AttributeClass!.Name, createValue: true, out HashSet? trackedTypeList); + trackedTypeList!.Add((modelType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), (string)attribute.ConstructorArguments[0].Value!)); + attributeFullQualifiedName = $"{Emitter.StaticGeneratedValidationAttributesClassesNamespace}.{Emitter.StaticAttributeClassNamePrefix}{_optionsSourceGenContext.Suffix}_{attribute.AttributeClass!.Name}"; + } + } + + private void TrackRangeAttributeForSubstitution(AttributeData attribute, ITypeSymbol memberType, ref string attributeFullQualifiedName) + { + ImmutableArray constructorParameters = attribute.AttributeConstructor?.Parameters ?? ImmutableArray.Empty; + SpecialType argumentSpecialType = SpecialType.None; + if (constructorParameters.Length == 2) + { + argumentSpecialType = constructorParameters[0].Type.SpecialType; + } + else if (constructorParameters.Length == 3) + { + object? argumentValue = null; + for (int i = 0; i < constructorParameters.Length; i++) + { + if (constructorParameters[i].Name == "type") + { + argumentValue = attribute.ConstructorArguments[i].Value; + break; + } + } + + if (argumentValue is INamedTypeSymbol namedTypeSymbol && OptionsSourceGenContext.IsConvertibleBasicType(namedTypeSymbol)) + { + argumentSpecialType = namedTypeSymbol.SpecialType; + } + } + + ITypeSymbol typeSymbol = memberType; + if (typeSymbol.OriginalDefinition.SpecialType == SpecialType.System_Nullable_T) + { + typeSymbol = ((INamedTypeSymbol)typeSymbol).TypeArguments[0]; + } + + if (argumentSpecialType != SpecialType.None && + OptionsSourceGenContext.IsConvertibleBasicType(typeSymbol) && + (constructorParameters.Length != 3 || typeSymbol.SpecialType == argumentSpecialType)) // When type is provided as a parameter, it has to match the property type. + { + _optionsSourceGenContext.EnsureTrackingAttribute(attribute.AttributeClass!.Name, createValue: false, out _); + attributeFullQualifiedName = $"{Emitter.StaticGeneratedValidationAttributesClassesNamespace}.{Emitter.StaticAttributeClassNamePrefix}{_optionsSourceGenContext.Suffix}_{attribute.AttributeClass!.Name}"; + } + } + private string? AddSynthesizedValidator(ITypeSymbol modelType, ISymbol member, Location location, ITypeSymbol validatorType) { var mt = modelType.WithNullableAnnotation(NullableAnnotation.None); @@ -623,6 +755,32 @@ private bool CanValidate(ITypeSymbol validatorType, ISymbol modelType) return false; } + private string GetArrayArgumentExpression(ImmutableArray value, bool isParams) + { + var sb = new StringBuilder(); + if (!isParams) + { + sb.Append("new[] { "); + } + + for (int i = 0; i < value.Length; i++) + { + sb.Append(GetArgumentExpression(value[i].Type!, value[i].Value)); + + if (i < value.Length - 1) + { + sb.Append(", "); + } + } + + if (!isParams) + { + sb.Append(" }"); + } + + return sb.ToString(); + } + private string GetArgumentExpression(ITypeSymbol type, object? value) { if (value == null) diff --git a/src/libraries/Microsoft.Extensions.Options/gen/ParserUtilities.cs b/src/libraries/Microsoft.Extensions.Options/gen/ParserUtilities.cs index d79ad4cccb653d..0b63cc90c800ed 100644 --- a/src/libraries/Microsoft.Extensions.Options/gen/ParserUtilities.cs +++ b/src/libraries/Microsoft.Extensions.Options/gen/ParserUtilities.cs @@ -68,6 +68,41 @@ internal static bool ImplementsInterface(this ITypeSymbol type, ITypeSymbol inte return false; } + internal static bool TypeHasProperty(ITypeSymbol typeSymbol, string propertyName, SpecialType returnType) + { + ITypeSymbol? type = typeSymbol; + do + { + if (type.OriginalDefinition.SpecialType == SpecialType.System_Nullable_T) + { + type = ((INamedTypeSymbol)type).TypeArguments[0]; // extract the T from a Nullable + } + + if (type.GetMembers(propertyName).OfType().Any(property => + property.Type.SpecialType == returnType && property.DeclaredAccessibility == Accessibility.Public && + property.Kind == SymbolKind.Property && !property.IsStatic && property.GetMethod != null && property.Parameters.IsEmpty)) + { + return true; + } + + type = type.BaseType; + } while (type is not null && type.SpecialType != SpecialType.System_Object); + + // When we have an interface type, we need to check all the interfaces that it extends. + // Like IList extends ICollection where the property we're looking for is defined. + foreach (var interfaceType in typeSymbol.AllInterfaces) + { + if (interfaceType.GetMembers(propertyName).OfType().Any(property => + property.Type.SpecialType == returnType && property.Kind == SymbolKind.Property && + !property.IsStatic && property.GetMethod != null && property.Parameters.IsEmpty)) + { + return true; + } + } + + return false; + } + // Check if parameter has either simplified (i.e. "int?") or explicit (Nullable) nullable type declaration: internal static bool IsNullableOfT(this ITypeSymbol type) => type.SpecialType == SpecialType.System_Nullable_T || type.OriginalDefinition.SpecialType == SpecialType.System_Nullable_T; diff --git a/src/libraries/Microsoft.Extensions.Options/gen/Resources/Strings.resx b/src/libraries/Microsoft.Extensions.Options/gen/Resources/Strings.resx index ef074f147f915e..7100030eecf132 100644 --- a/src/libraries/Microsoft.Extensions.Options/gen/Resources/Strings.resx +++ b/src/libraries/Microsoft.Extensions.Options/gen/Resources/Strings.resx @@ -207,4 +207,16 @@ Validation attribute on the member is inaccessible from the validator type.. + + C# language version not supported by the source generator. + + + The options validation source generator is not available in C# {0}. Please use language version {1} or greater. + + + The validation attribute is only applicable to properties of type string, array, or ICollection; it cannot be used with other types. + + + The validation attribute {0} should only be applied to properties of type string, array, or ICollection. Using it with the type {1} could lead to runtime failures. + diff --git a/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.cs.xlf b/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.cs.xlf index fadfdec9e3b7ad..c8490d951761cb 100644 --- a/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.cs.xlf +++ b/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.cs.xlf @@ -122,6 +122,16 @@ Pro atributy ValidateObjectMembersAttribute nebo ValidateEnumeratedItemsAttribute byl specifikovaný typ validátoru s hodnotou null. + + The options validation source generator is not available in C# {0}. Please use language version {1} or greater. + Zdrojový generátor ověřování možností není k dispozici v jazyce C#{0}. Použijte prosím jazykovou verzi {1} nebo vyšší. + + + + C# language version not supported by the source generator. + Zdrojový generátor nepodporuje jazykovou verzi jazyka C#. + + Type {0} has validation annotations, but member {1} doesn't specify [ValidateEnumeratedItems] which could be an oversight. Typ {0} obsahuje validační anotace, ale člen {1} neurčuje [ValidateEnumeratedItems], což může být přehlédnutí. @@ -142,6 +152,16 @@ U člena potenciálně chybí přenositelné ověření. + + The validation attribute {0} should only be applied to properties of type string, array, or ICollection. Using it with the type {1} could lead to runtime failures. + Ověřovací atribut {0} by měl být použit pouze na vlastnosti typu string, array nebo ICollection. Použití s typem {1} může vést k selháním modulu runtime. + + + + The validation attribute is only applicable to properties of type string, array, or ICollection; it cannot be used with other types. + Ověřovací atribut se vztahuje pouze na vlastnosti typu string, array nebo ICollection; nelze použít s jinými typy. + + Validator type {0} doesn't have a parameterless constructor. Typ validátoru {0} nemá konstruktor bez parametrů. diff --git a/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.de.xlf b/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.de.xlf index badb9618ac043b..eb7a7f423a3e70 100644 --- a/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.de.xlf +++ b/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.de.xlf @@ -122,6 +122,16 @@ Für die Attribute "ValidateObjectMembersAttribute" oder "ValidateEnumeratedItemsAttribute" wurde ein NULL-Validierungssteuerelementtyp angegeben. + + The options validation source generator is not available in C# {0}. Please use language version {1} or greater. + Der Optionsvalidierungsquellgenerator ist in C# {0} nicht verfügbar. Verwenden Sie die Sprachversion {1} oder höher. + + + + C# language version not supported by the source generator. + Die C#-Sprachversion wird vom Quellgenerator nicht unterstützt. + + Type {0} has validation annotations, but member {1} doesn't specify [ValidateEnumeratedItems] which could be an oversight. Der Typ "{0}" weist Validierungsanmerkungen auf, der Member "{1}" gibt jedoch keine [ValidateEnumeratedItems] an, die eine Vorhersage darstellen könnten. @@ -142,6 +152,16 @@ Dem Member fehlt möglicherweise die transitive Validierung. + + The validation attribute {0} should only be applied to properties of type string, array, or ICollection. Using it with the type {1} could lead to runtime failures. + Das Validierungsattribut {0} sollte nur auf Eigenschaften vom Typ "string", "array" oder "ICollection" angewendet werden. Die Verwendung mit dem Typ {1} kann zu Laufzeitfehlern führen. + + + + The validation attribute is only applicable to properties of type string, array, or ICollection; it cannot be used with other types. + Das Validierungsattribut gilt nur für Eigenschaften vom Typ "string", "array" oder "ICollection"; es kann nicht mit anderen Typen verwendet werden. + + Validator type {0} doesn't have a parameterless constructor. Der Validierungssteuerelementtyp "{0}" hat keinen parameterlosen Konstruktor. diff --git a/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.es.xlf b/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.es.xlf index 9637c745a081e0..19c69bb8c864f9 100644 --- a/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.es.xlf +++ b/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.es.xlf @@ -122,6 +122,16 @@ Se especificó un tipo de validador nulo para los atributos “ValidateObjectMembersAttribute” o “ValidateEnumeratedItemsAttribute”. + + The options validation source generator is not available in C# {0}. Please use language version {1} or greater. + El generador de origen de validación de opciones no está disponible en C# {0}. Use la versión de idioma {1} o superior. + + + + C# language version not supported by the source generator. + La versión del idioma C# no es compatible con el generador de origen. + + Type {0} has validation annotations, but member {1} doesn't specify [ValidateEnumeratedItems] which could be an oversight. El tipo {0} tiene anotaciones de validación, pero el miembro {1} no especifica [ValidateEnumeratedItems], lo que podría ser un error. @@ -142,6 +152,16 @@ Posiblemente falta la validación transitiva en el miembro. + + The validation attribute {0} should only be applied to properties of type string, array, or ICollection. Using it with the type {1} could lead to runtime failures. + El atributo de validación {0} solo se debe aplicar a propiedades de tipo cadena, matriz o ICollection. Si la usa con el tipo {1}, podrían producirse errores en tiempo de ejecución. + + + + The validation attribute is only applicable to properties of type string, array, or ICollection; it cannot be used with other types. + El atributo de validación solo es aplicable a propiedades de tipo cadena, matriz o ICollection; no se puede usar con otros tipos. + + Validator type {0} doesn't have a parameterless constructor. El tipo de validador {0} no tiene un constructor sin parámetros. diff --git a/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.fr.xlf b/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.fr.xlf index 93817f52c7ddf7..651ab8b303719a 100644 --- a/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.fr.xlf +++ b/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.fr.xlf @@ -122,6 +122,16 @@ Type de validateur Null spécifié pour les attributs 'ValidateObjectMembersAttribute' ou 'ValidateEnumeratedItemsAttribute'. + + The options validation source generator is not available in C# {0}. Please use language version {1} or greater. + Le générateur de source de validation d’options n'est pas disponible en C# « {0} ». Veuillez utiliser la version linguistique {1} ou supérieure. + + + + C# language version not supported by the source generator. + Version du langage C# non prise en charge par le générateur de source. + + Type {0} has validation annotations, but member {1} doesn't specify [ValidateEnumeratedItems] which could be an oversight. Le type {0} a des annotations de validation, mais le membre {1} ne spécifie pas [ValidateEnumeratedItems] qui peut être une supervision. @@ -142,6 +152,16 @@ Le membre n’a peut-être pas de validation transitive. + + The validation attribute {0} should only be applied to properties of type string, array, or ICollection. Using it with the type {1} could lead to runtime failures. + L’attribut de validation {0} doit uniquement être appliqué aux propriétés de type chaîne, tableau ou ICollection. Son utilisation avec le type {1} peut entraîner des échecs d’exécution. + + + + The validation attribute is only applicable to properties of type string, array, or ICollection; it cannot be used with other types. + L’attribut de validation s’applique uniquement aux propriétés de type chaîne, tableau ou ICollection ; il ne peut pas être utilisé avec d’autres types. + + Validator type {0} doesn't have a parameterless constructor. Le type de validateur {0} n’a pas de constructeur sans paramètre. diff --git a/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.it.xlf b/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.it.xlf index b80d8e27bcec22..575f60469e48dc 100644 --- a/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.it.xlf +++ b/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.it.xlf @@ -122,6 +122,16 @@ Tipo di validator Null specificato per gli attributi 'ValidateObjectMembersAttribute' o 'ValidateEnumeratedItemsAttribute'. + + The options validation source generator is not available in C# {0}. Please use language version {1} or greater. + Il generatore di origine di convalida delle opzioni non è disponibile in C# {0}. Usare la versione del linguaggio {1} o successiva. + + + + C# language version not supported by the source generator. + Versione del linguaggio C# non supportata dal generatore di origine. + + Type {0} has validation annotations, but member {1} doesn't specify [ValidateEnumeratedItems] which could be an oversight. Il tipo {0} include annotazioni di convalida, ma il membro {1} non specifica [ValidateEnumeratedItems] che potrebbe essere una supervisione. @@ -142,6 +152,16 @@ Il membro potrebbe non avere una convalida transitiva. + + The validation attribute {0} should only be applied to properties of type string, array, or ICollection. Using it with the type {1} could lead to runtime failures. + L'attributo {0} di convalida deve essere applicato solo alle proprietà di tipo stringa, matrice o ICollection. L'uso con il tipo {1} potrebbe causare errori di runtime. + + + + The validation attribute is only applicable to properties of type string, array, or ICollection; it cannot be used with other types. + L'attributo di convalida è applicabile solo alle proprietà di tipo stringa, matrice o ICollection; non può essere usato con altri tipi. + + Validator type {0} doesn't have a parameterless constructor. Il tipo di convalida {0} non dispone di un costruttore senza parametri. diff --git a/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.ja.xlf b/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.ja.xlf index b9b44dd4afd5e5..8251ac27c6da15 100644 --- a/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.ja.xlf +++ b/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.ja.xlf @@ -122,6 +122,16 @@ 'ValidateObjectMembersAttribute' 属性または 'ValidateEnumeratedItemsAttribute' 属性に対して NULL のバリデーター型が指定されています。 + + The options validation source generator is not available in C# {0}. Please use language version {1} or greater. + オプション検証ソース ジェネレーターは、C# {0} では使用できません。言語バージョン {1} 以降を使用してください。 + + + + C# language version not supported by the source generator. + ソース ジェネレーターでサポートされていない C# 言語バージョン。 + + Type {0} has validation annotations, but member {1} doesn't specify [ValidateEnumeratedItems] which could be an oversight. 型 {0} には検証の注釈がありますが、メンバー {1} では [ValidateEnumeratedItem] が指定されていません。これは誤りである可能性があります。 @@ -142,6 +152,16 @@ メンバーに推移性の検証がない可能性があります。 + + The validation attribute {0} should only be applied to properties of type string, array, or ICollection. Using it with the type {1} could lead to runtime failures. + 検証属性 {0} は、型文字列、配列、または ICollection のプロパティにしか適用できません。型 {1} と共に使用すると、ランタイム エラーが発生する可能性があります。 + + + + The validation attribute is only applicable to properties of type string, array, or ICollection; it cannot be used with other types. + 検証属性は、型文字列、配列、または ICollection のプロパティにのみ適用でき、他の型では使用できません。 + + Validator type {0} doesn't have a parameterless constructor. バリデーター型 {0} にパラメーターなしのコンストラクターがありません。 diff --git a/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.ko.xlf b/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.ko.xlf index 45819c9f750a5a..4c8ce2865d1070 100644 --- a/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.ko.xlf +++ b/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.ko.xlf @@ -122,6 +122,16 @@ 'ValidateObjectMembersAttribute' 또는 'ValidateEnumeratedItemsAttribute' 특성에 Null 유효성 검사기 형식이 지정되었습니다. + + The options validation source generator is not available in C# {0}. Please use language version {1} or greater. + 옵션 유효성 검사 원본 생성기는 C# '{0}'에서 사용할 수 없습니다. {1} 이상의 언어 버전을 사용하세요. + + + + C# language version not supported by the source generator. + 원본 생성기에서 지원되지 않는 C# 언어 버전입니다. + + Type {0} has validation annotations, but member {1} doesn't specify [ValidateEnumeratedItems] which could be an oversight. 형식 {0}은(는) 유효성 검사 주석이 있지만 멤버 {1}은(는) 참조할 수 있는 [ValidateEnumeratedItems]을(를) 지정하지 않습니다. @@ -142,6 +152,16 @@ 멤버에 전이적 유효성 검사가 누락되었을 수 있습니다. + + The validation attribute {0} should only be applied to properties of type string, array, or ICollection. Using it with the type {1} could lead to runtime failures. + 유효성 검사 특성 {0}은(는) 문자열, 배열 또는 ICollection 형식의 속성에만 적용해야 합니다. {1} 형식과 함께 사용하면 런타임 오류가 발생할 수 있습니다. + + + + The validation attribute is only applicable to properties of type string, array, or ICollection; it cannot be used with other types. + 유효성 검사 특성은 문자열, 배열 또는 ICollection 형식의 속성에만 적용할 수 있습니다. 다른 형식과 함께 사용할 수 없습니다. + + Validator type {0} doesn't have a parameterless constructor. 유효성 검사기 형식 {0}은(는) 매개 변수가 없는 생성자가 없습니다. diff --git a/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.pl.xlf b/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.pl.xlf index 0b57a0ac378924..e07e568d131874 100644 --- a/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.pl.xlf +++ b/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.pl.xlf @@ -122,6 +122,16 @@ Określono typ modułu sprawdzania poprawności o wartości null dla atrybutu „ValidateObjectMembersAttribute” lub „ValidateEnumeratedItemsAttribute”. + + The options validation source generator is not available in C# {0}. Please use language version {1} or greater. + Generator źródła sprawdzania poprawności opcji nie jest dostępny w języku C# {0}. Użyj wersji językowej {1} lub nowszej. + + + + C# language version not supported by the source generator. + Wersja języka C# nie jest obsługiwana przez generator źródła. + + Type {0} has validation annotations, but member {1} doesn't specify [ValidateEnumeratedItems] which could be an oversight. Typ {0} ma adnotacje walidacji, ale element członkowski {1} nie określa atrybutu [ValidateEnumeratedItems], co może być przeoczeniem. @@ -142,6 +152,16 @@ W przypadku elementu członkowskiego może potencjalnie brakować weryfikacji przechodniej. + + The validation attribute {0} should only be applied to properties of type string, array, or ICollection. Using it with the type {1} could lead to runtime failures. + Atrybut sprawdzania poprawności {0} powinien być stosowany tylko do właściwości typu ciąg, tablica lub ICollection. Użycie go z typem {1} może prowadzić do błędów środowiska uruchomieniowego. + + + + The validation attribute is only applicable to properties of type string, array, or ICollection; it cannot be used with other types. + Atrybut sprawdzania poprawności ma zastosowanie tylko do właściwości typu ciąg, tablica lub ICollection; nie można go używać z innymi typami. + + Validator type {0} doesn't have a parameterless constructor. Typ modułu sprawdzania poprawności {0} nie ma konstruktora bez parametrów. diff --git a/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.pt-BR.xlf b/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.pt-BR.xlf index 4afa1d6763db8b..d422309a40995b 100644 --- a/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.pt-BR.xlf +++ b/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.pt-BR.xlf @@ -122,6 +122,16 @@ Tipo de validador nulo especificado para os atributos "ValidateObjectMembersAttribute" ou "ValidateEnumeratedItemsAttribute". + + The options validation source generator is not available in C# {0}. Please use language version {1} or greater. + O gerador de fonte de validação de opções não está disponível em C# {0}. Use a versão do idioma {1} ou superior. + + + + C# language version not supported by the source generator. + Versão da linguagem C# não suportada pelo gerador de origem. + + Type {0} has validation annotations, but member {1} doesn't specify [ValidateEnumeratedItems] which could be an oversight. O tipo {0} tem anotações de validação, mas o membro {1} não especifica [ValidateEnumeratedItems], o que pode ser uma desatenção. @@ -142,6 +152,16 @@ Membro potencialmente ausente na validação transitiva. + + The validation attribute {0} should only be applied to properties of type string, array, or ICollection. Using it with the type {1} could lead to runtime failures. + O atributo de validação {0} deve ser aplicado somente às propriedades do tipo cadeia de caracteres, matriz ou ICollection. Usá-lo com o tipo {1} pode levar a falhas de runtime. + + + + The validation attribute is only applicable to properties of type string, array, or ICollection; it cannot be used with other types. + O atributo de validação só é aplicável às propriedades do tipo cadeia de caracteres, matriz ou ICollection; não pode ser usado com outros tipos. + + Validator type {0} doesn't have a parameterless constructor. O tipo de validador {0} não tem um construtor sem parâmetros. diff --git a/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.ru.xlf b/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.ru.xlf index 02fb57d5497df9..bd3a0be55eb270 100644 --- a/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.ru.xlf +++ b/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.ru.xlf @@ -122,6 +122,16 @@ Для атрибутов ValidateObjectMembersAttribute или ValidateEnumeratedItemsAttribute указан тип проверяющего элемента управления NULL. + + The options validation source generator is not available in C# {0}. Please use language version {1} or greater. + Генератор исходного кода проверки параметров недоступен в C# {0}. Используйте языковую версию {1} или более позднюю. + + + + C# language version not supported by the source generator. + Версия языка C# не поддерживается генератором исходного кода. + + Type {0} has validation annotations, but member {1} doesn't specify [ValidateEnumeratedItems] which could be an oversight. Тип {0} содержит примечания проверки, но элемент {1} не указывает [ValidateEnumeratedItems], который может быть задан. @@ -142,6 +152,16 @@ Возможно, в элементе отсутствует транзитивная проверка. + + The validation attribute {0} should only be applied to properties of type string, array, or ICollection. Using it with the type {1} could lead to runtime failures. + Атрибут проверки {0} следует применять только к свойствам строки типа, массива или ICollection. Использование его с типом {1} может привести к сбоям во время выполнения. + + + + The validation attribute is only applicable to properties of type string, array, or ICollection; it cannot be used with other types. + Атрибут проверки применим только к свойствам строки типа, массива или ICollection; его нельзя использовать с другими типами. + + Validator type {0} doesn't have a parameterless constructor. Тип проверяющего элемента управления {0} не имеет конструктора без параметров. diff --git a/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.tr.xlf b/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.tr.xlf index 6743baba13c804..e478a77cd3c238 100644 --- a/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.tr.xlf +++ b/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.tr.xlf @@ -122,6 +122,16 @@ 'ValidateObjectMembersAttribute' veya 'ValidateEnumeratedItemsAttribute' öznitelikleri için null doğrulayıcı türü belirtildi. + + The options validation source generator is not available in C# {0}. Please use language version {1} or greater. + Seçenek doğrulaması kaynak oluşturucusu C# {0} sürümünde kullanılamıyor. Lütfen dil sürümü {1} veya üstü bir sürümü kullanın. + + + + C# language version not supported by the source generator. + C# dil sürümü kaynak oluşturucu tarafından desteklenmiyor. + + Type {0} has validation annotations, but member {1} doesn't specify [ValidateEnumeratedItems] which could be an oversight. {0} türünde doğrulama ek açıklamaları var, ancak {1} üyesi gözden kaçmış olabilecek [ValidateEnumeratedItems] belirtmiyor. @@ -142,6 +152,16 @@ Üyede geçişli doğrulama eksik olabilir. + + The validation attribute {0} should only be applied to properties of type string, array, or ICollection. Using it with the type {1} could lead to runtime failures. + Doğrulama özniteliği {0} yalnızca string, dizi veya ICollection türündeki özelliklere uygulanmalıdır. {1} türüyle kullanılması çalışma zamanı hatalarına neden olabilir. + + + + The validation attribute is only applicable to properties of type string, array, or ICollection; it cannot be used with other types. + Doğrulama özniteliği yalnızca string, dizi veya ICollection türündeki özelliklere uygulanabilir; diğer türlerle kullanılamaz. + + Validator type {0} doesn't have a parameterless constructor. {0} doğrulayıcı türü parametresiz bir oluşturucuya sahip değil. diff --git a/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.zh-Hans.xlf b/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.zh-Hans.xlf index cd68a3a773332b..76a26db32a673c 100644 --- a/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.zh-Hans.xlf +++ b/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.zh-Hans.xlf @@ -122,6 +122,16 @@ 为“ValidateObjectMembersAttribute”或“ValidateEnumeratedItemsAttribute”属性指定的验证程序类型为 Null。 + + The options validation source generator is not available in C# {0}. Please use language version {1} or greater. + 选项验证源生成器在 C#“{0}”中不可用。请使用{1}或更高版本的语言版本。 + + + + C# language version not supported by the source generator. + 源生成器不支持 C# 语言版本。 + + Type {0} has validation annotations, but member {1} doesn't specify [ValidateEnumeratedItems] which could be an oversight. 类型 {0} 具有验证注释,但成员 {1} 未指定 [ValidateEnumeratedItems],这可能是一种监督。 @@ -142,6 +152,16 @@ 成员可能缺少可传递验证。 + + The validation attribute {0} should only be applied to properties of type string, array, or ICollection. Using it with the type {1} could lead to runtime failures. + 验证特性 {0} 只能应用于字符串、数组或 ICollection 类型的属性。将它与 {1} 类型一起使用可能会导致运行时故障。 + + + + The validation attribute is only applicable to properties of type string, array, or ICollection; it cannot be used with other types. + 验证特性仅适用于字符串、数组或 ICollection 类型的属性;它不能与其他类型一起使用。 + + Validator type {0} doesn't have a parameterless constructor. 验证程序类型 {0} 没有无参数构造函数。 diff --git a/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.zh-Hant.xlf b/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.zh-Hant.xlf index f21a748ed078e7..9997092e3be6f3 100644 --- a/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.zh-Hant.xlf +++ b/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.zh-Hant.xlf @@ -122,6 +122,16 @@ 為 'ValidateObjectMembersAttribute' 或 'ValidateEnumeratedItemsAttribute' 屬性指定的 Null 驗證程式類型。 + + The options validation source generator is not available in C# {0}. Please use language version {1} or greater. + 選項驗證來源產生器在 C# {0} 中不可用。請使用語言版本 {1} 或更新版本。 + + + + C# language version not supported by the source generator. + 來源產生器不支援 C# 語言版本。 + + Type {0} has validation annotations, but member {1} doesn't specify [ValidateEnumeratedItems] which could be an oversight. 成員 {0} 具備驗證註釋,但成員 {1} 未指定 [ValidateObjectMembers] (可能為監督)。 @@ -142,6 +152,16 @@ 成員可能遺漏轉移的驗證。 + + The validation attribute {0} should only be applied to properties of type string, array, or ICollection. Using it with the type {1} could lead to runtime failures. + 驗證屬性 {0} 只能套用至類型字串、陣列或 ICollection 的屬性。搭配 {1} 類型使用可能會導致執行階段失敗。 + + + + The validation attribute is only applicable to properties of type string, array, or ICollection; it cannot be used with other types. + 驗證屬性只適用於類型字串、陣列或 ICollection 的屬性;無法與其他類型搭配使用。 + + Validator type {0} doesn't have a parameterless constructor. 驗證程式類型 {0} 沒有無參數建構函式。 diff --git a/src/libraries/Microsoft.Extensions.Options/gen/SymbolHolder.cs b/src/libraries/Microsoft.Extensions.Options/gen/SymbolHolder.cs index c78106e1bc4a27..3447a07d398305 100644 --- a/src/libraries/Microsoft.Extensions.Options/gen/SymbolHolder.cs +++ b/src/libraries/Microsoft.Extensions.Options/gen/SymbolHolder.cs @@ -11,10 +11,18 @@ namespace Microsoft.Extensions.Options.Generators internal sealed record class SymbolHolder( INamedTypeSymbol OptionsValidatorSymbol, INamedTypeSymbol ValidationAttributeSymbol, + INamedTypeSymbol MaxLengthAttributeSymbol, + INamedTypeSymbol MinLengthAttributeSymbol, + INamedTypeSymbol CompareAttributeSymbol, + INamedTypeSymbol? LengthAttributeSymbol, + INamedTypeSymbol? UnconditionalSuppressMessageAttributeSymbol, + INamedTypeSymbol RangeAttributeSymbol, + INamedTypeSymbol ICollectionSymbol, INamedTypeSymbol DataTypeAttributeSymbol, INamedTypeSymbol ValidateOptionsSymbol, INamedTypeSymbol IValidatableObjectSymbol, + INamedTypeSymbol GenericIEnumerableSymbol, INamedTypeSymbol TypeSymbol, - INamedTypeSymbol? ValidateObjectMembersAttributeSymbol, - INamedTypeSymbol? ValidateEnumeratedItemsAttributeSymbol); + INamedTypeSymbol ValidateObjectMembersAttributeSymbol, + INamedTypeSymbol ValidateEnumeratedItemsAttributeSymbol); } diff --git a/src/libraries/Microsoft.Extensions.Options/gen/SymbolLoader.cs b/src/libraries/Microsoft.Extensions.Options/gen/SymbolLoader.cs index 6f805e91a05858..ea556228929756 100644 --- a/src/libraries/Microsoft.Extensions.Options/gen/SymbolLoader.cs +++ b/src/libraries/Microsoft.Extensions.Options/gen/SymbolLoader.cs @@ -9,41 +9,69 @@ internal static class SymbolLoader { public const string OptionsValidatorAttribute = "Microsoft.Extensions.Options.OptionsValidatorAttribute"; internal const string ValidationAttribute = "System.ComponentModel.DataAnnotations.ValidationAttribute"; + internal const string MaxLengthAttribute = "System.ComponentModel.DataAnnotations.MaxLengthAttribute"; + internal const string MinLengthAttribute = "System.ComponentModel.DataAnnotations.MinLengthAttribute"; + internal const string CompareAttribute = "System.ComponentModel.DataAnnotations.CompareAttribute"; + internal const string LengthAttribute = "System.ComponentModel.DataAnnotations.LengthAttribute"; + internal const string RangeAttribute = "System.ComponentModel.DataAnnotations.RangeAttribute"; + internal const string ICollectionType = "System.Collections.ICollection"; internal const string DataTypeAttribute = "System.ComponentModel.DataAnnotations.DataTypeAttribute"; internal const string IValidatableObjectType = "System.ComponentModel.DataAnnotations.IValidatableObject"; internal const string IValidateOptionsType = "Microsoft.Extensions.Options.IValidateOptions`1"; internal const string TypeOfType = "System.Type"; internal const string ValidateObjectMembersAttribute = "Microsoft.Extensions.Options.ValidateObjectMembersAttribute"; internal const string ValidateEnumeratedItemsAttribute = "Microsoft.Extensions.Options.ValidateEnumeratedItemsAttribute"; + internal const string GenericIEnumerableType = "System.Collections.Generic.IEnumerable`1"; + internal const string UnconditionalSuppressMessageAttributeType = "System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessageAttribute"; public static bool TryLoad(Compilation compilation, out SymbolHolder? symbolHolder) { - INamedTypeSymbol? GetSymbol(string metadataName, bool optional = false) - { - var symbol = compilation.GetTypeByMetadataName(metadataName); - if (symbol == null && !optional) - { - return null; - } - - return symbol; - } + INamedTypeSymbol? GetSymbol(string metadataName) => compilation.GetTypeByMetadataName(metadataName); // required var optionsValidatorSymbol = GetSymbol(OptionsValidatorAttribute); var validationAttributeSymbol = GetSymbol(ValidationAttribute); + var maxLengthAttributeSymbol = GetSymbol(MaxLengthAttribute); + var minLengthAttributeSymbol = GetSymbol(MinLengthAttribute); + var compareAttributeSymbol = GetSymbol(CompareAttribute); + var lengthAttributeSymbol = GetSymbol(LengthAttribute); + var rangeAttributeSymbol = GetSymbol(RangeAttribute); + var iCollectionSymbol = GetSymbol(ICollectionType); var dataTypeAttributeSymbol = GetSymbol(DataTypeAttribute); var ivalidatableObjectSymbol = GetSymbol(IValidatableObjectType); var validateOptionsSymbol = GetSymbol(IValidateOptionsType); + var genericIEnumerableSymbol = GetSymbol(GenericIEnumerableType); var typeSymbol = GetSymbol(TypeOfType); + var validateObjectMembersAttribute = GetSymbol(ValidateObjectMembersAttribute); + var validateEnumeratedItemsAttribute = GetSymbol(ValidateEnumeratedItemsAttribute); + var unconditionalSuppressMessageAttributeSymbol = GetSymbol(UnconditionalSuppressMessageAttributeType); + if (unconditionalSuppressMessageAttributeSymbol is not null) + { + var containingAssemblyName = unconditionalSuppressMessageAttributeSymbol.ContainingAssembly.Identity.Name; + if (!containingAssemblyName.Equals("System.Private.CoreLib", System.StringComparison.OrdinalIgnoreCase) && + !containingAssemblyName.Equals("System.Runtime", System.StringComparison.OrdinalIgnoreCase)) + { + // The compilation returns UnconditionalSuppressMessageAttribute symbol even if the attribute is not available like the case when running on .NET Framework. + // We need to make sure that the attribute is really available by checking the containing assembly which in .NET Core will be either System.Private.CoreLib or System.Runtime. + unconditionalSuppressMessageAttributeSymbol = null; + } + } #pragma warning disable S1067 // Expressions should not be too complex if (optionsValidatorSymbol == null || validationAttributeSymbol == null || + maxLengthAttributeSymbol == null || + minLengthAttributeSymbol == null || + compareAttributeSymbol == null || + rangeAttributeSymbol == null || + iCollectionSymbol == null || dataTypeAttributeSymbol == null || ivalidatableObjectSymbol == null || validateOptionsSymbol == null || - typeSymbol == null) + genericIEnumerableSymbol == null || + typeSymbol == null || + validateObjectMembersAttribute == null || + validateEnumeratedItemsAttribute == null) { symbolHolder = default; return false; @@ -53,14 +81,20 @@ public static bool TryLoad(Compilation compilation, out SymbolHolder? symbolHold symbolHolder = new( optionsValidatorSymbol, validationAttributeSymbol, + maxLengthAttributeSymbol, + minLengthAttributeSymbol, + compareAttributeSymbol, + lengthAttributeSymbol, + unconditionalSuppressMessageAttributeSymbol, + rangeAttributeSymbol, + iCollectionSymbol, dataTypeAttributeSymbol, validateOptionsSymbol, ivalidatableObjectSymbol, + genericIEnumerableSymbol, typeSymbol, - - // optional - GetSymbol(ValidateObjectMembersAttribute, optional: true), - GetSymbol(ValidateEnumeratedItemsAttribute, optional: true)); + validateObjectMembersAttribute, + validateEnumeratedItemsAttribute); return true; } diff --git a/src/libraries/Microsoft.Extensions.Options/src/Microsoft.Extensions.Options.csproj b/src/libraries/Microsoft.Extensions.Options/src/Microsoft.Extensions.Options.csproj index 7225edf84ba537..c7ea3e00049e3e 100644 --- a/src/libraries/Microsoft.Extensions.Options/src/Microsoft.Extensions.Options.csproj +++ b/src/libraries/Microsoft.Extensions.Options/src/Microsoft.Extensions.Options.csproj @@ -28,12 +28,13 @@ - + diff --git a/src/libraries/Microsoft.Extensions.Options/src/PACKAGE.md b/src/libraries/Microsoft.Extensions.Options/src/PACKAGE.md new file mode 100644 index 00000000000000..ee0cc2ab6f99ee --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Options/src/PACKAGE.md @@ -0,0 +1,170 @@ +## About +`Microsoft.Extensions.Options` provides a strongly typed way of specifying and accessing settings using dependency injection and acts as a bridge between configuration, DI, and higher level libraries. This library is the glue for how an app developer uses DI to configure the behavior of a library like HttpClient Factory. This also enables user to get a strongly-typed view of their configuration. + +Within this package, you'll find an options validation source generator that generates exceptionally efficient and optimized code for validating options. + +## Key Features + +* Offer the IValidateOptions interface for the validation of options, along with several generic ValidateOptions classes that implement this interface. +* OptionsBuilder to configure options. +* Provide extension methods for service collections and options builder to register options and validate options. +* Supply a set of generic ConfigureNamedOptions classes that implement the IConfigureNamedOptions interface for configuring named options. +* Provide a source generator that generates validation code for options. +* Options caching, managing and monitoring. + +## How to Use + +#### Options validation example + +```C# +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; + +var builder = WebApplication.CreateBuilder(args); + +builder.Services.AddControllersWithViews(); + +// Load the configuration and validate it +builder.Services.AddOptions() + .Bind(builder.Configuration.GetSection(MyConfigOptions.MyConfig)) + .ValidateDataAnnotations(); +var app = builder.Build(); + + +// Declare the option class to validate +public class MyConfigOptions +{ + public const string MyConfig = "MyConfig"; + + [RegularExpression(@"^[a-zA-Z''-'\s]{1,40}$")] + public string Key1 { get; set; } + [Range(0, 1000, + ErrorMessage = "Value for {0} must be between {1} and {2}.")] + public int Key2 { get; set; } + public int Key3 { get; set; } +} +``` + +#### Using IValidateOptions to validate options + +```C# +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; + +var builder = WebApplication.CreateBuilder(args); + +builder.Services.AddControllersWithViews(); + +// Configuration to validate +builder.Services.Configure(builder.Configuration.GetSection( + MyConfigOptions.MyConfig)); + +// OPtions validation through the DI container +builder.Services.AddSingleton, MyConfigValidation>(); + +var app = builder.Build(); + +public class MyConfigValidation : IValidateOptions +{ + public MyConfigOptions _config { get; private set; } + + public MyConfigValidation(IConfiguration config) + { + _config = config.GetSection(MyConfigOptions.MyConfig) + .Get(); + } + + public ValidateOptionsResult Validate(string name, MyConfigOptions options) + { + string? vor = null; + var rx = new Regex(@"^[a-zA-Z''-'\s]{1,40}$"); + var match = rx.Match(options.Key1!); + + if (string.IsNullOrEmpty(match.Value)) + { + vor = $"{options.Key1} doesn't match RegEx \n"; + } + + if ( options.Key2 < 0 || options.Key2 > 1000) + { + vor = $"{options.Key2} doesn't match Range 0 - 1000 \n"; + } + + if (_config.Key2 != default) + { + if(_config.Key3 <= _config.Key2) + { + vor += "Key3 must be > than Key2."; + } + } + + if (vor != null) + { + return ValidateOptionsResult.Fail(vor); + } + + return ValidateOptionsResult.Success; + } +} + +``` + +#### Options Validation Source Generator Example + +```C# +using System; +using System.ComponentModel.DataAnnotations; +using Microsoft.Extensions.Options; + +public class MyConfigOptions +{ + [RegularExpression(@"^[a-zA-Z''-'\s]{1,40}$")] + public string Key1 { get; set; } + + [Range(0, 1000, + ErrorMessage = "Value for {0} must be between {1} and {2}.")] + public int Key2 { get; set; } + public int Key3 { get; set; } +} + +[OptionsValidator] +public partial class MyConfigValidation : IValidateOptions +{ + // Source generator will automatically provide the implementation of IValidateOptions + // Then you can add the validation to the DI Container using the following code: + // + // builder.Services.AddSingleton, MyConfigValidation>(); + // builder.Services.AddOptions() + // .Bind(builder.Configuration.GetSection(MyConfigOptions.MyConfig)) + // .ValidateDataAnnotations(); +} + +``` + +## Main Types + +The main types provided by this library are: + +* `IOptions`, `IOptionsFactory`, and `IOptionsMonitor` +* `IValidateOptions` and `ValidateOptions` +* `OptionsBuilder`, `OptionsFactory`, `OptionsMonitor`, and `OptionsManager` +* `OptionsServiceCollectionExtensions` +* `OptionsValidatorAttribute` + +## Additional Documentation + +* [Conceptual documentation](https://learn.microsoft.com/aspnet/core/fundamentals/configuration/options) +* [API documentation](https://learn.microsoft.com/dotnet/api/microsoft.extensions.options) + +## Related Packages + +[Microsoft.Extensions.Logging](https://www.nuget.org/packages/Microsoft.Extensions.Logging) +[Microsoft.Extensions.Configuration](https://www.nuget.org/packages/Microsoft.Extensions.Configuration) + +## Feedback & Contributing + +Microsoft.Extensions.Options is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). \ No newline at end of file diff --git a/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Baselines/DataAnnotationAttributesWithParams.g.cs b/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Baselines/DataAnnotationAttributesWithParams.g.cs new file mode 100644 index 00000000000000..c51c551222f42f --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Baselines/DataAnnotationAttributesWithParams.g.cs @@ -0,0 +1,135 @@ + + // + #nullable enable + #pragma warning disable CS1591 // Compensate for https://github.com/dotnet/roslyn/issues/54103 + namespace Test +{ + partial class MyOptionsValidator + { + /// + /// Validates a specific named options instance (or all when is ). + /// + /// The name of the options instance being validated. + /// The options instance. + /// Validation result. + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] + public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::Test.MyOptions options) + { + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; + var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); + var validationResults = new global::System.Collections.Generic.List(); + var validationAttributes = new global::System.Collections.Generic.List(1); + + context.MemberName = "P1"; + context.DisplayName = string.IsNullOrEmpty(name) ? "MyOptions.P1" : $"{name}.P1"; + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P2"; + context.DisplayName = string.IsNullOrEmpty(name) ? "MyOptions.P2" : $"{name}.P2"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P2, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P3"; + context.DisplayName = string.IsNullOrEmpty(name) ? "MyOptions.P3" : $"{name}.P3"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A3); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P3, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P4"; + context.DisplayName = string.IsNullOrEmpty(name) ? "MyOptions.P4" : $"{name}.P4"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A4); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P4, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); + } + } +} +namespace __OptionValidationStaticInstances +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + file static class __Attributes + { + internal static readonly global::System.ComponentModel.DataAnnotations.RequiredAttribute A1 = new global::System.ComponentModel.DataAnnotations.RequiredAttribute(); + + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__LengthAttribute A2 = new __OptionValidationGeneratedAttributes.__SourceGen__LengthAttribute( + (int)10, + (int)20); + + internal static readonly global::System.ComponentModel.DataAnnotations.AllowedValuesAttribute A3 = new global::System.ComponentModel.DataAnnotations.AllowedValuesAttribute( + (int)10, (int)20, (int)30); + + internal static readonly global::System.ComponentModel.DataAnnotations.DeniedValuesAttribute A4 = new global::System.ComponentModel.DataAnnotations.DeniedValuesAttribute( + "One", "Ten", "Hundred"); + } +} +namespace __OptionValidationStaticInstances +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + file static class __Validators + { + } +} +namespace __OptionValidationGeneratedAttributes +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + file class __SourceGen__LengthAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + private static string DefaultErrorMessageString => "The field {0} must be a string or collection type with a minimum length of '{1}' and maximum length of '{2}'."; + public __SourceGen__LengthAttribute(int minimumLength, int maximumLength) : base(() => DefaultErrorMessageString) { MinimumLength = minimumLength; MaximumLength = maximumLength; } + public int MinimumLength { get; } + public int MaximumLength { get; } + public override bool IsValid(object? value) + { + if (MinimumLength < 0) + { + throw new global::System.InvalidOperationException("LengthAttribute must have a MinimumLength value that is zero or greater."); + } + if (MaximumLength < MinimumLength) + { + throw new global::System.InvalidOperationException("LengthAttribute must have a MaximumLength value that is greater than or equal to MinimumLength."); + } + if (value == null) + { + return true; + } + + int length; + if (value is string stringValue) + { + length = stringValue.Length; + } + else if (value is System.Collections.ICollection collectionValue) + { + length = collectionValue.Count; + } + else + { + throw new global::System.InvalidCastException($"The field of type {value.GetType()} must be a string, array, or ICollection type."); + } + + return (uint)(length - MinimumLength) <= (uint)(MaximumLength - MinimumLength); + } + public override string FormatErrorMessage(string name) => string.Format(global::System.Globalization.CultureInfo.CurrentCulture, ErrorMessageString, name, MinimumLength, MaximumLength); + } +} diff --git a/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Baselines/EmitterWithCustomValidator.netcore.g.cs b/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Baselines/EmitterWithCustomValidator.netcore.g.cs new file mode 100644 index 00000000000000..2c5af12c5b5f24 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Baselines/EmitterWithCustomValidator.netcore.g.cs @@ -0,0 +1,175 @@ + + // + #nullable enable + #pragma warning disable CS1591 // Compensate for https://github.com/dotnet/roslyn/issues/54103 + namespace HelloWorld +{ + partial struct MyOptionsValidator + { + /// + /// Validates a specific named options instance (or all when is ). + /// + /// The name of the options instance being validated. + /// The options instance. + /// Validation result. + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] + public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::HelloWorld.MyOptions options) + { + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; + var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); + var validationResults = new global::System.Collections.Generic.List(); + var validationAttributes = new global::System.Collections.Generic.List(1); + + context.MemberName = "Val1"; + context.DisplayName = string.IsNullOrEmpty(name) ? "MyOptions.Val1" : $"{name}.Val1"; + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val1, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "Val2"; + context.DisplayName = string.IsNullOrEmpty(name) ? "MyOptions.Val2" : $"{name}.Val2"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val2, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); + } + } +} +namespace __OptionValidationStaticInstances +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + file static class __Attributes + { + internal static readonly global::System.ComponentModel.DataAnnotations.RequiredAttribute A1 = new global::System.ComponentModel.DataAnnotations.RequiredAttribute(); + + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute A2 = new __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute( + (int)1, + (int)3); + } +} +namespace __OptionValidationStaticInstances +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + file static class __Validators + { + } +} +namespace __OptionValidationGeneratedAttributes +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + file class __SourceGen__RangeAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + public __SourceGen__RangeAttribute(int minimum, int maximum) : base() + { + Minimum = minimum; + Maximum = maximum; + OperandType = typeof(int); + } + public __SourceGen__RangeAttribute(double minimum, double maximum) : base() + { + Minimum = minimum; + Maximum = maximum; + OperandType = typeof(double); + } + public __SourceGen__RangeAttribute(global::System.Type type, string minimum, string maximum) : base() + { + OperandType = type; + NeedToConvertMinMax = true; + Minimum = minimum; + Maximum = maximum; + } + public object Minimum { get; private set; } + public object Maximum { get; private set; } + public bool MinimumIsExclusive { get; set; } + public bool MaximumIsExclusive { get; set; } + public global::System.Type OperandType { get; } + public bool ParseLimitsInInvariantCulture { get; set; } + public bool ConvertValueInInvariantCulture { get; set; } + public override string FormatErrorMessage(string name) => + string.Format(global::System.Globalization.CultureInfo.CurrentCulture, GetValidationErrorMessage(), name, Minimum, Maximum); + private bool NeedToConvertMinMax { get; } + private bool Initialized { get; set; } + public override bool IsValid(object? value) + { + if (!Initialized) + { + if (Minimum is null || Maximum is null) + { + throw new global::System.InvalidOperationException("The minimum and maximum values must be set to valid values."); + } + if (NeedToConvertMinMax) + { + System.Globalization.CultureInfo culture = ParseLimitsInInvariantCulture ? global::System.Globalization.CultureInfo.InvariantCulture : global::System.Globalization.CultureInfo.CurrentCulture; + Minimum = ConvertValue(Minimum, culture) ?? throw new global::System.InvalidOperationException("The minimum and maximum values must be set to valid values."); + Maximum = ConvertValue(Maximum, culture) ?? throw new global::System.InvalidOperationException("The minimum and maximum values must be set to valid values."); + } + int cmp = ((global::System.IComparable)Minimum).CompareTo((global::System.IComparable)Maximum); + if (cmp > 0) + { + throw new global::System.InvalidOperationException("The maximum value '{Maximum}' must be greater than or equal to the minimum value '{Minimum}'."); + } + else if (cmp == 0 && (MinimumIsExclusive || MaximumIsExclusive)) + { + throw new global::System.InvalidOperationException("Cannot use exclusive bounds when the maximum value is equal to the minimum value."); + } + Initialized = true; + } + + if (value is null or string { Length: 0 }) + { + return true; + } + + System.Globalization.CultureInfo formatProvider = ConvertValueInInvariantCulture ? global::System.Globalization.CultureInfo.InvariantCulture : global::System.Globalization.CultureInfo.CurrentCulture; + object? convertedValue; + + try + { + convertedValue = ConvertValue(value, formatProvider); + } + catch (global::System.Exception e) when (e is global::System.FormatException or global::System.InvalidCastException or global::System.NotSupportedException) + { + return false; + } + + var min = (global::System.IComparable)Minimum; + var max = (global::System.IComparable)Maximum; + + return + (MinimumIsExclusive ? min.CompareTo(convertedValue) < 0 : min.CompareTo(convertedValue) <= 0) && + (MaximumIsExclusive ? max.CompareTo(convertedValue) > 0 : max.CompareTo(convertedValue) >= 0); + } + private string GetValidationErrorMessage() + { + return (MinimumIsExclusive, MaximumIsExclusive) switch + { + (false, false) => "The field {0} must be between {1} and {2}.", + (true, false) => "The field {0} must be between {1} exclusive and {2}.", + (false, true) => "The field {0} must be between {1} and {2} exclusive.", + (true, true) => "The field {0} must be between {1} exclusive and {2} exclusive.", + }; + } + private object? ConvertValue(object? value, System.Globalization.CultureInfo formatProvider) + { + if (value is string stringValue) + { + value = global::System.Convert.ChangeType(stringValue, OperandType, formatProvider); + } + else + { + value = global::System.Convert.ChangeType(value, OperandType, formatProvider); + } + return value; + } + } +} diff --git a/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Baselines/EmitterWithCustomValidator.netfx.g.cs b/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Baselines/EmitterWithCustomValidator.netfx.g.cs new file mode 100644 index 00000000000000..9dc3ded5bd4624 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Baselines/EmitterWithCustomValidator.netfx.g.cs @@ -0,0 +1,173 @@ + + // + #nullable enable + #pragma warning disable CS1591 // Compensate for https://github.com/dotnet/roslyn/issues/54103 + namespace HelloWorld +{ + partial struct MyOptionsValidator + { + /// + /// Validates a specific named options instance (or all when is ). + /// + /// The name of the options instance being validated. + /// The options instance. + /// Validation result. + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::HelloWorld.MyOptions options) + { + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; + var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); + var validationResults = new global::System.Collections.Generic.List(); + var validationAttributes = new global::System.Collections.Generic.List(1); + + context.MemberName = "Val1"; + context.DisplayName = string.IsNullOrEmpty(name) ? "MyOptions.Val1" : $"{name}.Val1"; + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val1, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "Val2"; + context.DisplayName = string.IsNullOrEmpty(name) ? "MyOptions.Val2" : $"{name}.Val2"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val2, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); + } + } +} +namespace __OptionValidationStaticInstances +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + file static class __Attributes + { + internal static readonly global::System.ComponentModel.DataAnnotations.RequiredAttribute A1 = new global::System.ComponentModel.DataAnnotations.RequiredAttribute(); + + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute A2 = new __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute( + (int)1, + (int)3); + } +} +namespace __OptionValidationStaticInstances +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + file static class __Validators + { + } +} +namespace __OptionValidationGeneratedAttributes +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + file class __SourceGen__RangeAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + public __SourceGen__RangeAttribute(int minimum, int maximum) : base() + { + Minimum = minimum; + Maximum = maximum; + OperandType = typeof(int); + } + public __SourceGen__RangeAttribute(double minimum, double maximum) : base() + { + Minimum = minimum; + Maximum = maximum; + OperandType = typeof(double); + } + public __SourceGen__RangeAttribute(global::System.Type type, string minimum, string maximum) : base() + { + OperandType = type; + NeedToConvertMinMax = true; + Minimum = minimum; + Maximum = maximum; + } + public object Minimum { get; private set; } + public object Maximum { get; private set; } + public bool MinimumIsExclusive { get; set; } + public bool MaximumIsExclusive { get; set; } + public global::System.Type OperandType { get; } + public bool ParseLimitsInInvariantCulture { get; set; } + public bool ConvertValueInInvariantCulture { get; set; } + public override string FormatErrorMessage(string name) => + string.Format(global::System.Globalization.CultureInfo.CurrentCulture, GetValidationErrorMessage(), name, Minimum, Maximum); + private bool NeedToConvertMinMax { get; } + private bool Initialized { get; set; } + public override bool IsValid(object? value) + { + if (!Initialized) + { + if (Minimum is null || Maximum is null) + { + throw new global::System.InvalidOperationException("The minimum and maximum values must be set to valid values."); + } + if (NeedToConvertMinMax) + { + System.Globalization.CultureInfo culture = ParseLimitsInInvariantCulture ? global::System.Globalization.CultureInfo.InvariantCulture : global::System.Globalization.CultureInfo.CurrentCulture; + Minimum = ConvertValue(Minimum, culture) ?? throw new global::System.InvalidOperationException("The minimum and maximum values must be set to valid values."); + Maximum = ConvertValue(Maximum, culture) ?? throw new global::System.InvalidOperationException("The minimum and maximum values must be set to valid values."); + } + int cmp = ((global::System.IComparable)Minimum).CompareTo((global::System.IComparable)Maximum); + if (cmp > 0) + { + throw new global::System.InvalidOperationException("The maximum value '{Maximum}' must be greater than or equal to the minimum value '{Minimum}'."); + } + else if (cmp == 0 && (MinimumIsExclusive || MaximumIsExclusive)) + { + throw new global::System.InvalidOperationException("Cannot use exclusive bounds when the maximum value is equal to the minimum value."); + } + Initialized = true; + } + + if (value is null or string { Length: 0 }) + { + return true; + } + + System.Globalization.CultureInfo formatProvider = ConvertValueInInvariantCulture ? global::System.Globalization.CultureInfo.InvariantCulture : global::System.Globalization.CultureInfo.CurrentCulture; + object? convertedValue; + + try + { + convertedValue = ConvertValue(value, formatProvider); + } + catch (global::System.Exception e) when (e is global::System.FormatException or global::System.InvalidCastException or global::System.NotSupportedException) + { + return false; + } + + var min = (global::System.IComparable)Minimum; + var max = (global::System.IComparable)Maximum; + + return + (MinimumIsExclusive ? min.CompareTo(convertedValue) < 0 : min.CompareTo(convertedValue) <= 0) && + (MaximumIsExclusive ? max.CompareTo(convertedValue) > 0 : max.CompareTo(convertedValue) >= 0); + } + private string GetValidationErrorMessage() + { + return (MinimumIsExclusive, MaximumIsExclusive) switch + { + (false, false) => "The field {0} must be between {1} and {2}.", + (true, false) => "The field {0} must be between {1} exclusive and {2}.", + (false, true) => "The field {0} must be between {1} and {2} exclusive.", + (true, true) => "The field {0} must be between {1} exclusive and {2} exclusive.", + }; + } + private object? ConvertValue(object? value, System.Globalization.CultureInfo formatProvider) + { + if (value is string stringValue) + { + value = global::System.Convert.ChangeType(stringValue, OperandType, formatProvider); + } + else + { + value = global::System.Convert.ChangeType(value, OperandType, formatProvider); + } + return value; + } + } +} diff --git a/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Baselines/GeneratedAttributesTest.netcore.lang10.g.cs b/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Baselines/GeneratedAttributesTest.netcore.lang10.g.cs new file mode 100644 index 00000000000000..cc9864a2619c45 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Baselines/GeneratedAttributesTest.netcore.lang10.g.cs @@ -0,0 +1,471 @@ + + // + #nullable enable + #pragma warning disable CS1591 // Compensate for https://github.com/dotnet/roslyn/issues/54103 + namespace ValidationTest +{ + partial class OptionsUsingGeneratedAttributesValidator + { + /// + /// Validates a specific named options instance (or all when is ). + /// + /// The name of the options instance being validated. + /// The options instance. + /// Validation result. + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] + public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::ValidationTest.OptionsUsingGeneratedAttributes options) + { + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; + var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); + var validationResults = new global::System.Collections.Generic.List(); + var validationAttributes = new global::System.Collections.Generic.List(1); + + context.MemberName = "P0"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P0" : $"{name}.P0"; + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes_2C497155.A1); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P0, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P1"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P1" : $"{name}.P1"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes_2C497155.A1); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P2"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P2" : $"{name}.P2"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes_2C497155.A1); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P2, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P3"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P3" : $"{name}.P3"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes_2C497155.A2); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P3, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P4"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P4" : $"{name}.P4"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes_2C497155.A3); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P4, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P5"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P5" : $"{name}.P5"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes_2C497155.A4); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P5, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P6"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P6" : $"{name}.P6"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes_2C497155.A5); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P6, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P7"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P7" : $"{name}.P7"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes_2C497155.A3); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P7, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P8"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P8" : $"{name}.P8"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes_2C497155.A3); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P8, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P9"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P9" : $"{name}.P9"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes_2C497155.A4); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P9, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P10"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P10" : $"{name}.P10"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes_2C497155.A4); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P10, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P11"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P11" : $"{name}.P11"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes_2C497155.A3); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P11, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P12"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P12" : $"{name}.P12"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes_2C497155.A4); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P12, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); + } + } +} +namespace __OptionValidationStaticInstances +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + internal static class __Attributes_2C497155 + { + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__2C497155_LengthAttribute A1 = new __OptionValidationGeneratedAttributes.__SourceGen__2C497155_LengthAttribute( + (int)1, + (int)3); + + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__2C497155_RangeAttribute A2 = new __OptionValidationGeneratedAttributes.__SourceGen__2C497155_RangeAttribute( + (int)1, + (int)3); + + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__2C497155_MinLengthAttribute A3 = new __OptionValidationGeneratedAttributes.__SourceGen__2C497155_MinLengthAttribute( + (int)5); + + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__2C497155_MaxLengthAttribute A4 = new __OptionValidationGeneratedAttributes.__SourceGen__2C497155_MaxLengthAttribute( + (int)5); + + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__2C497155_CompareAttribute A5 = new __OptionValidationGeneratedAttributes.__SourceGen__2C497155_CompareAttribute( + "P5"); + } +} +namespace __OptionValidationStaticInstances +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + internal static class __Validators_2C497155 + { + } +} +namespace __OptionValidationGeneratedAttributes +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property, AllowMultiple = false)] + internal class __SourceGen__2C497155_CompareAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + private static string DefaultErrorMessageString => "'{0}' and '{1}' do not match."; + public __SourceGen__2C497155_CompareAttribute(string otherProperty) : base(() => DefaultErrorMessageString) + { + if (otherProperty == null) + { + throw new global::System.ArgumentNullException(nameof(otherProperty)); + } + OtherProperty = otherProperty; + } + public string OtherProperty { get; } + public override bool RequiresValidationContext => true; + + protected override global::System.ComponentModel.DataAnnotations.ValidationResult? IsValid(object? value, global::System.ComponentModel.DataAnnotations.ValidationContext validationContext) + { + bool result = true; + + if (validationContext.ObjectInstance is global::ValidationTest.OptionsUsingGeneratedAttributes && OtherProperty == "P5") + { + result = Equals(value, ((global::ValidationTest.OptionsUsingGeneratedAttributes)validationContext.ObjectInstance).P5); + } + + if (!result) + { + string[]? memberNames = validationContext.MemberName is null ? null : new string[] { validationContext.MemberName }; + return new global::System.ComponentModel.DataAnnotations.ValidationResult(FormatErrorMessage(validationContext.DisplayName), memberNames); + } + + return null; + } + public override string FormatErrorMessage(string name) => string.Format(global::System.Globalization.CultureInfo.CurrentCulture, ErrorMessageString, name, OtherProperty); + } + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + internal class __SourceGen__2C497155_LengthAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + private static string DefaultErrorMessageString => "The field {0} must be a string or collection type with a minimum length of '{1}' and maximum length of '{2}'."; + public __SourceGen__2C497155_LengthAttribute(int minimumLength, int maximumLength) : base(() => DefaultErrorMessageString) { MinimumLength = minimumLength; MaximumLength = maximumLength; } + public int MinimumLength { get; } + public int MaximumLength { get; } + public override bool IsValid(object? value) + { + if (MinimumLength < 0) + { + throw new global::System.InvalidOperationException("LengthAttribute must have a MinimumLength value that is zero or greater."); + } + if (MaximumLength < MinimumLength) + { + throw new global::System.InvalidOperationException("LengthAttribute must have a MaximumLength value that is greater than or equal to MinimumLength."); + } + if (value == null) + { + return true; + } + + int length; + if (value is string stringValue) + { + length = stringValue.Length; + } + else if (value is System.Collections.ICollection collectionValue) + { + length = collectionValue.Count; + } + else if (value is global::ValidationTest.FakeCount) + { + length = ((global::ValidationTest.FakeCount)value).Count; + } + else if (value is global::ValidationTest.FakeCountChild) + { + length = ((global::ValidationTest.FakeCountChild)value).Count; + } + else + { + throw new global::System.InvalidCastException($"The field of type {value.GetType()} must be a string, array, or ICollection type."); + } + + return (uint)(length - MinimumLength) <= (uint)(MaximumLength - MinimumLength); + } + public override string FormatErrorMessage(string name) => string.Format(global::System.Globalization.CultureInfo.CurrentCulture, ErrorMessageString, name, MinimumLength, MaximumLength); + } + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + internal class __SourceGen__2C497155_MaxLengthAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + private const int MaxAllowableLength = -1; + private static string DefaultErrorMessageString => "The field {0} must be a string or array type with a maximum length of '{1}'."; + public __SourceGen__2C497155_MaxLengthAttribute(int length) : base(() => DefaultErrorMessageString) { Length = length; } + public __SourceGen__2C497155_MaxLengthAttribute(): base(() => DefaultErrorMessageString) { Length = MaxAllowableLength; } + public int Length { get; } + public override string FormatErrorMessage(string name) => string.Format(global::System.Globalization.CultureInfo.CurrentCulture, ErrorMessageString, name, Length); + public override bool IsValid(object? value) + { + if (Length == 0 || Length < -1) + { + throw new global::System.InvalidOperationException("MaxLengthAttribute must have a Length value that is greater than zero. Use MaxLength() without parameters to indicate that the string or array can have the maximum allowable length."); + } + if (value == null || MaxAllowableLength == Length) + { + return true; + } + + int length; + if (value is string stringValue) + { + length = stringValue.Length; + } + else if (value is System.Collections.ICollection collectionValue) + { + length = collectionValue.Count; + } + else if (value is global::ValidationTest.FakeCount) + { + length = ((global::ValidationTest.FakeCount)value).Count; + } + else if (value is global::ValidationTest.FakeCountChild) + { + length = ((global::ValidationTest.FakeCountChild)value).Count; + } + else + { + throw new global::System.InvalidCastException($"The field of type {value.GetType()} must be a string, array, or ICollection type."); + } + + return length <= Length; + } + } + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + internal class __SourceGen__2C497155_MinLengthAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + private static string DefaultErrorMessageString => "The field {0} must be a string or array type with a minimum length of '{1}'."; + + public __SourceGen__2C497155_MinLengthAttribute(int length) : base(() => DefaultErrorMessageString) { Length = length; } + public int Length { get; } + public override bool IsValid(object? value) + { + if (Length < -1) + { + throw new global::System.InvalidOperationException("MinLengthAttribute must have a Length value that is zero or greater."); + } + if (value == null) + { + return true; + } + + int length; + if (value is string stringValue) + { + length = stringValue.Length; + } + else if (value is System.Collections.ICollection collectionValue) + { + length = collectionValue.Count; + } + else if (value is global::ValidationTest.FakeCount) + { + length = ((global::ValidationTest.FakeCount)value).Count; + } + else if (value is global::ValidationTest.FakeCountChild) + { + length = ((global::ValidationTest.FakeCountChild)value).Count; + } + else + { + throw new global::System.InvalidCastException($"The field of type {value.GetType()} must be a string, array, or ICollection type."); + } + + return length >= Length; + } + public override string FormatErrorMessage(string name) => string.Format(global::System.Globalization.CultureInfo.CurrentCulture, ErrorMessageString, name, Length); + } + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + internal class __SourceGen__2C497155_RangeAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + public __SourceGen__2C497155_RangeAttribute(int minimum, int maximum) : base() + { + Minimum = minimum; + Maximum = maximum; + OperandType = typeof(int); + } + public __SourceGen__2C497155_RangeAttribute(double minimum, double maximum) : base() + { + Minimum = minimum; + Maximum = maximum; + OperandType = typeof(double); + } + public __SourceGen__2C497155_RangeAttribute(global::System.Type type, string minimum, string maximum) : base() + { + OperandType = type; + NeedToConvertMinMax = true; + Minimum = minimum; + Maximum = maximum; + } + public object Minimum { get; private set; } + public object Maximum { get; private set; } + public bool MinimumIsExclusive { get; set; } + public bool MaximumIsExclusive { get; set; } + public global::System.Type OperandType { get; } + public bool ParseLimitsInInvariantCulture { get; set; } + public bool ConvertValueInInvariantCulture { get; set; } + public override string FormatErrorMessage(string name) => + string.Format(global::System.Globalization.CultureInfo.CurrentCulture, GetValidationErrorMessage(), name, Minimum, Maximum); + private bool NeedToConvertMinMax { get; } + private bool Initialized { get; set; } + public override bool IsValid(object? value) + { + if (!Initialized) + { + if (Minimum is null || Maximum is null) + { + throw new global::System.InvalidOperationException("The minimum and maximum values must be set to valid values."); + } + if (NeedToConvertMinMax) + { + System.Globalization.CultureInfo culture = ParseLimitsInInvariantCulture ? global::System.Globalization.CultureInfo.InvariantCulture : global::System.Globalization.CultureInfo.CurrentCulture; + Minimum = ConvertValue(Minimum, culture) ?? throw new global::System.InvalidOperationException("The minimum and maximum values must be set to valid values."); + Maximum = ConvertValue(Maximum, culture) ?? throw new global::System.InvalidOperationException("The minimum and maximum values must be set to valid values."); + } + int cmp = ((global::System.IComparable)Minimum).CompareTo((global::System.IComparable)Maximum); + if (cmp > 0) + { + throw new global::System.InvalidOperationException("The maximum value '{Maximum}' must be greater than or equal to the minimum value '{Minimum}'."); + } + else if (cmp == 0 && (MinimumIsExclusive || MaximumIsExclusive)) + { + throw new global::System.InvalidOperationException("Cannot use exclusive bounds when the maximum value is equal to the minimum value."); + } + Initialized = true; + } + + if (value is null or string { Length: 0 }) + { + return true; + } + + System.Globalization.CultureInfo formatProvider = ConvertValueInInvariantCulture ? global::System.Globalization.CultureInfo.InvariantCulture : global::System.Globalization.CultureInfo.CurrentCulture; + object? convertedValue; + + try + { + convertedValue = ConvertValue(value, formatProvider); + } + catch (global::System.Exception e) when (e is global::System.FormatException or global::System.InvalidCastException or global::System.NotSupportedException) + { + return false; + } + + var min = (global::System.IComparable)Minimum; + var max = (global::System.IComparable)Maximum; + + return + (MinimumIsExclusive ? min.CompareTo(convertedValue) < 0 : min.CompareTo(convertedValue) <= 0) && + (MaximumIsExclusive ? max.CompareTo(convertedValue) > 0 : max.CompareTo(convertedValue) >= 0); + } + private string GetValidationErrorMessage() + { + return (MinimumIsExclusive, MaximumIsExclusive) switch + { + (false, false) => "The field {0} must be between {1} and {2}.", + (true, false) => "The field {0} must be between {1} exclusive and {2}.", + (false, true) => "The field {0} must be between {1} and {2} exclusive.", + (true, true) => "The field {0} must be between {1} exclusive and {2} exclusive.", + }; + } + private object? ConvertValue(object? value, System.Globalization.CultureInfo formatProvider) + { + if (value is string stringValue) + { + value = global::System.Convert.ChangeType(stringValue, OperandType, formatProvider); + } + else + { + value = global::System.Convert.ChangeType(value, OperandType, formatProvider); + } + return value; + } + } +} diff --git a/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Baselines/GeneratedAttributesTest.netcore.lang11.g.cs b/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Baselines/GeneratedAttributesTest.netcore.lang11.g.cs new file mode 100644 index 00000000000000..2a33e51b0b6175 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Baselines/GeneratedAttributesTest.netcore.lang11.g.cs @@ -0,0 +1,471 @@ + + // + #nullable enable + #pragma warning disable CS1591 // Compensate for https://github.com/dotnet/roslyn/issues/54103 + namespace ValidationTest +{ + partial class OptionsUsingGeneratedAttributesValidator + { + /// + /// Validates a specific named options instance (or all when is ). + /// + /// The name of the options instance being validated. + /// The options instance. + /// Validation result. + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] + public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::ValidationTest.OptionsUsingGeneratedAttributes options) + { + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; + var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); + var validationResults = new global::System.Collections.Generic.List(); + var validationAttributes = new global::System.Collections.Generic.List(1); + + context.MemberName = "P0"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P0" : $"{name}.P0"; + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P0, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P1"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P1" : $"{name}.P1"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P2"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P2" : $"{name}.P2"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P2, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P3"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P3" : $"{name}.P3"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P3, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P4"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P4" : $"{name}.P4"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A3); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P4, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P5"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P5" : $"{name}.P5"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A4); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P5, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P6"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P6" : $"{name}.P6"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A5); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P6, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P7"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P7" : $"{name}.P7"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A3); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P7, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P8"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P8" : $"{name}.P8"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A3); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P8, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P9"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P9" : $"{name}.P9"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A4); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P9, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P10"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P10" : $"{name}.P10"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A4); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P10, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P11"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P11" : $"{name}.P11"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A3); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P11, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P12"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P12" : $"{name}.P12"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A4); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P12, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); + } + } +} +namespace __OptionValidationStaticInstances +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + file static class __Attributes + { + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__LengthAttribute A1 = new __OptionValidationGeneratedAttributes.__SourceGen__LengthAttribute( + (int)1, + (int)3); + + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute A2 = new __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute( + (int)1, + (int)3); + + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__MinLengthAttribute A3 = new __OptionValidationGeneratedAttributes.__SourceGen__MinLengthAttribute( + (int)5); + + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__MaxLengthAttribute A4 = new __OptionValidationGeneratedAttributes.__SourceGen__MaxLengthAttribute( + (int)5); + + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__CompareAttribute A5 = new __OptionValidationGeneratedAttributes.__SourceGen__CompareAttribute( + "P5"); + } +} +namespace __OptionValidationStaticInstances +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + file static class __Validators + { + } +} +namespace __OptionValidationGeneratedAttributes +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property, AllowMultiple = false)] + file class __SourceGen__CompareAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + private static string DefaultErrorMessageString => "'{0}' and '{1}' do not match."; + public __SourceGen__CompareAttribute(string otherProperty) : base(() => DefaultErrorMessageString) + { + if (otherProperty == null) + { + throw new global::System.ArgumentNullException(nameof(otherProperty)); + } + OtherProperty = otherProperty; + } + public string OtherProperty { get; } + public override bool RequiresValidationContext => true; + + protected override global::System.ComponentModel.DataAnnotations.ValidationResult? IsValid(object? value, global::System.ComponentModel.DataAnnotations.ValidationContext validationContext) + { + bool result = true; + + if (validationContext.ObjectInstance is global::ValidationTest.OptionsUsingGeneratedAttributes && OtherProperty == "P5") + { + result = Equals(value, ((global::ValidationTest.OptionsUsingGeneratedAttributes)validationContext.ObjectInstance).P5); + } + + if (!result) + { + string[]? memberNames = validationContext.MemberName is null ? null : new string[] { validationContext.MemberName }; + return new global::System.ComponentModel.DataAnnotations.ValidationResult(FormatErrorMessage(validationContext.DisplayName), memberNames); + } + + return null; + } + public override string FormatErrorMessage(string name) => string.Format(global::System.Globalization.CultureInfo.CurrentCulture, ErrorMessageString, name, OtherProperty); + } + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + file class __SourceGen__LengthAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + private static string DefaultErrorMessageString => "The field {0} must be a string or collection type with a minimum length of '{1}' and maximum length of '{2}'."; + public __SourceGen__LengthAttribute(int minimumLength, int maximumLength) : base(() => DefaultErrorMessageString) { MinimumLength = minimumLength; MaximumLength = maximumLength; } + public int MinimumLength { get; } + public int MaximumLength { get; } + public override bool IsValid(object? value) + { + if (MinimumLength < 0) + { + throw new global::System.InvalidOperationException("LengthAttribute must have a MinimumLength value that is zero or greater."); + } + if (MaximumLength < MinimumLength) + { + throw new global::System.InvalidOperationException("LengthAttribute must have a MaximumLength value that is greater than or equal to MinimumLength."); + } + if (value == null) + { + return true; + } + + int length; + if (value is string stringValue) + { + length = stringValue.Length; + } + else if (value is System.Collections.ICollection collectionValue) + { + length = collectionValue.Count; + } + else if (value is global::ValidationTest.FakeCount) + { + length = ((global::ValidationTest.FakeCount)value).Count; + } + else if (value is global::ValidationTest.FakeCountChild) + { + length = ((global::ValidationTest.FakeCountChild)value).Count; + } + else + { + throw new global::System.InvalidCastException($"The field of type {value.GetType()} must be a string, array, or ICollection type."); + } + + return (uint)(length - MinimumLength) <= (uint)(MaximumLength - MinimumLength); + } + public override string FormatErrorMessage(string name) => string.Format(global::System.Globalization.CultureInfo.CurrentCulture, ErrorMessageString, name, MinimumLength, MaximumLength); + } + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + file class __SourceGen__MaxLengthAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + private const int MaxAllowableLength = -1; + private static string DefaultErrorMessageString => "The field {0} must be a string or array type with a maximum length of '{1}'."; + public __SourceGen__MaxLengthAttribute(int length) : base(() => DefaultErrorMessageString) { Length = length; } + public __SourceGen__MaxLengthAttribute(): base(() => DefaultErrorMessageString) { Length = MaxAllowableLength; } + public int Length { get; } + public override string FormatErrorMessage(string name) => string.Format(global::System.Globalization.CultureInfo.CurrentCulture, ErrorMessageString, name, Length); + public override bool IsValid(object? value) + { + if (Length == 0 || Length < -1) + { + throw new global::System.InvalidOperationException("MaxLengthAttribute must have a Length value that is greater than zero. Use MaxLength() without parameters to indicate that the string or array can have the maximum allowable length."); + } + if (value == null || MaxAllowableLength == Length) + { + return true; + } + + int length; + if (value is string stringValue) + { + length = stringValue.Length; + } + else if (value is System.Collections.ICollection collectionValue) + { + length = collectionValue.Count; + } + else if (value is global::ValidationTest.FakeCount) + { + length = ((global::ValidationTest.FakeCount)value).Count; + } + else if (value is global::ValidationTest.FakeCountChild) + { + length = ((global::ValidationTest.FakeCountChild)value).Count; + } + else + { + throw new global::System.InvalidCastException($"The field of type {value.GetType()} must be a string, array, or ICollection type."); + } + + return length <= Length; + } + } + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + file class __SourceGen__MinLengthAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + private static string DefaultErrorMessageString => "The field {0} must be a string or array type with a minimum length of '{1}'."; + + public __SourceGen__MinLengthAttribute(int length) : base(() => DefaultErrorMessageString) { Length = length; } + public int Length { get; } + public override bool IsValid(object? value) + { + if (Length < -1) + { + throw new global::System.InvalidOperationException("MinLengthAttribute must have a Length value that is zero or greater."); + } + if (value == null) + { + return true; + } + + int length; + if (value is string stringValue) + { + length = stringValue.Length; + } + else if (value is System.Collections.ICollection collectionValue) + { + length = collectionValue.Count; + } + else if (value is global::ValidationTest.FakeCount) + { + length = ((global::ValidationTest.FakeCount)value).Count; + } + else if (value is global::ValidationTest.FakeCountChild) + { + length = ((global::ValidationTest.FakeCountChild)value).Count; + } + else + { + throw new global::System.InvalidCastException($"The field of type {value.GetType()} must be a string, array, or ICollection type."); + } + + return length >= Length; + } + public override string FormatErrorMessage(string name) => string.Format(global::System.Globalization.CultureInfo.CurrentCulture, ErrorMessageString, name, Length); + } + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + file class __SourceGen__RangeAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + public __SourceGen__RangeAttribute(int minimum, int maximum) : base() + { + Minimum = minimum; + Maximum = maximum; + OperandType = typeof(int); + } + public __SourceGen__RangeAttribute(double minimum, double maximum) : base() + { + Minimum = minimum; + Maximum = maximum; + OperandType = typeof(double); + } + public __SourceGen__RangeAttribute(global::System.Type type, string minimum, string maximum) : base() + { + OperandType = type; + NeedToConvertMinMax = true; + Minimum = minimum; + Maximum = maximum; + } + public object Minimum { get; private set; } + public object Maximum { get; private set; } + public bool MinimumIsExclusive { get; set; } + public bool MaximumIsExclusive { get; set; } + public global::System.Type OperandType { get; } + public bool ParseLimitsInInvariantCulture { get; set; } + public bool ConvertValueInInvariantCulture { get; set; } + public override string FormatErrorMessage(string name) => + string.Format(global::System.Globalization.CultureInfo.CurrentCulture, GetValidationErrorMessage(), name, Minimum, Maximum); + private bool NeedToConvertMinMax { get; } + private bool Initialized { get; set; } + public override bool IsValid(object? value) + { + if (!Initialized) + { + if (Minimum is null || Maximum is null) + { + throw new global::System.InvalidOperationException("The minimum and maximum values must be set to valid values."); + } + if (NeedToConvertMinMax) + { + System.Globalization.CultureInfo culture = ParseLimitsInInvariantCulture ? global::System.Globalization.CultureInfo.InvariantCulture : global::System.Globalization.CultureInfo.CurrentCulture; + Minimum = ConvertValue(Minimum, culture) ?? throw new global::System.InvalidOperationException("The minimum and maximum values must be set to valid values."); + Maximum = ConvertValue(Maximum, culture) ?? throw new global::System.InvalidOperationException("The minimum and maximum values must be set to valid values."); + } + int cmp = ((global::System.IComparable)Minimum).CompareTo((global::System.IComparable)Maximum); + if (cmp > 0) + { + throw new global::System.InvalidOperationException("The maximum value '{Maximum}' must be greater than or equal to the minimum value '{Minimum}'."); + } + else if (cmp == 0 && (MinimumIsExclusive || MaximumIsExclusive)) + { + throw new global::System.InvalidOperationException("Cannot use exclusive bounds when the maximum value is equal to the minimum value."); + } + Initialized = true; + } + + if (value is null or string { Length: 0 }) + { + return true; + } + + System.Globalization.CultureInfo formatProvider = ConvertValueInInvariantCulture ? global::System.Globalization.CultureInfo.InvariantCulture : global::System.Globalization.CultureInfo.CurrentCulture; + object? convertedValue; + + try + { + convertedValue = ConvertValue(value, formatProvider); + } + catch (global::System.Exception e) when (e is global::System.FormatException or global::System.InvalidCastException or global::System.NotSupportedException) + { + return false; + } + + var min = (global::System.IComparable)Minimum; + var max = (global::System.IComparable)Maximum; + + return + (MinimumIsExclusive ? min.CompareTo(convertedValue) < 0 : min.CompareTo(convertedValue) <= 0) && + (MaximumIsExclusive ? max.CompareTo(convertedValue) > 0 : max.CompareTo(convertedValue) >= 0); + } + private string GetValidationErrorMessage() + { + return (MinimumIsExclusive, MaximumIsExclusive) switch + { + (false, false) => "The field {0} must be between {1} and {2}.", + (true, false) => "The field {0} must be between {1} exclusive and {2}.", + (false, true) => "The field {0} must be between {1} and {2} exclusive.", + (true, true) => "The field {0} must be between {1} exclusive and {2} exclusive.", + }; + } + private object? ConvertValue(object? value, System.Globalization.CultureInfo formatProvider) + { + if (value is string stringValue) + { + value = global::System.Convert.ChangeType(stringValue, OperandType, formatProvider); + } + else + { + value = global::System.Convert.ChangeType(value, OperandType, formatProvider); + } + return value; + } + } +} diff --git a/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Baselines/GeneratedAttributesTest.netfx.lang10.g.cs b/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Baselines/GeneratedAttributesTest.netfx.lang10.g.cs new file mode 100644 index 00000000000000..7f5eb90a202815 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Baselines/GeneratedAttributesTest.netfx.lang10.g.cs @@ -0,0 +1,386 @@ + + // + #nullable enable + #pragma warning disable CS1591 // Compensate for https://github.com/dotnet/roslyn/issues/54103 + namespace ValidationTest +{ + partial class OptionsUsingGeneratedAttributesValidator + { + /// + /// Validates a specific named options instance (or all when is ). + /// + /// The name of the options instance being validated. + /// The options instance. + /// Validation result. + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::ValidationTest.OptionsUsingGeneratedAttributes options) + { + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; + var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); + var validationResults = new global::System.Collections.Generic.List(); + var validationAttributes = new global::System.Collections.Generic.List(1); + + context.MemberName = "P3"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P3" : $"{name}.P3"; + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes_2C497155.A1); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P3, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P4"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P4" : $"{name}.P4"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes_2C497155.A2); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P4, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P5"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P5" : $"{name}.P5"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes_2C497155.A3); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P5, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P6"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P6" : $"{name}.P6"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes_2C497155.A4); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P6, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P7"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P7" : $"{name}.P7"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes_2C497155.A2); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P7, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P8"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P8" : $"{name}.P8"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes_2C497155.A2); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P8, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P9"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P9" : $"{name}.P9"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes_2C497155.A3); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P9, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P10"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P10" : $"{name}.P10"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes_2C497155.A3); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P10, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P11"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P11" : $"{name}.P11"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes_2C497155.A2); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P11, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P12"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P12" : $"{name}.P12"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes_2C497155.A3); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P12, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); + } + } +} +namespace __OptionValidationStaticInstances +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + internal static class __Attributes_2C497155 + { + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__2C497155_RangeAttribute A1 = new __OptionValidationGeneratedAttributes.__SourceGen__2C497155_RangeAttribute( + (int)1, + (int)3); + + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__2C497155_MinLengthAttribute A2 = new __OptionValidationGeneratedAttributes.__SourceGen__2C497155_MinLengthAttribute( + (int)5); + + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__2C497155_MaxLengthAttribute A3 = new __OptionValidationGeneratedAttributes.__SourceGen__2C497155_MaxLengthAttribute( + (int)5); + + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__2C497155_CompareAttribute A4 = new __OptionValidationGeneratedAttributes.__SourceGen__2C497155_CompareAttribute( + "P5"); + } +} +namespace __OptionValidationStaticInstances +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + internal static class __Validators_2C497155 + { + } +} +namespace __OptionValidationGeneratedAttributes +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property, AllowMultiple = false)] + internal class __SourceGen__2C497155_CompareAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + private static string DefaultErrorMessageString => "'{0}' and '{1}' do not match."; + public __SourceGen__2C497155_CompareAttribute(string otherProperty) : base(() => DefaultErrorMessageString) + { + if (otherProperty == null) + { + throw new global::System.ArgumentNullException(nameof(otherProperty)); + } + OtherProperty = otherProperty; + } + public string OtherProperty { get; } + public override bool RequiresValidationContext => true; + + protected override global::System.ComponentModel.DataAnnotations.ValidationResult? IsValid(object? value, global::System.ComponentModel.DataAnnotations.ValidationContext validationContext) + { + bool result = true; + + if (validationContext.ObjectInstance is global::ValidationTest.OptionsUsingGeneratedAttributes && OtherProperty == "P5") + { + result = Equals(value, ((global::ValidationTest.OptionsUsingGeneratedAttributes)validationContext.ObjectInstance).P5); + } + + if (!result) + { + string[]? memberNames = validationContext.MemberName is null ? null : new string[] { validationContext.MemberName }; + return new global::System.ComponentModel.DataAnnotations.ValidationResult(FormatErrorMessage(validationContext.DisplayName), memberNames); + } + + return null; + } + public override string FormatErrorMessage(string name) => string.Format(global::System.Globalization.CultureInfo.CurrentCulture, ErrorMessageString, name, OtherProperty); + } + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + internal class __SourceGen__2C497155_MaxLengthAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + private const int MaxAllowableLength = -1; + private static string DefaultErrorMessageString => "The field {0} must be a string or array type with a maximum length of '{1}'."; + public __SourceGen__2C497155_MaxLengthAttribute(int length) : base(() => DefaultErrorMessageString) { Length = length; } + public __SourceGen__2C497155_MaxLengthAttribute(): base(() => DefaultErrorMessageString) { Length = MaxAllowableLength; } + public int Length { get; } + public override string FormatErrorMessage(string name) => string.Format(global::System.Globalization.CultureInfo.CurrentCulture, ErrorMessageString, name, Length); + public override bool IsValid(object? value) + { + if (Length == 0 || Length < -1) + { + throw new global::System.InvalidOperationException("MaxLengthAttribute must have a Length value that is greater than zero. Use MaxLength() without parameters to indicate that the string or array can have the maximum allowable length."); + } + if (value == null || MaxAllowableLength == Length) + { + return true; + } + + int length; + if (value is string stringValue) + { + length = stringValue.Length; + } + else if (value is System.Collections.ICollection collectionValue) + { + length = collectionValue.Count; + } + else if (value is global::ValidationTest.FakeCount) + { + length = ((global::ValidationTest.FakeCount)value).Count; + } + else if (value is global::ValidationTest.FakeCountChild) + { + length = ((global::ValidationTest.FakeCountChild)value).Count; + } + else + { + throw new global::System.InvalidCastException($"The field of type {value.GetType()} must be a string, array, or ICollection type."); + } + + return length <= Length; + } + } + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + internal class __SourceGen__2C497155_MinLengthAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + private static string DefaultErrorMessageString => "The field {0} must be a string or array type with a minimum length of '{1}'."; + + public __SourceGen__2C497155_MinLengthAttribute(int length) : base(() => DefaultErrorMessageString) { Length = length; } + public int Length { get; } + public override bool IsValid(object? value) + { + if (Length < -1) + { + throw new global::System.InvalidOperationException("MinLengthAttribute must have a Length value that is zero or greater."); + } + if (value == null) + { + return true; + } + + int length; + if (value is string stringValue) + { + length = stringValue.Length; + } + else if (value is System.Collections.ICollection collectionValue) + { + length = collectionValue.Count; + } + else if (value is global::ValidationTest.FakeCount) + { + length = ((global::ValidationTest.FakeCount)value).Count; + } + else if (value is global::ValidationTest.FakeCountChild) + { + length = ((global::ValidationTest.FakeCountChild)value).Count; + } + else + { + throw new global::System.InvalidCastException($"The field of type {value.GetType()} must be a string, array, or ICollection type."); + } + + return length >= Length; + } + public override string FormatErrorMessage(string name) => string.Format(global::System.Globalization.CultureInfo.CurrentCulture, ErrorMessageString, name, Length); + } + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + internal class __SourceGen__2C497155_RangeAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + public __SourceGen__2C497155_RangeAttribute(int minimum, int maximum) : base() + { + Minimum = minimum; + Maximum = maximum; + OperandType = typeof(int); + } + public __SourceGen__2C497155_RangeAttribute(double minimum, double maximum) : base() + { + Minimum = minimum; + Maximum = maximum; + OperandType = typeof(double); + } + public __SourceGen__2C497155_RangeAttribute(global::System.Type type, string minimum, string maximum) : base() + { + OperandType = type; + NeedToConvertMinMax = true; + Minimum = minimum; + Maximum = maximum; + } + public object Minimum { get; private set; } + public object Maximum { get; private set; } + public bool MinimumIsExclusive { get; set; } + public bool MaximumIsExclusive { get; set; } + public global::System.Type OperandType { get; } + public bool ParseLimitsInInvariantCulture { get; set; } + public bool ConvertValueInInvariantCulture { get; set; } + public override string FormatErrorMessage(string name) => + string.Format(global::System.Globalization.CultureInfo.CurrentCulture, GetValidationErrorMessage(), name, Minimum, Maximum); + private bool NeedToConvertMinMax { get; } + private bool Initialized { get; set; } + public override bool IsValid(object? value) + { + if (!Initialized) + { + if (Minimum is null || Maximum is null) + { + throw new global::System.InvalidOperationException("The minimum and maximum values must be set to valid values."); + } + if (NeedToConvertMinMax) + { + System.Globalization.CultureInfo culture = ParseLimitsInInvariantCulture ? global::System.Globalization.CultureInfo.InvariantCulture : global::System.Globalization.CultureInfo.CurrentCulture; + Minimum = ConvertValue(Minimum, culture) ?? throw new global::System.InvalidOperationException("The minimum and maximum values must be set to valid values."); + Maximum = ConvertValue(Maximum, culture) ?? throw new global::System.InvalidOperationException("The minimum and maximum values must be set to valid values."); + } + int cmp = ((global::System.IComparable)Minimum).CompareTo((global::System.IComparable)Maximum); + if (cmp > 0) + { + throw new global::System.InvalidOperationException("The maximum value '{Maximum}' must be greater than or equal to the minimum value '{Minimum}'."); + } + else if (cmp == 0 && (MinimumIsExclusive || MaximumIsExclusive)) + { + throw new global::System.InvalidOperationException("Cannot use exclusive bounds when the maximum value is equal to the minimum value."); + } + Initialized = true; + } + + if (value is null or string { Length: 0 }) + { + return true; + } + + System.Globalization.CultureInfo formatProvider = ConvertValueInInvariantCulture ? global::System.Globalization.CultureInfo.InvariantCulture : global::System.Globalization.CultureInfo.CurrentCulture; + object? convertedValue; + + try + { + convertedValue = ConvertValue(value, formatProvider); + } + catch (global::System.Exception e) when (e is global::System.FormatException or global::System.InvalidCastException or global::System.NotSupportedException) + { + return false; + } + + var min = (global::System.IComparable)Minimum; + var max = (global::System.IComparable)Maximum; + + return + (MinimumIsExclusive ? min.CompareTo(convertedValue) < 0 : min.CompareTo(convertedValue) <= 0) && + (MaximumIsExclusive ? max.CompareTo(convertedValue) > 0 : max.CompareTo(convertedValue) >= 0); + } + private string GetValidationErrorMessage() + { + return (MinimumIsExclusive, MaximumIsExclusive) switch + { + (false, false) => "The field {0} must be between {1} and {2}.", + (true, false) => "The field {0} must be between {1} exclusive and {2}.", + (false, true) => "The field {0} must be between {1} and {2} exclusive.", + (true, true) => "The field {0} must be between {1} exclusive and {2} exclusive.", + }; + } + private object? ConvertValue(object? value, System.Globalization.CultureInfo formatProvider) + { + if (value is string stringValue) + { + value = global::System.Convert.ChangeType(stringValue, OperandType, formatProvider); + } + else + { + value = global::System.Convert.ChangeType(value, OperandType, formatProvider); + } + return value; + } + } +} diff --git a/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Baselines/GeneratedAttributesTest.netfx.lang11.g.cs b/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Baselines/GeneratedAttributesTest.netfx.lang11.g.cs new file mode 100644 index 00000000000000..3ab56e21320a07 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Baselines/GeneratedAttributesTest.netfx.lang11.g.cs @@ -0,0 +1,386 @@ + + // + #nullable enable + #pragma warning disable CS1591 // Compensate for https://github.com/dotnet/roslyn/issues/54103 + namespace ValidationTest +{ + partial class OptionsUsingGeneratedAttributesValidator + { + /// + /// Validates a specific named options instance (or all when is ). + /// + /// The name of the options instance being validated. + /// The options instance. + /// Validation result. + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::ValidationTest.OptionsUsingGeneratedAttributes options) + { + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; + var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); + var validationResults = new global::System.Collections.Generic.List(); + var validationAttributes = new global::System.Collections.Generic.List(1); + + context.MemberName = "P3"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P3" : $"{name}.P3"; + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P3, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P4"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P4" : $"{name}.P4"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P4, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P5"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P5" : $"{name}.P5"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A3); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P5, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P6"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P6" : $"{name}.P6"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A4); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P6, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P7"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P7" : $"{name}.P7"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P7, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P8"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P8" : $"{name}.P8"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P8, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P9"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P9" : $"{name}.P9"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A3); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P9, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P10"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P10" : $"{name}.P10"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A3); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P10, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P11"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P11" : $"{name}.P11"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P11, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P12"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P12" : $"{name}.P12"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A3); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P12, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); + } + } +} +namespace __OptionValidationStaticInstances +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + file static class __Attributes + { + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute A1 = new __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute( + (int)1, + (int)3); + + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__MinLengthAttribute A2 = new __OptionValidationGeneratedAttributes.__SourceGen__MinLengthAttribute( + (int)5); + + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__MaxLengthAttribute A3 = new __OptionValidationGeneratedAttributes.__SourceGen__MaxLengthAttribute( + (int)5); + + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__CompareAttribute A4 = new __OptionValidationGeneratedAttributes.__SourceGen__CompareAttribute( + "P5"); + } +} +namespace __OptionValidationStaticInstances +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + file static class __Validators + { + } +} +namespace __OptionValidationGeneratedAttributes +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property, AllowMultiple = false)] + file class __SourceGen__CompareAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + private static string DefaultErrorMessageString => "'{0}' and '{1}' do not match."; + public __SourceGen__CompareAttribute(string otherProperty) : base(() => DefaultErrorMessageString) + { + if (otherProperty == null) + { + throw new global::System.ArgumentNullException(nameof(otherProperty)); + } + OtherProperty = otherProperty; + } + public string OtherProperty { get; } + public override bool RequiresValidationContext => true; + + protected override global::System.ComponentModel.DataAnnotations.ValidationResult? IsValid(object? value, global::System.ComponentModel.DataAnnotations.ValidationContext validationContext) + { + bool result = true; + + if (validationContext.ObjectInstance is global::ValidationTest.OptionsUsingGeneratedAttributes && OtherProperty == "P5") + { + result = Equals(value, ((global::ValidationTest.OptionsUsingGeneratedAttributes)validationContext.ObjectInstance).P5); + } + + if (!result) + { + string[]? memberNames = validationContext.MemberName is null ? null : new string[] { validationContext.MemberName }; + return new global::System.ComponentModel.DataAnnotations.ValidationResult(FormatErrorMessage(validationContext.DisplayName), memberNames); + } + + return null; + } + public override string FormatErrorMessage(string name) => string.Format(global::System.Globalization.CultureInfo.CurrentCulture, ErrorMessageString, name, OtherProperty); + } + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + file class __SourceGen__MaxLengthAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + private const int MaxAllowableLength = -1; + private static string DefaultErrorMessageString => "The field {0} must be a string or array type with a maximum length of '{1}'."; + public __SourceGen__MaxLengthAttribute(int length) : base(() => DefaultErrorMessageString) { Length = length; } + public __SourceGen__MaxLengthAttribute(): base(() => DefaultErrorMessageString) { Length = MaxAllowableLength; } + public int Length { get; } + public override string FormatErrorMessage(string name) => string.Format(global::System.Globalization.CultureInfo.CurrentCulture, ErrorMessageString, name, Length); + public override bool IsValid(object? value) + { + if (Length == 0 || Length < -1) + { + throw new global::System.InvalidOperationException("MaxLengthAttribute must have a Length value that is greater than zero. Use MaxLength() without parameters to indicate that the string or array can have the maximum allowable length."); + } + if (value == null || MaxAllowableLength == Length) + { + return true; + } + + int length; + if (value is string stringValue) + { + length = stringValue.Length; + } + else if (value is System.Collections.ICollection collectionValue) + { + length = collectionValue.Count; + } + else if (value is global::ValidationTest.FakeCount) + { + length = ((global::ValidationTest.FakeCount)value).Count; + } + else if (value is global::ValidationTest.FakeCountChild) + { + length = ((global::ValidationTest.FakeCountChild)value).Count; + } + else + { + throw new global::System.InvalidCastException($"The field of type {value.GetType()} must be a string, array, or ICollection type."); + } + + return length <= Length; + } + } + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + file class __SourceGen__MinLengthAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + private static string DefaultErrorMessageString => "The field {0} must be a string or array type with a minimum length of '{1}'."; + + public __SourceGen__MinLengthAttribute(int length) : base(() => DefaultErrorMessageString) { Length = length; } + public int Length { get; } + public override bool IsValid(object? value) + { + if (Length < -1) + { + throw new global::System.InvalidOperationException("MinLengthAttribute must have a Length value that is zero or greater."); + } + if (value == null) + { + return true; + } + + int length; + if (value is string stringValue) + { + length = stringValue.Length; + } + else if (value is System.Collections.ICollection collectionValue) + { + length = collectionValue.Count; + } + else if (value is global::ValidationTest.FakeCount) + { + length = ((global::ValidationTest.FakeCount)value).Count; + } + else if (value is global::ValidationTest.FakeCountChild) + { + length = ((global::ValidationTest.FakeCountChild)value).Count; + } + else + { + throw new global::System.InvalidCastException($"The field of type {value.GetType()} must be a string, array, or ICollection type."); + } + + return length >= Length; + } + public override string FormatErrorMessage(string name) => string.Format(global::System.Globalization.CultureInfo.CurrentCulture, ErrorMessageString, name, Length); + } + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + file class __SourceGen__RangeAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + public __SourceGen__RangeAttribute(int minimum, int maximum) : base() + { + Minimum = minimum; + Maximum = maximum; + OperandType = typeof(int); + } + public __SourceGen__RangeAttribute(double minimum, double maximum) : base() + { + Minimum = minimum; + Maximum = maximum; + OperandType = typeof(double); + } + public __SourceGen__RangeAttribute(global::System.Type type, string minimum, string maximum) : base() + { + OperandType = type; + NeedToConvertMinMax = true; + Minimum = minimum; + Maximum = maximum; + } + public object Minimum { get; private set; } + public object Maximum { get; private set; } + public bool MinimumIsExclusive { get; set; } + public bool MaximumIsExclusive { get; set; } + public global::System.Type OperandType { get; } + public bool ParseLimitsInInvariantCulture { get; set; } + public bool ConvertValueInInvariantCulture { get; set; } + public override string FormatErrorMessage(string name) => + string.Format(global::System.Globalization.CultureInfo.CurrentCulture, GetValidationErrorMessage(), name, Minimum, Maximum); + private bool NeedToConvertMinMax { get; } + private bool Initialized { get; set; } + public override bool IsValid(object? value) + { + if (!Initialized) + { + if (Minimum is null || Maximum is null) + { + throw new global::System.InvalidOperationException("The minimum and maximum values must be set to valid values."); + } + if (NeedToConvertMinMax) + { + System.Globalization.CultureInfo culture = ParseLimitsInInvariantCulture ? global::System.Globalization.CultureInfo.InvariantCulture : global::System.Globalization.CultureInfo.CurrentCulture; + Minimum = ConvertValue(Minimum, culture) ?? throw new global::System.InvalidOperationException("The minimum and maximum values must be set to valid values."); + Maximum = ConvertValue(Maximum, culture) ?? throw new global::System.InvalidOperationException("The minimum and maximum values must be set to valid values."); + } + int cmp = ((global::System.IComparable)Minimum).CompareTo((global::System.IComparable)Maximum); + if (cmp > 0) + { + throw new global::System.InvalidOperationException("The maximum value '{Maximum}' must be greater than or equal to the minimum value '{Minimum}'."); + } + else if (cmp == 0 && (MinimumIsExclusive || MaximumIsExclusive)) + { + throw new global::System.InvalidOperationException("Cannot use exclusive bounds when the maximum value is equal to the minimum value."); + } + Initialized = true; + } + + if (value is null or string { Length: 0 }) + { + return true; + } + + System.Globalization.CultureInfo formatProvider = ConvertValueInInvariantCulture ? global::System.Globalization.CultureInfo.InvariantCulture : global::System.Globalization.CultureInfo.CurrentCulture; + object? convertedValue; + + try + { + convertedValue = ConvertValue(value, formatProvider); + } + catch (global::System.Exception e) when (e is global::System.FormatException or global::System.InvalidCastException or global::System.NotSupportedException) + { + return false; + } + + var min = (global::System.IComparable)Minimum; + var max = (global::System.IComparable)Maximum; + + return + (MinimumIsExclusive ? min.CompareTo(convertedValue) < 0 : min.CompareTo(convertedValue) <= 0) && + (MaximumIsExclusive ? max.CompareTo(convertedValue) > 0 : max.CompareTo(convertedValue) >= 0); + } + private string GetValidationErrorMessage() + { + return (MinimumIsExclusive, MaximumIsExclusive) switch + { + (false, false) => "The field {0} must be between {1} and {2}.", + (true, false) => "The field {0} must be between {1} exclusive and {2}.", + (false, true) => "The field {0} must be between {1} and {2} exclusive.", + (true, true) => "The field {0} must be between {1} exclusive and {2} exclusive.", + }; + } + private object? ConvertValue(object? value, System.Globalization.CultureInfo formatProvider) + { + if (value is string stringValue) + { + value = global::System.Convert.ChangeType(stringValue, OperandType, formatProvider); + } + else + { + value = global::System.Convert.ChangeType(value, OperandType, formatProvider); + } + return value; + } + } +} diff --git a/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Baselines/UsingInterfaceAsPropertyTypeForLengthAttributesTests.netcore.g.cs b/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Baselines/UsingInterfaceAsPropertyTypeForLengthAttributesTests.netcore.g.cs new file mode 100644 index 00000000000000..1cd942fab0f1bb --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Baselines/UsingInterfaceAsPropertyTypeForLengthAttributesTests.netcore.g.cs @@ -0,0 +1,252 @@ + + // + #nullable enable + #pragma warning disable CS1591 // Compensate for https://github.com/dotnet/roslyn/issues/54103 + namespace Test +{ + partial class MyOptionsValidator + { + /// + /// Validates a specific named options instance (or all when is ). + /// + /// The name of the options instance being validated. + /// The options instance. + /// Validation result. + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] + public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::Test.MyOptions options) + { + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; + var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); + var validationResults = new global::System.Collections.Generic.List(); + var validationAttributes = new global::System.Collections.Generic.List(1); + + context.MemberName = "P1"; + context.DisplayName = string.IsNullOrEmpty(name) ? "MyOptions.P1" : $"{name}.P1"; + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P2"; + context.DisplayName = string.IsNullOrEmpty(name) ? "MyOptions.P2" : $"{name}.P2"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P2, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P3"; + context.DisplayName = string.IsNullOrEmpty(name) ? "MyOptions.P3" : $"{name}.P3"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A3); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P3, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P4"; + context.DisplayName = string.IsNullOrEmpty(name) ? "MyOptions.P4" : $"{name}.P4"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P4, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P5"; + context.DisplayName = string.IsNullOrEmpty(name) ? "MyOptions.P5" : $"{name}.P5"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P5, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P6"; + context.DisplayName = string.IsNullOrEmpty(name) ? "MyOptions.P6" : $"{name}.P6"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A3); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P6, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); + } + } +} +namespace __OptionValidationStaticInstances +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + file static class __Attributes + { + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__LengthAttribute A1 = new __OptionValidationGeneratedAttributes.__SourceGen__LengthAttribute( + (int)10, + (int)20); + + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__MinLengthAttribute A2 = new __OptionValidationGeneratedAttributes.__SourceGen__MinLengthAttribute( + (int)4); + + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__MaxLengthAttribute A3 = new __OptionValidationGeneratedAttributes.__SourceGen__MaxLengthAttribute( + (int)5); + } +} +namespace __OptionValidationStaticInstances +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + file static class __Validators + { + } +} +namespace __OptionValidationGeneratedAttributes +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + file class __SourceGen__LengthAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + private static string DefaultErrorMessageString => "The field {0} must be a string or collection type with a minimum length of '{1}' and maximum length of '{2}'."; + public __SourceGen__LengthAttribute(int minimumLength, int maximumLength) : base(() => DefaultErrorMessageString) { MinimumLength = minimumLength; MaximumLength = maximumLength; } + public int MinimumLength { get; } + public int MaximumLength { get; } + public override bool IsValid(object? value) + { + if (MinimumLength < 0) + { + throw new global::System.InvalidOperationException("LengthAttribute must have a MinimumLength value that is zero or greater."); + } + if (MaximumLength < MinimumLength) + { + throw new global::System.InvalidOperationException("LengthAttribute must have a MaximumLength value that is greater than or equal to MinimumLength."); + } + if (value == null) + { + return true; + } + + int length; + if (value is string stringValue) + { + length = stringValue.Length; + } + else if (value is System.Collections.ICollection collectionValue) + { + length = collectionValue.Count; + } + else if (value is global::System.Collections.Generic.IList) + { + length = ((global::System.Collections.Generic.IList)value).Count; + } + else if (value is global::System.Collections.Generic.ICollection) + { + length = ((global::System.Collections.Generic.ICollection)value).Count; + } + else + { + throw new global::System.InvalidCastException($"The field of type {value.GetType()} must be a string, array, or ICollection type."); + } + + return (uint)(length - MinimumLength) <= (uint)(MaximumLength - MinimumLength); + } + public override string FormatErrorMessage(string name) => string.Format(global::System.Globalization.CultureInfo.CurrentCulture, ErrorMessageString, name, MinimumLength, MaximumLength); + } + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + file class __SourceGen__MaxLengthAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + private const int MaxAllowableLength = -1; + private static string DefaultErrorMessageString => "The field {0} must be a string or array type with a maximum length of '{1}'."; + public __SourceGen__MaxLengthAttribute(int length) : base(() => DefaultErrorMessageString) { Length = length; } + public __SourceGen__MaxLengthAttribute(): base(() => DefaultErrorMessageString) { Length = MaxAllowableLength; } + public int Length { get; } + public override string FormatErrorMessage(string name) => string.Format(global::System.Globalization.CultureInfo.CurrentCulture, ErrorMessageString, name, Length); + public override bool IsValid(object? value) + { + if (Length == 0 || Length < -1) + { + throw new global::System.InvalidOperationException("MaxLengthAttribute must have a Length value that is greater than zero. Use MaxLength() without parameters to indicate that the string or array can have the maximum allowable length."); + } + if (value == null || MaxAllowableLength == Length) + { + return true; + } + + int length; + if (value is string stringValue) + { + length = stringValue.Length; + } + else if (value is System.Collections.ICollection collectionValue) + { + length = collectionValue.Count; + } + else if (value is global::System.Collections.Generic.IList) + { + length = ((global::System.Collections.Generic.IList)value).Count; + } + else if (value is global::System.Collections.Generic.ICollection) + { + length = ((global::System.Collections.Generic.ICollection)value).Count; + } + else + { + throw new global::System.InvalidCastException($"The field of type {value.GetType()} must be a string, array, or ICollection type."); + } + + return length <= Length; + } + } + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + file class __SourceGen__MinLengthAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + private static string DefaultErrorMessageString => "The field {0} must be a string or array type with a minimum length of '{1}'."; + + public __SourceGen__MinLengthAttribute(int length) : base(() => DefaultErrorMessageString) { Length = length; } + public int Length { get; } + public override bool IsValid(object? value) + { + if (Length < -1) + { + throw new global::System.InvalidOperationException("MinLengthAttribute must have a Length value that is zero or greater."); + } + if (value == null) + { + return true; + } + + int length; + if (value is string stringValue) + { + length = stringValue.Length; + } + else if (value is System.Collections.ICollection collectionValue) + { + length = collectionValue.Count; + } + else if (value is global::System.Collections.Generic.IList) + { + length = ((global::System.Collections.Generic.IList)value).Count; + } + else if (value is global::System.Collections.Generic.ICollection) + { + length = ((global::System.Collections.Generic.ICollection)value).Count; + } + else + { + throw new global::System.InvalidCastException($"The field of type {value.GetType()} must be a string, array, or ICollection type."); + } + + return length >= Length; + } + public override string FormatErrorMessage(string name) => string.Format(global::System.Globalization.CultureInfo.CurrentCulture, ErrorMessageString, name, Length); + } +} diff --git a/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Baselines/UsingInterfaceAsPropertyTypeForLengthAttributesTests.netfx.g.cs b/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Baselines/UsingInterfaceAsPropertyTypeForLengthAttributesTests.netfx.g.cs new file mode 100644 index 00000000000000..603680a9ec732c --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Baselines/UsingInterfaceAsPropertyTypeForLengthAttributesTests.netfx.g.cs @@ -0,0 +1,177 @@ + + // + #nullable enable + #pragma warning disable CS1591 // Compensate for https://github.com/dotnet/roslyn/issues/54103 + namespace Test +{ + partial class MyOptionsValidator + { + /// + /// Validates a specific named options instance (or all when is ). + /// + /// The name of the options instance being validated. + /// The options instance. + /// Validation result. + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::Test.MyOptions options) + { + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; + var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); + var validationResults = new global::System.Collections.Generic.List(); + var validationAttributes = new global::System.Collections.Generic.List(1); + + context.MemberName = "P2"; + context.DisplayName = string.IsNullOrEmpty(name) ? "MyOptions.P2" : $"{name}.P2"; + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P2, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P3"; + context.DisplayName = string.IsNullOrEmpty(name) ? "MyOptions.P3" : $"{name}.P3"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P3, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P5"; + context.DisplayName = string.IsNullOrEmpty(name) ? "MyOptions.P5" : $"{name}.P5"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P5, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P6"; + context.DisplayName = string.IsNullOrEmpty(name) ? "MyOptions.P6" : $"{name}.P6"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P6, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); + } + } +} +namespace __OptionValidationStaticInstances +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + file static class __Attributes + { + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__MinLengthAttribute A1 = new __OptionValidationGeneratedAttributes.__SourceGen__MinLengthAttribute( + (int)4); + + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__MaxLengthAttribute A2 = new __OptionValidationGeneratedAttributes.__SourceGen__MaxLengthAttribute( + (int)5); + } +} +namespace __OptionValidationStaticInstances +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + file static class __Validators + { + } +} +namespace __OptionValidationGeneratedAttributes +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + file class __SourceGen__MaxLengthAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + private const int MaxAllowableLength = -1; + private static string DefaultErrorMessageString => "The field {0} must be a string or array type with a maximum length of '{1}'."; + public __SourceGen__MaxLengthAttribute(int length) : base(() => DefaultErrorMessageString) { Length = length; } + public __SourceGen__MaxLengthAttribute(): base(() => DefaultErrorMessageString) { Length = MaxAllowableLength; } + public int Length { get; } + public override string FormatErrorMessage(string name) => string.Format(global::System.Globalization.CultureInfo.CurrentCulture, ErrorMessageString, name, Length); + public override bool IsValid(object? value) + { + if (Length == 0 || Length < -1) + { + throw new global::System.InvalidOperationException("MaxLengthAttribute must have a Length value that is greater than zero. Use MaxLength() without parameters to indicate that the string or array can have the maximum allowable length."); + } + if (value == null || MaxAllowableLength == Length) + { + return true; + } + + int length; + if (value is string stringValue) + { + length = stringValue.Length; + } + else if (value is System.Collections.ICollection collectionValue) + { + length = collectionValue.Count; + } + else if (value is global::System.Collections.Generic.IList) + { + length = ((global::System.Collections.Generic.IList)value).Count; + } + else if (value is global::System.Collections.Generic.ICollection) + { + length = ((global::System.Collections.Generic.ICollection)value).Count; + } + else + { + throw new global::System.InvalidCastException($"The field of type {value.GetType()} must be a string, array, or ICollection type."); + } + + return length <= Length; + } + } + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + file class __SourceGen__MinLengthAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + private static string DefaultErrorMessageString => "The field {0} must be a string or array type with a minimum length of '{1}'."; + + public __SourceGen__MinLengthAttribute(int length) : base(() => DefaultErrorMessageString) { Length = length; } + public int Length { get; } + public override bool IsValid(object? value) + { + if (Length < -1) + { + throw new global::System.InvalidOperationException("MinLengthAttribute must have a Length value that is zero or greater."); + } + if (value == null) + { + return true; + } + + int length; + if (value is string stringValue) + { + length = stringValue.Length; + } + else if (value is System.Collections.ICollection collectionValue) + { + length = collectionValue.Count; + } + else if (value is global::System.Collections.Generic.IList) + { + length = ((global::System.Collections.Generic.IList)value).Count; + } + else if (value is global::System.Collections.Generic.ICollection) + { + length = ((global::System.Collections.Generic.ICollection)value).Count; + } + else + { + throw new global::System.InvalidCastException($"The field of type {value.GetType()} must be a string, array, or ICollection type."); + } + + return length >= Length; + } + public override string FormatErrorMessage(string name) => string.Format(global::System.Globalization.CultureInfo.CurrentCulture, ErrorMessageString, name, Length); + } +} diff --git a/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Main.cs b/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Main.cs index b243ee50f361d9..623251707f87ba 100644 --- a/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Main.cs +++ b/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Main.cs @@ -52,78 +52,15 @@ public partial struct MyOptionsValidator : IValidateOptions } """; - string generatedSource = """ - - // - #nullable enable - #pragma warning disable CS1591 // Compensate for https://github.com/dotnet/roslyn/issues/54103 - namespace HelloWorld -{ - partial struct MyOptionsValidator - { - /// - /// Validates a specific named options instance (or all when is ). - /// - /// The name of the options instance being validated. - /// The options instance. - /// Validation result. - [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] - public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::HelloWorld.MyOptions options) - { - var baseName = (string.IsNullOrEmpty(name) ? "MyOptions" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); - var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); - var validationResults = new global::System.Collections.Generic.List(); - var validationAttributes = new global::System.Collections.Generic.List(1); - - context.MemberName = "Val1"; - context.DisplayName = baseName + "Val1"; - validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val1!, context, validationResults, validationAttributes)) - { - builder.AddResults(validationResults); - } - - context.MemberName = "Val2"; - context.DisplayName = baseName + "Val2"; - validationResults.Clear(); - validationAttributes.Clear(); - validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val2!, context, validationResults, validationAttributes)) - { - builder.AddResults(validationResults); - } - - return builder.Build(); - } - } -} -namespace __OptionValidationStaticInstances -{ - [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] - file static class __Attributes - { - internal static readonly global::System.ComponentModel.DataAnnotations.RequiredAttribute A1 = new global::System.ComponentModel.DataAnnotations.RequiredAttribute(); - - internal static readonly global::System.ComponentModel.DataAnnotations.RangeAttribute A2 = new global::System.ComponentModel.DataAnnotations.RangeAttribute( - (int)1, - (int)3); - } -} -namespace __OptionValidationStaticInstances -{ - [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] - file static class __Validators - { - } -} - -"""; - var (diagnostics, generatedSources) = await RunGeneratorOnOptionsSource(source); Assert.Empty(diagnostics); _ = Assert.Single(generatedSources); +#if NETCOREAPP + string generatedSource = File.ReadAllText(@"Baselines/EmitterWithCustomValidator.netcore.g.cs"); +#else + string generatedSource = File.ReadAllText(@"Baselines/EmitterWithCustomValidator.netfx.g.cs"); +#endif // NETCOREAPP Assert.Equal(generatedSource.Replace("\r\n", "\n"), generatedSources[0].SourceText.ToString().Replace("\r\n", "\n")); } @@ -1444,7 +1381,7 @@ internal sealed partial class ExtOptionsValidator : IValidateOptions Assert.Single(diagnostics); Assert.Equal(DiagDescriptors.InaccessibleValidationAttribute.Id, diagnostics[0].Id); string generatedSource = generatedSources[0].SourceText.ToString(); - Assert.Contains("global::System.ComponentModel.DataAnnotations.RangeAttribute", generatedSource); + Assert.Contains("__OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute", generatedSource); Assert.Contains("global::System.ComponentModel.DataAnnotations.RequiredAttribute", generatedSource); Assert.DoesNotContain("Timeout", generatedSource); @@ -1582,18 +1519,107 @@ public partial class FirstValidator : IValidateOptions Assert.Equal(DiagDescriptors.NotEnumerableType.Id, diagnostics[0].Id); } - private static CSharpCompilation CreateCompilationForOptionsSource(string assemblyName, string source, string? refAssemblyPath = null) + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsNotBrowser))] + public async Task LanguageVersionTest() + { + string source = """ + using System; + using System.ComponentModel.DataAnnotations; + using Microsoft.Extensions.Options; + + public class FirstModel + { + [Required] + public string? P1 { get; set; } + } + + [OptionsValidator] + public partial class FirstModelValidator : IValidateOptions + { + } + """; + + Assembly [] refAssemblies = new [] + { + Assembly.GetAssembly(typeof(RequiredAttribute)), + Assembly.GetAssembly(typeof(OptionsValidatorAttribute)), + Assembly.GetAssembly(typeof(IValidateOptions)), + }; + + // Run the generator with C# 7.0 and verify that it fails. + var (diagnostics, generatedSources) = await RoslynTestUtils.RunGenerator( + new OptionsValidatorGenerator(), refAssemblies.ToArray(), new[] { source }, includeBaseReferences: true, LanguageVersion.CSharp7).ConfigureAwait(false); + + Assert.NotEmpty(diagnostics); + Assert.Equal("SYSLIB1216", diagnostics[0].Id); + Assert.Empty(generatedSources); + + // Run the generator with C# 8.0 and verify that it succeeds. + (diagnostics, generatedSources) = await RoslynTestUtils.RunGenerator( + new OptionsValidatorGenerator(), refAssemblies.ToArray(), new[] { source }, includeBaseReferences: true, LanguageVersion.CSharp8).ConfigureAwait(false); + + Assert.Empty(diagnostics); + Assert.Single(generatedSources); + + // Compile the generated code with C# 7.0 and verify that it fails. + CSharpParseOptions parseOptions = new CSharpParseOptions(LanguageVersion.CSharp7); + SyntaxTree syntaxTree = SyntaxFactory.ParseSyntaxTree(generatedSources[0].SourceText.ToString(), parseOptions); + var diags = syntaxTree.GetDiagnostics().ToArray(); + Assert.Equal(1, diags.Length); + // error CS8107: Feature 'nullable reference types' is not available in C# 7.0. Please use language version 8.0 or greater. + Assert.Equal("CS8107", diags[0].Id); + + // Compile the generated code with C# 8.0 and verify that it succeeds. + parseOptions = new CSharpParseOptions(LanguageVersion.CSharp8); + syntaxTree = SyntaxFactory.ParseSyntaxTree(generatedSources[0].SourceText.ToString(), parseOptions); + diags = syntaxTree.GetDiagnostics().ToArray(); + Assert.Equal(0, diags.Length); + } + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsNotBrowser), nameof(PlatformDetection.IsNetCore))] + public async Task DataAnnotationAttributesWithParams() + { + var (diagnostics, generatedSources) = await RunGenerator(@""" + public class MyOptions + { + [Required] + public string P1 { get; set; } + + [Length(10, 20)] + public string P2 { get; set; } + + [AllowedValues(10, 20, 30)] + public int P3 { get; set; } + + [DeniedValues(""One"", ""Ten"", ""Hundred"")] + public string P4 { get; set; } + } + + [OptionsValidator] + public partial class MyOptionsValidator : IValidateOptions + { + } + """); + + Assert.Empty(diagnostics); + Assert.Single(generatedSources); + + string generatedSource = File.ReadAllText(@"Baselines/DataAnnotationAttributesWithParams.g.cs"); + Assert.Equal(generatedSource.Replace("\r\n", "\n"), generatedSources[0].SourceText.ToString().Replace("\r\n", "\n")); + } + + private static CSharpCompilation CreateCompilationForOptionsSource(string assemblyName, string source, string? refAssemblyPath = null, LanguageVersion languageVersion = LanguageVersion.Default) { // Ensure the generated source compiles var compilation = CSharpCompilation - .Create(Path.GetRandomFileName()+".dll", options: new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary)) + .Create($"{assemblyName}.dll", options: new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary)) .AddReferences(MetadataReference.CreateFromFile(AppDomain.CurrentDomain.GetAssemblies().FirstOrDefault(a => a.GetName().Name == "System.Runtime").Location)) .AddReferences(MetadataReference.CreateFromFile(typeof(string).Assembly.Location)) .AddReferences(MetadataReference.CreateFromFile(typeof(RequiredAttribute).Assembly.Location)) .AddReferences(MetadataReference.CreateFromFile(typeof(OptionsValidatorAttribute).Assembly.Location)) .AddReferences(MetadataReference.CreateFromFile(typeof(IValidateOptions).Assembly.Location)) .AddReferences(MetadataReference.CreateFromFile(typeof(System.CodeDom.Compiler.GeneratedCodeAttribute).Assembly.Location)) - .AddSyntaxTrees(CSharpSyntaxTree.ParseText(source)); + .AddSyntaxTrees(CSharpSyntaxTree.ParseText(source, new CSharpParseOptions(languageVersion))); if (refAssemblyPath is not null) { @@ -1620,7 +1646,7 @@ private static CSharpCompilation CreateCompilationForOptionsSource(string assemb refAssemblies.Add(refAssembly); } - return await RoslynTestUtils.RunGenerator(new Generator(), refAssemblies.ToArray(), new List { source }, includeBaseReferences: true, languageVersion).ConfigureAwait(false); + return await RoslynTestUtils.RunGenerator(new OptionsValidatorGenerator(), refAssemblies.ToArray(), new List { source }, includeBaseReferences: true, languageVersion).ConfigureAwait(false); } private static async Task<(IReadOnlyList diagnostics, ImmutableArray generatedSources)> RunGenerator( @@ -1677,9 +1703,158 @@ private static CSharpCompilation CreateCompilationForOptionsSource(string assemb assemblies.Add(Assembly.GetAssembly(typeof(Microsoft.Extensions.Options.ValidateObjectMembersAttribute))!); } - var result = await RoslynTestUtils.RunGenerator(new Generator(), assemblies.ToArray(), new[] { text }) + var result = await RoslynTestUtils.RunGenerator(new OptionsValidatorGenerator(), assemblies.ToArray(), new[] { text }) .ConfigureAwait(false); return result; } + + [ConditionalTheory(typeof(PlatformDetection), nameof(PlatformDetection.IsNotBrowser))] + [InlineData(LanguageVersion.CSharp10)] + [InlineData(LanguageVersion.CSharp11)] + public async Task GeneratedAttributesTest(LanguageVersion languageVersion) + { + +#if NETCOREAPP + string lengthAttribute = $$""" + [LengthAttribute(1, 3)] + public string? P0 { get; set; } + + [LengthAttribute(1, 3)] + public FakeCount? P1 { get; set; } + + [LengthAttribute(1, 3)] + public FakeCountChild? P2 { get; set; } + """; +#else +string lengthAttribute = ""; +#endif //NETCOREAPP + + string source = $$""" + using System.Collections.Generic; + using Microsoft.Extensions.Options; + using System.ComponentModel.DataAnnotations; + + #nullable enable + + namespace ValidationTest + { + public class FakeCount + { + public FakeCount(int count) { Count = count; } + public int Count { get; } + } + public class FakeCountChild : FakeCount + { + public FakeCountChild(int count) : base(count) { } + } + + public class OptionsUsingGeneratedAttributes + { + {{lengthAttribute}} + + [RangeAttribute(1, 3)] + public int P3 { get; set; } + + [MinLengthAttribute(5)] + public string? P4 { get; set; } + + [MaxLengthAttribute(5)] + public string? P5 { get; set; } + + [CompareAttribute("P5")] + public string? P6 { get; set; } + + [MinLengthAttribute(5)] + public FakeCount? P7 { get; set; } + + [MinLengthAttribute(5)] + public FakeCountChild? P8 { get; set; } + + [MaxLengthAttribute(5)] + public FakeCount? P9 { get; set; } + + [MaxLengthAttribute(5)] + public FakeCountChild? P10 { get; set; } + + [MinLengthAttribute(5)] + public List? P11 { get; set; } + + [MaxLengthAttribute(5)] + public List? P12 { get; set; } + } + + [OptionsValidator] + public sealed partial class OptionsUsingGeneratedAttributesValidator : IValidateOptions + { + } + } + """; + + var (diagnostics, generatedSources) = await RunGeneratorOnOptionsSource(source, null, languageVersion); + Assert.Empty(diagnostics); + Assert.Single(generatedSources); + + string emittedSource = generatedSources[0].SourceText.ToString(); + SyntaxTree syntaxTree = SyntaxFactory.ParseSyntaxTree(emittedSource, new CSharpParseOptions(languageVersion)); + var diags = syntaxTree.GetDiagnostics().ToArray(); + Assert.Empty(diags); + +#if NETCOREAPP + string generatedSource = File.ReadAllText(languageVersion == LanguageVersion.CSharp10 ? @"Baselines/GeneratedAttributesTest.netcore.lang10.g.cs" : @"Baselines/GeneratedAttributesTest.netcore.lang11.g.cs"); +#else + string generatedSource = File.ReadAllText(languageVersion == LanguageVersion.CSharp10 ? @"Baselines/GeneratedAttributesTest.netfx.lang10.g.cs" : @"Baselines/GeneratedAttributesTest.netfx.lang11.g.cs"); +#endif // NET8_0_OR_GREATER + Assert.Equal(generatedSource.Replace("\r\n", "\n"), emittedSource.Replace("\r\n", "\n")); + + CSharpCompilation compilation = CreateCompilationForOptionsSource(Path.GetRandomFileName(), source + emittedSource, refAssemblyPath: null, languageVersion); + var emitResult = compilation.Emit(new MemoryStream()); + + Assert.True(emitResult.Success); + // Console.WriteLine(emittedSource); + } + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsNotBrowser))] + public async Task UsingInterfaceAsPropertyTypeForLengthAttributesTests() + { + var (diagnostics, generatedSources) = await RunGenerator(@""" + using System.Collections.Generic; + + public class MyOptions + { + [Length(10, 20)] + public IList P1 { get; set; } + + [MinLength(4)] + public IList P2 { get; set; } + + [MaxLength(5)] + public IList P3 { get; set; } + + [Length(10, 20)] + public ICollection P4 { get; set; } + + [MinLength(4)] + public ICollection P5 { get; set; } + + [MaxLength(5)] + public ICollection P6 { get; set; } + } + + [OptionsValidator] + public partial class MyOptionsValidator : IValidateOptions + { + } + """); + + Assert.Empty(diagnostics); + Assert.Single(generatedSources); + +#if NETCOREAPP + string generatedSource = File.ReadAllText(@"Baselines/UsingInterfaceAsPropertyTypeForLengthAttributesTests.netcore.g.cs"); +#else + string generatedSource = File.ReadAllText(@"Baselines/UsingInterfaceAsPropertyTypeForLengthAttributesTests.netfx.g.cs"); +#endif // NETCOREAPP + Assert.Equal(generatedSource.Replace("\r\n", "\n"), generatedSources[0].SourceText.ToString().Replace("\r\n", "\n")); + } } diff --git a/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Microsoft.Extensions.Options.SourceGeneration.Unit.Tests.csproj b/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Microsoft.Extensions.Options.SourceGeneration.Unit.Tests.csproj index d76cbd45302f91..f3a5f33b4a2b45 100644 --- a/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Microsoft.Extensions.Options.SourceGeneration.Unit.Tests.csproj +++ b/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Microsoft.Extensions.Options.SourceGeneration.Unit.Tests.csproj @@ -14,6 +14,7 @@ + @@ -35,6 +36,12 @@ OutputItemType="Analyzer" ReferenceOutputAssembly="true" SetTargetFramework="TargetFramework=netstandard2.0"/> + + + PreserveNewest + + + diff --git a/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/OptionsRuntimeTests.cs b/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/OptionsRuntimeTests.cs index 1ac4618014ba92..4c701e4b9f498f 100644 --- a/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/OptionsRuntimeTests.cs +++ b/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/OptionsRuntimeTests.cs @@ -5,6 +5,7 @@ using System; using System.Collections.Generic; using System.ComponentModel.DataAnnotations; +using System.Globalization; using System.Linq; using System.Threading.Tasks; using Xunit; @@ -25,10 +26,15 @@ public void TestValidationSuccessResults() { Tall = 10, Id = "1", - Children = new() + Children1 = new() { - new ChildOptions() { Name = "C1" }, - new ChildOptions() { Name = "C2" } + new ChildOptions() { Name = "C1-1" }, + new ChildOptions() { Name = "C1-2" } + }, + Children2 = new List() + { + new ChildOptions() { Name = "C2-1" }, + new ChildOptions() { Name = "C2-2" } }, NestedList = new() { @@ -126,12 +132,19 @@ public void TestValidationWithEnumeration() { Tall = 10, Id = "1", - Children = new() + Children1 = new() { new ChildOptions(), new ChildOptions(), new ChildOptions() - } + }, + Children2 = new List() + { + new ChildOptions(), + new ChildOptions(), + new ChildOptions() + }, + } }; @@ -142,9 +155,12 @@ public void TestValidationWithEnumeration() Assert.True(result1.Failed); Assert.Equal(new List { - "Name: The MyOptions.Nested.Children[0].Name field is required.", - "Name: The MyOptions.Nested.Children[1].Name field is required.", - "Name: The MyOptions.Nested.Children[2].Name field is required.", + "Name: The MyOptions.Nested.Children1[0].Name field is required.", + "Name: The MyOptions.Nested.Children1[1].Name field is required.", + "Name: The MyOptions.Nested.Children1[2].Name field is required.", + "Name: The MyOptions.Nested.Children2[0].Name field is required.", + "Name: The MyOptions.Nested.Children2[1].Name field is required.", + "Name: The MyOptions.Nested.Children2[2].Name field is required.", }, result1.Failures); @@ -152,13 +168,40 @@ public void TestValidationWithEnumeration() Assert.True(result2.Failed); Assert.Equal(new List { - "DataAnnotation validation failed for 'MyOptions.Nested.Children[0]' members: 'Name' with the error: 'The Name field is required.'.", - "DataAnnotation validation failed for 'MyOptions.Nested.Children[1]' members: 'Name' with the error: 'The Name field is required.'.", - "DataAnnotation validation failed for 'MyOptions.Nested.Children[2]' members: 'Name' with the error: 'The Name field is required.'.", + "DataAnnotation validation failed for 'MyOptions.Nested.Children1[0]' members: 'Name' with the error: 'The Name field is required.'.", + "DataAnnotation validation failed for 'MyOptions.Nested.Children1[1]' members: 'Name' with the error: 'The Name field is required.'.", + "DataAnnotation validation failed for 'MyOptions.Nested.Children1[2]' members: 'Name' with the error: 'The Name field is required.'.", + "DataAnnotation validation failed for 'MyOptions.Nested.Children2[0]' members: 'Name' with the error: 'The Name field is required.'.", + "DataAnnotation validation failed for 'MyOptions.Nested.Children2[1]' members: 'Name' with the error: 'The Name field is required.'.", + "DataAnnotation validation failed for 'MyOptions.Nested.Children2[2]' members: 'Name' with the error: 'The Name field is required.'.", }, result2.Failures); } + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsNotBrowser))] + public void TestObjectsWithIndexerProperties() + { + DataAnnotationValidateOptions dataAnnotationValidateOptions1 = new("MyDictionaryOptions"); + MyDictionaryOptionsOptionsValidator sourceGenOptionsValidator1 = new(); + + var options1 = new MyDictionaryOptions(); + ValidateOptionsResult result1 = sourceGenOptionsValidator1.Validate("MyDictionaryOptions", options1); + ValidateOptionsResult result2 = dataAnnotationValidateOptions1.Validate("MyDictionaryOptions", options1); + + Assert.True(result1.Succeeded); + Assert.True(result2.Succeeded); + + DataAnnotationValidateOptions> dataAnnotationValidateOptions2 = new("MyListOptions"); + MyListOptionsOptionsValidator sourceGenOptionsValidator2 = new(); + + var options2 = new MyListOptions() { Prop = "test" }; + result1 = sourceGenOptionsValidator2.Validate("MyListOptions", options2); + result2 = dataAnnotationValidateOptions2.Validate("MyListOptions", options2); + + Assert.True(result1.Succeeded); + Assert.True(result2.Succeeded); + } + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsNotBrowser))] public void TestValidationWithCyclicReferences() { @@ -193,6 +236,284 @@ public void TestValidationWithCyclicReferences() ValidateOptionsResult result2 = dataAnnotationValidateOptions.Validate("MyOptions", options); Assert.True(result1.Succeeded); } + +#if NET8_0_OR_GREATER + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsNotBrowser))] + public void TestNewDataAnnotationFailures() + { + NewAttributesValidator sourceGenValidator = new(); + + OptionsUsingNewAttributes validOptions = new() + { + P1 = "123456", P2 = 2, P3 = 4, P4 = "c", P5 = "d" + }; + + ValidateOptionsResult result = sourceGenValidator.Validate("OptionsUsingNewAttributes", validOptions); + Assert.True(result.Succeeded); + + OptionsUsingNewAttributes invalidOptions = new() + { + P1 = "123", P2 = 4, P3 = 1, P4 = "e", P5 = "c" + }; + + result = sourceGenValidator.Validate("OptionsUsingNewAttributes", invalidOptions); + + Assert.Equal(new []{ + "P1: The field OptionsUsingNewAttributes.P1 must be a string or collection type with a minimum length of '5' and maximum length of '10'.", + "P2: The OptionsUsingNewAttributes.P2 field does not equal any of the values specified in AllowedValuesAttribute.", + "P3: The OptionsUsingNewAttributes.P3 field equals one of the values specified in DeniedValuesAttribute.", + "P4: The OptionsUsingNewAttributes.P4 field does not equal any of the values specified in AllowedValuesAttribute.", + "P5: The OptionsUsingNewAttributes.P5 field equals one of the values specified in DeniedValuesAttribute." + }, result.Failures); + } +#endif // NET8_0_OR_GREATER + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsNotBrowser))] + public void TestCustomGeneratedAttributes() + { + OptionsUsingGeneratedAttributes noFailures = new OptionsUsingGeneratedAttributes() + { +#if NET8_0_OR_GREATER + P0 = "123", + P11 = new DateTime(2023, 2, 1), + P12 = 6, + P13 = 9, + P14 = new List() { "1", "2" }, + P15 = new FakeCount(5), + P16 = new FakeCountChild(5), + P17 = new int[] { 1, 2 }, + P18 = new List() { "1", "2", "3" }, + P19 = new FakeCount(3), + P20 = new FakeCountChild(3), + P23 = new List() { "1", "2", "3", "4" }, + P24 = new FakeCount(4), + P25 = new FakeCountChild(4), + P27 = new List { "1", "2" }, + P28 = new HashSet { "1", "2" }, + P29 = new List { "1", "2", "3" }, + P30 = new HashSet { "1", "2", "3" }, + P31 = new List { 1, 2, 3, 4 }, + P32 = new HashSet { 1, 2, 3, 4 }, +#endif // NET8_0_OR_GREATER + P1 = 2, + P2 = "12345", + P3 = "12345", + P4 = "12345", + P5 = 4, + P6 = 4, + P7 = 15, + P8 = 15, + P9 = 2.5m, + P10 = 14.0, + P21 = new int[] { 1, 2, 3 }, + P22 = new int[] { 1, 2, 3, 4 }, + P26 = 14.0, + }; + List results = new(); + Assert.True(Validator.TryValidateObject(noFailures, new ValidationContext(noFailures), results, true)); + + OptionsUsingGeneratedAttributesValidator validator = new(); + Assert.True(validator.Validate("OptionsUsingGeneratedAttributes", noFailures).Succeeded); + + OptionsUsingGeneratedAttributes failing = new OptionsUsingGeneratedAttributes() + { +#if NET8_0_OR_GREATER + P0 = "", + P11 = new DateTime(2023, 1, 1), + P12 = 5, + P13 = 10, + P14 = new List() { "1" }, + P15 = new FakeCount(1), + P16 = new FakeCountChild(11), + P17 = new int[] { 1 }, + P18 = new List() { "1", "2" }, + P19 = new FakeCount(2), + P20 = new FakeCountChild(1), + P23 = new List() { "1", "2", "3", "4", "5" }, + P24 = new FakeCount(5), + P25 = new FakeCountChild(5), + P27 = new List { "1" }, + P28 = new HashSet { "1" }, + P29 = new List { "1", "2" }, + P30 = new HashSet { "1", "2" }, + P31 = new List { 1, 2, 3, 4, 5 }, + P32 = new HashSet { 1, 2, 3, 4, 5 }, +#endif // NET8_0_OR_GREATER + P1 = 4, + P2 = "1234", + P3 = "123456", + P4 = "12345", + P5 = 10, + P6 = 10, + P7 = 5, + P8 = 5, + P9 = 4.0m, + P10 = 20.0, + P21 = new int[] { 1, 2 }, + P22 = new int[] { 1, 2, 3, 4, 5 }, + P26 = 20.0, + }; + + Assert.False(Validator.TryValidateObject(failing, new ValidationContext(failing), results, true)); + + ValidateOptionsResult generatorResult = validator.Validate("OptionsUsingGeneratedAttributes", failing); + Assert.True(generatorResult.Failed); + + Assert.Equal(new [] { +#if NET8_0_OR_GREATER + "P0: The field OptionsUsingGeneratedAttributes.P0 must be a string or collection type with a minimum length of '1' and maximum length of '3'.", + string.Format(CultureInfo.CurrentCulture, "P11: The field OptionsUsingGeneratedAttributes.P11 must be between {0} and {1}.", new DateTime(2023, 1, 30), new DateTime(2023, 12, 30)), + "P12: The field OptionsUsingGeneratedAttributes.P12 must be between 5 exclusive and 10.", + "P13: The field OptionsUsingGeneratedAttributes.P13 must be between 5 and 10 exclusive.", + "P14: The field OptionsUsingGeneratedAttributes.P14 must be a string or collection type with a minimum length of '2' and maximum length of '10'.", + "P15: The field OptionsUsingGeneratedAttributes.P15 must be a string or collection type with a minimum length of '2' and maximum length of '10'.", + "P16: The field OptionsUsingGeneratedAttributes.P16 must be a string or collection type with a minimum length of '2' and maximum length of '10'.", + "P17: The field OptionsUsingGeneratedAttributes.P17 must be a string or collection type with a minimum length of '2' and maximum length of '10'.", + "P18: The field OptionsUsingGeneratedAttributes.P18 must be a string or array type with a minimum length of '3'.", + "P19: The field OptionsUsingGeneratedAttributes.P19 must be a string or array type with a minimum length of '3'.", + "P20: The field OptionsUsingGeneratedAttributes.P20 must be a string or array type with a minimum length of '3'.", + "P23: The field OptionsUsingGeneratedAttributes.P23 must be a string or array type with a maximum length of '4'.", + "P24: The field OptionsUsingGeneratedAttributes.P24 must be a string or array type with a maximum length of '4'.", + "P25: The field OptionsUsingGeneratedAttributes.P25 must be a string or array type with a maximum length of '4'.", + "P27: The field OptionsUsingGeneratedAttributes.P27 must be a string or collection type with a minimum length of '2' and maximum length of '10'.", + "P28: The field OptionsUsingGeneratedAttributes.P28 must be a string or collection type with a minimum length of '2' and maximum length of '10'.", + "P29: The field OptionsUsingGeneratedAttributes.P29 must be a string or array type with a minimum length of '3'.", + "P30: The field OptionsUsingGeneratedAttributes.P30 must be a string or array type with a minimum length of '3'.", + "P31: The field OptionsUsingGeneratedAttributes.P31 must be a string or array type with a maximum length of '4'.", + "P32: The field OptionsUsingGeneratedAttributes.P32 must be a string or array type with a maximum length of '4'.", +#endif // NET8_0_OR_GREATER + "P1: The field OptionsUsingGeneratedAttributes.P1 must be between 1 and 3.", + "P2: The field OptionsUsingGeneratedAttributes.P2 must be a string or array type with a minimum length of '5'.", + "P3: The field OptionsUsingGeneratedAttributes.P3 must be a string or array type with a maximum length of '5'.", + "P4: 'OptionsUsingGeneratedAttributes.P4' and 'P2' do not match.", + "P5: The field OptionsUsingGeneratedAttributes.P5 must be between 2 and 8.", + "P6: The field OptionsUsingGeneratedAttributes.P6 must be between 2 and 8.", + "P7: The field OptionsUsingGeneratedAttributes.P7 must be between 10 and 20.", + "P8: The field OptionsUsingGeneratedAttributes.P8 must be between 10 and 20.", + "P9: The field OptionsUsingGeneratedAttributes.P9 must be between 1.5 and 3.14.", + "P10: The field OptionsUsingGeneratedAttributes.P10 must be between 12.4 and 16.5.", + "P21: The field OptionsUsingGeneratedAttributes.P21 must be a string or array type with a minimum length of '3'.", + "P22: The field OptionsUsingGeneratedAttributes.P22 must be a string or array type with a maximum length of '4'.", + "P26: The field OptionsUsingGeneratedAttributes.P26 must be between 12.4 and 16.5.", + }, generatorResult.Failures); + + Assert.Equal(results.Count(), generatorResult.Failures.Count()); + } + } + + public class FakeCount(int count) { public int Count { get { return count; } } } + public class FakeCountChild(int count) : FakeCount(count) { } + + public class OptionsUsingGeneratedAttributes + { +#if NET8_0_OR_GREATER + [LengthAttribute(1, 3)] + public string? P0 { get; set; } + + [RangeAttribute(typeof(DateTime), "01/30/2023", "12/30/2023", ParseLimitsInInvariantCulture = true, ConvertValueInInvariantCulture = true)] + public DateTime P11 { get; set; } + + [RangeAttribute(5, 10, MinimumIsExclusive = true)] + public int P12 { get; set; } + + [RangeAttribute(5, 10, MaximumIsExclusive = true)] + public int P13 { get; set; } + + [LengthAttribute(2, 10)] + public List P14 { get; set; } + + [LengthAttribute(2, 10)] + public FakeCount P15 { get; set; } + + [LengthAttribute(2, 10)] + public FakeCountChild P16 { get; set; } + + [LengthAttribute(2, 10)] + public int[] P17 { get; set; } + + // Although MinLength and MaxLength attributes defined in NETFX but the implementation there has a bug which can produce exception like the following when using types like List: + // System.InvalidCastException : Unable to cast object of type 'System.Collections.Generic.List`1[System.String]' to type 'System.Array'. + + [MinLengthAttribute(3)] + public List P18 { get; set; } + + [MinLengthAttribute(3)] + public FakeCount P19 { get; set; } + + [MinLengthAttribute(3)] + public FakeCountChild P20 { get; set; } + + [MaxLengthAttribute(4)] + public List P23 { get; set; } + + [MaxLengthAttribute(4)] + public FakeCount P24 { get; set; } + + [MaxLengthAttribute(4)] + public FakeCountChild P25 { get; set; } + + [LengthAttribute(2, 10)] + public IList P27 { get; set; } + + [LengthAttribute(2, 10)] + public ICollection P28 { get; set; } + + [MinLengthAttribute(3)] + public IList P29 { get; set; } + + [MinLengthAttribute(3)] + public ICollection P30 { get; set; } + + [MaxLengthAttribute(4)] + public IList P31 { get; set; } + + [MaxLengthAttribute(4)] + public ICollection P32 { get; set; } +#endif // NET8_0_OR_GREATER + + [RangeAttribute(1, 3)] + public int P1 { get; set; } + + [MinLengthAttribute(5)] + public string? P2 { get; set; } + + [MaxLengthAttribute(5)] + public string? P3 { get; set; } + + [CompareAttribute("P2")] + public string? P4 { get; set; } + + [RangeAttribute(typeof(byte), "2", "8")] + public byte P5 { get; set; } + + [RangeAttribute(typeof(sbyte), "2", "8")] + public sbyte P6 { get; set; } + + [RangeAttribute(typeof(short), "10", "20")] + public short P7 { get; set; } + + [RangeAttribute(typeof(ulong), "10", "20")] + public ulong P8 { get; set; } + + [RangeAttribute(typeof(decimal), "1.5", "3.14")] + public decimal P9 { get; set; } + + [RangeAttribute(typeof(double), "12.40", "16.50")] + public double P10 { get; set; } + + [MinLengthAttribute(3)] + public int[] P21 { get; set; } + + [MaxLengthAttribute(4)] + public int[] P22 { get; set; } + + [RangeAttribute(typeof(double), "12.40", "16.50")] + public double? P26 { get; set; } + } + + [OptionsValidator] + public partial class OptionsUsingGeneratedAttributesValidator : IValidateOptions + { } public class MyOptions @@ -219,7 +540,10 @@ public class NestedOptions public string? Id { get; set; } [ValidateEnumeratedItems] - public List? Children { get; set; } + public List? Children1 { get; set; } + + [ValidateEnumeratedItems] + public IEnumerable? Children2 { get; set; } #pragma warning disable SYSLIB1211 // Source gen does static analysis for circular reference. We need to disable it for this test. [ValidateEnumeratedItems] @@ -249,4 +573,36 @@ public struct MyOptionsStruct public partial class MySourceGenOptionsValidator : IValidateOptions { } -} \ No newline at end of file + + public class MyDictionaryOptions : Dictionary { [Required] public string Prop { get; set; } = "test"; } + [OptionsValidator] public partial class MyDictionaryOptionsOptionsValidator : IValidateOptions { } + + public class MyListOptions : List { [Required] public T Prop { get; set; } = default; } + [OptionsValidator] public partial class MyListOptionsOptionsValidator : IValidateOptions> { } + +#if NET8_0_OR_GREATER + public class OptionsUsingNewAttributes + { + [Length(5, 10)] + public string P1 { get; set; } + + [AllowedValues(1, 2, 3)] + public int P2 { get; set; } + + [DeniedValues(1, 2, 3)] + public int P3 { get; set; } + + [AllowedValues(new object?[] { "a", "b", "c" })] + public string P4 { get; set; } + + [DeniedValues(new object?[] { "a", "b", "c" })] + public string P5 { get; set; } + } + + [OptionsValidator] + public partial class NewAttributesValidator : IValidateOptions + { + } +#endif // NET8_0_OR_GREATER + +} diff --git a/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Resources/Strings.resx b/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Resources/Strings.resx index c6d9021ea99672..90e5b01ed3b1b7 100644 --- a/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Resources/Strings.resx +++ b/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Resources/Strings.resx @@ -210,4 +210,16 @@ Validation attribute on the member is inaccessible from the validator type.. - \ No newline at end of file + + C# language version not supported by the source generator. + + + The options validation source generator is not available in C# {0}. Please use language version {1} or greater. + + + The validation attribute is only applicable to properties of type string, array, or ICollection; it cannot be used with other types. + + + The validation attribute {0} should only be applied to properties of type string, array, or ICollection. Using it with the type {1} could lead to runtime failures. + + diff --git a/src/libraries/Microsoft.Extensions.Options/tests/SourceGenerationTests/Baselines/NetCoreApp/Validators.g.cs b/src/libraries/Microsoft.Extensions.Options/tests/SourceGenerationTests/Baselines/NetCoreApp/Validators.g.cs index a467ed2dd6c007..c487888c9f16bb 100644 --- a/src/libraries/Microsoft.Extensions.Options/tests/SourceGenerationTests/Baselines/NetCoreApp/Validators.g.cs +++ b/src/libraries/Microsoft.Extensions.Options/tests/SourceGenerationTests/Baselines/NetCoreApp/Validators.g.cs @@ -12,24 +12,25 @@ internal sealed partial class __ThirdModelNoNamespaceValidator__ /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public static global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::ThirdModelNoNamespace options) { - var baseName = (string.IsNullOrEmpty(name) ? "ThirdModelNoNamespace" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(2); context.MemberName = "P5"; - context.DisplayName = baseName + "P5"; + context.DisplayName = string.IsNullOrEmpty(name) ? "ThirdModelNoNamespace.P5" : $"{name}.P5"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P5!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P5, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } partial class FirstValidatorNoNamespace @@ -41,34 +42,35 @@ partial class FirstValidatorNoNamespace /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::FirstModelNoNamespace options) { - var baseName = (string.IsNullOrEmpty(name) ? "FirstModelNoNamespace" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(2); context.MemberName = "P1"; - context.DisplayName = baseName + "P1"; + context.DisplayName = string.IsNullOrEmpty(name) ? "FirstModelNoNamespace.P1" : $"{name}.P1"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } if (options.P2 is not null) { - builder.AddResult(global::__OptionValidationStaticInstances.__Validators.V1.Validate(baseName + "P2", options.P2)); + (builder ??= new()).AddResult(global::__OptionValidationStaticInstances.__Validators.V1.Validate(string.IsNullOrEmpty(name) ? "FirstModelNoNamespace.P2" : $"{name}.P2", options.P2)); } if (options.P3 is not null) { - builder.AddResult(global::__ThirdModelNoNamespaceValidator__.Validate(baseName + "P3", options.P3)); + (builder ??= new()).AddResult(global::__ThirdModelNoNamespaceValidator__.Validate(string.IsNullOrEmpty(name) ? "FirstModelNoNamespace.P3" : $"{name}.P3", options.P3)); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } partial class SecondValidatorNoNamespace @@ -80,24 +82,25 @@ partial class SecondValidatorNoNamespace /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::SecondModelNoNamespace options) { - var baseName = (string.IsNullOrEmpty(name) ? "SecondModelNoNamespace" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(2); context.MemberName = "P4"; - context.DisplayName = baseName + "P4"; + context.DisplayName = string.IsNullOrEmpty(name) ? "SecondModelNoNamespace.P4" : $"{name}.P4"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P4!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P4, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } namespace CustomAttr @@ -111,33 +114,34 @@ partial class FirstValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::CustomAttr.FirstModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "FirstModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(1); context.MemberName = "P1"; - context.DisplayName = baseName + "P1"; + context.DisplayName = string.IsNullOrEmpty(name) ? "FirstModel.P1" : $"{name}.P1"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A3); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } context.MemberName = "P2"; - context.DisplayName = baseName + "P2"; + context.DisplayName = string.IsNullOrEmpty(name) ? "FirstModel.P2" : $"{name}.P2"; validationResults.Clear(); validationAttributes.Clear(); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A4); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P2!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P2, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -153,24 +157,25 @@ internal sealed partial class __SecondModelValidator__ /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public static global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::Enumeration.SecondModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "SecondModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(2); context.MemberName = "P6"; - context.DisplayName = baseName + "P6"; + context.DisplayName = string.IsNullOrEmpty(name) ? "SecondModel.P6" : $"{name}.P6"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P6!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P6, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -186,23 +191,24 @@ internal sealed partial class __ThirdModelValidator__ /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public static global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::Enumeration.ThirdModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "ThirdModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(1); context.MemberName = "Value"; - context.DisplayName = baseName + "Value"; + context.DisplayName = string.IsNullOrEmpty(name) ? "ThirdModel.Value" : $"{name}.Value"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A5); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Value!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Value, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -217,10 +223,11 @@ partial struct FirstValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::Enumeration.FirstModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "FirstModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); if (options.P1 is not null) @@ -230,11 +237,11 @@ partial struct FirstValidator { if (o is not null) { - builder.AddResult(global::Enumeration.__SecondModelValidator__.Validate(baseName + $"P1[{count}]", o)); + (builder ??= new()).AddResult(global::Enumeration.__SecondModelValidator__.Validate(string.IsNullOrEmpty(name) ? $"FirstModel.P1[{count}]" : $"{name}.P1[{count}]", o)); } else { - builder.AddError(baseName + $"P1[{count}] is null"); + (builder ??= new()).AddError(string.IsNullOrEmpty(name) ? $"FirstModel.P1[{count}] is null" : $"{name}.P1[{count}] is null"); } count++; } @@ -247,11 +254,11 @@ partial struct FirstValidator { if (o is not null) { - builder.AddResult(global::__OptionValidationStaticInstances.__Validators.V2.Validate(baseName + $"P2[{count}]", o)); + (builder ??= new()).AddResult(global::__OptionValidationStaticInstances.__Validators.V2.Validate(string.IsNullOrEmpty(name) ? $"FirstModel.P2[{count}]" : $"{name}.P2[{count}]", o)); } else { - builder.AddError(baseName + $"P2[{count}] is null"); + (builder ??= new()).AddError(string.IsNullOrEmpty(name) ? $"FirstModel.P2[{count}] is null" : $"{name}.P2[{count}] is null"); } count++; } @@ -264,7 +271,7 @@ partial struct FirstValidator { if (o is not null) { - builder.AddResult(global::Enumeration.__SecondModelValidator__.Validate(baseName + $"P3[{count}]", o)); + (builder ??= new()).AddResult(global::Enumeration.__SecondModelValidator__.Validate(string.IsNullOrEmpty(name) ? $"FirstModel.P3[{count}]" : $"{name}.P3[{count}]", o)); } count++; } @@ -275,7 +282,7 @@ partial struct FirstValidator var count = 0; foreach (var o in options.P4) { - builder.AddResult(global::Enumeration.__ThirdModelValidator__.Validate(baseName + $"P4[{count++}]", o)); + (builder ??= new()).AddResult(global::Enumeration.__ThirdModelValidator__.Validate(string.IsNullOrEmpty(name) ? $"FirstModel.P4[{count++}] is null" : $"{name}.P4[{count++}] is null", o)); } } @@ -286,7 +293,7 @@ partial struct FirstValidator { if (o is not null) { - builder.AddResult(global::Enumeration.__ThirdModelValidator__.Validate(baseName + $"P5[{count}]", o.Value)); + (builder ??= new()).AddResult(global::Enumeration.__ThirdModelValidator__.Validate(string.IsNullOrEmpty(name) ? $"FirstModel.P5[{count}]" : $"{name}.P5[{count}]", o.Value)); } count++; } @@ -299,7 +306,7 @@ partial struct FirstValidator { if (o is not null) { - builder.AddResult(global::Enumeration.__ThirdModelValidator__.Validate(baseName + $"P51[{count}]", o.Value)); + (builder ??= new()).AddResult(global::Enumeration.__ThirdModelValidator__.Validate(string.IsNullOrEmpty(name) ? $"FirstModel.P51[{count}]" : $"{name}.P51[{count}]", o.Value)); } count++; } @@ -312,11 +319,11 @@ partial struct FirstValidator { if (o is not null) { - builder.AddResult(global::Enumeration.__SecondModelValidator__.Validate(baseName + $"P6[{count}]", o)); + (builder ??= new()).AddResult(global::Enumeration.__SecondModelValidator__.Validate(string.IsNullOrEmpty(name) ? $"FirstModel.P6[{count}]" : $"{name}.P6[{count}]", o)); } else { - builder.AddError(baseName + $"P6[{count}] is null"); + (builder ??= new()).AddError(string.IsNullOrEmpty(name) ? $"FirstModel.P6[{count}] is null" : $"{name}.P6[{count}] is null"); } count++; } @@ -328,11 +335,11 @@ partial struct FirstValidator { if (o is not null) { - builder.AddResult(global::Enumeration.__SecondModelValidator__.Validate(baseName + $"P7[{count}]", o)); + (builder ??= new()).AddResult(global::Enumeration.__SecondModelValidator__.Validate(string.IsNullOrEmpty(name) ? $"FirstModel.P7[{count}]" : $"{name}.P7[{count}]", o)); } else { - builder.AddError(baseName + $"P7[{count}] is null"); + (builder ??= new()).AddError(string.IsNullOrEmpty(name) ? $"FirstModel.P7[{count}] is null" : $"{name}.P7[{count}] is null"); } count++; } @@ -345,17 +352,17 @@ partial struct FirstValidator { if (o is not null) { - builder.AddResult(global::Enumeration.__SecondModelValidator__.Validate(baseName + $"P8[{count}]", o)); + (builder ??= new()).AddResult(global::Enumeration.__SecondModelValidator__.Validate(string.IsNullOrEmpty(name) ? $"FirstModel.P8[{count}]" : $"{name}.P8[{count}]", o)); } else { - builder.AddError(baseName + $"P8[{count}] is null"); + (builder ??= new()).AddError(string.IsNullOrEmpty(name) ? $"FirstModel.P8[{count}] is null" : $"{name}.P8[{count}] is null"); } count++; } } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -370,24 +377,25 @@ partial struct SecondValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::Enumeration.SecondModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "SecondModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(2); context.MemberName = "P6"; - context.DisplayName = baseName + "P6"; + context.DisplayName = string.IsNullOrEmpty(name) ? "SecondModel.P6" : $"{name}.P6"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P6!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P6, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -402,24 +410,25 @@ partial struct FirstValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::FileScopedNamespace.FirstModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "FirstModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(2); context.MemberName = "P1"; - context.DisplayName = baseName + "P1"; + context.DisplayName = string.IsNullOrEmpty(name) ? "FirstModel.P1" : $"{name}.P1"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -434,23 +443,24 @@ partial struct FirstValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::FunnyStrings.FirstModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "FirstModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(1); context.MemberName = "P1"; - context.DisplayName = baseName + "P1"; + context.DisplayName = string.IsNullOrEmpty(name) ? "FirstModel.P1" : $"{name}.P1"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A6); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -466,24 +476,25 @@ internal sealed partial class __SecondModelValidator__ /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public static global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::Generics.SecondModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "SecondModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(2); context.MemberName = "P4"; - context.DisplayName = baseName + "P4"; + context.DisplayName = string.IsNullOrEmpty(name) ? "SecondModel.P4" : $"{name}.P4"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P4!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P4, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -498,29 +509,30 @@ partial class FirstValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::Generics.FirstModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "FirstModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(2); context.MemberName = "P1"; - context.DisplayName = baseName + "P1"; + context.DisplayName = string.IsNullOrEmpty(name) ? "FirstModel.P1" : $"{name}.P1"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } if (options.P3 is not null) { - builder.AddResult(global::Generics.__SecondModelValidator__.Validate(baseName + "P3", options.P3)); + (builder ??= new()).AddResult(global::Generics.__SecondModelValidator__.Validate(string.IsNullOrEmpty(name) ? "FirstModel.P3" : $"{name}.P3", options.P3)); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -535,29 +547,30 @@ partial struct MultiValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::MultiModelValidator.FirstModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "FirstModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(2); context.MemberName = "P1"; - context.DisplayName = baseName + "P1"; + context.DisplayName = string.IsNullOrEmpty(name) ? "FirstModel.P1" : $"{name}.P1"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } if (options.P2 is not null) { - builder.AddResult(global::__OptionValidationStaticInstances.__Validators.V3.Validate(baseName + "P2", options.P2)); + (builder ??= new()).AddResult(global::__OptionValidationStaticInstances.__Validators.V3.Validate(string.IsNullOrEmpty(name) ? "FirstModel.P2" : $"{name}.P2", options.P2)); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } /// /// Validates a specific named options instance (or all when is ). @@ -566,24 +579,25 @@ partial struct MultiValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::MultiModelValidator.SecondModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "SecondModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(2); context.MemberName = "P3"; - context.DisplayName = baseName + "P3"; + context.DisplayName = string.IsNullOrEmpty(name) ? "SecondModel.P3" : $"{name}.P3"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P3!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P3, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -599,24 +613,25 @@ internal sealed partial class __ThirdModelValidator__ /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public static global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::Nested.Container1.ThirdModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "ThirdModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(2); context.MemberName = "P6"; - context.DisplayName = baseName + "P6"; + context.DisplayName = string.IsNullOrEmpty(name) ? "ThirdModel.P6" : $"{name}.P6"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P6!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P6, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -633,24 +648,25 @@ partial record struct FifthValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::Nested.Container1.SecondModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "SecondModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(2); context.MemberName = "P5"; - context.DisplayName = baseName + "P5"; + context.DisplayName = string.IsNullOrEmpty(name) ? "SecondModel.P5" : $"{name}.P5"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P5!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P5, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -670,36 +686,37 @@ partial struct FirstValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::Nested.Container1.FirstModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "FirstModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(2); context.MemberName = "P1"; - context.DisplayName = baseName + "P1"; + context.DisplayName = string.IsNullOrEmpty(name) ? "FirstModel.P1" : $"{name}.P1"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } if (options.P2 is not null) { - builder.AddResult(global::__OptionValidationStaticInstances.__Validators.V4.Validate(baseName + "P2", options.P2)); + (builder ??= new()).AddResult(global::__OptionValidationStaticInstances.__Validators.V4.Validate(string.IsNullOrEmpty(name) ? "FirstModel.P2" : $"{name}.P2", options.P2)); } - builder.AddResult(global::Nested.__ThirdModelValidator__.Validate(baseName + "P3", options.P3)); + (builder ??= new()).AddResult(global::Nested.__ThirdModelValidator__.Validate(string.IsNullOrEmpty(name) ? "FirstModel.P3" : $"{name}.P3", options.P3)); if (options.P4 is not null) { - builder.AddResult(global::__OptionValidationStaticInstances.__Validators.V5.Validate(baseName + "P4", options.P4)); + (builder ??= new()).AddResult(global::__OptionValidationStaticInstances.__Validators.V5.Validate(string.IsNullOrEmpty(name) ? "FirstModel.P4" : $"{name}.P4", options.P4)); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -718,24 +735,25 @@ partial struct FourthValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::Nested.Container1.SecondModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "SecondModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(2); context.MemberName = "P5"; - context.DisplayName = baseName + "P5"; + context.DisplayName = string.IsNullOrEmpty(name) ? "SecondModel.P5" : $"{name}.P5"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P5!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P5, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -755,24 +773,25 @@ partial struct SecondValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::Nested.Container1.SecondModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "SecondModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(2); context.MemberName = "P5"; - context.DisplayName = baseName + "P5"; + context.DisplayName = string.IsNullOrEmpty(name) ? "SecondModel.P5" : $"{name}.P5"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P5!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P5, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -793,24 +812,25 @@ partial struct ThirdValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::Nested.Container1.SecondModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "SecondModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(2); context.MemberName = "P5"; - context.DisplayName = baseName + "P5"; + context.DisplayName = string.IsNullOrEmpty(name) ? "SecondModel.P5" : $"{name}.P5"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P5!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P5, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -827,24 +847,25 @@ partial class FirstValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::RandomMembers.FirstModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "FirstModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(2); context.MemberName = "P1"; - context.DisplayName = baseName + "P1"; + context.DisplayName = string.IsNullOrEmpty(name) ? "FirstModel.P1" : $"{name}.P1"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -860,24 +881,25 @@ internal sealed partial class __ThirdModelValidator__ /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public static global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::RecordTypes.ThirdModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "ThirdModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(2); context.MemberName = "P6"; - context.DisplayName = baseName + "P6"; + context.DisplayName = string.IsNullOrEmpty(name) ? "ThirdModel.P6" : $"{name}.P6"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P6!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P6, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -892,36 +914,37 @@ partial record struct FirstValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::RecordTypes.FirstModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "FirstModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(2); context.MemberName = "P1"; - context.DisplayName = baseName + "P1"; + context.DisplayName = string.IsNullOrEmpty(name) ? "FirstModel.P1" : $"{name}.P1"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } if (options.P2 is not null) { - builder.AddResult(global::__OptionValidationStaticInstances.__Validators.V6.Validate(baseName + "P2", options.P2)); + (builder ??= new()).AddResult(global::__OptionValidationStaticInstances.__Validators.V6.Validate(string.IsNullOrEmpty(name) ? "FirstModel.P2" : $"{name}.P2", options.P2)); } if (options.P3 is not null) { - builder.AddResult(global::__OptionValidationStaticInstances.__Validators.V7.Validate(baseName + "P3", options.P3)); + (builder ??= new()).AddResult(global::__OptionValidationStaticInstances.__Validators.V7.Validate(string.IsNullOrEmpty(name) ? "FirstModel.P3" : $"{name}.P3", options.P3)); } - builder.AddResult(global::RecordTypes.__ThirdModelValidator__.Validate(baseName + "P4", options.P4)); + (builder ??= new()).AddResult(global::RecordTypes.__ThirdModelValidator__.Validate(string.IsNullOrEmpty(name) ? "FirstModel.P4" : $"{name}.P4", options.P4)); - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -936,24 +959,25 @@ partial record struct SecondValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::RecordTypes.SecondModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "SecondModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(2); context.MemberName = "P5"; - context.DisplayName = baseName + "P5"; + context.DisplayName = string.IsNullOrEmpty(name) ? "SecondModel.P5" : $"{name}.P5"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P5!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P5, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -968,24 +992,25 @@ partial record class ThirdValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::RecordTypes.SecondModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "SecondModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(2); context.MemberName = "P5"; - context.DisplayName = baseName + "P5"; + context.DisplayName = string.IsNullOrEmpty(name) ? "SecondModel.P5" : $"{name}.P5"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P5!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P5, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -1001,28 +1026,29 @@ internal sealed partial class __SecondModelValidator__ /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public static global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::RepeatedTypes.SecondModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "SecondModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(1); context.MemberName = "P4"; - context.DisplayName = baseName + "P4"; + context.DisplayName = string.IsNullOrEmpty(name) ? "SecondModel.P4" : $"{name}.P4"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P4!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P4, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } if (options.P4 is not null) { - builder.AddResult(global::RepeatedTypes.__ThirdModelValidator__.Validate(baseName + "P4", options.P4)); + (builder ??= new()).AddResult(global::RepeatedTypes.__ThirdModelValidator__.Validate(string.IsNullOrEmpty(name) ? "SecondModel.P4" : $"{name}.P4", options.P4)); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -1038,24 +1064,25 @@ internal sealed partial class __ThirdModelValidator__ /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public static global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::RepeatedTypes.ThirdModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "ThirdModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(2); context.MemberName = "P5"; - context.DisplayName = baseName + "P5"; + context.DisplayName = string.IsNullOrEmpty(name) ? "ThirdModel.P5" : $"{name}.P5"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P5!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P5, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -1070,58 +1097,59 @@ partial class FirstValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::RepeatedTypes.FirstModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "FirstModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(1); context.MemberName = "P1"; - context.DisplayName = baseName + "P1"; + context.DisplayName = string.IsNullOrEmpty(name) ? "FirstModel.P1" : $"{name}.P1"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } if (options.P1 is not null) { - builder.AddResult(global::RepeatedTypes.__SecondModelValidator__.Validate(baseName + "P1", options.P1)); + (builder ??= new()).AddResult(global::RepeatedTypes.__SecondModelValidator__.Validate(string.IsNullOrEmpty(name) ? "FirstModel.P1" : $"{name}.P1", options.P1)); } context.MemberName = "P2"; - context.DisplayName = baseName + "P2"; + context.DisplayName = string.IsNullOrEmpty(name) ? "FirstModel.P2" : $"{name}.P2"; validationResults.Clear(); validationAttributes.Clear(); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P2!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P2, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } if (options.P2 is not null) { - builder.AddResult(global::RepeatedTypes.__SecondModelValidator__.Validate(baseName + "P2", options.P2)); + (builder ??= new()).AddResult(global::RepeatedTypes.__SecondModelValidator__.Validate(string.IsNullOrEmpty(name) ? "FirstModel.P2" : $"{name}.P2", options.P2)); } context.MemberName = "P3"; - context.DisplayName = baseName + "P3"; + context.DisplayName = string.IsNullOrEmpty(name) ? "FirstModel.P3" : $"{name}.P3"; validationResults.Clear(); validationAttributes.Clear(); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P3!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P3, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } if (options.P3 is not null) { - builder.AddResult(global::RepeatedTypes.__ThirdModelValidator__.Validate(baseName + "P3", options.P3)); + (builder ??= new()).AddResult(global::RepeatedTypes.__ThirdModelValidator__.Validate(string.IsNullOrEmpty(name) ? "FirstModel.P3" : $"{name}.P3", options.P3)); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -1136,25 +1164,26 @@ partial struct FirstValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::SelfValidation.FirstModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "FirstModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(1); context.MemberName = "P1"; - context.DisplayName = baseName + "P1"; + context.DisplayName = string.IsNullOrEmpty(name) ? "FirstModel.P1" : $"{name}.P1"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - builder.AddResults(((global::System.ComponentModel.DataAnnotations.IValidatableObject)options).Validate(context)); + (builder ??= new()).AddResults(((global::System.ComponentModel.DataAnnotations.IValidatableObject)options).Validate(context)); - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -1170,23 +1199,24 @@ internal sealed partial class __RangeAttributeModelDoubleValidator__ /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public static global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::TestClasses.OptionsValidation.RangeAttributeModelDouble options) { - var baseName = (string.IsNullOrEmpty(name) ? "RangeAttributeModelDouble" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(1); context.MemberName = "Val"; - context.DisplayName = baseName + "Val"; + context.DisplayName = string.IsNullOrEmpty(name) ? "RangeAttributeModelDouble.Val" : $"{name}.Val"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A7); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -1202,23 +1232,24 @@ internal sealed partial class __RequiredAttributeModelValidator__ /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public static global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::TestClasses.OptionsValidation.RequiredAttributeModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "RequiredAttributeModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(1); context.MemberName = "Val"; - context.DisplayName = baseName + "Val"; + context.DisplayName = string.IsNullOrEmpty(name) ? "RequiredAttributeModel.Val" : $"{name}.Val"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -1234,38 +1265,39 @@ internal sealed partial class __TypeWithoutOptionsValidatorValidator__ /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public static global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::TestClasses.OptionsValidation.TypeWithoutOptionsValidator options) { - var baseName = (string.IsNullOrEmpty(name) ? "TypeWithoutOptionsValidator" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(1); context.MemberName = "Val1"; - context.DisplayName = baseName + "Val1"; + context.DisplayName = string.IsNullOrEmpty(name) ? "TypeWithoutOptionsValidator.Val1" : $"{name}.Val1"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val1!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val1, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } context.MemberName = "Val2"; - context.DisplayName = baseName + "Val2"; + context.DisplayName = string.IsNullOrEmpty(name) ? "TypeWithoutOptionsValidator.Val2" : $"{name}.Val2"; validationResults.Clear(); validationAttributes.Clear(); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A8); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val2!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val2, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } if (options.YetAnotherComplexVal is not null) { - builder.AddResult(global::TestClasses.OptionsValidation.__RangeAttributeModelDoubleValidator__.Validate(baseName + "YetAnotherComplexVal", options.YetAnotherComplexVal)); + (builder ??= new()).AddResult(global::TestClasses.OptionsValidation.__RangeAttributeModelDoubleValidator__.Validate(string.IsNullOrEmpty(name) ? "TypeWithoutOptionsValidator.YetAnotherComplexVal" : $"{name}.YetAnotherComplexVal", options.YetAnotherComplexVal)); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -1280,33 +1312,34 @@ partial class AttributePropertyModelValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::TestClasses.OptionsValidation.AttributePropertyModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "AttributePropertyModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(1); context.MemberName = "Val1"; - context.DisplayName = baseName + "Val1"; + context.DisplayName = string.IsNullOrEmpty(name) ? "AttributePropertyModel.Val1" : $"{name}.Val1"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A9); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val1!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val1, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } context.MemberName = "Val2"; - context.DisplayName = baseName + "Val2"; + context.DisplayName = string.IsNullOrEmpty(name) ? "AttributePropertyModel.Val2" : $"{name}.Val2"; validationResults.Clear(); validationAttributes.Clear(); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A10); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val2!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val2, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -1321,23 +1354,24 @@ partial class ComplexModelValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::TestClasses.OptionsValidation.ComplexModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "ComplexModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); if (options.ComplexVal is not null) { - builder.AddResult(global::TestClasses.OptionsValidation.__RequiredAttributeModelValidator__.Validate(baseName + "ComplexVal", options.ComplexVal)); + (builder ??= new()).AddResult(global::TestClasses.OptionsValidation.__RequiredAttributeModelValidator__.Validate(string.IsNullOrEmpty(name) ? "ComplexModel.ComplexVal" : $"{name}.ComplexVal", options.ComplexVal)); } if (options.ValWithoutOptionsValidator is not null) { - builder.AddResult(global::TestClasses.OptionsValidation.__TypeWithoutOptionsValidatorValidator__.Validate(baseName + "ValWithoutOptionsValidator", options.ValWithoutOptionsValidator)); + (builder ??= new()).AddResult(global::TestClasses.OptionsValidation.__TypeWithoutOptionsValidatorValidator__.Validate(string.IsNullOrEmpty(name) ? "ComplexModel.ValWithoutOptionsValidator" : $"{name}.ValWithoutOptionsValidator", options.ValWithoutOptionsValidator)); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -1352,23 +1386,24 @@ partial class CustomTypeCustomValidationAttributeModelValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::TestClasses.OptionsValidation.CustomTypeCustomValidationAttributeModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "CustomTypeCustomValidationAttributeModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(1); context.MemberName = "Val"; - context.DisplayName = baseName + "Val"; + context.DisplayName = string.IsNullOrEmpty(name) ? "CustomTypeCustomValidationAttributeModel.Val" : $"{name}.Val"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A11); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -1383,23 +1418,24 @@ partial class CustomValidationAttributeModelValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::TestClasses.OptionsValidation.CustomValidationAttributeModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "CustomValidationAttributeModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(1); context.MemberName = "Val"; - context.DisplayName = baseName + "Val"; + context.DisplayName = string.IsNullOrEmpty(name) ? "CustomValidationAttributeModel.Val" : $"{name}.Val"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A12); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -1414,23 +1450,24 @@ partial class DataTypeAttributeModelValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::TestClasses.OptionsValidation.DataTypeAttributeModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "DataTypeAttributeModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(1); context.MemberName = "Val"; - context.DisplayName = baseName + "Val"; + context.DisplayName = string.IsNullOrEmpty(name) ? "DataTypeAttributeModel.Val" : $"{name}.Val"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A13); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -1445,43 +1482,44 @@ partial class DerivedModelValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::TestClasses.OptionsValidation.DerivedModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "DerivedModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(1); context.MemberName = "DerivedVal"; - context.DisplayName = baseName + "DerivedVal"; + context.DisplayName = string.IsNullOrEmpty(name) ? "DerivedModel.DerivedVal" : $"{name}.DerivedVal"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.DerivedVal!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.DerivedVal, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } context.MemberName = "VirtualValWithAttr"; - context.DisplayName = baseName + "VirtualValWithAttr"; + context.DisplayName = string.IsNullOrEmpty(name) ? "DerivedModel.VirtualValWithAttr" : $"{name}.VirtualValWithAttr"; validationResults.Clear(); validationAttributes.Clear(); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.VirtualValWithAttr!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.VirtualValWithAttr, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } context.MemberName = "Val"; - context.DisplayName = baseName + "Val"; + context.DisplayName = string.IsNullOrEmpty(name) ? "DerivedModel.Val" : $"{name}.Val"; validationResults.Clear(); validationAttributes.Clear(); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -1496,23 +1534,24 @@ partial class EmailAttributeModelValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::TestClasses.OptionsValidation.EmailAttributeModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "EmailAttributeModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(1); context.MemberName = "Val"; - context.DisplayName = baseName + "Val"; + context.DisplayName = string.IsNullOrEmpty(name) ? "EmailAttributeModel.Val" : $"{name}.Val"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A14); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -1527,43 +1566,44 @@ partial class LeafModelValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::TestClasses.OptionsValidation.LeafModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "LeafModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(1); context.MemberName = "VirtualValWithoutAttr"; - context.DisplayName = baseName + "VirtualValWithoutAttr"; + context.DisplayName = string.IsNullOrEmpty(name) ? "LeafModel.VirtualValWithoutAttr" : $"{name}.VirtualValWithoutAttr"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.VirtualValWithoutAttr!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.VirtualValWithoutAttr, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } context.MemberName = "DerivedVal"; - context.DisplayName = baseName + "DerivedVal"; + context.DisplayName = string.IsNullOrEmpty(name) ? "LeafModel.DerivedVal" : $"{name}.DerivedVal"; validationResults.Clear(); validationAttributes.Clear(); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.DerivedVal!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.DerivedVal, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } context.MemberName = "Val"; - context.DisplayName = baseName + "Val"; + context.DisplayName = string.IsNullOrEmpty(name) ? "LeafModel.Val" : $"{name}.Val"; validationResults.Clear(); validationAttributes.Clear(); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -1578,54 +1618,55 @@ partial class MultipleAttributeModelValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::TestClasses.OptionsValidation.MultipleAttributeModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "MultipleAttributeModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(2); context.MemberName = "Val1"; - context.DisplayName = baseName + "Val1"; + context.DisplayName = string.IsNullOrEmpty(name) ? "MultipleAttributeModel.Val1" : $"{name}.Val1"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A15); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val1!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val1, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } context.MemberName = "Val2"; - context.DisplayName = baseName + "Val2"; + context.DisplayName = string.IsNullOrEmpty(name) ? "MultipleAttributeModel.Val2" : $"{name}.Val2"; validationResults.Clear(); validationAttributes.Clear(); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A16); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val2!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val2, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } context.MemberName = "Val3"; - context.DisplayName = baseName + "Val3"; + context.DisplayName = string.IsNullOrEmpty(name) ? "MultipleAttributeModel.Val3" : $"{name}.Val3"; validationResults.Clear(); validationAttributes.Clear(); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A17); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val3!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val3, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } context.MemberName = "Val4"; - context.DisplayName = baseName + "Val4"; + context.DisplayName = string.IsNullOrEmpty(name) ? "MultipleAttributeModel.Val4" : $"{name}.Val4"; validationResults.Clear(); validationAttributes.Clear(); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A18); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val4!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val4, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -1640,23 +1681,24 @@ partial class RangeAttributeModelDateValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::TestClasses.OptionsValidation.RangeAttributeModelDate options) { - var baseName = (string.IsNullOrEmpty(name) ? "RangeAttributeModelDate" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(1); context.MemberName = "Val"; - context.DisplayName = baseName + "Val"; + context.DisplayName = string.IsNullOrEmpty(name) ? "RangeAttributeModelDate.Val" : $"{name}.Val"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A19); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -1671,23 +1713,24 @@ partial class RangeAttributeModelDoubleValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::TestClasses.OptionsValidation.RangeAttributeModelDouble options) { - var baseName = (string.IsNullOrEmpty(name) ? "RangeAttributeModelDouble" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(1); context.MemberName = "Val"; - context.DisplayName = baseName + "Val"; + context.DisplayName = string.IsNullOrEmpty(name) ? "RangeAttributeModelDouble.Val" : $"{name}.Val"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A7); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -1702,23 +1745,24 @@ partial class RangeAttributeModelIntValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::TestClasses.OptionsValidation.RangeAttributeModelInt options) { - var baseName = (string.IsNullOrEmpty(name) ? "RangeAttributeModelInt" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(1); context.MemberName = "Val"; - context.DisplayName = baseName + "Val"; + context.DisplayName = string.IsNullOrEmpty(name) ? "RangeAttributeModelInt.Val" : $"{name}.Val"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A16); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -1733,23 +1777,24 @@ partial class RegularExpressionAttributeModelValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::TestClasses.OptionsValidation.RegularExpressionAttributeModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "RegularExpressionAttributeModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(1); context.MemberName = "Val"; - context.DisplayName = baseName + "Val"; + context.DisplayName = string.IsNullOrEmpty(name) ? "RegularExpressionAttributeModel.Val" : $"{name}.Val"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A20); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -1764,23 +1809,24 @@ partial class RequiredAttributeModelValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::TestClasses.OptionsValidation.RequiredAttributeModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "RequiredAttributeModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(1); context.MemberName = "Val"; - context.DisplayName = baseName + "Val"; + context.DisplayName = string.IsNullOrEmpty(name) ? "RequiredAttributeModel.Val" : $"{name}.Val"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -1796,24 +1842,25 @@ internal sealed partial class __SecondModelValidator__ /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public static global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::ValueTypes.SecondModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "SecondModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(2); context.MemberName = "P4"; - context.DisplayName = baseName + "P4"; + context.DisplayName = string.IsNullOrEmpty(name) ? "SecondModel.P4" : $"{name}.P4"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P4!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P4, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -1828,36 +1875,37 @@ partial struct FirstValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::ValueTypes.FirstModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "FirstModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(2); context.MemberName = "P1"; - context.DisplayName = baseName + "P1"; + context.DisplayName = string.IsNullOrEmpty(name) ? "FirstModel.P1" : $"{name}.P1"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } if (options.P2 is not null) { - builder.AddResult(global::ValueTypes.__SecondModelValidator__.Validate(baseName + "P2", options.P2.Value)); + (builder ??= new()).AddResult(global::ValueTypes.__SecondModelValidator__.Validate(string.IsNullOrEmpty(name) ? "FirstModel.P2" : $"{name}.P2", options.P2.Value)); } - builder.AddResult(global::ValueTypes.__SecondModelValidator__.Validate(baseName + "P3", options.P3)); + (builder ??= new()).AddResult(global::ValueTypes.__SecondModelValidator__.Validate(string.IsNullOrEmpty(name) ? "FirstModel.P3" : $"{name}.P3", options.P3)); if (options.P4 is not null) { - builder.AddResult(global::ValueTypes.__SecondModelValidator__.Validate(baseName + "P4", options.P4.Value)); + (builder ??= new()).AddResult(global::ValueTypes.__SecondModelValidator__.Validate(string.IsNullOrEmpty(name) ? "FirstModel.P4" : $"{name}.P4", options.P4.Value)); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -1868,7 +1916,7 @@ file static class __Attributes { internal static readonly global::System.ComponentModel.DataAnnotations.RequiredAttribute A1 = new global::System.ComponentModel.DataAnnotations.RequiredAttribute(); - internal static readonly global::System.ComponentModel.DataAnnotations.MinLengthAttribute A2 = new global::System.ComponentModel.DataAnnotations.MinLengthAttribute( + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__MinLengthAttribute A2 = new __OptionValidationGeneratedAttributes.__SourceGen__MinLengthAttribute( (int)5); internal static readonly global::CustomAttr.CustomAttribute A3 = new global::CustomAttr.CustomAttribute( @@ -1881,30 +1929,30 @@ file static class __Attributes false, "X"); - internal static readonly global::System.ComponentModel.DataAnnotations.RangeAttribute A5 = new global::System.ComponentModel.DataAnnotations.RangeAttribute( + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute A5 = new __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute( (int)0, (int)10); internal static readonly global::System.ComponentModel.DataAnnotations.RegularExpressionAttribute A6 = new global::System.ComponentModel.DataAnnotations.RegularExpressionAttribute( "\"\r\n\\\\"); - internal static readonly global::System.ComponentModel.DataAnnotations.RangeAttribute A7 = new global::System.ComponentModel.DataAnnotations.RangeAttribute( + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute A7 = new __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute( (double)0.5, (double)0.9); - internal static readonly global::System.ComponentModel.DataAnnotations.RangeAttribute A8 = new global::System.ComponentModel.DataAnnotations.RangeAttribute( + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute A8 = new __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute( typeof(global::System.DateTime), "1/2/2004", "3/4/2004"); - internal static readonly global::System.ComponentModel.DataAnnotations.RangeAttribute A9 = new global::System.ComponentModel.DataAnnotations.RangeAttribute( + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute A9 = new __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute( (int)1, (int)3) { ErrorMessage = "ErrorMessage" }; - internal static readonly global::System.ComponentModel.DataAnnotations.RangeAttribute A10 = new global::System.ComponentModel.DataAnnotations.RangeAttribute( + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute A10 = new __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute( (int)1, (int)3) { @@ -1928,19 +1976,19 @@ file static class __Attributes internal static readonly global::System.ComponentModel.DataAnnotations.DataTypeAttribute A15 = new global::System.ComponentModel.DataAnnotations.DataTypeAttribute( (global::System.ComponentModel.DataAnnotations.DataType)11); - internal static readonly global::System.ComponentModel.DataAnnotations.RangeAttribute A16 = new global::System.ComponentModel.DataAnnotations.RangeAttribute( + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute A16 = new __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute( (int)1, (int)3); - internal static readonly global::System.ComponentModel.DataAnnotations.RangeAttribute A17 = new global::System.ComponentModel.DataAnnotations.RangeAttribute( + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute A17 = new __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute( (int)3, (int)5); - internal static readonly global::System.ComponentModel.DataAnnotations.RangeAttribute A18 = new global::System.ComponentModel.DataAnnotations.RangeAttribute( + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute A18 = new __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute( (int)5, (int)9); - internal static readonly global::System.ComponentModel.DataAnnotations.RangeAttribute A19 = new global::System.ComponentModel.DataAnnotations.RangeAttribute( + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute A19 = new __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute( typeof(global::System.DateTime), "1/2/2004", "3/4/2004") @@ -1972,3 +2020,150 @@ file static class __Validators internal static readonly global::RecordTypes.ThirdValidator V7 = new global::RecordTypes.ThirdValidator(); } } +namespace __OptionValidationGeneratedAttributes +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + file class __SourceGen__MinLengthAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + private static string DefaultErrorMessageString => "The field {0} must be a string or array type with a minimum length of '{1}'."; + + public __SourceGen__MinLengthAttribute(int length) : base(() => DefaultErrorMessageString) { Length = length; } + public int Length { get; } + public override bool IsValid(object? value) + { + if (Length < -1) + { + throw new global::System.InvalidOperationException("MinLengthAttribute must have a Length value that is zero or greater."); + } + if (value == null) + { + return true; + } + + int length; + if (value is string stringValue) + { + length = stringValue.Length; + } + else if (value is System.Collections.ICollection collectionValue) + { + length = collectionValue.Count; + } + else + { + throw new global::System.InvalidCastException($"The field of type {value.GetType()} must be a string, array, or ICollection type."); + } + + return length >= Length; + } + public override string FormatErrorMessage(string name) => string.Format(global::System.Globalization.CultureInfo.CurrentCulture, ErrorMessageString, name, Length); + } + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + file class __SourceGen__RangeAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + public __SourceGen__RangeAttribute(int minimum, int maximum) : base() + { + Minimum = minimum; + Maximum = maximum; + OperandType = typeof(int); + } + public __SourceGen__RangeAttribute(double minimum, double maximum) : base() + { + Minimum = minimum; + Maximum = maximum; + OperandType = typeof(double); + } + public __SourceGen__RangeAttribute(global::System.Type type, string minimum, string maximum) : base() + { + OperandType = type; + NeedToConvertMinMax = true; + Minimum = minimum; + Maximum = maximum; + } + public object Minimum { get; private set; } + public object Maximum { get; private set; } + public bool MinimumIsExclusive { get; set; } + public bool MaximumIsExclusive { get; set; } + public global::System.Type OperandType { get; } + public bool ParseLimitsInInvariantCulture { get; set; } + public bool ConvertValueInInvariantCulture { get; set; } + public override string FormatErrorMessage(string name) => + string.Format(global::System.Globalization.CultureInfo.CurrentCulture, GetValidationErrorMessage(), name, Minimum, Maximum); + private bool NeedToConvertMinMax { get; } + private bool Initialized { get; set; } + public override bool IsValid(object? value) + { + if (!Initialized) + { + if (Minimum is null || Maximum is null) + { + throw new global::System.InvalidOperationException("The minimum and maximum values must be set to valid values."); + } + if (NeedToConvertMinMax) + { + System.Globalization.CultureInfo culture = ParseLimitsInInvariantCulture ? global::System.Globalization.CultureInfo.InvariantCulture : global::System.Globalization.CultureInfo.CurrentCulture; + Minimum = ConvertValue(Minimum, culture) ?? throw new global::System.InvalidOperationException("The minimum and maximum values must be set to valid values."); + Maximum = ConvertValue(Maximum, culture) ?? throw new global::System.InvalidOperationException("The minimum and maximum values must be set to valid values."); + } + int cmp = ((global::System.IComparable)Minimum).CompareTo((global::System.IComparable)Maximum); + if (cmp > 0) + { + throw new global::System.InvalidOperationException("The maximum value '{Maximum}' must be greater than or equal to the minimum value '{Minimum}'."); + } + else if (cmp == 0 && (MinimumIsExclusive || MaximumIsExclusive)) + { + throw new global::System.InvalidOperationException("Cannot use exclusive bounds when the maximum value is equal to the minimum value."); + } + Initialized = true; + } + + if (value is null or string { Length: 0 }) + { + return true; + } + + System.Globalization.CultureInfo formatProvider = ConvertValueInInvariantCulture ? global::System.Globalization.CultureInfo.InvariantCulture : global::System.Globalization.CultureInfo.CurrentCulture; + object? convertedValue; + + try + { + convertedValue = ConvertValue(value, formatProvider); + } + catch (global::System.Exception e) when (e is global::System.FormatException or global::System.InvalidCastException or global::System.NotSupportedException) + { + return false; + } + + var min = (global::System.IComparable)Minimum; + var max = (global::System.IComparable)Maximum; + + return + (MinimumIsExclusive ? min.CompareTo(convertedValue) < 0 : min.CompareTo(convertedValue) <= 0) && + (MaximumIsExclusive ? max.CompareTo(convertedValue) > 0 : max.CompareTo(convertedValue) >= 0); + } + private string GetValidationErrorMessage() + { + return (MinimumIsExclusive, MaximumIsExclusive) switch + { + (false, false) => "The field {0} must be between {1} and {2}.", + (true, false) => "The field {0} must be between {1} exclusive and {2}.", + (false, true) => "The field {0} must be between {1} and {2} exclusive.", + (true, true) => "The field {0} must be between {1} exclusive and {2} exclusive.", + }; + } + private object? ConvertValue(object? value, System.Globalization.CultureInfo formatProvider) + { + if (value is string stringValue) + { + value = global::System.Convert.ChangeType(stringValue, OperandType, formatProvider); + } + else + { + value = global::System.Convert.ChangeType(value, OperandType, formatProvider); + } + return value; + } + } +} diff --git a/src/libraries/Microsoft.Extensions.Options/tests/SourceGenerationTests/Baselines/NetFX/Validators.g.cs b/src/libraries/Microsoft.Extensions.Options/tests/SourceGenerationTests/Baselines/NetFX/Validators.g.cs index 9c68710f2a5ec2..7e998cea22cddf 100644 --- a/src/libraries/Microsoft.Extensions.Options/tests/SourceGenerationTests/Baselines/NetFX/Validators.g.cs +++ b/src/libraries/Microsoft.Extensions.Options/tests/SourceGenerationTests/Baselines/NetFX/Validators.g.cs @@ -14,22 +14,21 @@ internal sealed partial class __ThirdModelNoNamespaceValidator__ [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] public static global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::ThirdModelNoNamespace options) { - var baseName = (string.IsNullOrEmpty(name) ? "ThirdModelNoNamespace" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(2); context.MemberName = "P5"; - context.DisplayName = baseName + "P5"; + context.DisplayName = string.IsNullOrEmpty(name) ? "ThirdModelNoNamespace.P5" : $"{name}.P5"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P5!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P5, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } partial class FirstValidatorNoNamespace @@ -43,32 +42,31 @@ partial class FirstValidatorNoNamespace [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::FirstModelNoNamespace options) { - var baseName = (string.IsNullOrEmpty(name) ? "FirstModelNoNamespace" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(2); context.MemberName = "P1"; - context.DisplayName = baseName + "P1"; + context.DisplayName = string.IsNullOrEmpty(name) ? "FirstModelNoNamespace.P1" : $"{name}.P1"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } if (options.P2 is not null) { - builder.AddResult(global::__OptionValidationStaticInstances.__Validators.V1.Validate(baseName + "P2", options.P2)); + (builder ??= new()).AddResult(global::__OptionValidationStaticInstances.__Validators.V1.Validate(string.IsNullOrEmpty(name) ? "FirstModelNoNamespace.P2" : $"{name}.P2", options.P2)); } if (options.P3 is not null) { - builder.AddResult(global::__ThirdModelNoNamespaceValidator__.Validate(baseName + "P3", options.P3)); + (builder ??= new()).AddResult(global::__ThirdModelNoNamespaceValidator__.Validate(string.IsNullOrEmpty(name) ? "FirstModelNoNamespace.P3" : $"{name}.P3", options.P3)); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } partial class SecondValidatorNoNamespace @@ -82,22 +80,21 @@ partial class SecondValidatorNoNamespace [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::SecondModelNoNamespace options) { - var baseName = (string.IsNullOrEmpty(name) ? "SecondModelNoNamespace" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(2); context.MemberName = "P4"; - context.DisplayName = baseName + "P4"; + context.DisplayName = string.IsNullOrEmpty(name) ? "SecondModelNoNamespace.P4" : $"{name}.P4"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P4!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P4, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } namespace CustomAttr @@ -113,31 +110,30 @@ partial class FirstValidator [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::CustomAttr.FirstModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "FirstModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(1); context.MemberName = "P1"; - context.DisplayName = baseName + "P1"; + context.DisplayName = string.IsNullOrEmpty(name) ? "FirstModel.P1" : $"{name}.P1"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A3); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } context.MemberName = "P2"; - context.DisplayName = baseName + "P2"; + context.DisplayName = string.IsNullOrEmpty(name) ? "FirstModel.P2" : $"{name}.P2"; validationResults.Clear(); validationAttributes.Clear(); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A4); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P2!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P2, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -155,22 +151,21 @@ internal sealed partial class __SecondModelValidator__ [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] public static global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::Enumeration.SecondModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "SecondModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(2); context.MemberName = "P6"; - context.DisplayName = baseName + "P6"; + context.DisplayName = string.IsNullOrEmpty(name) ? "SecondModel.P6" : $"{name}.P6"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P6!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P6, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -188,21 +183,20 @@ internal sealed partial class __ThirdModelValidator__ [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] public static global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::Enumeration.ThirdModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "ThirdModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(1); context.MemberName = "Value"; - context.DisplayName = baseName + "Value"; + context.DisplayName = string.IsNullOrEmpty(name) ? "ThirdModel.Value" : $"{name}.Value"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A5); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Value!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Value, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -219,8 +213,7 @@ partial struct FirstValidator [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::Enumeration.FirstModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "FirstModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); if (options.P1 is not null) @@ -230,11 +223,11 @@ partial struct FirstValidator { if (o is not null) { - builder.AddResult(global::Enumeration.__SecondModelValidator__.Validate(baseName + $"P1[{count}]", o)); + (builder ??= new()).AddResult(global::Enumeration.__SecondModelValidator__.Validate(string.IsNullOrEmpty(name) ? $"FirstModel.P1[{count}]" : $"{name}.P1[{count}]", o)); } else { - builder.AddError(baseName + $"P1[{count}] is null"); + (builder ??= new()).AddError(string.IsNullOrEmpty(name) ? $"FirstModel.P1[{count}] is null" : $"{name}.P1[{count}] is null"); } count++; } @@ -247,11 +240,11 @@ partial struct FirstValidator { if (o is not null) { - builder.AddResult(global::__OptionValidationStaticInstances.__Validators.V2.Validate(baseName + $"P2[{count}]", o)); + (builder ??= new()).AddResult(global::__OptionValidationStaticInstances.__Validators.V2.Validate(string.IsNullOrEmpty(name) ? $"FirstModel.P2[{count}]" : $"{name}.P2[{count}]", o)); } else { - builder.AddError(baseName + $"P2[{count}] is null"); + (builder ??= new()).AddError(string.IsNullOrEmpty(name) ? $"FirstModel.P2[{count}] is null" : $"{name}.P2[{count}] is null"); } count++; } @@ -264,7 +257,7 @@ partial struct FirstValidator { if (o is not null) { - builder.AddResult(global::Enumeration.__SecondModelValidator__.Validate(baseName + $"P3[{count}]", o)); + (builder ??= new()).AddResult(global::Enumeration.__SecondModelValidator__.Validate(string.IsNullOrEmpty(name) ? $"FirstModel.P3[{count}]" : $"{name}.P3[{count}]", o)); } count++; } @@ -275,7 +268,7 @@ partial struct FirstValidator var count = 0; foreach (var o in options.P4) { - builder.AddResult(global::Enumeration.__ThirdModelValidator__.Validate(baseName + $"P4[{count++}]", o)); + (builder ??= new()).AddResult(global::Enumeration.__ThirdModelValidator__.Validate(string.IsNullOrEmpty(name) ? $"FirstModel.P4[{count++}] is null" : $"{name}.P4[{count++}] is null", o)); } } @@ -286,7 +279,7 @@ partial struct FirstValidator { if (o is not null) { - builder.AddResult(global::Enumeration.__ThirdModelValidator__.Validate(baseName + $"P5[{count}]", o.Value)); + (builder ??= new()).AddResult(global::Enumeration.__ThirdModelValidator__.Validate(string.IsNullOrEmpty(name) ? $"FirstModel.P5[{count}]" : $"{name}.P5[{count}]", o.Value)); } count++; } @@ -299,7 +292,7 @@ partial struct FirstValidator { if (o is not null) { - builder.AddResult(global::Enumeration.__ThirdModelValidator__.Validate(baseName + $"P51[{count}]", o.Value)); + (builder ??= new()).AddResult(global::Enumeration.__ThirdModelValidator__.Validate(string.IsNullOrEmpty(name) ? $"FirstModel.P51[{count}]" : $"{name}.P51[{count}]", o.Value)); } count++; } @@ -312,11 +305,11 @@ partial struct FirstValidator { if (o is not null) { - builder.AddResult(global::Enumeration.__SecondModelValidator__.Validate(baseName + $"P6[{count}]", o)); + (builder ??= new()).AddResult(global::Enumeration.__SecondModelValidator__.Validate(string.IsNullOrEmpty(name) ? $"FirstModel.P6[{count}]" : $"{name}.P6[{count}]", o)); } else { - builder.AddError(baseName + $"P6[{count}] is null"); + (builder ??= new()).AddError(string.IsNullOrEmpty(name) ? $"FirstModel.P6[{count}] is null" : $"{name}.P6[{count}] is null"); } count++; } @@ -328,11 +321,11 @@ partial struct FirstValidator { if (o is not null) { - builder.AddResult(global::Enumeration.__SecondModelValidator__.Validate(baseName + $"P7[{count}]", o)); + (builder ??= new()).AddResult(global::Enumeration.__SecondModelValidator__.Validate(string.IsNullOrEmpty(name) ? $"FirstModel.P7[{count}]" : $"{name}.P7[{count}]", o)); } else { - builder.AddError(baseName + $"P7[{count}] is null"); + (builder ??= new()).AddError(string.IsNullOrEmpty(name) ? $"FirstModel.P7[{count}] is null" : $"{name}.P7[{count}] is null"); } count++; } @@ -345,17 +338,17 @@ partial struct FirstValidator { if (o is not null) { - builder.AddResult(global::Enumeration.__SecondModelValidator__.Validate(baseName + $"P8[{count}]", o)); + (builder ??= new()).AddResult(global::Enumeration.__SecondModelValidator__.Validate(string.IsNullOrEmpty(name) ? $"FirstModel.P8[{count}]" : $"{name}.P8[{count}]", o)); } else { - builder.AddError(baseName + $"P8[{count}] is null"); + (builder ??= new()).AddError(string.IsNullOrEmpty(name) ? $"FirstModel.P8[{count}] is null" : $"{name}.P8[{count}] is null"); } count++; } } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -372,22 +365,21 @@ partial struct SecondValidator [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::Enumeration.SecondModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "SecondModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(2); context.MemberName = "P6"; - context.DisplayName = baseName + "P6"; + context.DisplayName = string.IsNullOrEmpty(name) ? "SecondModel.P6" : $"{name}.P6"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P6!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P6, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -404,22 +396,21 @@ partial struct FirstValidator [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::FileScopedNamespace.FirstModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "FirstModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(2); context.MemberName = "P1"; - context.DisplayName = baseName + "P1"; + context.DisplayName = string.IsNullOrEmpty(name) ? "FirstModel.P1" : $"{name}.P1"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -436,21 +427,20 @@ partial struct FirstValidator [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::FunnyStrings.FirstModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "FirstModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(1); context.MemberName = "P1"; - context.DisplayName = baseName + "P1"; + context.DisplayName = string.IsNullOrEmpty(name) ? "FirstModel.P1" : $"{name}.P1"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A6); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -468,22 +458,21 @@ internal sealed partial class __SecondModelValidator__ [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] public static global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::Generics.SecondModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "SecondModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(2); context.MemberName = "P4"; - context.DisplayName = baseName + "P4"; + context.DisplayName = string.IsNullOrEmpty(name) ? "SecondModel.P4" : $"{name}.P4"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P4!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P4, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -500,27 +489,26 @@ partial class FirstValidator [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::Generics.FirstModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "FirstModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(2); context.MemberName = "P1"; - context.DisplayName = baseName + "P1"; + context.DisplayName = string.IsNullOrEmpty(name) ? "FirstModel.P1" : $"{name}.P1"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } if (options.P3 is not null) { - builder.AddResult(global::Generics.__SecondModelValidator__.Validate(baseName + "P3", options.P3)); + (builder ??= new()).AddResult(global::Generics.__SecondModelValidator__.Validate(string.IsNullOrEmpty(name) ? "FirstModel.P3" : $"{name}.P3", options.P3)); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -537,27 +525,26 @@ partial struct MultiValidator [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::MultiModelValidator.FirstModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "FirstModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(2); context.MemberName = "P1"; - context.DisplayName = baseName + "P1"; + context.DisplayName = string.IsNullOrEmpty(name) ? "FirstModel.P1" : $"{name}.P1"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } if (options.P2 is not null) { - builder.AddResult(global::__OptionValidationStaticInstances.__Validators.V3.Validate(baseName + "P2", options.P2)); + (builder ??= new()).AddResult(global::__OptionValidationStaticInstances.__Validators.V3.Validate(string.IsNullOrEmpty(name) ? "FirstModel.P2" : $"{name}.P2", options.P2)); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } /// /// Validates a specific named options instance (or all when is ). @@ -568,22 +555,21 @@ partial struct MultiValidator [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::MultiModelValidator.SecondModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "SecondModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(2); context.MemberName = "P3"; - context.DisplayName = baseName + "P3"; + context.DisplayName = string.IsNullOrEmpty(name) ? "SecondModel.P3" : $"{name}.P3"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P3!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P3, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -601,22 +587,21 @@ internal sealed partial class __ThirdModelValidator__ [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] public static global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::Nested.Container1.ThirdModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "ThirdModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(2); context.MemberName = "P6"; - context.DisplayName = baseName + "P6"; + context.DisplayName = string.IsNullOrEmpty(name) ? "ThirdModel.P6" : $"{name}.P6"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P6!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P6, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -635,22 +620,21 @@ partial record struct FifthValidator [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::Nested.Container1.SecondModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "SecondModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(2); context.MemberName = "P5"; - context.DisplayName = baseName + "P5"; + context.DisplayName = string.IsNullOrEmpty(name) ? "SecondModel.P5" : $"{name}.P5"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P5!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P5, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -672,34 +656,33 @@ partial struct FirstValidator [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::Nested.Container1.FirstModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "FirstModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(2); context.MemberName = "P1"; - context.DisplayName = baseName + "P1"; + context.DisplayName = string.IsNullOrEmpty(name) ? "FirstModel.P1" : $"{name}.P1"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } if (options.P2 is not null) { - builder.AddResult(global::__OptionValidationStaticInstances.__Validators.V4.Validate(baseName + "P2", options.P2)); + (builder ??= new()).AddResult(global::__OptionValidationStaticInstances.__Validators.V4.Validate(string.IsNullOrEmpty(name) ? "FirstModel.P2" : $"{name}.P2", options.P2)); } - builder.AddResult(global::Nested.__ThirdModelValidator__.Validate(baseName + "P3", options.P3)); + (builder ??= new()).AddResult(global::Nested.__ThirdModelValidator__.Validate(string.IsNullOrEmpty(name) ? "FirstModel.P3" : $"{name}.P3", options.P3)); if (options.P4 is not null) { - builder.AddResult(global::__OptionValidationStaticInstances.__Validators.V5.Validate(baseName + "P4", options.P4)); + (builder ??= new()).AddResult(global::__OptionValidationStaticInstances.__Validators.V5.Validate(string.IsNullOrEmpty(name) ? "FirstModel.P4" : $"{name}.P4", options.P4)); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -720,22 +703,21 @@ partial struct FourthValidator [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::Nested.Container1.SecondModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "SecondModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(2); context.MemberName = "P5"; - context.DisplayName = baseName + "P5"; + context.DisplayName = string.IsNullOrEmpty(name) ? "SecondModel.P5" : $"{name}.P5"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P5!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P5, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -757,22 +739,21 @@ partial struct SecondValidator [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::Nested.Container1.SecondModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "SecondModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(2); context.MemberName = "P5"; - context.DisplayName = baseName + "P5"; + context.DisplayName = string.IsNullOrEmpty(name) ? "SecondModel.P5" : $"{name}.P5"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P5!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P5, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -795,22 +776,21 @@ partial struct ThirdValidator [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::Nested.Container1.SecondModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "SecondModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(2); context.MemberName = "P5"; - context.DisplayName = baseName + "P5"; + context.DisplayName = string.IsNullOrEmpty(name) ? "SecondModel.P5" : $"{name}.P5"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P5!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P5, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -829,22 +809,21 @@ partial class FirstValidator [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::RandomMembers.FirstModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "FirstModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(2); context.MemberName = "P1"; - context.DisplayName = baseName + "P1"; + context.DisplayName = string.IsNullOrEmpty(name) ? "FirstModel.P1" : $"{name}.P1"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -862,22 +841,21 @@ internal sealed partial class __ThirdModelValidator__ [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] public static global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::RecordTypes.ThirdModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "ThirdModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(2); context.MemberName = "P6"; - context.DisplayName = baseName + "P6"; + context.DisplayName = string.IsNullOrEmpty(name) ? "ThirdModel.P6" : $"{name}.P6"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P6!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P6, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -894,34 +872,33 @@ partial record struct FirstValidator [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::RecordTypes.FirstModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "FirstModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(2); context.MemberName = "P1"; - context.DisplayName = baseName + "P1"; + context.DisplayName = string.IsNullOrEmpty(name) ? "FirstModel.P1" : $"{name}.P1"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } if (options.P2 is not null) { - builder.AddResult(global::__OptionValidationStaticInstances.__Validators.V6.Validate(baseName + "P2", options.P2)); + (builder ??= new()).AddResult(global::__OptionValidationStaticInstances.__Validators.V6.Validate(string.IsNullOrEmpty(name) ? "FirstModel.P2" : $"{name}.P2", options.P2)); } if (options.P3 is not null) { - builder.AddResult(global::__OptionValidationStaticInstances.__Validators.V7.Validate(baseName + "P3", options.P3)); + (builder ??= new()).AddResult(global::__OptionValidationStaticInstances.__Validators.V7.Validate(string.IsNullOrEmpty(name) ? "FirstModel.P3" : $"{name}.P3", options.P3)); } - builder.AddResult(global::RecordTypes.__ThirdModelValidator__.Validate(baseName + "P4", options.P4)); + (builder ??= new()).AddResult(global::RecordTypes.__ThirdModelValidator__.Validate(string.IsNullOrEmpty(name) ? "FirstModel.P4" : $"{name}.P4", options.P4)); - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -938,22 +915,21 @@ partial record struct SecondValidator [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::RecordTypes.SecondModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "SecondModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(2); context.MemberName = "P5"; - context.DisplayName = baseName + "P5"; + context.DisplayName = string.IsNullOrEmpty(name) ? "SecondModel.P5" : $"{name}.P5"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P5!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P5, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -970,22 +946,21 @@ partial record class ThirdValidator [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::RecordTypes.SecondModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "SecondModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(2); context.MemberName = "P5"; - context.DisplayName = baseName + "P5"; + context.DisplayName = string.IsNullOrEmpty(name) ? "SecondModel.P5" : $"{name}.P5"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P5!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P5, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -1003,26 +978,25 @@ internal sealed partial class __SecondModelValidator__ [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] public static global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::RepeatedTypes.SecondModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "SecondModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(1); context.MemberName = "P4"; - context.DisplayName = baseName + "P4"; + context.DisplayName = string.IsNullOrEmpty(name) ? "SecondModel.P4" : $"{name}.P4"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P4!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P4, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } if (options.P4 is not null) { - builder.AddResult(global::RepeatedTypes.__ThirdModelValidator__.Validate(baseName + "P4", options.P4)); + (builder ??= new()).AddResult(global::RepeatedTypes.__ThirdModelValidator__.Validate(string.IsNullOrEmpty(name) ? "SecondModel.P4" : $"{name}.P4", options.P4)); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -1040,22 +1014,21 @@ internal sealed partial class __ThirdModelValidator__ [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] public static global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::RepeatedTypes.ThirdModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "ThirdModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(2); context.MemberName = "P5"; - context.DisplayName = baseName + "P5"; + context.DisplayName = string.IsNullOrEmpty(name) ? "ThirdModel.P5" : $"{name}.P5"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P5!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P5, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -1072,56 +1045,55 @@ partial class FirstValidator [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::RepeatedTypes.FirstModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "FirstModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(1); context.MemberName = "P1"; - context.DisplayName = baseName + "P1"; + context.DisplayName = string.IsNullOrEmpty(name) ? "FirstModel.P1" : $"{name}.P1"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } if (options.P1 is not null) { - builder.AddResult(global::RepeatedTypes.__SecondModelValidator__.Validate(baseName + "P1", options.P1)); + (builder ??= new()).AddResult(global::RepeatedTypes.__SecondModelValidator__.Validate(string.IsNullOrEmpty(name) ? "FirstModel.P1" : $"{name}.P1", options.P1)); } context.MemberName = "P2"; - context.DisplayName = baseName + "P2"; + context.DisplayName = string.IsNullOrEmpty(name) ? "FirstModel.P2" : $"{name}.P2"; validationResults.Clear(); validationAttributes.Clear(); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P2!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P2, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } if (options.P2 is not null) { - builder.AddResult(global::RepeatedTypes.__SecondModelValidator__.Validate(baseName + "P2", options.P2)); + (builder ??= new()).AddResult(global::RepeatedTypes.__SecondModelValidator__.Validate(string.IsNullOrEmpty(name) ? "FirstModel.P2" : $"{name}.P2", options.P2)); } context.MemberName = "P3"; - context.DisplayName = baseName + "P3"; + context.DisplayName = string.IsNullOrEmpty(name) ? "FirstModel.P3" : $"{name}.P3"; validationResults.Clear(); validationAttributes.Clear(); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P3!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P3, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } if (options.P3 is not null) { - builder.AddResult(global::RepeatedTypes.__ThirdModelValidator__.Validate(baseName + "P3", options.P3)); + (builder ??= new()).AddResult(global::RepeatedTypes.__ThirdModelValidator__.Validate(string.IsNullOrEmpty(name) ? "FirstModel.P3" : $"{name}.P3", options.P3)); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -1138,23 +1110,22 @@ partial struct FirstValidator [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::SelfValidation.FirstModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "FirstModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(1); context.MemberName = "P1"; - context.DisplayName = baseName + "P1"; + context.DisplayName = string.IsNullOrEmpty(name) ? "FirstModel.P1" : $"{name}.P1"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - builder.AddResults(((global::System.ComponentModel.DataAnnotations.IValidatableObject)options).Validate(context)); + (builder ??= new()).AddResults(((global::System.ComponentModel.DataAnnotations.IValidatableObject)options).Validate(context)); - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -1172,21 +1143,20 @@ internal sealed partial class __RangeAttributeModelDoubleValidator__ [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] public static global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::TestClasses.OptionsValidation.RangeAttributeModelDouble options) { - var baseName = (string.IsNullOrEmpty(name) ? "RangeAttributeModelDouble" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(1); context.MemberName = "Val"; - context.DisplayName = baseName + "Val"; + context.DisplayName = string.IsNullOrEmpty(name) ? "RangeAttributeModelDouble.Val" : $"{name}.Val"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A7); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -1204,21 +1174,20 @@ internal sealed partial class __RequiredAttributeModelValidator__ [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] public static global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::TestClasses.OptionsValidation.RequiredAttributeModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "RequiredAttributeModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(1); context.MemberName = "Val"; - context.DisplayName = baseName + "Val"; + context.DisplayName = string.IsNullOrEmpty(name) ? "RequiredAttributeModel.Val" : $"{name}.Val"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -1236,36 +1205,35 @@ internal sealed partial class __TypeWithoutOptionsValidatorValidator__ [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] public static global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::TestClasses.OptionsValidation.TypeWithoutOptionsValidator options) { - var baseName = (string.IsNullOrEmpty(name) ? "TypeWithoutOptionsValidator" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(1); context.MemberName = "Val1"; - context.DisplayName = baseName + "Val1"; + context.DisplayName = string.IsNullOrEmpty(name) ? "TypeWithoutOptionsValidator.Val1" : $"{name}.Val1"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val1!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val1, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } context.MemberName = "Val2"; - context.DisplayName = baseName + "Val2"; + context.DisplayName = string.IsNullOrEmpty(name) ? "TypeWithoutOptionsValidator.Val2" : $"{name}.Val2"; validationResults.Clear(); validationAttributes.Clear(); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A8); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val2!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val2, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } if (options.YetAnotherComplexVal is not null) { - builder.AddResult(global::TestClasses.OptionsValidation.__RangeAttributeModelDoubleValidator__.Validate(baseName + "YetAnotherComplexVal", options.YetAnotherComplexVal)); + (builder ??= new()).AddResult(global::TestClasses.OptionsValidation.__RangeAttributeModelDoubleValidator__.Validate(string.IsNullOrEmpty(name) ? "TypeWithoutOptionsValidator.YetAnotherComplexVal" : $"{name}.YetAnotherComplexVal", options.YetAnotherComplexVal)); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -1282,31 +1250,30 @@ partial class AttributePropertyModelValidator [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::TestClasses.OptionsValidation.AttributePropertyModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "AttributePropertyModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(1); context.MemberName = "Val1"; - context.DisplayName = baseName + "Val1"; + context.DisplayName = string.IsNullOrEmpty(name) ? "AttributePropertyModel.Val1" : $"{name}.Val1"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A9); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val1!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val1, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } context.MemberName = "Val2"; - context.DisplayName = baseName + "Val2"; + context.DisplayName = string.IsNullOrEmpty(name) ? "AttributePropertyModel.Val2" : $"{name}.Val2"; validationResults.Clear(); validationAttributes.Clear(); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A10); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val2!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val2, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -1323,21 +1290,20 @@ partial class ComplexModelValidator [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::TestClasses.OptionsValidation.ComplexModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "ComplexModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); if (options.ComplexVal is not null) { - builder.AddResult(global::TestClasses.OptionsValidation.__RequiredAttributeModelValidator__.Validate(baseName + "ComplexVal", options.ComplexVal)); + (builder ??= new()).AddResult(global::TestClasses.OptionsValidation.__RequiredAttributeModelValidator__.Validate(string.IsNullOrEmpty(name) ? "ComplexModel.ComplexVal" : $"{name}.ComplexVal", options.ComplexVal)); } if (options.ValWithoutOptionsValidator is not null) { - builder.AddResult(global::TestClasses.OptionsValidation.__TypeWithoutOptionsValidatorValidator__.Validate(baseName + "ValWithoutOptionsValidator", options.ValWithoutOptionsValidator)); + (builder ??= new()).AddResult(global::TestClasses.OptionsValidation.__TypeWithoutOptionsValidatorValidator__.Validate(string.IsNullOrEmpty(name) ? "ComplexModel.ValWithoutOptionsValidator" : $"{name}.ValWithoutOptionsValidator", options.ValWithoutOptionsValidator)); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -1354,21 +1320,20 @@ partial class CustomTypeCustomValidationAttributeModelValidator [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::TestClasses.OptionsValidation.CustomTypeCustomValidationAttributeModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "CustomTypeCustomValidationAttributeModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(1); context.MemberName = "Val"; - context.DisplayName = baseName + "Val"; + context.DisplayName = string.IsNullOrEmpty(name) ? "CustomTypeCustomValidationAttributeModel.Val" : $"{name}.Val"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A11); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -1385,21 +1350,20 @@ partial class CustomValidationAttributeModelValidator [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::TestClasses.OptionsValidation.CustomValidationAttributeModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "CustomValidationAttributeModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(1); context.MemberName = "Val"; - context.DisplayName = baseName + "Val"; + context.DisplayName = string.IsNullOrEmpty(name) ? "CustomValidationAttributeModel.Val" : $"{name}.Val"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A12); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -1416,21 +1380,20 @@ partial class DataTypeAttributeModelValidator [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::TestClasses.OptionsValidation.DataTypeAttributeModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "DataTypeAttributeModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(1); context.MemberName = "Val"; - context.DisplayName = baseName + "Val"; + context.DisplayName = string.IsNullOrEmpty(name) ? "DataTypeAttributeModel.Val" : $"{name}.Val"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A13); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -1447,41 +1410,40 @@ partial class DerivedModelValidator [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::TestClasses.OptionsValidation.DerivedModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "DerivedModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(1); context.MemberName = "DerivedVal"; - context.DisplayName = baseName + "DerivedVal"; + context.DisplayName = string.IsNullOrEmpty(name) ? "DerivedModel.DerivedVal" : $"{name}.DerivedVal"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.DerivedVal!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.DerivedVal, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } context.MemberName = "VirtualValWithAttr"; - context.DisplayName = baseName + "VirtualValWithAttr"; + context.DisplayName = string.IsNullOrEmpty(name) ? "DerivedModel.VirtualValWithAttr" : $"{name}.VirtualValWithAttr"; validationResults.Clear(); validationAttributes.Clear(); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.VirtualValWithAttr!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.VirtualValWithAttr, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } context.MemberName = "Val"; - context.DisplayName = baseName + "Val"; + context.DisplayName = string.IsNullOrEmpty(name) ? "DerivedModel.Val" : $"{name}.Val"; validationResults.Clear(); validationAttributes.Clear(); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -1498,21 +1460,20 @@ partial class EmailAttributeModelValidator [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::TestClasses.OptionsValidation.EmailAttributeModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "EmailAttributeModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(1); context.MemberName = "Val"; - context.DisplayName = baseName + "Val"; + context.DisplayName = string.IsNullOrEmpty(name) ? "EmailAttributeModel.Val" : $"{name}.Val"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A14); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -1529,41 +1490,40 @@ partial class LeafModelValidator [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::TestClasses.OptionsValidation.LeafModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "LeafModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(1); context.MemberName = "VirtualValWithoutAttr"; - context.DisplayName = baseName + "VirtualValWithoutAttr"; + context.DisplayName = string.IsNullOrEmpty(name) ? "LeafModel.VirtualValWithoutAttr" : $"{name}.VirtualValWithoutAttr"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.VirtualValWithoutAttr!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.VirtualValWithoutAttr, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } context.MemberName = "DerivedVal"; - context.DisplayName = baseName + "DerivedVal"; + context.DisplayName = string.IsNullOrEmpty(name) ? "LeafModel.DerivedVal" : $"{name}.DerivedVal"; validationResults.Clear(); validationAttributes.Clear(); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.DerivedVal!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.DerivedVal, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } context.MemberName = "Val"; - context.DisplayName = baseName + "Val"; + context.DisplayName = string.IsNullOrEmpty(name) ? "LeafModel.Val" : $"{name}.Val"; validationResults.Clear(); validationAttributes.Clear(); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -1580,52 +1540,51 @@ partial class MultipleAttributeModelValidator [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::TestClasses.OptionsValidation.MultipleAttributeModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "MultipleAttributeModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(2); context.MemberName = "Val1"; - context.DisplayName = baseName + "Val1"; + context.DisplayName = string.IsNullOrEmpty(name) ? "MultipleAttributeModel.Val1" : $"{name}.Val1"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A15); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val1!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val1, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } context.MemberName = "Val2"; - context.DisplayName = baseName + "Val2"; + context.DisplayName = string.IsNullOrEmpty(name) ? "MultipleAttributeModel.Val2" : $"{name}.Val2"; validationResults.Clear(); validationAttributes.Clear(); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A16); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val2!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val2, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } context.MemberName = "Val3"; - context.DisplayName = baseName + "Val3"; + context.DisplayName = string.IsNullOrEmpty(name) ? "MultipleAttributeModel.Val3" : $"{name}.Val3"; validationResults.Clear(); validationAttributes.Clear(); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A17); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val3!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val3, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } context.MemberName = "Val4"; - context.DisplayName = baseName + "Val4"; + context.DisplayName = string.IsNullOrEmpty(name) ? "MultipleAttributeModel.Val4" : $"{name}.Val4"; validationResults.Clear(); validationAttributes.Clear(); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A18); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val4!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val4, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -1642,21 +1601,20 @@ partial class RangeAttributeModelDateValidator [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::TestClasses.OptionsValidation.RangeAttributeModelDate options) { - var baseName = (string.IsNullOrEmpty(name) ? "RangeAttributeModelDate" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(1); context.MemberName = "Val"; - context.DisplayName = baseName + "Val"; + context.DisplayName = string.IsNullOrEmpty(name) ? "RangeAttributeModelDate.Val" : $"{name}.Val"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A8); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -1673,21 +1631,20 @@ partial class RangeAttributeModelDoubleValidator [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::TestClasses.OptionsValidation.RangeAttributeModelDouble options) { - var baseName = (string.IsNullOrEmpty(name) ? "RangeAttributeModelDouble" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(1); context.MemberName = "Val"; - context.DisplayName = baseName + "Val"; + context.DisplayName = string.IsNullOrEmpty(name) ? "RangeAttributeModelDouble.Val" : $"{name}.Val"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A7); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -1704,21 +1661,20 @@ partial class RangeAttributeModelIntValidator [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::TestClasses.OptionsValidation.RangeAttributeModelInt options) { - var baseName = (string.IsNullOrEmpty(name) ? "RangeAttributeModelInt" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(1); context.MemberName = "Val"; - context.DisplayName = baseName + "Val"; + context.DisplayName = string.IsNullOrEmpty(name) ? "RangeAttributeModelInt.Val" : $"{name}.Val"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A16); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -1735,21 +1691,20 @@ partial class RegularExpressionAttributeModelValidator [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::TestClasses.OptionsValidation.RegularExpressionAttributeModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "RegularExpressionAttributeModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(1); context.MemberName = "Val"; - context.DisplayName = baseName + "Val"; + context.DisplayName = string.IsNullOrEmpty(name) ? "RegularExpressionAttributeModel.Val" : $"{name}.Val"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A19); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -1766,21 +1721,20 @@ partial class RequiredAttributeModelValidator [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::TestClasses.OptionsValidation.RequiredAttributeModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "RequiredAttributeModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(1); context.MemberName = "Val"; - context.DisplayName = baseName + "Val"; + context.DisplayName = string.IsNullOrEmpty(name) ? "RequiredAttributeModel.Val" : $"{name}.Val"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -1798,22 +1752,21 @@ internal sealed partial class __SecondModelValidator__ [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] public static global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::ValueTypes.SecondModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "SecondModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(2); context.MemberName = "P4"; - context.DisplayName = baseName + "P4"; + context.DisplayName = string.IsNullOrEmpty(name) ? "SecondModel.P4" : $"{name}.P4"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P4!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P4, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -1830,34 +1783,33 @@ partial struct FirstValidator [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::ValueTypes.FirstModel options) { - var baseName = (string.IsNullOrEmpty(name) ? "FirstModel" : name) + "."; - var builder = new global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder(); + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); var validationResults = new global::System.Collections.Generic.List(); var validationAttributes = new global::System.Collections.Generic.List(2); context.MemberName = "P1"; - context.DisplayName = baseName + "P1"; + context.DisplayName = string.IsNullOrEmpty(name) ? "FirstModel.P1" : $"{name}.P1"; validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1!, context, validationResults, validationAttributes)) + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1, context, validationResults, validationAttributes)) { - builder.AddResults(validationResults); + (builder ??= new()).AddResults(validationResults); } if (options.P2 is not null) { - builder.AddResult(global::ValueTypes.__SecondModelValidator__.Validate(baseName + "P2", options.P2.Value)); + (builder ??= new()).AddResult(global::ValueTypes.__SecondModelValidator__.Validate(string.IsNullOrEmpty(name) ? "FirstModel.P2" : $"{name}.P2", options.P2.Value)); } - builder.AddResult(global::ValueTypes.__SecondModelValidator__.Validate(baseName + "P3", options.P3)); + (builder ??= new()).AddResult(global::ValueTypes.__SecondModelValidator__.Validate(string.IsNullOrEmpty(name) ? "FirstModel.P3" : $"{name}.P3", options.P3)); if (options.P4 is not null) { - builder.AddResult(global::ValueTypes.__SecondModelValidator__.Validate(baseName + "P4", options.P4.Value)); + (builder ??= new()).AddResult(global::ValueTypes.__SecondModelValidator__.Validate(string.IsNullOrEmpty(name) ? "FirstModel.P4" : $"{name}.P4", options.P4.Value)); } - return builder.Build(); + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); } } } @@ -1868,7 +1820,7 @@ file static class __Attributes { internal static readonly global::System.ComponentModel.DataAnnotations.RequiredAttribute A1 = new global::System.ComponentModel.DataAnnotations.RequiredAttribute(); - internal static readonly global::System.ComponentModel.DataAnnotations.MinLengthAttribute A2 = new global::System.ComponentModel.DataAnnotations.MinLengthAttribute( + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__MinLengthAttribute A2 = new __OptionValidationGeneratedAttributes.__SourceGen__MinLengthAttribute( (int)5); internal static readonly global::CustomAttr.CustomAttribute A3 = new global::CustomAttr.CustomAttribute( @@ -1881,30 +1833,30 @@ file static class __Attributes false, "X"); - internal static readonly global::System.ComponentModel.DataAnnotations.RangeAttribute A5 = new global::System.ComponentModel.DataAnnotations.RangeAttribute( + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute A5 = new __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute( (int)0, (int)10); internal static readonly global::System.ComponentModel.DataAnnotations.RegularExpressionAttribute A6 = new global::System.ComponentModel.DataAnnotations.RegularExpressionAttribute( "\"\r\n\\\\"); - internal static readonly global::System.ComponentModel.DataAnnotations.RangeAttribute A7 = new global::System.ComponentModel.DataAnnotations.RangeAttribute( + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute A7 = new __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute( (double)0.5, (double)0.9); - internal static readonly global::System.ComponentModel.DataAnnotations.RangeAttribute A8 = new global::System.ComponentModel.DataAnnotations.RangeAttribute( + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute A8 = new __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute( typeof(global::System.DateTime), "1/2/2004", "3/4/2004"); - internal static readonly global::System.ComponentModel.DataAnnotations.RangeAttribute A9 = new global::System.ComponentModel.DataAnnotations.RangeAttribute( + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute A9 = new __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute( (int)1, (int)3) { ErrorMessage = "ErrorMessage" }; - internal static readonly global::System.ComponentModel.DataAnnotations.RangeAttribute A10 = new global::System.ComponentModel.DataAnnotations.RangeAttribute( + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute A10 = new __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute( (int)1, (int)3) { @@ -1928,15 +1880,15 @@ file static class __Attributes internal static readonly global::System.ComponentModel.DataAnnotations.DataTypeAttribute A15 = new global::System.ComponentModel.DataAnnotations.DataTypeAttribute( (global::System.ComponentModel.DataAnnotations.DataType)11); - internal static readonly global::System.ComponentModel.DataAnnotations.RangeAttribute A16 = new global::System.ComponentModel.DataAnnotations.RangeAttribute( + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute A16 = new __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute( (int)1, (int)3); - internal static readonly global::System.ComponentModel.DataAnnotations.RangeAttribute A17 = new global::System.ComponentModel.DataAnnotations.RangeAttribute( + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute A17 = new __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute( (int)3, (int)5); - internal static readonly global::System.ComponentModel.DataAnnotations.RangeAttribute A18 = new global::System.ComponentModel.DataAnnotations.RangeAttribute( + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute A18 = new __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute( (int)5, (int)9); @@ -1964,3 +1916,150 @@ file static class __Validators internal static readonly global::RecordTypes.ThirdValidator V7 = new global::RecordTypes.ThirdValidator(); } } +namespace __OptionValidationGeneratedAttributes +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + file class __SourceGen__MinLengthAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + private static string DefaultErrorMessageString => "The field {0} must be a string or array type with a minimum length of '{1}'."; + + public __SourceGen__MinLengthAttribute(int length) : base(() => DefaultErrorMessageString) { Length = length; } + public int Length { get; } + public override bool IsValid(object? value) + { + if (Length < -1) + { + throw new global::System.InvalidOperationException("MinLengthAttribute must have a Length value that is zero or greater."); + } + if (value == null) + { + return true; + } + + int length; + if (value is string stringValue) + { + length = stringValue.Length; + } + else if (value is System.Collections.ICollection collectionValue) + { + length = collectionValue.Count; + } + else + { + throw new global::System.InvalidCastException($"The field of type {value.GetType()} must be a string, array, or ICollection type."); + } + + return length >= Length; + } + public override string FormatErrorMessage(string name) => string.Format(global::System.Globalization.CultureInfo.CurrentCulture, ErrorMessageString, name, Length); + } + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + file class __SourceGen__RangeAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + public __SourceGen__RangeAttribute(int minimum, int maximum) : base() + { + Minimum = minimum; + Maximum = maximum; + OperandType = typeof(int); + } + public __SourceGen__RangeAttribute(double minimum, double maximum) : base() + { + Minimum = minimum; + Maximum = maximum; + OperandType = typeof(double); + } + public __SourceGen__RangeAttribute(global::System.Type type, string minimum, string maximum) : base() + { + OperandType = type; + NeedToConvertMinMax = true; + Minimum = minimum; + Maximum = maximum; + } + public object Minimum { get; private set; } + public object Maximum { get; private set; } + public bool MinimumIsExclusive { get; set; } + public bool MaximumIsExclusive { get; set; } + public global::System.Type OperandType { get; } + public bool ParseLimitsInInvariantCulture { get; set; } + public bool ConvertValueInInvariantCulture { get; set; } + public override string FormatErrorMessage(string name) => + string.Format(global::System.Globalization.CultureInfo.CurrentCulture, GetValidationErrorMessage(), name, Minimum, Maximum); + private bool NeedToConvertMinMax { get; } + private bool Initialized { get; set; } + public override bool IsValid(object? value) + { + if (!Initialized) + { + if (Minimum is null || Maximum is null) + { + throw new global::System.InvalidOperationException("The minimum and maximum values must be set to valid values."); + } + if (NeedToConvertMinMax) + { + System.Globalization.CultureInfo culture = ParseLimitsInInvariantCulture ? global::System.Globalization.CultureInfo.InvariantCulture : global::System.Globalization.CultureInfo.CurrentCulture; + Minimum = ConvertValue(Minimum, culture) ?? throw new global::System.InvalidOperationException("The minimum and maximum values must be set to valid values."); + Maximum = ConvertValue(Maximum, culture) ?? throw new global::System.InvalidOperationException("The minimum and maximum values must be set to valid values."); + } + int cmp = ((global::System.IComparable)Minimum).CompareTo((global::System.IComparable)Maximum); + if (cmp > 0) + { + throw new global::System.InvalidOperationException("The maximum value '{Maximum}' must be greater than or equal to the minimum value '{Minimum}'."); + } + else if (cmp == 0 && (MinimumIsExclusive || MaximumIsExclusive)) + { + throw new global::System.InvalidOperationException("Cannot use exclusive bounds when the maximum value is equal to the minimum value."); + } + Initialized = true; + } + + if (value is null or string { Length: 0 }) + { + return true; + } + + System.Globalization.CultureInfo formatProvider = ConvertValueInInvariantCulture ? global::System.Globalization.CultureInfo.InvariantCulture : global::System.Globalization.CultureInfo.CurrentCulture; + object? convertedValue; + + try + { + convertedValue = ConvertValue(value, formatProvider); + } + catch (global::System.Exception e) when (e is global::System.FormatException or global::System.InvalidCastException or global::System.NotSupportedException) + { + return false; + } + + var min = (global::System.IComparable)Minimum; + var max = (global::System.IComparable)Maximum; + + return + (MinimumIsExclusive ? min.CompareTo(convertedValue) < 0 : min.CompareTo(convertedValue) <= 0) && + (MaximumIsExclusive ? max.CompareTo(convertedValue) > 0 : max.CompareTo(convertedValue) >= 0); + } + private string GetValidationErrorMessage() + { + return (MinimumIsExclusive, MaximumIsExclusive) switch + { + (false, false) => "The field {0} must be between {1} and {2}.", + (true, false) => "The field {0} must be between {1} exclusive and {2}.", + (false, true) => "The field {0} must be between {1} and {2} exclusive.", + (true, true) => "The field {0} must be between {1} exclusive and {2} exclusive.", + }; + } + private object? ConvertValue(object? value, System.Globalization.CultureInfo formatProvider) + { + if (value is string stringValue) + { + value = global::System.Convert.ChangeType(stringValue, OperandType, formatProvider); + } + else + { + value = global::System.Convert.ChangeType(value, OperandType, formatProvider); + } + return value; + } + } +} diff --git a/src/libraries/Microsoft.Extensions.Options/tests/SourceGenerationTests/EmitterTests.cs b/src/libraries/Microsoft.Extensions.Options/tests/SourceGenerationTests/EmitterTests.cs index fe3e007e0464d8..af91cab872c88f 100644 --- a/src/libraries/Microsoft.Extensions.Options/tests/SourceGenerationTests/EmitterTests.cs +++ b/src/libraries/Microsoft.Extensions.Options/tests/SourceGenerationTests/EmitterTests.cs @@ -32,7 +32,7 @@ public async Task TestEmitter() } var (d, r) = await RoslynTestUtils.RunGenerator( - new Generator(), + new OptionsValidatorGenerator(), new[] { Assembly.GetAssembly(typeof(RequiredAttribute))!, diff --git a/src/libraries/Microsoft.Extensions.Options/tests/SourceGenerationTests/Resources/Strings.resx b/src/libraries/Microsoft.Extensions.Options/tests/SourceGenerationTests/Resources/Strings.resx index c6d9021ea99672..90e5b01ed3b1b7 100644 --- a/src/libraries/Microsoft.Extensions.Options/tests/SourceGenerationTests/Resources/Strings.resx +++ b/src/libraries/Microsoft.Extensions.Options/tests/SourceGenerationTests/Resources/Strings.resx @@ -210,4 +210,16 @@ Validation attribute on the member is inaccessible from the validator type.. - \ No newline at end of file + + C# language version not supported by the source generator. + + + The options validation source generator is not available in C# {0}. Please use language version {1} or greater. + + + The validation attribute is only applicable to properties of type string, array, or ICollection; it cannot be used with other types. + + + The validation attribute {0} should only be applied to properties of type string, array, or ICollection. Using it with the type {1} could lead to runtime failures. + + diff --git a/src/libraries/Microsoft.Extensions.Options/tests/TrimmingTests/ConfigureTests.cs b/src/libraries/Microsoft.Extensions.Options/tests/TrimmingTests/ConfigureTests.cs index 870b27bbe9e100..7a39d8a8810fd3 100644 --- a/src/libraries/Microsoft.Extensions.Options/tests/TrimmingTests/ConfigureTests.cs +++ b/src/libraries/Microsoft.Extensions.Options/tests/TrimmingTests/ConfigureTests.cs @@ -4,6 +4,8 @@ using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Options; using System; +using System.Collections.Generic; +using System.ComponentModel.DataAnnotations; class Program { @@ -37,6 +39,22 @@ optionsC is null || return -1; } + LocalOptionsValidator localOptionsValidator = new LocalOptionsValidator(); + OptionsUsingValidationAttributes optionsUsingValidationAttributes = new OptionsUsingValidationAttributes + { + P1 = "12345", + P2 = new List { "1234", "12345" }, + P3 = "123456", + P4 = "12345", + P5 = 7 + }; + + ValidateOptionsResult result = localOptionsValidator.Validate("", optionsUsingValidationAttributes); + if (result.Failed) + { + return -2; + } + return 100; } @@ -76,3 +94,29 @@ private class OptionsD public string OptionString { get; set; } } } + +public class OptionsUsingValidationAttributes +{ + [Required] + [MinLength(5)] + public string P1 { get; set; } + + [Required] + [MaxLength(5)] + public List P2 { get; set; } + + [Length(2, 8)] + public string P3 { get; set; } + + [Compare("P1")] + public string P4 { get; set; } + + [Range(1, 10, MinimumIsExclusive = true, MaximumIsExclusive = true)] + public int P5 { get; set; } +} + +[OptionsValidator] +public partial class LocalOptionsValidator : IValidateOptions +{ +} + diff --git a/src/libraries/Microsoft.Extensions.Options/tests/TrimmingTests/Microsoft.Extensions.Options.TrimmingTests.proj b/src/libraries/Microsoft.Extensions.Options/tests/TrimmingTests/Microsoft.Extensions.Options.TrimmingTests.proj index 669ac862ad7b16..15b6dc0a6ea0e2 100644 --- a/src/libraries/Microsoft.Extensions.Options/tests/TrimmingTests/Microsoft.Extensions.Options.TrimmingTests.proj +++ b/src/libraries/Microsoft.Extensions.Options/tests/TrimmingTests/Microsoft.Extensions.Options.TrimmingTests.proj @@ -7,10 +7,15 @@ Microsoft.Extensions.DependencyInjection - + + + + <_additionalProjectReference Include="<ProjectReference Include="$(LibrariesProjectRoot)Microsoft.Extensions.Options\gen\Microsoft.Extensions.Options.SourceGeneration.csproj" OutputItemType="Analyzer" ReferenceOutputAssembly="true" SetTargetFramework="TargetFramework=netstandard2.0" />" /> + + diff --git a/src/libraries/Microsoft.Extensions.Primitives/src/PACKAGE.md b/src/libraries/Microsoft.Extensions.Primitives/src/PACKAGE.md new file mode 100644 index 00000000000000..432abfa969f057 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Primitives/src/PACKAGE.md @@ -0,0 +1,109 @@ +## About + +`Microsoft.Extensions.Primitives` contains isolated types that are used in many places within console or ASP.NET Core applications using framework extensions. + +## Key Features + +* IChangeToken: An interface that represents a token that can notify when a change occurs. This can be used to trigger actions or invalidate caches when something changes. For example, the configuration and file providers libraries use this interface to reload settings or files when they are modified. +* StringValues: A struct that represents a single string or an array of strings. This can be used to efficiently store and manipulate multiple values that are logically a single value. For example, the HTTP headers and query strings libraries use this struct to handle multiple values for the same key. +* StringSegment: A struct that represents a substring of another string. This can be used to avoid allocating new strings when performing operations on parts of a string. For example, the configuration and logging libraries use this struct to parse and format strings. + +## How to Use + +#### IChangeToken with configuration example + +```C# +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.Primitives; +using System; + +class Program +{ + static void Main(string[] args) + { + // Create a configuration builder + var configurationBuilder = new ConfigurationBuilder() + .SetBasePath(Environment.CurrentDirectory) + // appsettings.json expected to have the following contents: + // { + // "SomeKey": "SomeValue" + // } + .AddJsonFile("appsettings.json", optional: true, reloadOnChange: true); + + // Build the configuration + IConfiguration configuration = configurationBuilder.Build(); + + // Create a change token for the configuration + IChangeToken changeToken = configuration.GetReloadToken(); + + // Attach a change callback + IDisposable changeTokenRegistration = changeToken.RegisterChangeCallback(state => + { + Console.WriteLine("Configuration changed!"); + IConfigurationRoot root = (IConfigurationRoot)state; + var someValue = root["SomeKey"]; // Access the updated configuration value + Console.WriteLine($"New value of SomeKey: {someValue}"); + }, configuration); + + // go and update the value of the key SomeKey in appsettings.json. + // The change callback will be invoked when the file is saved. + Console.WriteLine("Listening for configuration changes. Press any key to exit."); + Console.ReadKey(); + + // Clean up the change token registration when no longer needed + changeTokenRegistration.Dispose(); + } +} +``` +#### StringValues example + +```C# +using System; +using Microsoft.Extensions.Primitives; + +namespace StringValuesSample +{ + class Program + { + static void Main(string[] args) + { + // Create a StringValues object from a single string or an array of strings + StringValues single = "Hello"; + StringValues multiple = new string[] { "Hello", "World" }; + + // Use the implicit conversion to string or the ToString method to get the values + Console.WriteLine($"Single: {single}"); // Single: Hello + Console.WriteLine($"Multiple: {multiple}"); // Multiple: Hello,World + + // Use the indexer, the Count property, and the IsNullOrEmpty method to access the values + Console.WriteLine($"Multiple[1]: {multiple[1]}"); // Multiple[1]: World + Console.WriteLine($"Single.Count: {single.Count}"); // Single.Count: 1 + Console.WriteLine($"Multiple.IsNullOrEmpty: {StringValues.IsNullOrEmpty(multiple)}"); // Multiple.IsNullOrEmpty: False + + // Use the Equals method or the == operator to compare two StringValues objects + Console.WriteLine($"single == \"Hello\": {single == "Hello"}"); // single == "Hello": True + Console.WriteLine($"multiple == \"Hello\": {multiple == "Hello"}"); // multiple == "Hello": False + } + } +} +``` +## Main Types + +The main types provided by this library are: + +* `IChangeToken` +* `StringValues` +* `StringSegment` + +## Additional Documentation + +* [Conceptual documentation](https://learn.microsoft.com/dotnet/core/extensions/primitives) +* [API documentation](https://learn.microsoft.com/dotnet/api/microsoft.extensions.primitives) + +## Related Packages + +* [Microsoft.Extensions.Configuration](https://www.nuget.org/packages/Microsoft.Extensions.Configuration) + +## Feedback & Contributing + +Microsoft.Extensions.Primitives is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). \ No newline at end of file diff --git a/src/libraries/Microsoft.Internal.Runtime.AspNetCore.Transport/src/Microsoft.Internal.Runtime.AspNetCore.Transport.proj b/src/libraries/Microsoft.Internal.Runtime.AspNetCore.Transport/src/Microsoft.Internal.Runtime.AspNetCore.Transport.proj index 611ac6d7d33814..732c102bab245b 100644 --- a/src/libraries/Microsoft.Internal.Runtime.AspNetCore.Transport/src/Microsoft.Internal.Runtime.AspNetCore.Transport.proj +++ b/src/libraries/Microsoft.Internal.Runtime.AspNetCore.Transport/src/Microsoft.Internal.Runtime.AspNetCore.Transport.proj @@ -34,5 +34,8 @@ + diff --git a/src/libraries/Microsoft.VisualBasic.Core/src/Microsoft/VisualBasic/CompilerServices/NewLateBinding.vb b/src/libraries/Microsoft.VisualBasic.Core/src/Microsoft/VisualBasic/CompilerServices/NewLateBinding.vb index de0bde5e73a2f0..5d2f624b7a6b93 100644 --- a/src/libraries/Microsoft.VisualBasic.Core/src/Microsoft/VisualBasic/CompilerServices/NewLateBinding.vb +++ b/src/libraries/Microsoft.VisualBasic.Core/src/Microsoft/VisualBasic/CompilerServices/NewLateBinding.vb @@ -47,6 +47,10 @@ Namespace Microsoft.VisualBasic.CompilerServices baseReference = New Container(Instance) End If + If baseReference.IsCOMObject AndAlso Not baseReference.IsWindowsRuntimeObject Then + Return LateBinding.InternalLateCall(Instance, Type, MemberName, Arguments, ArgumentNames, CopyBack, IgnoreReturn) + End If + Dim idmop As IDynamicMetaObjectProvider = IDOUtils.TryCastToIDMOP(Instance) If idmop IsNot Nothing AndAlso TypeArguments Is NoTypeArguments Then Return IDOBinder.IDOCall(idmop, MemberName, Arguments, ArgumentNames, CopyBack, IgnoreReturn) @@ -139,7 +143,7 @@ Namespace Microsoft.VisualBasic.CompilerServices ' LateCallInvokeDefault is used to optionally invoke the default action on a call target. ' If the arguments are non-empty, then it isn't optional, and is treated ' as an error if there is no default action. - ' Currently we can get here only in the process of execution of NewLateBinding.LateCall. + ' Currently we can get here only in the process of execution of NewLateBinding.LateCall. @@ -155,7 +159,7 @@ Namespace Microsoft.VisualBasic.CompilerServices ' LateGetInvokeDefault is used to optionally invoke the default action. ' If the arguments are non-empty, then it isn't optional, and is treated ' as an error if there is no default action. - ' Currently we can get here only in the process of execution of NewLateBinding.LateGet. + ' Currently we can get here only in the process of execution of NewLateBinding.LateGet. @@ -167,7 +171,7 @@ Namespace Microsoft.VisualBasic.CompilerServices ' According to a comment in VBGetBinder.FallbackInvoke, this function is called when ' "The DLR was able to resolve o.member, but not o.member(args)" - ' When NewLateBinding.LateGet is evaluating similar expression itself, it never tries to invoke default action + ' When NewLateBinding.LateGet is evaluating similar expression itself, it never tries to invoke default action ' if arguments are not empty. It simply returns result of evaluating o.member. I believe, it makes sense ' to follow the same logic here. I.e., if there are no arguments, simply return the instance unless it is an IDO. @@ -278,6 +282,9 @@ Namespace Microsoft.VisualBasic.CompilerServices If argumentNames Is Nothing Then argumentNames = NoArgumentNames Dim baseReference As Container = New Container(instance) + If baseReference.IsCOMObject AndAlso Not baseReference.IsWindowsRuntimeObject Then + Return LateBinding.LateIndexGet(instance, arguments, argumentNames) + End If 'An r-value expression o(a) has two possible forms: ' 1: o(a) array lookup--where o is an array object and a is a set of indices @@ -372,6 +379,10 @@ Namespace Microsoft.VisualBasic.CompilerServices baseReference = New Container(Instance) End If + If baseReference.IsCOMObject AndAlso Not baseReference.IsWindowsRuntimeObject Then + Return LateBinding.LateGet(Instance, Type, MemberName, Arguments, ArgumentNames, CopyBack) + End If + Dim invocationFlags As BindingFlags = BindingFlagsInvokeMethod Or BindingFlagsGetProperty Dim idmop As IDynamicMetaObjectProvider = IDOUtils.TryCastToIDMOP(Instance) @@ -653,6 +664,10 @@ Namespace Microsoft.VisualBasic.CompilerServices End If Dim methodName As String = "" + If baseReference.IsCOMObject AndAlso Not baseReference.IsWindowsRuntimeObject Then + LateBinding.LateIndexSetComplex(instance, arguments, argumentNames, optimisticSet, rValueBase) + Return + End If Dim invocationFlags As BindingFlags = BindingFlagsSetProperty @@ -927,6 +942,18 @@ Namespace Microsoft.VisualBasic.CompilerServices baseReference = New Container(Instance) End If + If baseReference.IsCOMObject AndAlso Not baseReference.IsWindowsRuntimeObject Then + Try + LateBinding.InternalLateSet(Instance, Type, MemberName, Arguments, ArgumentNames, OptimisticSet, CallType) + If RValueBase And Type.IsValueType Then + Throw New Exception(Utils.GetResourceString(SR.RValueBaseForValueType, baseReference.VBFriendlyName, baseReference.VBFriendlyName)) + End If + Return + Catch ex As MissingMemberException When OptimisticSet + Return + End Try + End If + Dim invocationFlags As BindingFlags ' If we have a IDO that implements TryGetMember for a property but not TrySetMember then we could land up diff --git a/src/libraries/Microsoft.VisualBasic.Core/src/Microsoft/VisualBasic/CompilerServices/Symbols.vb b/src/libraries/Microsoft.VisualBasic.Core/src/Microsoft/VisualBasic/CompilerServices/Symbols.vb index 8301ee324af3fc..44c694595b35bb 100644 --- a/src/libraries/Microsoft.VisualBasic.Core/src/Microsoft/VisualBasic/CompilerServices/Symbols.vb +++ b/src/libraries/Microsoft.VisualBasic.Core/src/Microsoft/VisualBasic/CompilerServices/Symbols.vb @@ -823,6 +823,12 @@ Namespace Microsoft.VisualBasic.CompilerServices End Get End Property + Friend ReadOnly Property IsCOMObject() As Boolean + Get + Return _type.IsCOMObject + End Get + End Property + Friend ReadOnly Property VBFriendlyName() As String Get Return Utils.VBFriendlyName(_type, _instance) diff --git a/src/libraries/System.Collections.Immutable/src/PACKAGE.md b/src/libraries/System.Collections.Immutable/src/PACKAGE.md index cf4995b4cad9ca..0ca0b161aa448d 100644 --- a/src/libraries/System.Collections.Immutable/src/PACKAGE.md +++ b/src/libraries/System.Collections.Immutable/src/PACKAGE.md @@ -1,10 +1,72 @@ ## About + + This package provides collections that are thread safe and guaranteed to never change their contents, also known as immutable collections. Like strings, any methods that perform modifications will not change the existing instance but instead return a new instance. For efficiency reasons, the implementation uses a sharing mechanism to ensure that newly created instances share as much data as possible with the previous instance while ensuring that operations have a predictable time complexity. The `System.Collections.Immutable` library is built-in as part of the shared framework in .NET Runtime. The package can be installed when you need to use it in other target frameworks. -For more information, see the documentation: +## How to Use + + + +```C# +using System.Collections.Immutable; + +// Create immutable set of strings +ImmutableHashSet colors = ImmutableHashSet.Create("Red", "Green", "Blue"); + +// Create a new set by adding and removing items from the original set +ImmutableHashSet colorsModified = colors.Remove("Red").Add("Orange"); + +foreach (string s in colorsModified) +{ + Console.WriteLine(s); +} + +/* Example output: + Blue + Green + Orange + */ + ``` + +## Main Types + + + +The main types provided by this library are: + +* `System.Collections.Immutable.ImmutableArray` +* `System.Collections.Immutable.ImmutableArray` +* `System.Collections.Immutable.ImmutableDictionary` +* `System.Collections.Immutable.ImmutableDictionary` +* `System.Collections.Immutable.ImmutableHashSet` +* `System.Collections.Immutable.ImmutableHashSet` +* `System.Collections.Immutable.ImmutableList` +* `System.Collections.Immutable.ImmutableList` +* `System.Collections.Immutable.ImmutableQueue` +* `System.Collections.Immutable.ImmutableQueue` +* `System.Collections.Immutable.ImmutableSortedDictionary` +* `System.Collections.Immutable.ImmutableSortedDictionary` +* `System.Collections.Immutable.ImmutableSortedSet` +* `System.Collections.Immutable.ImmutableSortedSet` +* `System.Collections.Immutable.ImmutableStack` +* `System.Collections.Immutable.ImmutableStack` +* `System.Collections.Frozen.FrozenDictionary` +* `System.Collections.Frozen.FrozenDictionary` +* `System.Collections.Frozen.FrozenSet` +* `System.Collections.Frozen.FrozenSet` + +## Additional Documentation + + - [Collections and Data Structures](https://docs.microsoft.com/dotnet/standard/collections/) -- [System.Collections.Immutable API reference](https://docs.microsoft.com/dotnet/api/system.collections.immutable) +- [API documentation](https://docs.microsoft.com/dotnet/api/system.collections.immutable) + +## Feedback & Contributing + + + +System.Collections.Immutable is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). diff --git a/src/libraries/System.Collections.Immutable/src/System/Collections/Frozen/FrozenDictionary.cs b/src/libraries/System.Collections.Immutable/src/System/Collections/Frozen/FrozenDictionary.cs index 0fab5139dc6351..e4fbdcef00b3c6 100644 --- a/src/libraries/System.Collections.Immutable/src/System/Collections/Frozen/FrozenDictionary.cs +++ b/src/libraries/System.Collections.Immutable/src/System/Collections/Frozen/FrozenDictionary.cs @@ -41,7 +41,7 @@ public static FrozenDictionary ToFrozenDictionary(th public static FrozenDictionary ToFrozenDictionary( this IEnumerable source, Func keySelector, IEqualityComparer? comparer = null) where TKey : notnull => - CreateFromDictionary(source.ToDictionary(keySelector, comparer)); + source.ToDictionary(keySelector, comparer).ToFrozenDictionary(comparer); /// Creates a from an according to specified key selector and element selector functions. /// The type of the elements of . @@ -55,7 +55,7 @@ public static FrozenDictionary ToFrozenDictionary( public static FrozenDictionary ToFrozenDictionary( this IEnumerable source, Func keySelector, Func elementSelector, IEqualityComparer? comparer = null) where TKey : notnull => - CreateFromDictionary(source.ToDictionary(keySelector, elementSelector, comparer)); + source.ToDictionary(keySelector, elementSelector, comparer).ToFrozenDictionary(comparer); /// /// Extracts from the source either an existing instance or a @@ -113,6 +113,8 @@ public static FrozenDictionary ToFrozenDictionary CreateFromDictionary(Dictionary source) where TKey : notnull { + Debug.Assert(source.Count > 0, "Empty sources should have been filtered out by caller"); + IEqualityComparer comparer = source.Comparer; // Optimize for value types when the default comparer is being used. In such a case, the implementation diff --git a/src/libraries/System.Collections.Immutable/src/System/Collections/Frozen/FrozenSet.cs b/src/libraries/System.Collections.Immutable/src/System/Collections/Frozen/FrozenSet.cs index ed472ff80666ae..8c315f214fe03c 100644 --- a/src/libraries/System.Collections.Immutable/src/System/Collections/Frozen/FrozenSet.cs +++ b/src/libraries/System.Collections.Immutable/src/System/Collections/Frozen/FrozenSet.cs @@ -59,6 +59,8 @@ public static FrozenSet ToFrozenSet(this IEnumerable source, IEqualityC private static FrozenSet CreateFromSet(HashSet source) { + Debug.Assert(source.Count > 0, "Empty sources should have been filtered out by caller"); + IEqualityComparer comparer = source.Comparer; // Optimize for value types when the default comparer is being used. In such a case, the implementation diff --git a/src/libraries/System.Collections.Immutable/tests/Frozen/FrozenDictionaryTests.cs b/src/libraries/System.Collections.Immutable/tests/Frozen/FrozenDictionaryTests.cs index a1f6adc0c782bd..13e695ef0a791a 100644 --- a/src/libraries/System.Collections.Immutable/tests/Frozen/FrozenDictionaryTests.cs +++ b/src/libraries/System.Collections.Immutable/tests/Frozen/FrozenDictionaryTests.cs @@ -86,23 +86,35 @@ public void NullSource_ThrowsException() [Fact] public void EmptySource_ProducedFrozenDictionaryEmpty() { - Assert.Same(FrozenDictionary.Empty, new Dictionary().ToFrozenDictionary()); - Assert.Same(FrozenDictionary.Empty, Enumerable.Empty>().ToFrozenDictionary()); - Assert.Same(FrozenDictionary.Empty, Array.Empty>().ToFrozenDictionary()); - Assert.Same(FrozenDictionary.Empty, new List>().ToFrozenDictionary()); + IEnumerable>[] sources = new[] + { + new Dictionary(), + Enumerable.Empty>(), + Array.Empty>(), + new List>() + }; - foreach (IEqualityComparer comparer in new IEqualityComparer[] { null, EqualityComparer.Default }) + foreach (IEnumerable> source in sources) { - Assert.Same(FrozenDictionary.Empty, new Dictionary().ToFrozenDictionary(comparer)); - Assert.Same(FrozenDictionary.Empty, Enumerable.Empty>().ToFrozenDictionary(comparer)); - Assert.Same(FrozenDictionary.Empty, Array.Empty>().ToFrozenDictionary(comparer)); - Assert.Same(FrozenDictionary.Empty, new List>().ToFrozenDictionary(comparer)); - } + Assert.Same(FrozenDictionary.Empty, source.ToFrozenDictionary()); + Assert.Same(FrozenDictionary>.Empty, source.ToFrozenDictionary(s => s.Key)); + Assert.Same(FrozenDictionary.Empty, source.ToFrozenDictionary(s => s.Key, s => s.Value)); - Assert.NotSame(FrozenDictionary.Empty, new Dictionary().ToFrozenDictionary(NonDefaultEqualityComparer.Instance)); - Assert.NotSame(FrozenDictionary.Empty, Enumerable.Empty>().ToFrozenDictionary(NonDefaultEqualityComparer.Instance)); - Assert.NotSame(FrozenDictionary.Empty, Array.Empty>().ToFrozenDictionary(NonDefaultEqualityComparer.Instance)); - Assert.NotSame(FrozenDictionary.Empty, new List>().ToFrozenDictionary(NonDefaultEqualityComparer.Instance)); + foreach (IEqualityComparer comparer in new IEqualityComparer[] { null, EqualityComparer.Default }) + { + Assert.Same(FrozenDictionary.Empty, source.ToFrozenDictionary(comparer)); + Assert.Same(FrozenDictionary>.Empty, source.ToFrozenDictionary(s => s.Key, comparer)); + Assert.Same(FrozenDictionary.Empty, source.ToFrozenDictionary(s => s.Key, s => s.Value, comparer)); + } + + Assert.NotSame(FrozenDictionary.Empty, source.ToFrozenDictionary(NonDefaultEqualityComparer.Instance)); + Assert.NotSame(FrozenDictionary>.Empty, source.ToFrozenDictionary(s => s.Key, NonDefaultEqualityComparer.Instance)); + Assert.NotSame(FrozenDictionary.Empty, source.ToFrozenDictionary(s => s.Key, s => s.Value, NonDefaultEqualityComparer.Instance)); + + Assert.Equal(0, source.ToFrozenDictionary(NonDefaultEqualityComparer.Instance).Count); + Assert.Equal(0, source.ToFrozenDictionary(s => s.Key, NonDefaultEqualityComparer.Instance).Count); + Assert.Equal(0, source.ToFrozenDictionary(s => s.Key, s => s.Value, NonDefaultEqualityComparer.Instance).Count); + } } [Fact] diff --git a/src/libraries/System.Collections.Immutable/tests/Frozen/FrozenSetTests.cs b/src/libraries/System.Collections.Immutable/tests/Frozen/FrozenSetTests.cs index f9bca34ae07060..8465f66c4c3747 100644 --- a/src/libraries/System.Collections.Immutable/tests/Frozen/FrozenSetTests.cs +++ b/src/libraries/System.Collections.Immutable/tests/Frozen/FrozenSetTests.cs @@ -63,23 +63,24 @@ public void NullSource_ThrowsException() [Fact] public void EmptySource_ProducedFrozenSetEmpty() { - Assert.Same(FrozenSet.Empty, new List().ToFrozenSet()); - Assert.Same(FrozenSet.Empty, Enumerable.Empty().ToFrozenSet()); - Assert.Same(FrozenSet.Empty, Array.Empty().ToFrozenSet()); - Assert.Same(FrozenSet.Empty, new List().ToFrozenSet()); + IEnumerable[] sources = new[] + { + new List(), + Enumerable.Empty(), + Array.Empty(), + }; - foreach (IEqualityComparer comparer in new IEqualityComparer[] { null, EqualityComparer.Default }) + foreach (IEnumerable source in sources) { - Assert.Same(FrozenSet.Empty, new List().ToFrozenSet(comparer)); - Assert.Same(FrozenSet.Empty, Enumerable.Empty().ToFrozenSet(comparer)); - Assert.Same(FrozenSet.Empty, Array.Empty().ToFrozenSet(comparer)); - Assert.Same(FrozenSet.Empty, new List().ToFrozenSet(comparer)); - } + Assert.Same(FrozenSet.Empty, source.ToFrozenSet()); + + foreach (IEqualityComparer comparer in new IEqualityComparer[] { null, EqualityComparer.Default }) + { + Assert.Same(FrozenSet.Empty, source.ToFrozenSet(comparer)); + } - Assert.NotSame(FrozenSet.Empty, new List().ToFrozenSet(NonDefaultEqualityComparer.Instance)); - Assert.NotSame(FrozenSet.Empty, Enumerable.Empty().ToFrozenSet(NonDefaultEqualityComparer.Instance)); - Assert.NotSame(FrozenSet.Empty, Array.Empty().ToFrozenSet(NonDefaultEqualityComparer.Instance)); - Assert.NotSame(FrozenSet.Empty, new List().ToFrozenSet(NonDefaultEqualityComparer.Instance)); + Assert.NotSame(FrozenSet.Empty, source.ToFrozenSet(NonDefaultEqualityComparer.Instance)); + } } [Fact] diff --git a/src/libraries/System.Collections/tests/Generic/PriorityQueue/PriorityQueue.Tests.cs b/src/libraries/System.Collections/tests/Generic/PriorityQueue/PriorityQueue.Tests.cs index c7de89b28629b5..0bd6a70a8f2a57 100644 --- a/src/libraries/System.Collections/tests/Generic/PriorityQueue/PriorityQueue.Tests.cs +++ b/src/libraries/System.Collections/tests/Generic/PriorityQueue/PriorityQueue.Tests.cs @@ -287,7 +287,7 @@ void trimAndEnsureCapacity() private static int GetUnderlyingBufferCapacity(PriorityQueue queue) { - FieldInfo nodesField = queue.GetType().GetField("_nodes", BindingFlags.NonPublic | BindingFlags.Instance); + FieldInfo nodesField = typeof(PriorityQueue).GetField("_nodes", BindingFlags.NonPublic | BindingFlags.Instance); Assert.NotNull(nodesField); var nodes = ((TElement Element, TPriority Priority)[])nodesField.GetValue(queue); return nodes.Length; diff --git a/src/libraries/System.ComponentModel.Annotations/ref/System.ComponentModel.Annotations.cs b/src/libraries/System.ComponentModel.Annotations/ref/System.ComponentModel.Annotations.cs index 542845070d0c38..c9e82b8cece0c7 100644 --- a/src/libraries/System.ComponentModel.Annotations/ref/System.ComponentModel.Annotations.cs +++ b/src/libraries/System.ComponentModel.Annotations/ref/System.ComponentModel.Annotations.cs @@ -387,14 +387,14 @@ public static partial class Validator public static bool TryValidateObject(object instance, System.ComponentModel.DataAnnotations.ValidationContext validationContext, System.Collections.Generic.ICollection? validationResults, bool validateAllProperties) { throw null; } [System.Diagnostics.CodeAnalysis.RequiresUnreferencedCodeAttribute("The Type of validationContext.ObjectType cannot be statically discovered.")] public static bool TryValidateProperty(object? value, System.ComponentModel.DataAnnotations.ValidationContext validationContext, System.Collections.Generic.ICollection? validationResults) { throw null; } - public static bool TryValidateValue(object value, System.ComponentModel.DataAnnotations.ValidationContext validationContext, System.Collections.Generic.ICollection? validationResults, System.Collections.Generic.IEnumerable validationAttributes) { throw null; } + public static bool TryValidateValue(object? value, System.ComponentModel.DataAnnotations.ValidationContext validationContext, System.Collections.Generic.ICollection? validationResults, System.Collections.Generic.IEnumerable validationAttributes) { throw null; } [System.Diagnostics.CodeAnalysis.RequiresUnreferencedCodeAttribute("The Type of instance cannot be statically discovered and the Type's properties can be trimmed.")] public static void ValidateObject(object instance, System.ComponentModel.DataAnnotations.ValidationContext validationContext) { } [System.Diagnostics.CodeAnalysis.RequiresUnreferencedCodeAttribute("The Type of instance cannot be statically discovered and the Type's properties can be trimmed.")] public static void ValidateObject(object instance, System.ComponentModel.DataAnnotations.ValidationContext validationContext, bool validateAllProperties) { } [System.Diagnostics.CodeAnalysis.RequiresUnreferencedCodeAttribute("The Type of validationContext.ObjectType cannot be statically discovered.")] public static void ValidateProperty(object? value, System.ComponentModel.DataAnnotations.ValidationContext validationContext) { } - public static void ValidateValue(object value, System.ComponentModel.DataAnnotations.ValidationContext validationContext, System.Collections.Generic.IEnumerable validationAttributes) { } + public static void ValidateValue(object? value, System.ComponentModel.DataAnnotations.ValidationContext validationContext, System.Collections.Generic.IEnumerable validationAttributes) { } } } namespace System.ComponentModel.DataAnnotations.Schema diff --git a/src/libraries/System.ComponentModel.Annotations/src/System/ComponentModel/DataAnnotations/Validator.cs b/src/libraries/System.ComponentModel.Annotations/src/System/ComponentModel/DataAnnotations/Validator.cs index 3913a1df4c3fe7..d6719673011117 100644 --- a/src/libraries/System.ComponentModel.Annotations/src/System/ComponentModel/DataAnnotations/Validator.cs +++ b/src/libraries/System.ComponentModel.Annotations/src/System/ComponentModel/DataAnnotations/Validator.cs @@ -171,7 +171,7 @@ public static bool TryValidateObject(object instance, ValidationContext validati /// then all validators will be evaluated. /// /// - /// The value to test. It cannot be null. + /// The value to test. /// /// Describes the object being validated and provides services and context for the /// validators. @@ -182,7 +182,7 @@ public static bool TryValidateObject(object instance, ValidationContext validati /// against. /// /// true if the object is valid, false if any validation errors are encountered. - public static bool TryValidateValue(object value, ValidationContext validationContext, + public static bool TryValidateValue(object? value, ValidationContext validationContext, ICollection? validationResults, IEnumerable validationAttributes) { ArgumentNullException.ThrowIfNull(validationAttributes); @@ -303,12 +303,12 @@ public static void ValidateObject(object instance, ValidationContext validationC /// first. /// /// - /// The value to test. It cannot be null. + /// The value to test. /// Describes the object being tested. /// The list of s to validate against this instance. /// When is null. /// When is found to be invalid. - public static void ValidateValue(object value, ValidationContext validationContext, + public static void ValidateValue(object? value, ValidationContext validationContext, IEnumerable validationAttributes) { ArgumentNullException.ThrowIfNull(validationContext); diff --git a/src/libraries/System.ComponentModel.TypeConverter/src/System/ComponentModel/TypeDescriptor.cs b/src/libraries/System.ComponentModel.TypeConverter/src/System/ComponentModel/TypeDescriptor.cs index 874fa6e2e7d03e..2efd53e6223ae7 100644 --- a/src/libraries/System.ComponentModel.TypeConverter/src/System/ComponentModel/TypeDescriptor.cs +++ b/src/libraries/System.ComponentModel.TypeConverter/src/System/ComponentModel/TypeDescriptor.cs @@ -1542,7 +1542,7 @@ private static TypeDescriptionNode NodeFor(object instance, bool createDelegator { type = ComObjectType; } - else if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows) + else if (OperatingSystem.IsWindows() && ComWrappers.TryGetComInstance(instance, out nint unknown)) { // ComObjectType uses the Windows Forms provided ComNativeDescriptor. It currently has hard Win32 diff --git a/src/libraries/System.Configuration.ConfigurationManager/src/PACKAGE.md b/src/libraries/System.Configuration.ConfigurationManager/src/PACKAGE.md index 8448be2423b11b..992b7669426993 100644 --- a/src/libraries/System.Configuration.ConfigurationManager/src/PACKAGE.md +++ b/src/libraries/System.Configuration.ConfigurationManager/src/PACKAGE.md @@ -1,15 +1,12 @@ ## About -Provides types that support using XML configuration files (`app.config`). This package exists only to support migrating existing .NET Framework code that already uses System.Configuration. When writing new code, use another configuration system instead, such as [Microsoft.Extensions.Configuration](https://www.nuget.org/packages/Microsoft.Extensions.Configuration/). + -For more information, see the documentation: +Provides types that support using XML configuration files (`app.config`). This package exists only to support migrating existing .NET Framework code that already uses System.Configuration. When writing new code, use another configuration system instead, such as [Microsoft.Extensions.Configuration](https://www.nuget.org/packages/Microsoft.Extensions.Configuration/). -- [Configure apps by using configuration files](https://docs.microsoft.com/dotnet/framework/configure-apps/) -- [System.Configuration namespace](https://docs.microsoft.com/dotnet/api/system.configuration) -- [System.Configuration.Configuration](https://docs.microsoft.com/dotnet/api/system.configuration.configuration) -- [System.Configuration.ConfigurationManager](https://docs.microsoft.com/dotnet/api/system.configuration.configurationmanager) +## How to Use -## Example + The following example shows how to read and modify the application configuration settings. @@ -63,3 +60,27 @@ To run this example, include an `app.config` file with the following content in ``` + +## Main Types + + + +The main types provided by this library are: + +* `System.Configuration.Configuration` +* `System.Configuration.ConfigurationManager` + +## Additional Documentation + + + +* [Configure apps by using configuration files](https://docs.microsoft.com/dotnet/framework/configure-apps/) +* [System.Configuration namespace](https://docs.microsoft.com/dotnet/api/system.configuration) +* [System.Configuration.Configuration](https://docs.microsoft.com/dotnet/api/system.configuration.configuration) +* [System.Configuration.ConfigurationManager](https://docs.microsoft.com/dotnet/api/system.configuration.configurationmanager) + +## Feedback & Contributing + + + +System.Configuration.ConfigurationManager is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). \ No newline at end of file diff --git a/src/libraries/System.Data.Odbc/src/PACKAGE.md b/src/libraries/System.Data.Odbc/src/PACKAGE.md new file mode 100644 index 00000000000000..a45dd52364c18e --- /dev/null +++ b/src/libraries/System.Data.Odbc/src/PACKAGE.md @@ -0,0 +1,48 @@ +## About + +This package implements a data provider for ODBC data sources. + +## Key Features + +Allows access to ODBC data sources. + +## How to Use + +This is a basic example of retrieving the results of a query using an [OdbcDataReader](https://learn.microsoft.com/dotnet/api/system.data.odbc.odbcdatareader). For examples of using an [OdbcDataAdapter](https://learn.microsoft.com/dotnet/api/system.data.odbc.odbcdataadapter), and of updating an ODBC data source, please see the documentation. + +```cs +using System.Data.Odbc; + +string connectionString = ""; +string queryString = "SELECT DISTINCT CustomerID FROM Orders"; + +using OdbcConnection connection = new OdbcConnection(connectionString); +using OdbcCommand command = new OdbcCommand(queryString, connection); + +connection.Open(); +using OdbcDataReader reader = command.ExecuteReader(); + +while (reader.Read()) +{ + Console.WriteLine("CustomerID={0}", reader[0]); +} +``` + +## Main Types + +* [OdbcConnection](https://learn.microsoft.com/dotnet/api/system.data.odbc.odbcconnection) represents a connection to an ODBC data source. +* [OdbcCommand](https://learn.microsoft.com/dotnet/api/system.data.odbc.odbccommand) represents an SQL statement or stored procedure to execute against an ODBC data source.. +* [OdbcDataReader](https://learn.microsoft.com/dotnet/api/system.data.odbc.odbcdatareader) provides a way of reading a forward-only stream of data rows from an ODBC data source. +* [OdbcDataAdapter](https://learn.microsoft.com/dotnet/api/system.data.odbc.odbcdataadapter) represents a set of data commands and a database connection that are used to fill a [DataSet](https://learn.microsoft.com/dotnet/api/system.data.dataset) and update the ODBC data source. + +## Additional Documentation + +* [API documentation](https://learn.microsoft.com/dotnet/api/system.data.odbc) + +## Related Packages + +System.Data.OleDb is a similar package for accessing OLE DB data sources. + +## Feedback & Contributing + +System.Data.Odbc is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports are welcome at [the GitHub repository](https://github.com/dotnet/runtime). This package is considered complete and we only consider low-risk, high-impact fixes that are necessary to maintain or improve quality. \ No newline at end of file diff --git a/src/libraries/System.Data.OleDb/src/PACKAGE.md b/src/libraries/System.Data.OleDb/src/PACKAGE.md new file mode 100644 index 00000000000000..6e81daa638b2eb --- /dev/null +++ b/src/libraries/System.Data.OleDb/src/PACKAGE.md @@ -0,0 +1,48 @@ +## About + +This package implements a data provider for OLE DB data sources. + +## Key Features + +Allows access to legacy OLE DB data sources. + +## How to Use + +This is a basic example of retrieving the results of a query using an [OleDbDataReader](https://learn.microsoft.com/dotnet/api/system.data.oledb.oledbdatareader). For examples of using an [OleDbDataAdapter](https://learn.microsoft.com/dotnet/api/system.data.oledb.oledbdataadapter), and of updating an OLE DB data source, please see the documentation. + +```cs +using System.Data.OleDb; + +string connectionString = ""; +string queryString = "SELECT OrderID, CustomerID FROM Orders"; + +using OleDbConnection connection = new OleDbConnection(connectionString); +using OleDbCommand command = new OleDbCommand(queryString, connection); + +connection.Open(); +using OleDbDataReader reader = command.ExecuteReader(); + +while (reader.Read()) +{ + Console.WriteLine(reader.GetInt32(0) + ", " + reader.GetString(1)); +} +``` + +## Main Types + +* [OleDbConnection](https://learn.microsoft.com/dotnet/api/system.data.oledb.oledbconnection) represents an open connection to an OLE DB data source. +* [OleDbCommand](https://learn.microsoft.com/dotnet/api/system.data.oledb.oledbcommand) represents an SQL statement or stored procedure to execute against an OLE DB data source. +* [OleDbDataReader](https://learn.microsoft.com/dotnet/api/system.data.oledb.oledbdatareader) provides a way of reading a forward-only stream of data rows from an OLE DB data source. +* [OleDbDataAdapter](https://learn.microsoft.com/dotnet/api/system.data.oledb.oledbdataadapter) represents a set of data commands and a database connection that are used to fill a `DataSet` and update the OLE DB data source. + +## Additional Documentation + +* [API documentation](https://learn.microsoft.com/dotnet/api/system.data.oledb) + +## Related Packages + +System.Data.Odbc is a similar package for accessing ODBC data sources. + +## Feedback & Contributing + +System.Data.OleDb is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports are welcome at [the GitHub repository](https://github.com/dotnet/runtime). This package is considered complete and we only consider low-risk, high-impact fixes that are necessary to maintain or improve quality. \ No newline at end of file diff --git a/src/libraries/System.Diagnostics.DiagnosticSource/src/ILLink/ILLink.Substitutions.Shared.xml b/src/libraries/System.Diagnostics.DiagnosticSource/src/ILLink/ILLink.Substitutions.Shared.xml new file mode 100644 index 00000000000000..b67ac8623c402a --- /dev/null +++ b/src/libraries/System.Diagnostics.DiagnosticSource/src/ILLink/ILLink.Substitutions.Shared.xml @@ -0,0 +1,10 @@ + + + + + + + + + + diff --git a/src/libraries/System.Diagnostics.DiagnosticSource/src/System.Diagnostics.DiagnosticSource.csproj b/src/libraries/System.Diagnostics.DiagnosticSource/src/System.Diagnostics.DiagnosticSource.csproj index 24a23478b7c0c2..49b7834a9de36c 100644 --- a/src/libraries/System.Diagnostics.DiagnosticSource/src/System.Diagnostics.DiagnosticSource.csproj +++ b/src/libraries/System.Diagnostics.DiagnosticSource/src/System.Diagnostics.DiagnosticSource.csproj @@ -20,6 +20,10 @@ System.Diagnostics.DiagnosticSource true + + + + diff --git a/src/libraries/System.Diagnostics.DiagnosticSource/src/System/Diagnostics/Metrics/Instrument.cs b/src/libraries/System.Diagnostics.DiagnosticSource/src/System/Diagnostics/Metrics/Instrument.cs index 0cb50f523f0282..c1e2fcb199b3a5 100644 --- a/src/libraries/System.Diagnostics.DiagnosticSource/src/System/Diagnostics/Metrics/Instrument.cs +++ b/src/libraries/System.Diagnostics.DiagnosticSource/src/System/Diagnostics/Metrics/Instrument.cs @@ -64,6 +64,12 @@ protected Instrument(Meter meter, string name, string? unit, string? description /// protected void Publish() { + // All instruments call Publish when they are created. We don't want to publish the instrument if the Meter is not supported. + if (!Meter.IsSupported) + { + return; + } + List? allListeners = null; lock (Instrument.SyncObject) { diff --git a/src/libraries/System.Diagnostics.DiagnosticSource/src/System/Diagnostics/Metrics/Meter.cs b/src/libraries/System.Diagnostics.DiagnosticSource/src/System/Diagnostics/Metrics/Meter.cs index a5722aa7b6c4dc..60314b1d5b4c00 100644 --- a/src/libraries/System.Diagnostics.DiagnosticSource/src/System/Diagnostics/Metrics/Meter.cs +++ b/src/libraries/System.Diagnostics.DiagnosticSource/src/System/Diagnostics/Metrics/Meter.cs @@ -17,6 +17,11 @@ public class Meter : IDisposable private Dictionary> _nonObservableInstrumentsCache = new(); internal bool Disposed { get; private set; } + internal static bool IsSupported { get; } = InitializeIsSupported(); + + private static bool InitializeIsSupported() => + AppContext.TryGetSwitch("System.Diagnostics.Metrics.Meter.IsSupported", out bool isSupported) ? isSupported : true; + /// /// Initialize a new instance of the Meter using the . /// @@ -77,6 +82,11 @@ private void Initialize(string name, string? version, IEnumerable - /// A delegate to represent the Meterlistener callbacks used in measurements recording operation. + /// A delegate to represent the MeterListener callbacks used in measurements recording operation. /// public delegate void MeasurementCallback(Instrument instrument, T measurement, ReadOnlySpan> tags, object? state) where T : struct; @@ -56,6 +56,11 @@ public MeterListener() { } /// A state object which will be passed back to the callback getting measurements events. public void EnableMeasurementEvents(Instrument instrument, object? state = null) { + if (!Meter.IsSupported) + { + return; + } + bool oldStateStored = false; bool enabled = false; object? oldState = null; @@ -92,6 +97,11 @@ public void EnableMeasurementEvents(Instrument instrument, object? state = null) /// The state object originally passed to method. public object? DisableMeasurementEvents(Instrument instrument) { + if (!Meter.IsSupported) + { + return default; + } + object? state = null; lock (Instrument.SyncObject) { @@ -114,6 +124,11 @@ public void EnableMeasurementEvents(Instrument instrument, object? state = null) /// The callback which can be used to get measurement recording of numeric type T. public void SetMeasurementEventCallback(MeasurementCallback? measurementCallback) where T : struct { + if (!Meter.IsSupported) + { + return; + } + measurementCallback ??= (instrument, measurement, tags, state) => { /* no-op */}; if (typeof(T) == typeof(byte)) @@ -155,6 +170,11 @@ public void SetMeasurementEventCallback(MeasurementCallback? measurementCa /// public void Start() { + if (!Meter.IsSupported) + { + return; + } + List? publishedInstruments = null; lock (Instrument.SyncObject) { @@ -184,6 +204,11 @@ public void Start() /// public void RecordObservableInstruments() { + if (!Meter.IsSupported) + { + return; + } + List? exceptionsList = null; DiagNode? current = _enabledMeasurementInstruments.First; while (current is not null) @@ -215,6 +240,11 @@ public void RecordObservableInstruments() /// public void Dispose() { + if (!Meter.IsSupported) + { + return; + } + Dictionary? callbacksArguments = null; Action? measurementsCompleted = MeasurementsCompleted; diff --git a/src/libraries/System.Diagnostics.DiagnosticSource/tests/System.Diagnostics.DiagnosticSource.Tests.csproj b/src/libraries/System.Diagnostics.DiagnosticSource/tests/System.Diagnostics.DiagnosticSource.Tests.csproj index 7e8136e0a1a6b9..29a8e55dc2fc02 100644 --- a/src/libraries/System.Diagnostics.DiagnosticSource/tests/System.Diagnostics.DiagnosticSource.Tests.csproj +++ b/src/libraries/System.Diagnostics.DiagnosticSource/tests/System.Diagnostics.DiagnosticSource.Tests.csproj @@ -11,6 +11,7 @@ + diff --git a/src/libraries/System.Diagnostics.DiagnosticSource/tests/TestNotSupported.cs b/src/libraries/System.Diagnostics.DiagnosticSource/tests/TestNotSupported.cs new file mode 100644 index 00000000000000..43fae749957204 --- /dev/null +++ b/src/libraries/System.Diagnostics.DiagnosticSource/tests/TestNotSupported.cs @@ -0,0 +1,56 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.DotNet.RemoteExecutor; +using System.Diagnostics.Metrics; +using Xunit; + +namespace System.Diagnostics.Metrics.Tests +{ + public class MetricsNotSupportedTest + { + /// + /// Tests using Metrics when the System.Diagnostics.Metrics.Meter.IsSupported + /// feature switch is set to disable all metrics operations. + /// + [ConditionalTheory(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))] + [InlineData(false)] + [InlineData(true)] + public void IsSupportedSwitch(bool value) + { + RemoteInvokeOptions options = new RemoteInvokeOptions(); + options.RuntimeConfigurationOptions.Add("System.Diagnostics.Metrics.Meter.IsSupported", value); + + RemoteExecutor.Invoke((val) => + { + bool isSupported = bool.Parse(val); + + Meter meter = new Meter("IsSupportedTest"); + Counter counter = meter.CreateCounter("counter"); + bool instrumentsPublished = false; + bool instrumentCompleted = false; + long counterValue = 100; + + using (MeterListener listener = new MeterListener + { + InstrumentPublished = (instruments, theListener) => instrumentsPublished = true, + MeasurementsCompleted = (instruments, state) => instrumentCompleted = true + }) + { + listener.EnableMeasurementEvents(counter, null); + listener.SetMeasurementEventCallback((inst, measurement, tags, state) => counterValue = measurement); + listener.Start(); + + Assert.Equal(isSupported, counter.Enabled); + + counter.Add(20); + } + meter.Dispose(); + + Assert.Equal(isSupported, instrumentsPublished); + Assert.Equal(isSupported, instrumentCompleted); + Assert.Equal(isSupported ? 20 : 100, counterValue); + }, value.ToString(), options).Dispose(); + } + } +} diff --git a/src/libraries/System.Diagnostics.EventLog/src/PACKAGE.md b/src/libraries/System.Diagnostics.EventLog/src/PACKAGE.md new file mode 100644 index 00000000000000..56e111a68ac802 --- /dev/null +++ b/src/libraries/System.Diagnostics.EventLog/src/PACKAGE.md @@ -0,0 +1,87 @@ +## About + + + +This package provides types that allow applications to interact with the Windows Event Log service. + +When an error occurs in a Windows machine, the system administrator or support representative must determine what caused the error, attempt to recover any lost data, and prevent the error from recurring. It is helpful if applications, the operating system, and other system services record important events, such as low-memory conditions or excessive attempts to access a disk. The system administrator can then use the Windows Event Log to help determine what conditions caused the error and identify the context in which it occurred. + +## Key Features + + + +* Allows reading from existing logs. +* Allows writing entries to logs. +* Can create or delete event sources. +* Can delete logs. +* Can respond to log entries. +* Can create new logs when creating an event source. + +## How to Use + + + +```cs +if(!EventLog.SourceExists("MySource")) +{ + // An event log source should not be created and immediately used. + // There is a latency time to enable the source, it should be created + // prior to executing the application that uses the source. + // Execute this sample a second time to use the new source. + EventLog.CreateEventSource("MySource", "MyNewLog"); + Console.WriteLine("Event source created. Exiting, execute the application a second time to use the source."); + // The source is created. Exit the application to allow it to be registered. + return; +} + +EventLog myLog = new(); +myLog.Source = "MySource"; +myLog.WriteEntry("Writing an informational entry to the event log."); +``` + +Notes: + +- This assembly is only supported on Windows operating systems. +- Starting with Windows Vista, you must run the application as an administrator to interact with the Windows Event Log service using the `System.Diagnostics.EventLog` class. + +## Main Types + + + +The main types provided by this library are: + +Under the [`System.Diagnostics`](https://learn.microsoft.com/dotnet/api/System.Diagnostics) namespace, the main types are: + +- [`System.Diagnostics.EventLog`](https://learn.microsoft.com/dotnet/api/System.Diagnostics.EventLog) +- [`System.Diagnostics.EventLogEntry`](https://learn.microsoft.com/dotnet/api/System.Diagnostics.EventLogEntry) +- [`System.Diagnostics.EventLogEntryCollection`](https://learn.microsoft.com/dotnet/api/System.Diagnostics.EventLogEntryCollection) +- [`System.Diagnostics.EventLogEntryType`](https://learn.microsoft.com/dotnet/api/System.Diagnostics.EventLogEntryType) + +Under the [`System.Diagnostics.Eventing.Reader`](https://learn.microsoft.com/dotnet/api/System.Diagnostics.Eventing.Reader) namespace, the main types are: + +- [`System.Diagnostics.Eventing.Reader.EventLogQuery`](https://learn.microsoft.com/dotnet/api/System.Diagnostics.Eventing.Reader.EventLogQuery) +- [`System.Diagnostics.Eventing.Reader.EventLogReader`](https://learn.microsoft.com/dotnet/api/System.Diagnostics.Eventing.Reader.EventLogReader) +- [`System.Diagnostics.Eventing.Reader.EventLogRecord`](https://learn.microsoft.com/dotnet/api/System.Diagnostics.Eventing.Reader.EventLogRecord) +- [`System.Diagnostics.Eventing.Reader.EventLogSession`](https://learn.microsoft.com/dotnet/api/System.Diagnostics.Eventing.Reader.EventLogSession) +- [`System.Diagnostics.Eventing.Reader.EventLogType`](https://learn.microsoft.com/dotnet/api/System.Diagnostics.Eventing.Reader.EventLogType) +- [`System.Diagnostics.Eventing.Reader.EventRecord`](https://learn.microsoft.com/dotnet/api/System.Diagnostics.Eventing.Reader.EventRecord) + +## Additional Documentation + + + +- [Microsoft Learn - System.Diagnostics.EventLog API reference](https://learn.microsoft.com/dotnet/api/System.Diagnostics.EventLog) +- [Windows App Development - Event logging](https://learn.microsoft.com/windows/win32/eventlog/event-logging) +- [GitHub - Source code](https://github.com/dotnet/runtime/tree/main/src/libraries/System.Diagnostics.EventLog) + +## Related Packages + + + +- [System.Diagnostics.PerformanceCounter](https://www.nuget.org/packages/System.Diagnostics.PerformanceCounter) + +## Feedback & Contributing + + + +System.Diagnostics.EventLog is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). diff --git a/src/libraries/System.Diagnostics.PerformanceCounter/src/PACKAGE.md b/src/libraries/System.Diagnostics.PerformanceCounter/src/PACKAGE.md new file mode 100644 index 00000000000000..475291274ff289 --- /dev/null +++ b/src/libraries/System.Diagnostics.PerformanceCounter/src/PACKAGE.md @@ -0,0 +1,223 @@ +## About + + + +This package provides types that allow applications to interact with the Windows performance counters. + +Windows allows you to examine how programs you run affect your computer's performance, both in real time and by collecting log data for later analysis. You can do this via the Windows Performance Monitor tool, which uses performance counters, among other features. + +Windows performance counters provide a high-level abstraction layer that provides a consistent interface for collecting various kinds of system data such as CPU, memory, and disk usage. They can be included in the operating system or can be part of individual applications. Windows Performance Monitor requests the current value of performance counters at specifiedtime intervals. + +System administrators often use performance counters to monitor systems for performance or behavior problems. Software developers often use performance counters to examine the resource usage of their programs. + +## Key Features + + + +* Can be used to read existing predefined or custom counters. +* Can be used for publishing (writing) data to custom counters. +* Can collect performance counters from the local machine or from a remote machine. + +## How to Use + + + +```cs +using System; +using System.Collections.Generic; +using System.Diagnostics; + +public class App +{ + public static void Main() + { + List samples = []; + + // If the category does not exist, create the category and exit. + // Performance counters should not be created and immediately used. + // There is a latency time to enable the counters, they should be created + // prior to executing the application that uses the counters. + // Execute this sample a second time to use the category. + if (SetupCategory()) + { + return; + } + + CollectSamples(samples); + CalculateResults(samples); + } + + private static bool SetupCategory() + { + if (PerformanceCounterCategory.Exists("AverageCounter64SampleCategory")) + { + Console.WriteLine("Category exists - AverageCounter64SampleCategory"); + return false; + } + + CounterCreationDataCollection counterDataCollection = []; + + // Add the counter. + CounterCreationData averageCount64 = new() + { + CounterType = PerformanceCounterType.AverageCount64, + CounterName = "AverageCounter64Sample" + }; + counterDataCollection.Add(averageCount64); + + // Add the base counter. + CounterCreationData averageCount64Base = new() + { + CounterType = PerformanceCounterType.AverageBase, + CounterName = "AverageCounter64SampleBase" + }; + counterDataCollection.Add(averageCount64Base); + + // Create the category. + PerformanceCounterCategory.Create("AverageCounter64SampleCategory", + "Demonstrates usage of the AverageCounter64 performance counter type.", + PerformanceCounterCategoryType.SingleInstance, counterDataCollection); + + return true; + } + + private static void CollectSamples(List samples) + { + // Create the counters + + PerformanceCounter avgCounter64Sample = new PerformanceCounter("AverageCounter64SampleCategory", + "AverageCounter64Sample", + false) + { + RawValue = 0 + }; + + PerformanceCounter avgCounter64SampleBase = new PerformanceCounter("AverageCounter64SampleCategory", + "AverageCounter64SampleBase", + false) + { + RawValue = 0 + }; + + Random r = new(DateTime.Now.Millisecond); + + for (int j = 0; j < 100; j++) + { + int value = r.Next(1, 10); + Console.Write(j + " = " + value); + + avgCounter64Sample.IncrementBy(value); + + avgCounter64SampleBase.Increment(); + + if ((j % 10) == 9) + { + OutputSample(avgCounter64Sample.NextSample()); + samples.Add(avgCounter64Sample.NextSample()); + } + else + { + Console.WriteLine(); + } + + System.Threading.Thread.Sleep(50); + } + } + + private static void CalculateResults(List samples) + { + for (int i = 0; i < (samples.Count - 1); i++) + { + // Output the sample. + OutputSample(samples[i]); + OutputSample(samples[i + 1]); + + // Use .NET to calculate the counter value. + Console.WriteLine($".NET computed counter value = {CounterSampleCalculator.ComputeCounterValue(samples[i], samples[i + 1])}"); + + // Calculate the counter value manually. + Console.WriteLine($"My computed counter value = {MyComputeCounterValue(samples[i], samples[i + 1])}"); + } + } + + // Description - This counter type shows how many items are processed, on average, + // during an operation. Counters of this type display a ratio of the items + // processed (such as bytes sent) to the number of operations completed. The + // ratio is calculated by comparing the number of items processed during the + // last interval to the number of operations completed during the last interval. + // Generic type - Average + // Formula - (N1 - N0) / (D1 - D0), where the numerator (N) represents the number + // of items processed during the last sample interval and the denominator (D) + // represents the number of operations completed during the last two sample + // intervals. + // Average (Nx - N0) / (Dx - D0) + // Example PhysicalDisk\ Avg. Disk Bytes/Transfer + private static float MyComputeCounterValue(CounterSample s0, CounterSample s1) + { + float numerator = (float)s1.RawValue - s0.RawValue; + float denomenator = (float)s1.BaseValue - s0.BaseValue; + float counterValue = numerator / denomenator; + return counterValue; + } + + private static void OutputSample(CounterSample s) + { + Console.WriteLine("\r\n+++++++++++"); + Console.WriteLine("Sample values - \r\n"); + Console.WriteLine($" BaseValue = {s.BaseValue}"); + Console.WriteLine($" CounterFrequency = {s.CounterFrequency}"); + Console.WriteLine($" CounterTimeStamp = {s.CounterTimeStamp}"); + Console.WriteLine($" CounterType = {s.CounterType}"); + Console.WriteLine($" RawValue = {s.RawValue}"); + Console.WriteLine($" SystemFrequency = {s.SystemFrequency}"); + Console.WriteLine($" TimeStamp = {s.TimeStamp}"); + Console.WriteLine($" TimeStamp100nSec = {s.TimeStamp100nSec}"); + Console.WriteLine("++++++++++++++++++++++"); + } +} +``` + +Notes: + +* This assembly is only supported on Windows operating systems. +* Only the administrator of the computer or users in the Performance Logs User Group can log counter data. Users in the Administrator group can log counter data only if the tool they use to log counter data is started from a Command Prompt window that is opened with `Run as administrator...`. Any users in interactive logon sessions can view counter data. However, users in non-interactive logon sessions must be in the Performance Monitoring Users group to view counter data. + +## Main Types + + + +The main types provided by this library are: + +Under the [`System.Diagnostics`](https://learn.microsoft.com/dotnet/api/System.Diagnostics) namespace, the main types are: + +* [`System.Diagnostics.CounterCreationData`](https://learn.microsoft.com/dotnet/api/System.Diagnostics.CounterCreationData) +* [`System.Diagnostics.CounterCreationDataCollection`](https://learn.microsoft.com/dotnet/api/System.Diagnostics.CounterCreationDataCollection) +* [`System.Diagnostics.PerformanceCounter`](https://learn.microsoft.com/dotnet/api/System.Diagnostics.PerformanceCounter) + +Under the [`System.Diagnostics.PerformanceData`](https://learn.microsoft.com/dotnet/api/System.Diagnostics.PerformanceData) namespace, the main types are: + +* [`System.Diagnostics.PerformanceData.CounterData`](https://learn.microsoft.com/dotnet/api/System.Diagnostics.PerformanceData.CounterData) +* [`System.Diagnostics.PerformanceData.CounterSet`](https://learn.microsoft.com/dotnet/api/System.Diagnostics.PerformanceData.CounterSet) +* [`System.Diagnostics.PerformanceData.CounterType`](https://learn.microsoft.com/dotnet/api/System.Diagnostics.PerformanceData.CounterType) + +## Additional Documentation + + + +* [Microsoft Learn - System.Diagnostics.PerformanceCounter API reference](https://learn.microsoft.com/dotnet/api/system.diagnostics.performancecounter?view=dotnet-plat-ext-7.0) +* [Windows App Development - Performance Counters](https://learn.microsoft.com/windows/win32/perfctrs/performance-counters-portal) +* [Windows Performance and Reliability - Windows Performance Monitor](https://learn.microsoft.com/previous-versions/windows/it-pro/windows-server-2008-R2-and-2008/cc749249(v=ws.11)) +* [Windows Server - perfmon](https://learn.microsoft.com/windows-server/administration/windows-commands/perfmon) +* [GitHub - Source code](https://github.com/dotnet/runtime/tree/main/src/libraries/System.Diagnostics.PerformanceCounter) + +## Related Packages + + + +* [System.Diagnostics.EventLog](https://www.nuget.org/packages/System.Diagnostics.EventLog) + +## Feedback & Contributing + + + +System.Diagnostics.PerformanceCounter is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). diff --git a/src/libraries/System.DirectoryServices.AccountManagement/src/PACKAGE.md b/src/libraries/System.DirectoryServices.AccountManagement/src/PACKAGE.md new file mode 100644 index 00000000000000..bee50af73d3c45 --- /dev/null +++ b/src/libraries/System.DirectoryServices.AccountManagement/src/PACKAGE.md @@ -0,0 +1,74 @@ +## About + + +Provides uniform access and manipulation of security principals across multiple principal stores. The principal objects in the Account Management API include computer, group and user objects. The principal stores includes: + * Active Directory Domain Services (AD DS) + * Active Directory Lightweight Directory Services (AD LDS) + * Machine SAM (MSAM). + +## Key Features + + + +* Basic directory operations such as creating and updating security principals. The application requires less knowledge of the underlying stores to perform these operations. +* Applications can extend the object model to include new types of directory objects. +* Account management tasks, such as enabling and disabling a user account. +* Cross-store support allows group objects in the Active Directory Domain Services (AD DS), Active Directory Lightweight Directory Services (AD LDS), and Machine SAM (MSAM) databases to contain members from different types of stores. +* Query by example searching, available on the PrincipalSearcher class, enables applications to set properties on a principal object and search the selected store for other objects that contain matching property values. +* Enhanced search on computer, user and group principal objects enables applications to search the selected store for matching principal objects. +* Recursive search, available on the group principal object, enables applications to search a group recursively and return only principal objects that are leaf nodes. +* Credential validation against the Machine SAM, AD DS, and AD LS stores. +* Connections speeds are increased by using the Fast Concurrent Bind (FSB) feature when available. Connection caching decreases the number of ports used. + +## How to Use + + + +```cs +// Create the principal context for the usr object. +PrincipalContext ctx = new PrincipalContext(ContextType.Domain, "fabrikam.com", "CN=Users,DC=fabrikam,DC=com", "administrator", "securelyStoredPassword"); + +// Create the principal user object from the context. +UserPrincipal usr = new UserPrincipal(ctx); +usr.AdvancedSearchFilter.LastLogonTime(DateTime.Now, MatchType.LessThan); +usr.AdvancedSearchFilter.LastLogonTime(DateTime.Yesterday, MatchType.GreaterThan); + +// Create a PrincipalSearcher object. +PrincipalSearcher ps = new PrincipalSearcher(usr); +PrincipalSearchResult fr = ps.FindAll(); +foreach (UserPrincipal u in results) +{ + Console.WriteLine(u.Name); +} +``` + +## Main Types + + + +The main types provided by this library are: + +* `System.DirectoryServices.AccountManagement.PrincipalContext` +* `System.DirectoryServices.AccountManagement.PrincipalSearcher` +* `System.DirectoryServices.AccountManagement.Principal` and its subclasses: `System.DirectoryServices.AccountManagement.UserPrincipal`, `System.DirectoryServices.AccountManagement.GroupPrincipal` and `System.DirectoryServices.AccountManagement.ComputerPrincipal` + +## Additional Documentation + + + +* Conceptual documentations: + - [System.DirectoryServices.AccountManagement Namespace Overview](https://learn.microsoft.com/previous-versions/bb384379(v=vs.90)) + - [About System.DirectoryServices.AccountManagement](https://learn.microsoft.com//previous-versions/bb384375(v=vs.90)) + - [Using System.DirectoryServices.AccountManagement](https://learn.microsoft.com/previous-versions/bb384384(v=vs.90)) +* API documentation + - [System.DirectoryServices.AccountManagement namespace](https://learn.microsoft.com/dotnet/api/system.directoryservices.accountmanagement) + +## Related Packages + +[System.DirectoryServices](https://learn.microsoft.com/dotnet/api/system.directoryservices) + +## Feedback & Contributing + + + +System.DirectoryServices.AccountManagement is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). diff --git a/src/libraries/System.DirectoryServices.Protocols/src/PACKAGE.md b/src/libraries/System.DirectoryServices.Protocols/src/PACKAGE.md new file mode 100644 index 00000000000000..2938eee70fa8b0 --- /dev/null +++ b/src/libraries/System.DirectoryServices.Protocols/src/PACKAGE.md @@ -0,0 +1,69 @@ +## About + + + +System.DirectoryServices.Protocols provides a managed implementation of Lightweight Directory Access Protocol (LDAP) version 3 and Directory Services Markup Language (DSML) version 2.0 (V2) standards. + +It primarily uses the `LdapConnection` type for interacting with LDAP servers, using system native libraries to establish TCP/IP or UDP LDAP connections. +Supports both Windows and Unix, but certain features, such as setting client or server certificate options, are not available on Unix. + +## Key Features + + + +* Managed implementation of LDAP v3 and DSML V2 standards. + +## How to Use + + + +Using the `LdapConnection` type, you can establish connections to LDAP servers and issue requests. + +Here is a simple example: + +```csharp +using System.DirectoryServices.Protocols; + +// Create a new LdapConnection instance using the server URL. +using (LdapConnection connection = new LdapConnection("ldap.example.com")) { + + // Some credentials + connection.Credential = new NetworkCredential(dn, password); + + // Connect to the server + connection.Bind(); + + // Perform LDAP operations +} +``` + +## Main Types + + + +The main types provided by this library are: + +* `System.DirectoryServices.Protocols.LdapConnection` +* `System.DirectoryServices.Protocols.DirectoryAttribute` +* `System.DirectoryServices.Protocols.DirectoryOperation` +* `System.DirectoryServices.Protocols.DirectoryRequest` +* `System.DirectoryServices.Protocols.DirectoryResponse` + +## Additional Documentation + + + +* [API documentation](https://learn.microsoft.com/dotnet/api/system.directoryservices.protocols) +* [Active Directory Domain Services](https://learn.microsoft.com/windows/win32/ad/active-directory-domain-services) + +## Related Packages + + + +* [System.DirectoryServices](https://www.nuget.org/packages/System.DirectoryServices/) + +## Feedback & Contributing + + + +System.DirectoryServices.Protocols is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). diff --git a/src/libraries/System.DirectoryServices/src/PACKAGE.md b/src/libraries/System.DirectoryServices/src/PACKAGE.md new file mode 100644 index 00000000000000..981c865baf9210 --- /dev/null +++ b/src/libraries/System.DirectoryServices/src/PACKAGE.md @@ -0,0 +1,81 @@ +## About + + + +Provides easy access to [Active Directory Domain Services](https://learn.microsoft.com/windows/win32/ad/active-directory-domain-services) from managed code. `Microsoft Active Directory Domain Services` are the foundation for distributed networks built on Windows 2000 Server, Windows Server 2003 and Microsoft Windows Server 2008 operating systems that use domain controllers. The namespace contains two component classes, [DirectoryEntry](https://learn.microsoft.com/dotnet/api/system.directoryservices.directoryentry) and [DirectorySearcher](https://learn.microsoft.com/dotnet/api/system.directoryservices.directorysearcher), which use the Active Directory Services Interfaces (ADSI) technology. ADSI is the set of interfaces that Microsoft provides as a flexible tool for working with a variety of network providers. ADSI gives the administrator the ability to locate and manage resources on a network with relative ease, regardless of the size of the network. + +## Key Features + + + +Active Directory Domain Services use a tree structure. Each node in the tree contains a set of properties. Use this library to traverse, search, and modify the tree, and read and write to the properties of a node. + +* The [DirectoryEntry](https://learn.microsoft.com/dotnet/api/system.directoryservices.directoryentry) class encapsulates a node or object in the Active Directory Domain Services hierarchy. Use this class for binding to objects, reading properties, and updating attributes. Together with helper classes, DirectoryEntry provides support for life-cycle management and navigation methods, including creating, deleting, renaming, moving a child node, and enumerating children. +* Use the [DirectorySearcher](https://learn.microsoft.com/dotnet/api/system.directoryservices.directorysearcher) class to perform queries against the Active Directory Domain Services hierarchy. LDAP is the only system-supplied Active Directory Service Interfaces (ADSI) provider that supports searching. A search of the Active Directory Domain Services hierarchy through [DirectorySearcher](https://learn.microsoft.com/dotnet/api/system.directoryservices.directorysearcher) returns instances of [SearchResult](https://learn.microsoft.com/dotnet/api/system.directoryservices.searchresult), which are contained in an instance of the [SearchResultCollection](https://learn.microsoft.com/dotnet/api/system.directoryservices.searchresultcollection) class. +* Network administrators write scripts and applications that access Active Directory Domain Services to automate common administrative tasks, such as adding users and groups, managing printers, and setting permissions for network resources. + +## How to Use + + + +Install the `System.DirectoryServices` library from nuget + +```dotnetcli +dotnet add package System.DirectoryServices --version 7.0.1 +``` + +The sample needs a real path to an Active Directory server to work properly: + +```cs +using System.DirectoryServices; + +namespace TestDirectoryServices +{ + internal class Program + { + static void Main(string[] args) + { + DirectoryEntry rootDse = new DirectoryEntry("LDAP://RootDSE"); + string configNamingContext = rootDse.Properties["configurationNamingContext"].Value.ToString(); + + DirectoryEntry certTemplates = new DirectoryEntry("LDAP://CN=Certificate Templates,CN=Public Key Services,CN=Services," + configNamingContext); + DirectorySearcher templatesSearch = new DirectorySearcher(certTemplates, "(objectClass=pKICertificateTemplate)", null, SearchScope.OneLevel); + + SearchResultCollection templates = templatesSearch.FindAll(); + + foreach (SearchResult template in templates) + { + Console.WriteLine($"Name: {template.Properties["name"][0]} ({template.Properties["displayName"][0]})"); + Console.WriteLine($"Flags: {template.Properties["msPKI-Enrollment-Flag"][0]}"); + } + } + } +} +``` + +## Main Types + + + +The main types provided by this library are: + +* `System.DirectoryServices.DirectoryEntry` +* `System.DirectoryServices.DirectorySearcher` + +## Additional Documentation + +* [API documentation](https://learn.microsoft.com/dotnet/api/system.directoryservices) +* [Active Directory Domain Services](https://learn.microsoft.com/windows/win32/ad/active-directory-domain-services) +* [Active Directory Service Interfaces](https://learn.microsoft.com/windows/win32/adsi/active-directory-service-interfaces-adsi) +* [Lightweight Directory Access Protocol (LDAP)](https://learn.microsoft.com/previous-versions/windows/desktop/ldap/lightweight-directory-access-protocol-ldap-api) + +## Related Packages + +* [System.DirectoryServices.AccountManagement](https://learn.microsoft.com/dotnet/api/system.directoryservices.accountmanagement) +* [System.DirectoryServices.Protocols](https://learn.microsoft.com/dotnet/api/system.directoryservices.protocols) + +## Feedback & Contributing + + + +System.DirectoryServices is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). diff --git a/src/libraries/System.Formats.Cbor/src/PACKAGE.md b/src/libraries/System.Formats.Cbor/src/PACKAGE.md new file mode 100644 index 00000000000000..b6549d0941e8c2 --- /dev/null +++ b/src/libraries/System.Formats.Cbor/src/PACKAGE.md @@ -0,0 +1,95 @@ +## About + + + +Provides support for reading and writing values in Concise Binary Object Representation (CBOR) format, as originally defined in [IETF RFC 7049](https://www.ietf.org/rfc/rfc7049.html). + + +## Key Features + + + +* Reader and writer types for the CBOR format. +* Built-in support for different CBOR conformance modes. + +## How to Use + + + +Write and read primitives: + +```csharp +using System.Formats.Cbor; + +var cborWriter = new CborWriter(CborConformanceMode.Lax); +cborWriter.WriteTextString("Hello World"); + +var cborReader = new CborReader(cborWriter.Encode(), CborConformanceMode.Lax); +Console.WriteLine(cborReader.ReadTextString()); +// Hello World +``` + +Write and read an array: + +```csharp +var cborWriter = new CborWriter(CborConformanceMode.Lax); +cborWriter.WriteStartArray(5); +for (var index = 0; index < 5; index++) +{ + cborWriter.WriteInt32(index); +} +cborWriter.WriteEndArray(); + +var cborReader = new CborReader(cborWriter.Encode(), CborConformanceMode.Lax); +var arrayLength = cborReader.ReadStartArray(); +for (var index = 0; index < arrayLength; index++) +{ + Console.Write(cborReader.ReadInt32()); +} +// 01234 +cborReader.ReadEndArray(); +``` + +Inspect writer and reader state: + +```csharp +var cborWriter = new CborWriter(CborConformanceMode.Lax); +cborWriter.WriteTextString("SomeArray"); +Console.WriteLine(cborWriter.BytesWritten); +// 10 +Console.WriteLine(cborWriter.IsWriteCompleted); +// True + +var cborReader = new CborReader(cborWriter.Encode(), CborConformanceMode.Lax); +Console.WriteLine(cborReader.BytesRemaining); +// 10 +Console.WriteLine(cborReader.ReadTextString()); +// SomeArray +Console.WriteLine(cborReader.BytesRemaining); +// 0 +``` + +## Main Types + + + +The main types provided by this library are: + +* `System.Formats.Cbor.CborReader` +* `System.Formats.Cbor.CborWriter` +* `System.Formats.Cbor.CborReaderState` +* `System.Formats.Cbor.CborConformanceMode` +* `System.Formats.Cbor.CborContentException` +* `System.Formats.Cbor.CborTag` + +## Additional Documentation + + + +* [API documentation](https://learn.microsoft.com/dotnet/api/system.formats.cbor) + +## Feedback & Contributing + + + +System.Formats.Cbor is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). \ No newline at end of file diff --git a/src/libraries/System.Formats.Cbor/src/System/Formats/Cbor/Reader/CborReader.Tag.cs b/src/libraries/System.Formats.Cbor/src/System/Formats/Cbor/Reader/CborReader.Tag.cs index 238217856c1905..2dd4c4d1d9abe1 100644 --- a/src/libraries/System.Formats.Cbor/src/System/Formats/Cbor/Reader/CborReader.Tag.cs +++ b/src/libraries/System.Formats.Cbor/src/System/Formats/Cbor/Reader/CborReader.Tag.cs @@ -72,7 +72,7 @@ public DateTimeOffset ReadDateTimeOffset() string dateString = ReadTextString(); // TODO determine if conformance modes should allow inexact date sting parsing - if (!DateTimeOffset.TryParseExact(dateString, CborWriter.Rfc3339FormatString, null, DateTimeStyles.RoundtripKind, out DateTimeOffset result)) + if (!DateTimeOffset.TryParseExact(dateString, CborWriter.Rfc3339FormatString, CultureInfo.InvariantCulture, DateTimeStyles.RoundtripKind, out DateTimeOffset result)) { throw new CborContentException(SR.Cbor_Reader_InvalidDateTimeEncoding); } diff --git a/src/libraries/System.Formats.Cbor/src/System/Formats/Cbor/Writer/CborWriter.Tag.cs b/src/libraries/System.Formats.Cbor/src/System/Formats/Cbor/Writer/CborWriter.Tag.cs index e5e772dcfd1be5..3ca04f085e04c0 100644 --- a/src/libraries/System.Formats.Cbor/src/System/Formats/Cbor/Writer/CborWriter.Tag.cs +++ b/src/libraries/System.Formats.Cbor/src/System/Formats/Cbor/Writer/CborWriter.Tag.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Diagnostics; +using System.Globalization; using System.Numerics; namespace System.Formats.Cbor @@ -42,8 +43,8 @@ public void WriteDateTimeOffset(DateTimeOffset value) #else value.Offset == TimeSpan.Zero ? #endif // NET8_0_OR_GREATER - value.UtcDateTime.ToString(Rfc3339FormatString) : // prefer 'Z' over '+00:00' - value.ToString(Rfc3339FormatString); + value.UtcDateTime.ToString(Rfc3339FormatString, CultureInfo.InvariantCulture) : // prefer 'Z' over '+00:00' + value.ToString(Rfc3339FormatString, CultureInfo.InvariantCulture); WriteTag(CborTag.DateTimeString); WriteTextString(dateString); diff --git a/src/libraries/System.Formats.Cbor/tests/Reader/CborReaderTests.Tag.cs b/src/libraries/System.Formats.Cbor/tests/Reader/CborReaderTests.Tag.cs index af9adfbe67b500..c6fba6d2a3981c 100644 --- a/src/libraries/System.Formats.Cbor/tests/Reader/CborReaderTests.Tag.cs +++ b/src/libraries/System.Formats.Cbor/tests/Reader/CborReaderTests.Tag.cs @@ -2,8 +2,11 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Collections.Generic; +using System.Globalization; using System.Linq; using System.Numerics; +using System.Threading; +using Microsoft.DotNet.RemoteExecutor; using Test.Cryptography; using Xunit; @@ -192,6 +195,31 @@ public static void ReadDateTimeOffset_SingleValue_HappyPath(string expectedValue Assert.Equal(expectedValue.Offset, result.Offset); } + [SkipOnTargetFramework(TargetFrameworkMonikers.NetFramework)] + [ConditionalFact(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))] + public static void ReadDateTimeOffset_IsCultureInvariant() + { + // Regression test for https://github.com/dotnet/runtime/pull/92539 + RemoteExecutor.Invoke(static () => + { + DateTimeOffset expectedValue = DateTimeOffset.Parse("2020-04-09T14:31:21.3535941+01:00", CultureInfo.InvariantCulture); + byte[] data = "c07821323032302d30342d30395431343a33313a32312e333533353934312b30313a3030".HexToByteArray(); + + // Install a non-Gregorian calendar + var culture = new CultureInfo("he-IL"); + culture.DateTimeFormat.Calendar = new HebrewCalendar(); + Thread.CurrentThread.CurrentCulture = culture; + + var reader = new CborReader(data); + + DateTimeOffset result = reader.ReadDateTimeOffset(); + + Assert.Equal(CborReaderState.Finished, reader.PeekState()); + Assert.Equal(expectedValue, result); + Assert.Equal(expectedValue.Offset, result.Offset); + }).Dispose(); + } + [Theory] [InlineData("c01a514b67b0")] // string datetime tag with unix time payload public static void ReadDateTimeOffset_InvalidTagPayload_ShouldThrowCborContentException(string hexEncoding) @@ -206,6 +234,7 @@ public static void ReadDateTimeOffset_InvalidTagPayload_ShouldThrowCborContentEx [Theory] [InlineData("c07330392f30342f323032302031393a35313a3530")] // 0("09/04/2020 19:51:50") [InlineData("c06e4c617374204368726973746d6173")] // 0("Last Christmas") + [InlineData("c07828d7aad7a922d7a42dd796272dd79822d7955431343a33313a32312e333533353934312b30313a3030")] // Non-Gregorian calendar date. public static void ReadDateTimeOffset_InvalidDateString_ShouldThrowCborContentException(string hexEncoding) { byte[] encoding = hexEncoding.HexToByteArray(); diff --git a/src/libraries/System.Formats.Cbor/tests/System.Formats.Cbor.Tests.csproj b/src/libraries/System.Formats.Cbor/tests/System.Formats.Cbor.Tests.csproj index 2ade4c628c7fbc..bf7b2f2b4aac54 100644 --- a/src/libraries/System.Formats.Cbor/tests/System.Formats.Cbor.Tests.csproj +++ b/src/libraries/System.Formats.Cbor/tests/System.Formats.Cbor.Tests.csproj @@ -1,6 +1,7 @@ - + $(NetCoreAppCurrent);$(NetFrameworkCurrent) + true enable $(NoWarn);CS8002 diff --git a/src/libraries/System.Formats.Cbor/tests/Writer/CborWriterTests.Tag.cs b/src/libraries/System.Formats.Cbor/tests/Writer/CborWriterTests.Tag.cs index 3413eadc84cc31..ff480bca39e119 100644 --- a/src/libraries/System.Formats.Cbor/tests/Writer/CborWriterTests.Tag.cs +++ b/src/libraries/System.Formats.Cbor/tests/Writer/CborWriterTests.Tag.cs @@ -2,8 +2,11 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Collections.Generic; +using System.Globalization; using System.Linq; using System.Numerics; +using System.Threading; +using Microsoft.DotNet.RemoteExecutor; using Test.Cryptography; using Xunit; @@ -88,6 +91,30 @@ public static void WriteDateTimeOffset_SingleValue_HappyPath(string valueString, AssertHelper.HexEqual(expectedHexEncoding.HexToByteArray(), encoding); } + [SkipOnTargetFramework(TargetFrameworkMonikers.NetFramework)] + [ConditionalFact(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))] + public static void WriteDateTimeOffset_IsCultureInvariant() + { + // Regression test for https://github.com/dotnet/runtime/pull/92539 + RemoteExecutor.Invoke(static () => + { + DateTimeOffset value = DateTimeOffset.Parse("2020-04-09T14:31:21.3535941+01:00", CultureInfo.InvariantCulture); + string expectedHexEncoding = "c07821323032302d30342d30395431343a33313a32312e333533353934312b30313a3030"; + + // Install a non-Gregorian calendar + var culture = new CultureInfo("he-IL"); + culture.DateTimeFormat.Calendar = new HebrewCalendar(); + Thread.CurrentThread.CurrentCulture = culture; + + var writer = new CborWriter(); + + writer.WriteDateTimeOffset(value); + + byte[] encoding = writer.Encode(); + AssertHelper.HexEqual(expectedHexEncoding.HexToByteArray(), encoding); + }).Dispose(); + } + [Theory] [InlineData(1363896240, "c11a514b67b0")] [InlineData(1586439081, "c11a5e8f23a9")] diff --git a/src/libraries/System.Globalization/tests/CultureInfo/CultureInfoGetCultures.cs b/src/libraries/System.Globalization/tests/CultureInfo/CultureInfoGetCultures.cs new file mode 100644 index 00000000000000..5dcdfb54c1f835 --- /dev/null +++ b/src/libraries/System.Globalization/tests/CultureInfo/CultureInfoGetCultures.cs @@ -0,0 +1,25 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Xunit; + +namespace System.Globalization.Tests +{ + public class CultureInfoGetCultures + { + [Fact] + public void GetSpecificCultures() + { + var specificCultures = CultureInfo.GetCultures(CultureTypes.SpecificCultures); + Assert.True(specificCultures.Length > 0); + Assert.All(specificCultures, c => Assert.True(c.IsNeutralCulture == false)); + } + + [Fact] + public void GetAllCultures() + { + var allCultures = CultureInfo.GetCultures(CultureTypes.AllCultures); + Assert.True(allCultures.Length > 0); + } + } +} diff --git a/src/libraries/System.Globalization/tests/Hybrid/System.Globalization.IOS.Tests.csproj b/src/libraries/System.Globalization/tests/Hybrid/System.Globalization.IOS.Tests.csproj index eac50b0d2f3e0f..bb38f00b88f9d8 100644 --- a/src/libraries/System.Globalization/tests/Hybrid/System.Globalization.IOS.Tests.csproj +++ b/src/libraries/System.Globalization/tests/Hybrid/System.Globalization.IOS.Tests.csproj @@ -6,6 +6,7 @@ + diff --git a/src/libraries/System.IO.Hashing/src/PACKAGE.md b/src/libraries/System.IO.Hashing/src/PACKAGE.md new file mode 100644 index 00000000000000..41b90205eac84e --- /dev/null +++ b/src/libraries/System.IO.Hashing/src/PACKAGE.md @@ -0,0 +1,91 @@ +## About + + + +System.IO.Hashing offers a variety of hash code algorithms. + +Hash code algorithms are pivotal for generating unique values for objects based on their content, facilitating object comparisons, and detecting content alterations. +The namespace encompasses algorithms like CRC-32, CRC-64, xxHash3, xxHash32, xxHash64, and xxHash128, all engineered for swift and efficient hash code generation, with xxHash being an "Extremely fast hash algorithm". + +**Warning**: The hash functions provided by System.IO.Hashing are not suitable for security purposes such as handling passwords or verifying untrusted content. +For such security-critical applications, consider using cryptographic hash functions provided by the [System.Security.Cryptography](https://learn.microsoft.com/dotnet/api/system.security.cryptography) namespace. + +## Key Features + + + +* Variety of hash code algorithms including CRC-32, CRC-64, xxHash3, xxHash32, xxHash64, and xxHash128. +* Implementations of CRC-32 and CRC-64 algorithms, as used in IEEE 802.3, and described in ECMA-182, Annex B respectively. +* Implementations of XxHash32 for generating 32-bit hashes, XxHash3 and XxHash64 for generating 64-bit hashes, and xxHash128 for generating 128-bit hashes. + +## How to Use + + + +Creating hash codes is straightforward. +Call the `Hash` method with the content to be hashed. + +Here is a practical example: + +```csharp +using System; +using System.IO.Hashing; + +byte[] data = new byte[] { 1, 2, 3, 4 }; + +byte[] crc32Value = Crc32.Hash(data); +Console.WriteLine($"CRC-32 Hash: {BitConverter.ToString(crc32Value)}"); +// CRC-32 Hash: CD-FB-3C-B6 + +byte[] crc64Value = Crc64.Hash(data); +Console.WriteLine($"CRC-64 Hash: {BitConverter.ToString(crc64Value)}"); +// CRC-64 Hash: 58-8D-5A-D4-2A-70-1D-B2 + +byte[] xxHash3Value = XxHash3.Hash(data); +Console.WriteLine($"XxHash3 Hash: {BitConverter.ToString(xxHash3Value)}"); +// XxHash3 Hash: 98-8B-7B-90-33-AC-46-22 + +byte[] xxHash32Value = XxHash32.Hash(data); +Console.WriteLine($"XxHash32 Hash: {BitConverter.ToString(xxHash32Value)}"); +// XxHash32 Hash: FE-96-D1-9C + +byte[] xxHash64Value = XxHash64.Hash(data); +Console.WriteLine($"XxHash64 Hash: {BitConverter.ToString(xxHash64Value)}"); +// XxHash64 Hash: 54-26-20-E3-A2-A9-2E-D1 + +byte[] xxHash128Value = XxHash128.Hash(data); +Console.WriteLine($"XxHash128 Hash: {BitConverter.ToString(xxHash128Value)}"); +// XxHash128 Hash: 49-A0-48-99-59-7A-35-67-53-76-53-A0-D9-95-5B-86 +``` + +## Main Types + + + +The main types provided by this library are: + +* `System.IO.Hashing.Crc32` +* `System.IO.Hashing.Crc64` +* `System.IO.Hashing.XxHash3` +* `System.IO.Hashing.XxHash32` +* `System.IO.Hashing.XxHash64` +* `System.IO.Hashing.XxHash128` + +## Additional Documentation + + + +* [API documentation](https://learn.microsoft.com/dotnet/api/system.io.hashing) +* [xxHash - Extremely fast hash algorithm](https://github.com/Cyan4973/xxHash/blob/release/doc/xxhash_spec.md) + +## Related Packages + + + +Cryptographic services, including secure encryption and decryption of data: [System.Security.Cryptography](https://learn.microsoft.com/dotnet/api/system.security.cryptography) + +## Feedback & Contributing + + + +System.IO.Hashing is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). diff --git a/src/libraries/System.IO.Ports/src/PACKAGE.md b/src/libraries/System.IO.Ports/src/PACKAGE.md new file mode 100644 index 00000000000000..218bf288eb0454 --- /dev/null +++ b/src/libraries/System.IO.Ports/src/PACKAGE.md @@ -0,0 +1,59 @@ +## About + + + +[System.IO.Ports](https://www.nuget.org/packages/System.IO.Ports) package provides synchronous serial port file resource. Additionally, the functionality of this class can be wrapped in an internal `Stream` object, accessible through the `BaseStream` property, and passed to classes that wrap or use streams. + +## Key Features + + + +* synchronous and event-driven I/O +* access to pin and break states +* access to serial driver properties +* access to `Stream` object through the `BaseStream` property + +## How to Use + + + +```C# +using System.IO.Ports; + +// Provides list of available serial ports +string[] portNames = SerialPort.GetPortNames(); + +// First available port +string myPortName = portNames[0]; +int baudRate = 9600; + +SerialPort sp = new SerialPort(myPortName, baudRate); +sp.Open(); +sp.WriteLine("Hello World!"); +``` + +## Main Types + + + +The main type provided by this library is: + +* `SerialPort` + +## Additional Documentation + + + +* [SerialPort class documentation](https://learn.microsoft.com/dotnet/api/system.io.ports.serialport?view=dotnet-plat-ext-7.0) +* [API documentation](https://learn.microsoft.com/dotnet/api/System.IO.Ports) + +## Related Packages + + +- [System.IO.Ports](https://www.nuget.org/packages/System.IO.Ports) + +## Feedback & Contributing + + + +System.IO.Ports is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). diff --git a/src/libraries/System.IO/tests/BinaryWriter/BinaryWriter.EncodingTests.cs b/src/libraries/System.IO/tests/BinaryWriter/BinaryWriter.EncodingTests.cs index 43e5441ea0676d..4b8a9577e0ee59 100644 --- a/src/libraries/System.IO/tests/BinaryWriter/BinaryWriter.EncodingTests.cs +++ b/src/libraries/System.IO/tests/BinaryWriter/BinaryWriter.EncodingTests.cs @@ -190,7 +190,7 @@ public void WriteString_NotUtf8(int stringLengthInChars) private static bool IsUsingFastUtf8(BinaryWriter writer) { - return (bool)writer.GetType().GetField("_useFastUtf8", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(writer); + return (bool)typeof(BinaryWriter).GetField("_useFastUtf8", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(writer); } private static string GenerateLargeUnicodeString(int charCount) diff --git a/src/libraries/System.Linq.Expressions/ref/System.Linq.Expressions.cs b/src/libraries/System.Linq.Expressions/ref/System.Linq.Expressions.cs index 98bfd460ddb960..015cfb079d360a 100644 --- a/src/libraries/System.Linq.Expressions/ref/System.Linq.Expressions.cs +++ b/src/libraries/System.Linq.Expressions/ref/System.Linq.Expressions.cs @@ -607,9 +607,13 @@ protected Expression(System.Linq.Expressions.ExpressionType nodeType, System.Typ public static System.Linq.Expressions.NewExpression New(System.Reflection.ConstructorInfo constructor, System.Collections.Generic.IEnumerable? arguments, params System.Reflection.MemberInfo[]? members) { throw null; } public static System.Linq.Expressions.NewExpression New(System.Reflection.ConstructorInfo constructor, params System.Linq.Expressions.Expression[]? arguments) { throw null; } public static System.Linq.Expressions.NewExpression New([System.Diagnostics.CodeAnalysis.DynamicallyAccessedMembersAttribute(System.Diagnostics.CodeAnalysis.DynamicallyAccessedMemberTypes.NonPublicConstructors | System.Diagnostics.CodeAnalysis.DynamicallyAccessedMemberTypes.PublicConstructors)] System.Type type) { throw null; } + [System.Diagnostics.CodeAnalysis.RequiresDynamicCode("Creating arrays at runtime requires dynamic code generation.")] public static System.Linq.Expressions.NewArrayExpression NewArrayBounds(System.Type type, System.Collections.Generic.IEnumerable bounds) { throw null; } + [System.Diagnostics.CodeAnalysis.RequiresDynamicCode("Creating arrays at runtime requires dynamic code generation.")] public static System.Linq.Expressions.NewArrayExpression NewArrayBounds(System.Type type, params System.Linq.Expressions.Expression[] bounds) { throw null; } + [System.Diagnostics.CodeAnalysis.RequiresDynamicCode("Creating arrays at runtime requires dynamic code generation.")] public static System.Linq.Expressions.NewArrayExpression NewArrayInit(System.Type type, System.Collections.Generic.IEnumerable initializers) { throw null; } + [System.Diagnostics.CodeAnalysis.RequiresDynamicCode("Creating arrays at runtime requires dynamic code generation.")] public static System.Linq.Expressions.NewArrayExpression NewArrayInit(System.Type type, params System.Linq.Expressions.Expression[] initializers) { throw null; } public static System.Linq.Expressions.UnaryExpression Not(System.Linq.Expressions.Expression expression) { throw null; } public static System.Linq.Expressions.UnaryExpression Not(System.Linq.Expressions.Expression expression, System.Reflection.MethodInfo? method) { throw null; } @@ -1028,6 +1032,7 @@ internal MethodCallExpression() { } System.Linq.Expressions.Expression System.Linq.Expressions.IArgumentProvider.GetArgument(int index) { throw null; } public System.Linq.Expressions.MethodCallExpression Update(System.Linq.Expressions.Expression? @object, System.Collections.Generic.IEnumerable? arguments) { throw null; } } + [System.Diagnostics.CodeAnalysis.RequiresDynamicCode("Creating arrays at runtime requires dynamic code generation.")] public partial class NewArrayExpression : System.Linq.Expressions.Expression { internal NewArrayExpression() { } diff --git a/src/libraries/System.Linq.Expressions/src/CompatibilitySuppressions.xml b/src/libraries/System.Linq.Expressions/src/CompatibilitySuppressions.xml index 8185aa8209fa11..bf51da8ccbe294 100644 --- a/src/libraries/System.Linq.Expressions/src/CompatibilitySuppressions.xml +++ b/src/libraries/System.Linq.Expressions/src/CompatibilitySuppressions.xml @@ -97,6 +97,48 @@ ref/net8.0/System.Linq.Expressions.dll lib/net8.0/System.Linq.Expressions.dll + + CP0016 + M:System.Linq.Expressions.Expression.Call(System.Linq.Expressions.Expression,System.String,System.Type[],System.Linq.Expressions.Expression[]):[T:System.Diagnostics.CodeAnalysis.RequiresDynamicCodeAttribute] + ref/net8.0/System.Linq.Expressions.dll + lib/net8.0/System.Linq.Expressions.dll + + + CP0016 + M:System.Linq.Expressions.Expression.Call(System.Type,System.String,System.Type[],System.Linq.Expressions.Expression[]):[T:System.Diagnostics.CodeAnalysis.RequiresDynamicCodeAttribute] + ref/net8.0/System.Linq.Expressions.dll + lib/net8.0/System.Linq.Expressions.dll + + + CP0016 + M:System.Linq.Expressions.Expression.ListInit(System.Linq.Expressions.NewExpression,System.Collections.Generic.IEnumerable{System.Linq.Expressions.Expression}):[T:System.Diagnostics.CodeAnalysis.RequiresDynamicCodeAttribute] + ref/net8.0/System.Linq.Expressions.dll + lib/net8.0/System.Linq.Expressions.dll + + + CP0016 + M:System.Linq.Expressions.Expression.ListInit(System.Linq.Expressions.NewExpression,System.Linq.Expressions.Expression[]):[T:System.Diagnostics.CodeAnalysis.RequiresDynamicCodeAttribute] + ref/net8.0/System.Linq.Expressions.dll + lib/net8.0/System.Linq.Expressions.dll + + + CP0016 + M:System.Linq.Expressions.Expression.ListInit(System.Linq.Expressions.NewExpression,System.Reflection.MethodInfo,System.Collections.Generic.IEnumerable{System.Linq.Expressions.Expression}):[T:System.Diagnostics.CodeAnalysis.RequiresDynamicCodeAttribute] + ref/net8.0/System.Linq.Expressions.dll + lib/net8.0/System.Linq.Expressions.dll + + + CP0016 + M:System.Linq.Expressions.Expression.ListInit(System.Linq.Expressions.NewExpression,System.Reflection.MethodInfo,System.Linq.Expressions.Expression[]):[T:System.Diagnostics.CodeAnalysis.RequiresDynamicCodeAttribute] + ref/net8.0/System.Linq.Expressions.dll + lib/net8.0/System.Linq.Expressions.dll + + + CP0016 + M:System.Runtime.CompilerServices.CallSite`1.Create(System.Runtime.CompilerServices.CallSiteBinder):[T:System.Diagnostics.CodeAnalysis.RequiresDynamicCodeAttribute] + ref/net8.0/System.Linq.Expressions.dll + lib/net8.0/System.Linq.Expressions.dll + CP0020 M:System.Linq.Expressions.DynamicExpressionVisitor.#ctor diff --git a/src/libraries/System.Linq.Expressions/src/Resources/Strings.resx b/src/libraries/System.Linq.Expressions/src/Resources/Strings.resx index 01f6ec7be8ddc5..cfbf09d82e02e9 100644 --- a/src/libraries/System.Linq.Expressions/src/Resources/Strings.resx +++ b/src/libraries/System.Linq.Expressions/src/Resources/Strings.resx @@ -564,4 +564,7 @@ The given key '{0}' was not present in the dictionary. + + Nullable lifting on non-primitive type '{0}' is only supported in expression trees when dynamic code generation is available. + diff --git a/src/libraries/System.Linq.Expressions/src/System.Linq.Expressions.csproj b/src/libraries/System.Linq.Expressions/src/System.Linq.Expressions.csproj index 47375e1d827797..08f93036a937da 100644 --- a/src/libraries/System.Linq.Expressions/src/System.Linq.Expressions.csproj +++ b/src/libraries/System.Linq.Expressions/src/System.Linq.Expressions.csproj @@ -3,7 +3,6 @@ $(NetCoreAppCurrent) $(DefineConstants);FEATURE_FAST_CREATE $(NoWarn);CA1859 - false + +Provides access to a rich set of management information and management events about the system, devices, and applications instrumented to the Windows Management Instrumentation (WMI) infrastructure. Not supported on other platforms. + +## Key Features + + + +* Consume Windows Management Instrumentation (WMI) data and events +* High performance extensible event mechanism + +## How to Use + + + +### Retrieve management information +```C# +using System.Management; + +// Get the WMI class +ManagementClass managementClass = new("Win32_Processor"); + +// Loop through the WMI class instances and print the processor information found +foreach (ManagementObject managementObject in managementClass.GetInstances()) +{ + Console.WriteLine("--- Processor information ---"); + Console.WriteLine($"Name: {managementObject["Name"]}"); + Console.WriteLine($"Architecture: {managementObject["Architecture"]}"); +} +``` + +### Query management information via the SelectQuery type +```C# +using System.Management; + +// Search for win32 services with a stopped state +SelectQuery selectQuery = new("Win32_Service", "State = 'Stopped'"); +ManagementObjectSearcher managementObjectSearcher = new(selectQuery); + +foreach (ManagementObject service in managementObjectSearcher.Get()) +{ + Console.WriteLine(service.ToString()); +} +``` + +## Main Types + + + +The main types provided by this library are: + +* `System.Management.ManagementClass` +* `System.Management.ManagementObject` +* `System.Management.SelectQuery` + +## Additional Documentation + + + +* [Conceptual documentation](https://learn.microsoft.com/windows/win32/wmisdk/wmi-start-page) +* [System.Management API documentation](https://learn.microsoft.com/dotnet/api/system.management?view=dotnet-plat-ext-7.0) +* [System.Management.ManagementClass documentation](https://learn.microsoft.com/dotnet/api/system.management.managementclass.-ctor?view=dotnet-plat-ext-7.0) + +## Feedback & Contributing + + + +System.Management is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). diff --git a/src/libraries/System.Memory.Data/src/PACKAGE.md b/src/libraries/System.Memory.Data/src/PACKAGE.md new file mode 100644 index 00000000000000..33996a16b08252 --- /dev/null +++ b/src/libraries/System.Memory.Data/src/PACKAGE.md @@ -0,0 +1,102 @@ +## About + + + +System.Memory.Data introduces the `BinaryData` type, a lightweight abstraction for a byte payload. +It makes it easy to convert between string, bytes, and stream. + +This abstraction can simplify the API surface by exposing a single type instead of numerous overloads or properties. +The `BinaryData` type handles data ownership efficiently, wrapping passed-in bytes when using `byte[]` or `ReadOnlyMemory` constructors or methods, and managing data as bytes when dealing with streams, strings, or rich model types serialized as JSON. + + +## Key Features + + + +* Lightweight abstraction for byte payload via `BinaryData` type. +* Convenient helper methods for common conversions among string, bytes, and stream. +* Efficient data ownership handling. + +## How to Use + + + +To/From String: + +```csharp +var data = new BinaryData("some data"); + +// ToString will decode the bytes using UTF-8 +Console.WriteLine(data.ToString()); // prints "some data" +``` + +To/From Bytes: + +```csharp +byte[] bytes = Encoding.UTF8.GetBytes("some data"); + +// Create BinaryData using a constructor ... +BinaryData data = new BinaryData(bytes); + +// Or using a static factory method. +data = BinaryData.FromBytes(bytes); + +// There is an implicit cast defined for ReadOnlyMemory +ReadOnlyMemory rom = data; + +// There is also an implicit cast defined for ReadOnlySpan +ReadOnlySpan ros = data; + +// there is also a ToMemory method that gives access to the ReadOnlyMemory. +rom = data.ToMemory(); + +// and a ToArray method that converts into a byte array. +byte[] array = data.ToArray(); +``` + +To/From stream: + +```csharp +var bytes = Encoding.UTF8.GetBytes("some data"); +Stream stream = new MemoryStream(bytes); +var data = BinaryData.FromStream(stream); + +// Calling ToStream will give back a stream that is backed by ReadOnlyMemory, so it is not writable. +stream = data.ToStream(); +Console.WriteLine(stream.CanWrite); // prints false +``` + +`BinaryData` also can be used to integrate with `ObjectSerializer`. +By default, the `JsonObjectSerializer` will be used, but any serializer deriving from `ObjectSerializer` can be used. + +```csharp +var model = new CustomModel +{ + A = "some text", + B = 5, + C = true +}; + +var data = BinaryData.FromObjectAsJson(model); +model = data.ToObjectFromJson(); +``` + +## Main Types + + + +The main types provided by this library are: + +* `System.BinaryData` + +## Additional Documentation + + + +* [API documentation](https://learn.microsoft.com/dotnet/api/system.binarydata) + +## Feedback & Contributing + + + +System.Memory.Data is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). diff --git a/src/libraries/System.Net.Http.Json/src/PACKAGE.md b/src/libraries/System.Net.Http.Json/src/PACKAGE.md new file mode 100644 index 00000000000000..23fa95df54eccf --- /dev/null +++ b/src/libraries/System.Net.Http.Json/src/PACKAGE.md @@ -0,0 +1,50 @@ +## About + +Provides extension methods for `System.Net.Http.HttpClient` and `System.Net.Http.HttpContent` that facilitate serialization and deserialization of HTTP requests using System.Text.Json. + +## Key Features + +* Extension methods for deserializing HTTP response JSON bodies. +* Extension methods for serializing HTTP request JSON bodies. +* Extension methods for deserializing JSON from `HttpContent` instances. + +## How to Use + +```C# +using System.Net.Http.Json; + +using var client = new HttpClient(); + +// Get the list of all books +Book[] books = await client.GetFromJsonAsync("https://api.contoso.com/books"); + +// Send a POST request to add a new book +var book = new Book(id: 42, "Title", "Author", publishedYear: 2023); +HttpResponseMessage response = await client.PostAsJsonAsync($"https://api.contoso.com/books/{book.id}", book); + +if (response.IsSuccessStatusCode) + Console.WriteLine("Book added successfully."); +else + Console.WriteLine($"HTTP request failed with status code: {response.StatusCode}"); + +public record Book(int id, string title, string author, int publishedYear); +``` + +## Main Types + +The main types provided by this library are: + +* `HttpClientJsonExtensions` +* `HttpContentJsonExtensions` + +## Additional Documentation + +* [API documentation](https://learn.microsoft.com/dotnet/api/system.net.http.json) + +## Related Packages + +* [System.Text.Json](https://www.nuget.org/packages/System.Text.Json) + +## Feedback & Contributing + +System.Net.Http.Json is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). diff --git a/src/libraries/System.Net.Http.Json/src/System/Net/Http/Json/HttpClientJsonExtensions.Get.AsyncEnumerable.cs b/src/libraries/System.Net.Http.Json/src/System/Net/Http/Json/HttpClientJsonExtensions.Get.AsyncEnumerable.cs index 057bf872692c3c..cf32fc760836a1 100644 --- a/src/libraries/System.Net.Http.Json/src/System/Net/Http/Json/HttpClientJsonExtensions.Get.AsyncEnumerable.cs +++ b/src/libraries/System.Net.Http.Json/src/System/Net/Http/Json/HttpClientJsonExtensions.Get.AsyncEnumerable.cs @@ -135,10 +135,7 @@ public static partial class HttpClientJsonExtensions JsonSerializerOptions? options, CancellationToken cancellationToken) { - options ??= JsonSerializerOptions.Default; - options.MakeReadOnly(); - - var jsonTypeInfo = (JsonTypeInfo)options.GetTypeInfo(typeof(TValue)); + var jsonTypeInfo = (JsonTypeInfo)JsonHelpers.GetJsonTypeInfo(typeof(TValue), options); return FromJsonStreamAsyncCore(client, requestUri, jsonTypeInfo, cancellationToken); } diff --git a/src/libraries/System.Net.Http.Json/tests/FunctionalTests/HttpClientJsonExtensionsTests.cs b/src/libraries/System.Net.Http.Json/tests/FunctionalTests/HttpClientJsonExtensionsTests.cs index 328f28f1bbf1e3..fee02cc8822950 100644 --- a/src/libraries/System.Net.Http.Json/tests/FunctionalTests/HttpClientJsonExtensionsTests.cs +++ b/src/libraries/System.Net.Http.Json/tests/FunctionalTests/HttpClientJsonExtensionsTests.cs @@ -580,6 +580,43 @@ public async Task GetFromJsonAsAsyncEnumerable_EnforcesTimeoutOnInitialRequest() await Task.Delay(TimeSpan.FromMilliseconds(10)); } } + + [Fact] + public async Task GetFromJsonAsAsyncEnumerable_SerializerUsesCamelCase() + { + using var client = new HttpClient(new CustomResponseHandler((r, c) => + { + string json = """[{"value":1},{"value":2}]"""; + HttpResponseMessage response = new() + { + Content = new StringContent(json) + }; + return Task.FromResult(response); + })); + + await foreach (var m in client.GetFromJsonAsAsyncEnumerable("http://dummyUrl")) + { + Assert.True(m.Value > 0); + } + } + + [Fact] + public async Task GetFromJsonAsAsyncEnumerable_CustomSerializerOptions() + { + using var client = new HttpClient(new CustomResponseHandler((r, c) => + { + string json = """[{"Value":1},{"Value":2}]"""; + HttpResponseMessage response = new() + { + Content = new StringContent(json) + }; + return Task.FromResult(response); + })); + await foreach (var m in client.GetFromJsonAsAsyncEnumerable("http://dummyUrl", JsonSerializerOptions.Default)) + { + Assert.True(m.Value > 0); + } + } } } @@ -593,3 +630,8 @@ public CustomResponseHandler( protected override Task SendAsync( HttpRequestMessage request, CancellationToken cancellationToken) => _func(request, cancellationToken); } + +file sealed class TestModel +{ + public int Value { get; set; } +} \ No newline at end of file diff --git a/src/libraries/System.Net.Http.WinHttpHandler/src/PACKAGE.md b/src/libraries/System.Net.Http.WinHttpHandler/src/PACKAGE.md new file mode 100644 index 00000000000000..dcc800e598728a --- /dev/null +++ b/src/libraries/System.Net.Http.WinHttpHandler/src/PACKAGE.md @@ -0,0 +1,48 @@ +## About + +This package provides an [`HttpMessageHandler`](https://learn.microsoft.com/dotnet/api/system.net.http.httpmessagehandler) implementation backed by [Windows HTTP Services (WinHTTP)](https://learn.microsoft.com/windows/win32/winhttp/winhttp-start-page). +While the use of the default `HttpClientHandler` is highly recommended for applications targeting modern .NET, `WinHttpHandler` might help migration scenarios by providing an alternative HTTP backend for Windows that works consistently accross .NET Framework and modern .NET. + +## Key Features + +* Enables sending *asynchronous* HTTP requests with `HttpClient` on Windows. +* Handles authentication and credentials. +* Exposes a subset of WinHTTP options as C# properties on `WinHttpHandler`. +* Use custom proxy. +* Handle cookies. + +## How to Use + +```C# +using System.Net; + +using WinHttpHandler handler = new() +{ + ServerCredentials = new NetworkCredential("usr", "pwd") +}; + +using HttpClient client = new(handler); +using HttpRequestMessage request = new(HttpMethod.Get, "https://httpbin.org/basic-auth/usr/pwd"); +using HttpResponseMessage response = await client.SendAsync(request); + +Console.WriteLine($"Status: {response.StatusCode}"); +if (response.IsSuccessStatusCode) +{ + string content = await response.Content.ReadAsStringAsync(); + Console.WriteLine(content); +} +``` + +## Main Types + +The main types provided by this library are: + +* `System.Net.Http.WinHttpHandler` + +## Additional Documentation + +* [API documentation](https://learn.microsoft.com/dotnet/api/system.net.http.winhttphandler) + +## Feedback & Contributing + +System.Net.Http.WinHttpHandler is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). diff --git a/src/libraries/System.Net.Http.WinHttpHandler/tests/FunctionalTests/BidirectionStreamingTest.cs b/src/libraries/System.Net.Http.WinHttpHandler/tests/FunctionalTests/BidirectionStreamingTest.cs index d18de54bc46cba..ce79f20ad4efec 100644 --- a/src/libraries/System.Net.Http.WinHttpHandler/tests/FunctionalTests/BidirectionStreamingTest.cs +++ b/src/libraries/System.Net.Http.WinHttpHandler/tests/FunctionalTests/BidirectionStreamingTest.cs @@ -143,7 +143,17 @@ public async Task AfterReadResponseServerError_ClientWrite() // Server sends RST_STREAM. await connection.WriteFrameAsync(new RstStreamFrame(FrameFlags.EndStream, 0, streamId)); - await Assert.ThrowsAsync(() => requestStream.WriteAsync(new byte[50]).AsTask()); + await Assert.ThrowsAsync(async () => + { + for (int i = 0; i < 10; i++) + { + await requestStream.WriteAsync(new byte[50]); + + // WriteAsync succeeded because handler hasn't processed RST_STREAM yet. + // Small wait before trying again. + await Task.Delay(50); + } + }); } } diff --git a/src/libraries/System.Net.Http/src/ILLink/ILLink.Suppressions.Mobile.LibraryBuild.xml b/src/libraries/System.Net.Http/src/ILLink/ILLink.Suppressions.Mobile.LibraryBuild.xml deleted file mode 100644 index 7fd4d6a15a34fb..00000000000000 --- a/src/libraries/System.Net.Http/src/ILLink/ILLink.Suppressions.Mobile.LibraryBuild.xml +++ /dev/null @@ -1,269 +0,0 @@ - - - - - ILLink - IL2075 - member - M:System.Net.Http.HttpClientHandler.InvokeNativeHandlerMethod(System.String,System.Object[]) - The Xamarin.iOS and Mono.Android libraries are not present when running the trimmer analysis during our build. A consuming application will get a warning if these libraries aren't present when trimming the full app. - - - ILLink - IL2035 - member - M:System.Net.Http.HttpClientHandler.GetUseCookies() - The Xamarin.iOS and Mono.Android libraries are not present when running the trimmer analysis during our build. A consuming application will get a warning if these libraries aren't present when trimming the full app. - - - ILLink - IL2035 - member - M:System.Net.Http.HttpClientHandler.SetUseCookies(System.Boolean) - The Xamarin.iOS and Mono.Android libraries are not present when running the trimmer analysis during our build. A consuming application will get a warning if these libraries aren't present when trimming the full app. - - - ILLink - IL2035 - member - M:System.Net.Http.HttpClientHandler.GetCookieContainer() - The Xamarin.iOS and Mono.Android libraries are not present when running the trimmer analysis during our build. A consuming application will get a warning if these libraries aren't present when trimming the full app. - - - ILLink - IL2035 - member - M:System.Net.Http.HttpClientHandler.SetCookieContainer(System.Net.CookieContainer) - The Xamarin.iOS and Mono.Android libraries are not present when running the trimmer analysis during our build. A consuming application will get a warning if these libraries aren't present when trimming the full app. - - - ILLink - IL2035 - member - M:System.Net.Http.HttpClientHandler.GetAllowAutoRedirect() - The Xamarin.iOS and Mono.Android libraries are not present when running the trimmer analysis during our build. A consuming application will get a warning if these libraries aren't present when trimming the full app. - - - ILLink - IL2035 - member - M:System.Net.Http.HttpClientHandler.SetAllowAutoRedirect(System.Boolean) - The Xamarin.iOS and Mono.Android libraries are not present when running the trimmer analysis during our build. A consuming application will get a warning if these libraries aren't present when trimming the full app. - - - ILLink - IL2035 - member - M:System.Net.Http.HttpClientHandler.GetCredentials() - The Xamarin.iOS and Mono.Android libraries are not present when running the trimmer analysis during our build. A consuming application will get a warning if these libraries aren't present when trimming the full app. - - - ILLink - IL2035 - member - M:System.Net.Http.HttpClientHandler.SetCredentials(System.Net.ICredentials) - The Xamarin.iOS and Mono.Android libraries are not present when running the trimmer analysis during our build. A consuming application will get a warning if these libraries aren't present when trimming the full app. - - - ILLink - IL2035 - member - M:System.Net.Http.HttpClientHandler.GetAutomaticDecompression() - The Xamarin.iOS and Mono.Android libraries are not present when running the trimmer analysis during our build. A consuming application will get a warning if these libraries aren't present when trimming the full app. - - - ILLink - IL2035 - member - M:System.Net.Http.HttpClientHandler.SetAutomaticDecompression(System.Net.DecompressionMethods) - The Xamarin.iOS and Mono.Android libraries are not present when running the trimmer analysis during our build. A consuming application will get a warning if these libraries aren't present when trimming the full app. - - - ILLink - IL2035 - member - M:System.Net.Http.HttpClientHandler.GetUseProxy() - The Xamarin.iOS and Mono.Android libraries are not present when running the trimmer analysis during our build. A consuming application will get a warning if these libraries aren't present when trimming the full app. - - - ILLink - IL2035 - member - M:System.Net.Http.HttpClientHandler.SetUseProxy(System.Boolean) - The Xamarin.iOS and Mono.Android libraries are not present when running the trimmer analysis during our build. A consuming application will get a warning if these libraries aren't present when trimming the full app. - - - ILLink - IL2035 - member - M:System.Net.Http.HttpClientHandler.GetProxy() - The Xamarin.iOS and Mono.Android libraries are not present when running the trimmer analysis during our build. A consuming application will get a warning if these libraries aren't present when trimming the full app. - - - ILLink - IL2035 - member - M:System.Net.Http.HttpClientHandler.SetProxy(System.Net.IWebProxy) - The Xamarin.iOS and Mono.Android libraries are not present when running the trimmer analysis during our build. A consuming application will get a warning if these libraries aren't present when trimming the full app. - - - ILLink - IL2035 - member - M:System.Net.Http.HttpClientHandler.GetPreAuthenticate() - The Xamarin.iOS and Mono.Android libraries are not present when running the trimmer analysis during our build. A consuming application will get a warning if these libraries aren't present when trimming the full app. - - - ILLink - IL2035 - member - M:System.Net.Http.HttpClientHandler.SetPreAuthenticate(System.Boolean) - The Xamarin.iOS and Mono.Android libraries are not present when running the trimmer analysis during our build. A consuming application will get a warning if these libraries aren't present when trimming the full app. - - - ILLink - IL2035 - member - M:System.Net.Http.HttpClientHandler.GetMaxAutomaticRedirections() - The Xamarin.iOS and Mono.Android libraries are not present when running the trimmer analysis during our build. A consuming application will get a warning if these libraries aren't present when trimming the full app. - - - ILLink - IL2035 - member - M:System.Net.Http.HttpClientHandler.SetMaxAutomaticRedirections(System.Int32) - The Xamarin.iOS and Mono.Android libraries are not present when running the trimmer analysis during our build. A consuming application will get a warning if these libraries aren't present when trimming the full app. - - - ILLink - IL2035 - member - M:System.Net.Http.HttpClientHandler.GetServerCertificateCustomValidationCallback - - - ILLink - IL2035 - member - M:System.Net.Http.HttpClientHandler.SetServerCertificateCustomValidationCallback(System.Func{System.Net.Http.HttpRequestMessage,System.Security.Cryptography.X509Certificates.X509Certificate2,System.Security.Cryptography.X509Certificates.X509Chain,System.Net.Security.SslPolicyErrors,System.Boolean}) - - - ILLink - IL2035 - member - M:System.Net.Http.HttpClientHandler.GetCheckCertificateRevocationList() - The Xamarin.iOS and Mono.Android libraries are not present when running the trimmer analysis during our build. A consuming application will get a warning if these libraries aren't present when trimming the full app. - - - ILLink - IL2035 - member - M:System.Net.Http.HttpClientHandler.SetCheckCertificateRevocationList(System.Boolean) - The Xamarin.iOS and Mono.Android libraries are not present when running the trimmer analysis during our build. A consuming application will get a warning if these libraries aren't present when trimming the full app. - - - ILLink - IL2035 - member - M:System.Net.Http.HttpClientHandler.GetClientCertificateOptions() - The Xamarin.iOS and Mono.Android libraries are not present when running the trimmer analysis during our build. A consuming application will get a warning if these libraries aren't present when trimming the full app. - - - ILLink - IL2035 - member - M:System.Net.Http.HttpClientHandler.SetClientCertificateOptions(System.Net.Http.ClientCertificateOption) - The Xamarin.iOS and Mono.Android libraries are not present when running the trimmer analysis during our build. A consuming application will get a warning if these libraries aren't present when trimming the full app. - - - ILLink - IL2035 - member - M:System.Net.Http.HttpClientHandler.GetClientCertificates() - The Xamarin.iOS and Mono.Android libraries are not present when running the trimmer analysis during our build. A consuming application will get a warning if these libraries aren't present when trimming the full app. - - - ILLink - IL2035 - member - M:System.Net.Http.HttpClientHandler.GetDefaultProxyCredentials() - The Xamarin.iOS and Mono.Android libraries are not present when running the trimmer analysis during our build. A consuming application will get a warning if these libraries aren't present when trimming the full app. - - - ILLink - IL2035 - member - M:System.Net.Http.HttpClientHandler.SetDefaultProxyCredentials(System.Net.ICredentials) - The Xamarin.iOS and Mono.Android libraries are not present when running the trimmer analysis during our build. A consuming application will get a warning if these libraries aren't present when trimming the full app. - - - ILLink - IL2035 - member - M:System.Net.Http.HttpClientHandler.GetMaxConnectionsPerServer() - The Xamarin.iOS and Mono.Android libraries are not present when running the trimmer analysis during our build. A consuming application will get a warning if these libraries aren't present when trimming the full app. - - - ILLink - IL2035 - member - M:System.Net.Http.HttpClientHandler.SetMaxConnectionsPerServer(System.Int32) - The Xamarin.iOS and Mono.Android libraries are not present when running the trimmer analysis during our build. A consuming application will get a warning if these libraries aren't present when trimming the full app. - - - ILLink - IL2035 - member - M:System.Net.Http.HttpClientHandler.GetMaxResponseHeadersLength() - The Xamarin.iOS and Mono.Android libraries are not present when running the trimmer analysis during our build. A consuming application will get a warning if these libraries aren't present when trimming the full app. - - - ILLink - IL2035 - member - M:System.Net.Http.HttpClientHandler.SetMaxResponseHeadersLength(System.Int32) - The Xamarin.iOS and Mono.Android libraries are not present when running the trimmer analysis during our build. A consuming application will get a warning if these libraries aren't present when trimming the full app. - - - ILLink - IL2035 - member - M:System.Net.Http.HttpClientHandler.GetProperties() - The Xamarin.iOS and Mono.Android libraries are not present when running the trimmer analysis during our build. A consuming application will get a warning if these libraries aren't present when trimming the full app. - - - ILLink - IL2035 - member - M:System.Net.Http.HttpClientHandler.GetSupportsAutomaticDecompression() - The Xamarin.iOS and Mono.Android libraries are not present when running the trimmer analysis during our build. A consuming application will get a warning if these libraries aren't present when trimming the full app. - - - ILLink - IL2035 - member - M:System.Net.Http.HttpClientHandler.GetSupportsProxy() - The Xamarin.iOS and Mono.Android libraries are not present when running the trimmer analysis during our build. A consuming application will get a warning if these libraries aren't present when trimming the full app. - - - ILLink - IL2035 - member - M:System.Net.Http.HttpClientHandler.GetSupportsRedirectConfiguration() - The Xamarin.iOS and Mono.Android libraries are not present when running the trimmer analysis during our build. A consuming application will get a warning if these libraries aren't present when trimming the full app. - - - ILLink - IL2035 - member - M:System.Net.Http.HttpClientHandler.GetSslProtocols() - The Xamarin.iOS and Mono.Android libraries are not present when running the trimmer analysis during our build. A consuming application will get a warning if these libraries aren't present when trimming the full app. - - - ILLink - IL2035 - member - M:System.Net.Http.HttpClientHandler.SetSslProtocols(System.Security.Authentication.SslProtocols) - The Xamarin.iOS and Mono.Android libraries are not present when running the trimmer analysis during our build. A consuming application will get a warning if these libraries aren't present when trimming the full app. - - - diff --git a/src/libraries/System.Net.Http/src/System.Net.Http.csproj b/src/libraries/System.Net.Http/src/System.Net.Http.csproj index f9c229575e1214..67909e99ba391b 100644 --- a/src/libraries/System.Net.Http/src/System.Net.Http.csproj +++ b/src/libraries/System.Net.Http/src/System.Net.Http.csproj @@ -25,7 +25,6 @@ - @@ -316,11 +315,11 @@ - + - + - \ No newline at end of file + diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/HttpClientHandler.AnyMobile.InvokeNativeHandler.cs b/src/libraries/System.Net.Http/src/System/Net/Http/HttpClientHandler.AnyMobile.InvokeNativeHandler.cs index c11aa3a0c323f5..8f8c101b9e4a52 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/HttpClientHandler.AnyMobile.InvokeNativeHandler.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/HttpClientHandler.AnyMobile.InvokeNativeHandler.cs @@ -6,6 +6,8 @@ using System.Diagnostics.CodeAnalysis; using System.Net.Security; using System.Reflection; +using System.Runtime.CompilerServices; +using System.Runtime.ExceptionServices; using System.Security.Authentication; using System.Security.Cryptography.X509Certificates; @@ -16,135 +18,130 @@ public partial class HttpClientHandler : HttpMessageHandler private static MethodInfo? _nativeHandlerMethod; #if TARGET_ANDROID - private const string NativeHandlerType = "Xamarin.Android.Net.AndroidMessageHandler"; - private const string AssemblyName = "Mono.Android"; + private const string NativeHandlerType = "Xamarin.Android.Net.AndroidMessageHandler, Mono.Android"; private const string GetHttpMessageHandlerType = "Android.Runtime.AndroidEnvironment, Mono.Android"; #elif TARGET_IOS - private const string NativeHandlerType = "System.Net.Http.NSUrlSessionHandler"; - private const string AssemblyName = "Microsoft.iOS"; + private const string NativeHandlerType = "System.Net.Http.NSUrlSessionHandler, Microsoft.iOS"; private const string GetHttpMessageHandlerType = "ObjCRuntime.RuntimeOptions, Microsoft.iOS"; #elif TARGET_MACCATALYST - private const string NativeHandlerType = "System.Net.Http.NSUrlSessionHandler"; - private const string AssemblyName = "Microsoft.MacCatalyst"; + private const string NativeHandlerType = "System.Net.Http.NSUrlSessionHandler, Microsoft.MacCatalyst"; private const string GetHttpMessageHandlerType = "ObjCRuntime.RuntimeOptions, Microsoft.MacCatalyst"; #elif TARGET_TVOS - private const string NativeHandlerType = "System.Net.Http.NSUrlSessionHandler"; - private const string AssemblyName = "Microsoft.tvOS"; + private const string NativeHandlerType = "System.Net.Http.NSUrlSessionHandler, Microsoft.tvOS"; private const string GetHttpMessageHandlerType = "ObjCRuntime.RuntimeOptions, Microsoft.tvOS"; #else #error Unknown target #endif - [DynamicDependency("get_DefaultProxyCredentials", NativeHandlerType, AssemblyName)] - private ICredentials? GetDefaultProxyCredentials() => (ICredentials?)InvokeNativeHandlerMethod("get_DefaultProxyCredentials"); + private ICredentials? GetDefaultProxyCredentials() + => (ICredentials?)InvokeNativeHandlerGetter(() => Type.GetType(NativeHandlerType)!.GetMethod("get_DefaultProxyCredentials")!); - [DynamicDependency("set_DefaultProxyCredentials", NativeHandlerType, AssemblyName)] - private void SetDefaultProxyCredentials(ICredentials? value) => InvokeNativeHandlerMethod("set_DefaultProxyCredentials", value); + private void SetDefaultProxyCredentials(ICredentials? value) + => InvokeNativeHandlerSetter(() => Type.GetType(NativeHandlerType)!.GetMethod("set_DefaultProxyCredentials")!, value); - [DynamicDependency("get_MaxConnectionsPerServer", NativeHandlerType, AssemblyName)] - private int GetMaxConnectionsPerServer() => (int)InvokeNativeHandlerMethod("get_MaxConnectionsPerServer"); + private int GetMaxConnectionsPerServer() + => (int)InvokeNativeHandlerGetter(() => Type.GetType(NativeHandlerType)!.GetMethod("get_MaxConnectionsPerServer")!); - [DynamicDependency("set_MaxConnectionsPerServer", NativeHandlerType, AssemblyName)] - private void SetMaxConnectionsPerServer(int value) => InvokeNativeHandlerMethod("set_MaxConnectionsPerServer", value); + private void SetMaxConnectionsPerServer(int value) + => InvokeNativeHandlerSetter(() => Type.GetType(NativeHandlerType)!.GetMethod("set_MaxConnectionsPerServer")!, value); - [DynamicDependency("get_MaxResponseHeadersLength", NativeHandlerType, AssemblyName)] - private int GetMaxResponseHeadersLength() => (int)InvokeNativeHandlerMethod("get_MaxResponseHeadersLength"); + private int GetMaxResponseHeadersLength() + => (int)InvokeNativeHandlerGetter(() => Type.GetType(NativeHandlerType)!.GetMethod("get_MaxResponseHeadersLength")!); - [DynamicDependency("set_MaxResponseHeadersLength", NativeHandlerType, AssemblyName)] - private void SetMaxResponseHeadersLength(int value) => InvokeNativeHandlerMethod("set_MaxResponseHeadersLength", value); + private void SetMaxResponseHeadersLength(int value) + => InvokeNativeHandlerSetter(() => Type.GetType(NativeHandlerType)!.GetMethod("set_MaxResponseHeadersLength")!, value); - [DynamicDependency("get_ClientCertificateOptions", NativeHandlerType, AssemblyName)] - private ClientCertificateOption GetClientCertificateOptions() => (ClientCertificateOption)InvokeNativeHandlerMethod("get_ClientCertificateOptions"); + private ClientCertificateOption GetClientCertificateOptions() + => (ClientCertificateOption)InvokeNativeHandlerGetter(() => Type.GetType(NativeHandlerType)!.GetMethod("get_ClientCertificateOptions")!); - [DynamicDependency("set_ClientCertificateOptions", NativeHandlerType, AssemblyName)] - private void SetClientCertificateOptions(ClientCertificateOption value) => InvokeNativeHandlerMethod("set_ClientCertificateOptions", value); + private void SetClientCertificateOptions(ClientCertificateOption value) + => InvokeNativeHandlerSetter(() => Type.GetType(NativeHandlerType)!.GetMethod("set_ClientCertificateOptions")!, value); - [DynamicDependency("get_ClientCertificates", NativeHandlerType, AssemblyName)] - private X509CertificateCollection GetClientCertificates() => (X509CertificateCollection)InvokeNativeHandlerMethod("get_ClientCertificates"); + private X509CertificateCollection GetClientCertificates() + => (X509CertificateCollection)InvokeNativeHandlerGetter(() => Type.GetType(NativeHandlerType)!.GetMethod("get_ClientCertificates")!); - [DynamicDependency("get_ServerCertificateCustomValidationCallback", NativeHandlerType, AssemblyName)] - private Func GetServerCertificateCustomValidationCallback() => (Func)InvokeNativeHandlerMethod("get_ServerCertificateCustomValidationCallback"); + private Func GetServerCertificateCustomValidationCallback() + => (Func)InvokeNativeHandlerGetter(() => Type.GetType(NativeHandlerType)!.GetMethod("get_ServerCertificateCustomValidationCallback")!); - [DynamicDependency("set_ServerCertificateCustomValidationCallback", NativeHandlerType, AssemblyName)] - private void SetServerCertificateCustomValidationCallback(Func? value) => InvokeNativeHandlerMethod("set_ServerCertificateCustomValidationCallback", value); + private void SetServerCertificateCustomValidationCallback(Func? value) + => InvokeNativeHandlerSetter(() => Type.GetType(NativeHandlerType)!.GetMethod("set_ServerCertificateCustomValidationCallback")!, value); - [DynamicDependency("get_CheckCertificateRevocationList", NativeHandlerType, AssemblyName)] - private bool GetCheckCertificateRevocationList() => (bool)InvokeNativeHandlerMethod("get_CheckCertificateRevocationList"); + private bool GetCheckCertificateRevocationList() + => (bool)InvokeNativeHandlerGetter(() => Type.GetType(NativeHandlerType)!.GetMethod("get_CheckCertificateRevocationList")!); - [DynamicDependency("set_CheckCertificateRevocationList", NativeHandlerType, AssemblyName)] - private void SetCheckCertificateRevocationList(bool value) => InvokeNativeHandlerMethod("set_CheckCertificateRevocationList", value); + private void SetCheckCertificateRevocationList(bool value) + => InvokeNativeHandlerSetter(() => Type.GetType(NativeHandlerType)!.GetMethod("set_CheckCertificateRevocationList")!, value); - [DynamicDependency("get_SslProtocols", NativeHandlerType, AssemblyName)] - private SslProtocols GetSslProtocols() => (SslProtocols)InvokeNativeHandlerMethod("get_SslProtocols"); + private SslProtocols GetSslProtocols() + => (SslProtocols)InvokeNativeHandlerGetter(() => Type.GetType(NativeHandlerType)!.GetMethod("get_SslProtocols")!); - [DynamicDependency("set_SslProtocols", NativeHandlerType, AssemblyName)] - private void SetSslProtocols(SslProtocols value) => InvokeNativeHandlerMethod("set_SslProtocols", value); + private void SetSslProtocols(SslProtocols value) + => InvokeNativeHandlerSetter(() => Type.GetType(NativeHandlerType)!.GetMethod("set_SslProtocols")!, value); - [DynamicDependency("get_Properties", NativeHandlerType, AssemblyName)] - private IDictionary GetProperties() => (IDictionary)InvokeNativeHandlerMethod("get_Properties"); + private IDictionary GetProperties() + => (IDictionary)InvokeNativeHandlerGetter(() => Type.GetType(NativeHandlerType)!.GetMethod("get_Properties")!); - [DynamicDependency("get_SupportsAutomaticDecompression", NativeHandlerType, AssemblyName)] - private bool GetSupportsAutomaticDecompression() => (bool)InvokeNativeHandlerMethod("get_SupportsAutomaticDecompression"); + private bool GetSupportsAutomaticDecompression() + => (bool)InvokeNativeHandlerGetter(() => Type.GetType(NativeHandlerType)!.GetMethod("get_SupportsAutomaticDecompression")!); - [DynamicDependency("get_SupportsProxy", NativeHandlerType, AssemblyName)] - private bool GetSupportsProxy() => (bool)InvokeNativeHandlerMethod("get_SupportsProxy"); + private bool GetSupportsProxy() + => (bool)InvokeNativeHandlerGetter(() => Type.GetType(NativeHandlerType)!.GetMethod("get_SupportsProxy")!); - [DynamicDependency("get_SupportsRedirectConfiguration", NativeHandlerType, AssemblyName)] - private bool GetSupportsRedirectConfiguration() => (bool)InvokeNativeHandlerMethod("get_SupportsRedirectConfiguration"); + private bool GetSupportsRedirectConfiguration() + => (bool)InvokeNativeHandlerGetter(() => Type.GetType(NativeHandlerType)!.GetMethod("get_SupportsRedirectConfiguration")!); - [DynamicDependency("get_AutomaticDecompression", NativeHandlerType, AssemblyName)] - private DecompressionMethods GetAutomaticDecompression() => (DecompressionMethods)InvokeNativeHandlerMethod("get_AutomaticDecompression"); + private DecompressionMethods GetAutomaticDecompression() + => (DecompressionMethods)InvokeNativeHandlerGetter(() => Type.GetType(NativeHandlerType)!.GetMethod("get_AutomaticDecompression")!); - [DynamicDependency("set_AutomaticDecompression", NativeHandlerType, AssemblyName)] - private void SetAutomaticDecompression(DecompressionMethods value) => InvokeNativeHandlerMethod("set_AutomaticDecompression", value); + private void SetAutomaticDecompression(DecompressionMethods value) + => InvokeNativeHandlerSetter(() => Type.GetType(NativeHandlerType)!.GetMethod("set_AutomaticDecompression")!, value); - [DynamicDependency("get_UseProxy", NativeHandlerType, AssemblyName)] - private bool GetUseProxy() => (bool)InvokeNativeHandlerMethod("get_UseProxy"); + private bool GetUseProxy() + => (bool)InvokeNativeHandlerGetter(() => Type.GetType(NativeHandlerType)!.GetMethod("get_UseProxy")!); - [DynamicDependency("set_UseProxy", NativeHandlerType, AssemblyName)] - private void SetUseProxy(bool value) => InvokeNativeHandlerMethod("set_UseProxy", value); + private void SetUseProxy(bool value) + => InvokeNativeHandlerSetter(() => Type.GetType(NativeHandlerType)!.GetMethod("set_UseProxy")!, value); - [DynamicDependency("get_Proxy", NativeHandlerType, AssemblyName)] - private IWebProxy GetProxy() => (IWebProxy)InvokeNativeHandlerMethod("get_Proxy"); + private IWebProxy GetProxy() + => (IWebProxy)InvokeNativeHandlerGetter(() => Type.GetType(NativeHandlerType)!.GetMethod("get_Proxy")!); - [DynamicDependency("set_Proxy", NativeHandlerType, AssemblyName)] - private void SetProxy(IWebProxy value) => InvokeNativeHandlerMethod("set_Proxy", value); + private void SetProxy(IWebProxy value) + => InvokeNativeHandlerSetter(() => Type.GetType(NativeHandlerType)!.GetMethod("set_Proxy")!, value); - [DynamicDependency("get_PreAuthenticate", NativeHandlerType, AssemblyName)] - private bool GetPreAuthenticate() => (bool)InvokeNativeHandlerMethod("get_PreAuthenticate"); + private bool GetPreAuthenticate() + => (bool)InvokeNativeHandlerGetter(() => Type.GetType(NativeHandlerType)!.GetMethod("get_PreAuthenticate")!); - [DynamicDependency("set_PreAuthenticate", NativeHandlerType, AssemblyName)] - private void SetPreAuthenticate(bool value) => InvokeNativeHandlerMethod("set_PreAuthenticate", value); + private void SetPreAuthenticate(bool value) + => InvokeNativeHandlerSetter(() => Type.GetType(NativeHandlerType)!.GetMethod("set_PreAuthenticate")!, value); - [DynamicDependency("get_MaxAutomaticRedirections", NativeHandlerType, AssemblyName)] - private int GetMaxAutomaticRedirections() => (int)InvokeNativeHandlerMethod("get_MaxAutomaticRedirections"); + private int GetMaxAutomaticRedirections() + => (int)InvokeNativeHandlerGetter(() => Type.GetType(NativeHandlerType)!.GetMethod("get_MaxAutomaticRedirections")!); - [DynamicDependency("set_MaxAutomaticRedirections", NativeHandlerType, AssemblyName)] - private void SetMaxAutomaticRedirections(int value) => InvokeNativeHandlerMethod("set_MaxAutomaticRedirections", value); + private void SetMaxAutomaticRedirections(int value) + => InvokeNativeHandlerSetter(() => Type.GetType(NativeHandlerType)!.GetMethod("set_MaxAutomaticRedirections")!, value); - [DynamicDependency("get_UseCookies", NativeHandlerType, AssemblyName)] - private bool GetUseCookies() => (bool)InvokeNativeHandlerMethod("get_UseCookies"); + private bool GetUseCookies() => (bool)InvokeNativeHandlerGetter(() => Type.GetType(NativeHandlerType)!.GetMethod("get_UseCookies")!); - [DynamicDependency("set_UseCookies", NativeHandlerType, AssemblyName)] - private void SetUseCookies(bool value) => InvokeNativeHandlerMethod("set_UseCookies", value); + private void SetUseCookies(bool value) + => InvokeNativeHandlerSetter(() => Type.GetType(NativeHandlerType)!.GetMethod("set_UseCookies")!, value); - [DynamicDependency("get_CookieContainer", NativeHandlerType, AssemblyName)] - private CookieContainer GetCookieContainer() => (CookieContainer)InvokeNativeHandlerMethod("get_CookieContainer"); + private CookieContainer GetCookieContainer() + => (CookieContainer)InvokeNativeHandlerGetter(() => Type.GetType(NativeHandlerType)!.GetMethod("get_CookieContainer")!); - [DynamicDependency("set_CookieContainer", NativeHandlerType, AssemblyName)] - private void SetCookieContainer(CookieContainer value) => InvokeNativeHandlerMethod("set_CookieContainer", value); + private void SetCookieContainer(CookieContainer value) + => InvokeNativeHandlerSetter(() => Type.GetType(NativeHandlerType)!.GetMethod("set_CookieContainer")!, value); - [DynamicDependency("get_AllowAutoRedirect", NativeHandlerType, AssemblyName)] - private bool GetAllowAutoRedirect() => (bool)InvokeNativeHandlerMethod("get_AllowAutoRedirect"); + private bool GetAllowAutoRedirect() + => (bool)InvokeNativeHandlerGetter(() => Type.GetType(NativeHandlerType)!.GetMethod("get_AllowAutoRedirect")!); - [DynamicDependency("set_AllowAutoRedirect", NativeHandlerType, AssemblyName)] - private void SetAllowAutoRedirect(bool value) => InvokeNativeHandlerMethod("set_AllowAutoRedirect", value); + private void SetAllowAutoRedirect(bool value) + => InvokeNativeHandlerSetter(() => Type.GetType(NativeHandlerType)!.GetMethod("set_AllowAutoRedirect")!, value); - [DynamicDependency("get_Credentials", NativeHandlerType, AssemblyName)] - private ICredentials GetCredentials() => (ICredentials)InvokeNativeHandlerMethod("get_Credentials"); + private ICredentials GetCredentials() + => (ICredentials)InvokeNativeHandlerGetter(() => Type.GetType(NativeHandlerType)!.GetMethod("get_Credentials")!); - [DynamicDependency("set_Credentials", NativeHandlerType, AssemblyName)] - private void SetCredentials(ICredentials? value) => InvokeNativeHandlerMethod("set_Credentials", value); + private void SetCredentials(ICredentials? value) + => InvokeNativeHandlerSetter(() => Type.GetType(NativeHandlerType)!.GetMethod("set_Credentials")!, value); private static HttpMessageHandler CreateNativeHandler() { @@ -156,5 +153,36 @@ private static HttpMessageHandler CreateNativeHandler() return (HttpMessageHandler)_nativeHandlerMethod!.Invoke(null, null)!; } + + private object InvokeNativeHandlerGetter(Func getMethod, [CallerMemberName] string? cachingKey = null) + { + return InvokeNativeHandlerMethod(getMethod, parameters: null, cachingKey!); + } + + private void InvokeNativeHandlerSetter(Func getMethod, object? value, [CallerMemberName] string? cachingKey = null) + { + InvokeNativeHandlerMethod(getMethod, parameters: new object?[] { value }, cachingKey!); + } + + private object InvokeNativeHandlerMethod(Func getMethod, object?[]? parameters, string cachingKey) + { + MethodInfo? method; + + if (!s_cachedMethods.TryGetValue(cachingKey, out method)) + { + method = getMethod(); + s_cachedMethods[cachingKey] = method; + } + + try + { + return method!.Invoke(_nativeHandler, parameters)!; + } + catch (TargetInvocationException e) + { + ExceptionDispatchInfo.Capture(e.InnerException!).Throw(); + throw; + } + } } } diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/HttpClientHandler.AnyMobile.cs b/src/libraries/System.Net.Http/src/System/Net/Http/HttpClientHandler.AnyMobile.cs index d1b3db12b8ed1c..4b04787d3acdce 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/HttpClientHandler.AnyMobile.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/HttpClientHandler.AnyMobile.cs @@ -4,12 +4,12 @@ using System.Collections.Concurrent; using System.Collections.Generic; using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; using System.Diagnostics.Metrics; using System.Globalization; using System.Net.Http.Metrics; using System.Net.Security; using System.Reflection; -using System.Runtime.ExceptionServices; using System.Runtime.Versioning; using System.Security.Authentication; using System.Security.Cryptography.X509Certificates; @@ -51,10 +51,7 @@ private HttpMessageHandler Handler MetricsHandler metricsHandler = new MetricsHandler(handler, _nativeMeterFactory, out _); // Ensure a single handler is used for all requests. - if (Interlocked.CompareExchange(ref _nativeMetricsHandler, metricsHandler, null) != null) - { - handler.Dispose(); - } + Interlocked.CompareExchange(ref _nativeMetricsHandler, metricsHandler, null); } return _nativeMetricsHandler; @@ -87,7 +84,7 @@ protected override void Dispose(bool disposing) if (IsNativeHandlerEnabled) { - _nativeHandler!.Dispose(); + Handler.Dispose(); } else { @@ -796,27 +793,6 @@ private void ThrowForModifiedManagedSslOptionsIfStarted() _socketHandler!.SslOptions = _socketHandler!.SslOptions; } - private object InvokeNativeHandlerMethod(string name, params object?[] parameters) - { - MethodInfo? method; - - if (!s_cachedMethods.TryGetValue(name, out method)) - { - method = _nativeHandler!.GetType()!.GetMethod(name); - s_cachedMethods[name] = method; - } - - try - { - return method!.Invoke(_nativeHandler, parameters)!; - } - catch (TargetInvocationException e) - { - ExceptionDispatchInfo.Capture(e.InnerException!).Throw(); - throw; - } - } - private static bool IsNativeHandlerEnabled => RuntimeSettingParser.QueryRuntimeSettingSwitch( "System.Net.Http.UseNativeHttpHandler", false); diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/Metrics/MetricsHandler.cs b/src/libraries/System.Net.Http/src/System/Net/Http/Metrics/MetricsHandler.cs index 3c34e43011938e..4de6df347972fe 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/Metrics/MetricsHandler.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/Metrics/MetricsHandler.cs @@ -112,11 +112,12 @@ private void RequestStop(HttpRequestMessage request, HttpResponseMessage? respon tags.Add("http.response.status_code", GetBoxedStatusCode((int)response.StatusCode)); tags.Add("network.protocol.version", GetProtocolVersionString(response.Version)); } - else + + if (TryGetErrorType(response, exception, out string? errorType)) { - Debug.Assert(exception is not null); - tags.Add("http.error.reason", GetErrorReason(exception)); + tags.Add("error.type", errorType); } + TimeSpan durationTime = Stopwatch.GetElapsedTime(startTimestamp, Stopwatch.GetTimestamp()); HttpMetricsEnrichmentContext? enrichmentContext = HttpMetricsEnrichmentContext.GetEnrichmentContextForRequest(request); @@ -130,37 +131,47 @@ private void RequestStop(HttpRequestMessage request, HttpResponseMessage? respon } } - private static string GetErrorReason(Exception exception) + private static bool TryGetErrorType(HttpResponseMessage? response, Exception? exception, out string? errorType) { - if (exception is HttpRequestException e) + if (response is not null) { - Debug.Assert(Enum.GetValues().Length == 12, "We need to extend the mapping in case new values are added to HttpRequestError."); + int statusCode = (int)response.StatusCode; - string? errorReason = e.HttpRequestError switch + // In case the status code indicates a client or a server error, return the string representation of the status code. + // See the paragraph Status and the definition of 'error.type' in + // https://github.com/open-telemetry/semantic-conventions/blob/2bad9afad58fbd6b33cc683d1ad1f006e35e4a5d/docs/http/http-spans.md + if (statusCode >= 400 && statusCode <= 599) { - HttpRequestError.NameResolutionError => "name_resolution_error", - HttpRequestError.ConnectionError => "connection_error", - HttpRequestError.SecureConnectionError => "secure_connection_error", - HttpRequestError.HttpProtocolError => "http_protocol_error", - HttpRequestError.ExtendedConnectNotSupported => "extended_connect_not_supported", - HttpRequestError.VersionNegotiationError => "version_negotiation_error", - HttpRequestError.UserAuthenticationError => "user_authentication_error", - HttpRequestError.ProxyTunnelError => "proxy_tunnel_error", - HttpRequestError.InvalidResponse => "invalid_response", - HttpRequestError.ResponseEnded => "response_ended", - HttpRequestError.ConfigurationLimitExceeded => "configuration_limit_exceeded", - - // Fall back to the exception type name (including for HttpRequestError.Unknown). - _ => null - }; - - if (errorReason is not null) - { - return errorReason; + errorType = GetErrorStatusCodeString(statusCode); + return true; } } - return exception.GetType().Name; + if (exception is null) + { + errorType = null; + return false; + } + + Debug.Assert(Enum.GetValues().Length == 12, "We need to extend the mapping in case new values are added to HttpRequestError."); + errorType = (exception as HttpRequestException)?.HttpRequestError switch + { + HttpRequestError.NameResolutionError => "name_resolution_error", + HttpRequestError.ConnectionError => "connection_error", + HttpRequestError.SecureConnectionError => "secure_connection_error", + HttpRequestError.HttpProtocolError => "http_protocol_error", + HttpRequestError.ExtendedConnectNotSupported => "extended_connect_not_supported", + HttpRequestError.VersionNegotiationError => "version_negotiation_error", + HttpRequestError.UserAuthenticationError => "user_authentication_error", + HttpRequestError.ProxyTunnelError => "proxy_tunnel_error", + HttpRequestError.InvalidResponse => "invalid_response", + HttpRequestError.ResponseEnded => "response_ended", + HttpRequestError.ConfigurationLimitExceeded => "configuration_limit_exceeded", + + // Fall back to the exception type name in case of HttpRequestError.Unknown or when exception is not an HttpRequestException. + _ => exception.GetType().FullName! + }; + return true; } private static string GetProtocolVersionString(Version httpVersion) => (httpVersion.Major, httpVersion.Minor) switch @@ -199,6 +210,7 @@ private static TagList InitializeCommonTags(HttpRequestMessage request) } private static object[]? s_boxedStatusCodes; + private static string[]? s_statusCodeStrings; private static object GetBoxedStatusCode(int statusCode) { @@ -209,6 +221,17 @@ private static object GetBoxedStatusCode(int statusCode) : statusCode; } + private static string GetErrorStatusCodeString(int statusCode) + { + Debug.Assert(statusCode >= 400 && statusCode <= 599); + + string[] strings = LazyInitializer.EnsureInitialized(ref s_statusCodeStrings, static () => new string[200]); + int index = statusCode - 400; + return (uint)index < (uint)strings.Length + ? strings[index] ??= statusCode.ToString() + : statusCode.ToString(); + } + private sealed class SharedMeter : Meter { public static Meter Instance { get; } = new SharedMeter(); diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Metrics/ConnectionMetrics.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Metrics/ConnectionMetrics.cs index a9498cdc948dfb..12dbdface7398f 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Metrics/ConnectionMetrics.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Metrics/ConnectionMetrics.cs @@ -14,10 +14,10 @@ internal sealed class ConnectionMetrics private readonly object _schemeTag; private readonly object _hostTag; private readonly object? _portTag; - private readonly object? _socketAddressTag; + private readonly object? _peerAddressTag; private bool _currentlyIdle; - public ConnectionMetrics(SocketsHttpHandlerMetrics metrics, string protocolVersion, string scheme, string host, int? port, string? socketAddress) + public ConnectionMetrics(SocketsHttpHandlerMetrics metrics, string protocolVersion, string scheme, string host, int? port, string? peerAddress) { _metrics = metrics; _openConnectionsEnabled = _metrics.OpenConnections.Enabled; @@ -25,7 +25,7 @@ public ConnectionMetrics(SocketsHttpHandlerMetrics metrics, string protocolVersi _schemeTag = scheme; _hostTag = host; _portTag = port; - _socketAddressTag = socketAddress; + _peerAddressTag = peerAddress; } // TagList is a huge struct, so we avoid storing it in a field to reduce the amount we allocate on the heap. @@ -42,9 +42,9 @@ private TagList GetTags() tags.Add("server.port", _portTag); } - if (_socketAddressTag is not null) + if (_peerAddressTag is not null) { - tags.Add("server.socket.address", _socketAddressTag); + tags.Add("network.peer.address", _peerAddressTag); } return tags; diff --git a/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http3.cs b/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http3.cs index 92c4ab0d6097f6..844d2866bde61e 100644 --- a/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http3.cs +++ b/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http3.cs @@ -1626,6 +1626,7 @@ public async Task DuplexStreaming_AbortByServer_StreamingCancelled(bool graceful public async Task ServerSendsTrailingHeaders_Success() { using Http3LoopbackServer server = CreateHttp3LoopbackServer(); + SemaphoreSlim clientFinishedSemaphore = new SemaphoreSlim(0); Task serverTask = Task.Run(async () => { @@ -1636,6 +1637,7 @@ public async Task ServerSendsTrailingHeaders_Success() await requestStream.ReadRequestDataAsync(); await requestStream.SendResponseAsync(isFinal: false); await requestStream.SendResponseHeadersAsync(null, new[] { new HttpHeaderData("MyHeader", "MyValue") }); + await clientFinishedSemaphore.WaitAsync(TimeSpan.FromSeconds(20)); }); Task clientTask = Task.Run(async () => @@ -1655,6 +1657,7 @@ public async Task ServerSendsTrailingHeaders_Success() (string key, IEnumerable value) = Assert.Single(response.TrailingHeaders); Assert.Equal("MyHeader", key); Assert.Equal("MyValue", Assert.Single(value)); + clientFinishedSemaphore.Release(); }); await new[] { clientTask, serverTask }.WhenAllOrAnyFailed(200_000); diff --git a/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTestBase.SocketsHttpHandler.cs b/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTestBase.SocketsHttpHandler.cs index 602d177f5f1be1..888b38b813127e 100644 --- a/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTestBase.SocketsHttpHandler.cs +++ b/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTestBase.SocketsHttpHandler.cs @@ -3,14 +3,31 @@ using System.IO; using System.Net.Quic; +using System.Net.Sockets; using System.Net.Test.Common; using System.Reflection; +using System.Threading; using System.Threading.Tasks; namespace System.Net.Http.Functional.Tests { public abstract partial class HttpClientHandlerTestBase : FileCleanupTestBase { + protected static async Task DefaultConnectCallback(EndPoint endPoint, CancellationToken cancellationToken) + { + Socket socket = new Socket(SocketType.Stream, ProtocolType.Tcp) { NoDelay = true }; + try + { + await socket.ConnectAsync(endPoint, cancellationToken); + return new NetworkStream(socket, ownsSocket: true); + } + catch + { + socket.Dispose(); + throw; + } + } + protected static bool IsWinHttpHandler => false; public static bool IsQuicSupported diff --git a/src/libraries/System.Net.Http/tests/FunctionalTests/MetricsTest.cs b/src/libraries/System.Net.Http/tests/FunctionalTests/MetricsTest.cs index f3a97d2f15d530..4bf2638f9d3515 100644 --- a/src/libraries/System.Net.Http/tests/FunctionalTests/MetricsTest.cs +++ b/src/libraries/System.Net.Http/tests/FunctionalTests/MetricsTest.cs @@ -30,7 +30,7 @@ protected static class InstrumentNames public const string ConnectionDuration = "http.client.connection.duration"; public const string TimeInQueue = "http.client.request.time_in_queue"; } - + protected HttpMetricsTestBase(ITestOutputHelper output) : base(output) { } @@ -47,9 +47,9 @@ protected static void VerifyTag(KeyValuePair[] tags, string } } - private static void VerifySocketAddress(KeyValuePair[] tags) + private static void VerifyPeerAddress(KeyValuePair[] tags) { - string ipString = (string)tags.Single(t => t.Key == "server.socket.address").Value; + string ipString = (string)tags.Single(t => t.Key == "network.peer.address").Value; IPAddress ip = IPAddress.Parse(ipString); Assert.True(ip.Equals(IPAddress.Loopback.MapToIPv6()) || ip.Equals(IPAddress.Loopback) || @@ -75,8 +75,8 @@ protected static void VerifyRequestDuration(Measurement measurement, Version? protocolVersion = null, int? statusCode = null, string method = "GET", - string[] acceptedErrorReasons = null) => - VerifyRequestDuration(InstrumentNames.RequestDuration, measurement.Value, measurement.Tags.ToArray(), uri, protocolVersion, statusCode, method, acceptedErrorReasons); + string[] acceptedErrorTypes = null) => + VerifyRequestDuration(InstrumentNames.RequestDuration, measurement.Value, measurement.Tags.ToArray(), uri, protocolVersion, statusCode, method, acceptedErrorTypes); protected static void VerifyRequestDuration(string instrumentName, double measurement, @@ -85,7 +85,7 @@ protected static void VerifyRequestDuration(string instrumentName, Version? protocolVersion, int? statusCode, string method = "GET", - string[] acceptedErrorReasons = null) + string[] acceptedErrorTypes = null) { Assert.Equal(InstrumentNames.RequestDuration, instrumentName); Assert.InRange(measurement, double.Epsilon, 60); @@ -93,14 +93,14 @@ protected static void VerifyRequestDuration(string instrumentName, VerifyTag(tags, "http.request.method", method); VerifyTag(tags, "network.protocol.version", GetVersionString(protocolVersion)); VerifyTag(tags, "http.response.status_code", statusCode); - if (acceptedErrorReasons == null) + if (acceptedErrorTypes == null) { - Assert.DoesNotContain(tags, t => t.Key == "http.error.reason"); + Assert.DoesNotContain(tags, t => t.Key == "error.type"); } else { - string errorReason = (string)tags.Single(t => t.Key == "http.error.reason").Value; - Assert.Contains(errorReason, acceptedErrorReasons); + string errorReason = (string)tags.Single(t => t.Key == "error.type").Value; + Assert.Contains(errorReason, acceptedErrorTypes); } } @@ -122,7 +122,7 @@ protected static void VerifyOpenConnections(string actualName, object measuremen VerifySchemeHostPortTags(tags, uri); VerifyTag(tags, "network.protocol.version", GetVersionString(protocolVersion)); VerifyTag(tags, "http.connection.state", state); - VerifySocketAddress(tags); + VerifyPeerAddress(tags); } protected static void VerifyConnectionDuration(string instrumentName, object measurement, KeyValuePair[] tags, Uri uri, Version? protocolVersion) @@ -132,7 +132,7 @@ protected static void VerifyConnectionDuration(string instrumentName, object mea Assert.InRange(value, double.Epsilon, 60); VerifySchemeHostPortTags(tags, uri); VerifyTag(tags, "network.protocol.version", GetVersionString(protocolVersion)); - VerifySocketAddress(tags); + VerifyPeerAddress(tags); } protected static void VerifyTimeInQueue(string instrumentName, object measurement, KeyValuePair[] tags, Uri uri, Version? protocolVersion, string method = "GET") @@ -292,17 +292,8 @@ await LoopbackServerFactory.CreateClientAndServerAsync(async uri => GetUnderlyingSocketsHttpHandler(Handler).ConnectCallback = async (ctx, cancellationToken) => { connectionStarted.SetResult(); - Socket socket = new Socket(SocketType.Stream, ProtocolType.Tcp) { NoDelay = true }; - try - { - await socket.ConnectAsync(ctx.DnsEndPoint, cancellationToken); - return new NetworkStream(socket, ownsSocket: true); - } - catch - { - socket.Dispose(); - throw; - } + + return await DefaultConnectCallback(ctx.DnsEndPoint, cancellationToken); }; // Enable recording request-duration to test the path with metrics enabled. @@ -356,7 +347,7 @@ public Task RequestDuration_CustomTags_Recorded() { ctx.AddCustomTag("route", "/test"); }); - + using HttpResponseMessage response = await SendAsync(client, request); Measurement m = Assert.Single(recorder.GetMeasurements()); @@ -464,6 +455,21 @@ await LoopbackServerFactory.CreateClientAndServerAsync(async uri => using InstrumentRecorder recorder = SetupInstrumentRecorder(InstrumentNames.RequestDuration); using HttpRequestMessage request = new(HttpMethod.Get, uri) { Version = UseVersion }; using HttpResponseMessage response = await client.SendAsync(TestAsync, request, completionOption); + string responseContent = await response.Content.ReadAsStringAsync(); + + if (responseContentType == ResponseContentType.ContentLength) + { + Assert.NotNull(response.Content.Headers.ContentLength); + } + else if (responseContentType == ResponseContentType.TransferEncodingChunked) + { + Assert.NotNull(response.Headers.TransferEncodingChunked); + } + else + { + // Empty + Assert.Empty(responseContent); + } Measurement m = Assert.Single(recorder.GetMeasurements()); VerifyRequestDuration(m, uri, UseVersion, 200); ; @@ -664,11 +670,11 @@ await LoopbackServerFactory.CreateClientAndServerAsync(async uri => _output.WriteLine($"Client exception: {clientException}"); string[] expectedExceptionTypes = TestAsync - ? [nameof(TaskCanceledException)] - : [nameof(TaskCanceledException), nameof(OperationCanceledException)]; + ? [typeof(TaskCanceledException).FullName] + : [typeof(TaskCanceledException).FullName, typeof(OperationCanceledException).FullName]; Measurement m = Assert.Single(recorder.GetMeasurements()); - VerifyRequestDuration(m, uri, acceptedErrorReasons: expectedExceptionTypes); + VerifyRequestDuration(m, uri, acceptedErrorTypes: expectedExceptionTypes); clientCompleted.SetResult(); }, @@ -712,7 +718,7 @@ public async Task RequestDuration_ConnectionError_LogsExpectedErrorReason() _output.WriteLine($"Client exception: {ex}"); Measurement m = Assert.Single(recorder.GetMeasurements()); - VerifyRequestDuration(m, uri, acceptedErrorReasons: ["connection_error"]); + VerifyRequestDuration(m, uri, acceptedErrorTypes: ["connection_error"]); } protected override void Dispose(bool disposing) @@ -792,7 +798,7 @@ await Assert.ThrowsAsync(async () => using HttpResponseMessage response = await SendAsync(client, request); }); } - + Measurement m = Assert.Single(recorder.GetMeasurements()); VerifyRequestDuration(m, uri, UseVersion, 200); Assert.Equal("before!", m.Tags.ToArray().Single(t => t.Key == "before").Value); @@ -802,6 +808,29 @@ await Assert.ThrowsAsync(async () => }, content: "x")); } + [Theory] + [InlineData(400)] + [InlineData(404)] + [InlineData(599)] + public Task RequestDuration_ErrorStatus_ErrorTypeRecorded(int statusCode) + { + return LoopbackServerFactory.CreateClientAndServerAsync(async uri => + { + using HttpMessageInvoker client = CreateHttpMessageInvoker(); + using InstrumentRecorder recorder = SetupInstrumentRecorder(InstrumentNames.RequestDuration); + using HttpRequestMessage request = new(HttpMethod.Get, uri) { Version = UseVersion }; + + using HttpResponseMessage response = await SendAsync(client, request); + + Measurement m = Assert.Single(recorder.GetMeasurements()); + VerifyRequestDuration(m, uri, UseVersion, statusCode, "GET", acceptedErrorTypes: new[] { $"{statusCode}" }); + + }, async server => + { + await server.AcceptConnectionSendResponseAndCloseAsync(statusCode: (HttpStatusCode)statusCode); + }); + } + [Fact] [SkipOnPlatform(TestPlatforms.Browser, "Browser is relaxed about validating HTTP headers")] public async Task RequestDuration_ConnectionClosedWhileReceivingHeaders_Recorded() @@ -823,7 +852,7 @@ await LoopbackServerFactory.CreateClientAndServerAsync(async uri => Assert.True(ex is HttpRequestException or TaskCanceledException); Measurement m = Assert.Single(recorder.GetMeasurements()); - VerifyRequestDuration(m, uri, acceptedErrorReasons: [nameof(TaskCanceledException), "response_ended"]); + VerifyRequestDuration(m, uri, acceptedErrorTypes: [typeof(TaskCanceledException).FullName, "response_ended"]); }, async server => { try @@ -878,7 +907,7 @@ await server.AcceptConnectionAsync(async connection => { await Assert.ThrowsAsync(() => clientTask); Measurement m = Assert.Single(recorder.GetMeasurements()); - VerifyRequestDuration(m, server.Address, acceptedErrorReasons: ["response_ended"]); + VerifyRequestDuration(m, server.Address, acceptedErrorTypes: ["response_ended"]); } else { @@ -976,7 +1005,7 @@ await Assert.ThrowsAsync(async () => }); Measurement m = Assert.Single(recorder.GetMeasurements()); - VerifyRequestDuration(m, server.Address, acceptedErrorReasons: ["http_protocol_error"]); + VerifyRequestDuration(m, server.Address, acceptedErrorTypes: ["http_protocol_error"]); } } diff --git a/src/libraries/System.Net.Http/tests/FunctionalTests/NtAuthTests.FakeServer.cs b/src/libraries/System.Net.Http/tests/FunctionalTests/NtAuthTests.FakeServer.cs index 780db637ba335e..c29beca16f4015 100644 --- a/src/libraries/System.Net.Http/tests/FunctionalTests/NtAuthTests.FakeServer.cs +++ b/src/libraries/System.Net.Http/tests/FunctionalTests/NtAuthTests.FakeServer.cs @@ -139,5 +139,32 @@ await server.AcceptConnectionAsync(async connection => }).ConfigureAwait(false); }); } + + [Fact] + [SkipOnPlatform(TestPlatforms.Browser | TestPlatforms.Windows, "DefaultCredentials are unsupported for NTLM on Unix / Managed implementation")] + public async Task DefaultHandler_FakeServer_DefaultCredentials() + { + await LoopbackServer.CreateClientAndServerAsync( + async uri => + { + HttpRequestMessage requestMessage = new HttpRequestMessage(HttpMethod.Get, uri); + requestMessage.Version = new Version(1, 1); + HttpMessageHandler handler = new HttpClientHandler() { Credentials = CredentialCache.DefaultCredentials }; + using (var client = new HttpClient(handler)) + { + HttpResponseMessage response = await client.SendAsync(requestMessage); + Assert.Equal(HttpStatusCode.Unauthorized, response.StatusCode); + } + }, + async server => + { + await server.AcceptConnectionAsync(async connection => + { + var authHeader = "WWW-Authenticate: NTLM\r\n"; + await connection.SendResponseAsync(HttpStatusCode.Unauthorized, authHeader).ConfigureAwait(false); + connection.CompleteRequestProcessing(); + }).ConfigureAwait(false); + }); + } } } diff --git a/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.Cancellation.cs b/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.Cancellation.cs index c793a1d55d6e76..76d7086c37c174 100644 --- a/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.Cancellation.cs +++ b/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.Cancellation.cs @@ -165,9 +165,7 @@ await LoopbackServerFactory.CreateClientAndServerAsync(async uri => else { // Succeed the second connection attempt - Socket socket = new Socket(SocketType.Stream, ProtocolType.Tcp) { NoDelay = true }; - await socket.ConnectAsync(context.DnsEndPoint, token); - return new NetworkStream(socket, ownsSocket: true); + return await DefaultConnectCallback(context.DnsEndPoint, token); } }; diff --git a/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs b/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs index e25f69529b0e60..2613b451be4556 100644 --- a/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs +++ b/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs @@ -1369,17 +1369,7 @@ await RetryHelper.ExecuteAsync(async () => { Assert.Equal("foo", context.DnsEndPoint.Host); - Socket socket = new Socket(SocketType.Stream, ProtocolType.Tcp) { NoDelay = true }; - try - { - await socket.ConnectAsync(lastServerUri.IdnHost, lastServerUri.Port); - return new NetworkStream(socket, ownsSocket: true); - } - catch - { - socket.Dispose(); - throw; - } + return await DefaultConnectCallback(new DnsEndPoint(lastServerUri.IdnHost, lastServerUri.Port), ct); }; TaskCompletionSource waitingForLastRequest = new(TaskCreationOptions.RunContinuationsAsynchronously); @@ -2659,30 +2649,18 @@ public async Task Http2_MultipleConnectionsEnabled_ManyRequestsEnqueuedSimultane AcquireAllStreamSlots(server, client, sendTasks, RequestCount); - List<(Http2LoopbackConnection connection, int streamId)> acceptedRequests = new(); - await using Http2LoopbackConnection c1 = await server.EstablishConnectionAsync(new SettingsEntry { SettingId = SettingId.MaxConcurrentStreams, Value = 100 }); - for (int i = 0; i < MaxConcurrentStreams; i++) - { - (int streamId, _) = await c1.ReadAndParseRequestHeaderAsync(); - acceptedRequests.Add((c1, streamId)); - } + int[] streamIds1 = await AcceptRequests(c1, MaxConcurrentStreams); await using Http2LoopbackConnection c2 = await server.EstablishConnectionAsync(new SettingsEntry { SettingId = SettingId.MaxConcurrentStreams, Value = 100 }); - for (int i = 0; i < MaxConcurrentStreams; i++) - { - (int streamId, _) = await c2.ReadAndParseRequestHeaderAsync(); - acceptedRequests.Add((c2, streamId)); - } + int[] streamIds2 = await AcceptRequests(c2, MaxConcurrentStreams); await using Http2LoopbackConnection c3 = await server.EstablishConnectionAsync(new SettingsEntry { SettingId = SettingId.MaxConcurrentStreams, Value = 100 }); (int finalStreamId, _) = await c3.ReadAndParseRequestHeaderAsync(); - acceptedRequests.Add((c3, finalStreamId)); - foreach ((Http2LoopbackConnection connection, int streamId) request in acceptedRequests) - { - await request.connection.SendDefaultResponseAsync(request.streamId); - } + await SendResponses(c1, streamIds1); + await SendResponses(c2, streamIds2); + await c3.SendDefaultResponseAsync(finalStreamId); await VerifySendTasks(sendTasks); } @@ -2702,19 +2680,17 @@ public async Task Http2_MultipleConnectionsEnabled_InfiniteRequestsCompletelyBlo Http2LoopbackConnection connection0 = await PrepareConnection(server, client, MaxConcurrentStreams).ConfigureAwait(false); AcquireAllStreamSlots(server, client, sendTasks, MaxConcurrentStreams); - // Block the first connection on infinite requests. + // Accept requests but don't send responses on connection 0 int[] blockedStreamIds = await AcceptRequests(connection0, MaxConcurrentStreams).ConfigureAwait(false); - Assert.Equal(MaxConcurrentStreams, blockedStreamIds.Length); Http2LoopbackConnection connection1 = await PrepareConnection(server, client, MaxConcurrentStreams).ConfigureAwait(false); AcquireAllStreamSlots(server, client, sendTasks, MaxConcurrentStreams); - await HandleAllPendingRequests(connection1, MaxConcurrentStreams).ConfigureAwait(false); + // Send responses on connection 1 + await SendResponses(connection1, await AcceptRequests(connection1, MaxConcurrentStreams).ConfigureAwait(false)); - // Complete infinite requests. - int handledRequestCount = await SendResponses(connection0, blockedStreamIds); - - Assert.Equal(MaxConcurrentStreams, handledRequestCount); + // Send responses on connection 0 + await SendResponses(connection0, blockedStreamIds); await VerifySendTasks(sendTasks).ConfigureAwait(false); } @@ -2729,44 +2705,62 @@ public async Task Http2_MultipleConnectionsEnabled_OpenAndCloseMultipleConnectio const int MaxConcurrentStreams = 2; using Http2LoopbackServer server = Http2LoopbackServer.CreateServer(); + server.AllowMultipleConnections = true; + + // Allow 5 connections through the ConnectCallback. + SemaphoreSlim connectCallbackSemaphore = new(initialCount: 5); + using SocketsHttpHandler handler = CreateHandler(); + + handler.ConnectCallback = async (context, ct) => + { + await connectCallbackSemaphore.WaitAsync(ct); + + return await DefaultConnectCallback(context.DnsEndPoint, ct); + }; + using (HttpClient client = CreateHttpClient(handler)) { - server.AllowMultipleConnections = true; - List> sendTasks = new List>(); + List> sendTasks = new(); + Http2LoopbackConnection connection0 = await PrepareConnection(server, client, MaxConcurrentStreams).ConfigureAwait(false); AcquireAllStreamSlots(server, client, sendTasks, MaxConcurrentStreams); + int[] streamIds0 = await AcceptRequests(connection0, MaxConcurrentStreams).ConfigureAwait(false); + Http2LoopbackConnection connection1 = await PrepareConnection(server, client, MaxConcurrentStreams).ConfigureAwait(false); AcquireAllStreamSlots(server, client, sendTasks, MaxConcurrentStreams); + int[] streamIds1 = await AcceptRequests(connection1, MaxConcurrentStreams).ConfigureAwait(false); + Http2LoopbackConnection connection2 = await PrepareConnection(server, client, MaxConcurrentStreams).ConfigureAwait(false); AcquireAllStreamSlots(server, client, sendTasks, MaxConcurrentStreams); + int[] streamIds2 = await AcceptRequests(connection2, MaxConcurrentStreams).ConfigureAwait(false); - Task[] handleRequestTasks = new[] { - HandleAllPendingRequests(connection0, MaxConcurrentStreams), - HandleAllPendingRequests(connection1, MaxConcurrentStreams), - HandleAllPendingRequests(connection2, MaxConcurrentStreams) - }; - - await TestHelper.WhenAllCompletedOrAnyFailed(handleRequestTasks).ConfigureAwait(false); + await TestHelper.WhenAllCompletedOrAnyFailed( + SendResponses(connection0, streamIds0), + SendResponses(connection1, streamIds1), + SendResponses(connection2, streamIds2)) + .ConfigureAwait(false); - await connection0.ShutdownIgnoringErrorsAsync(await handleRequestTasks[0]).ConfigureAwait(false); - await connection2.ShutdownIgnoringErrorsAsync(await handleRequestTasks[2]).ConfigureAwait(false); + await connection0.ShutdownIgnoringErrorsAsync(streamIds0[^1]).ConfigureAwait(false); + await connection2.ShutdownIgnoringErrorsAsync(streamIds2[^1]).ConfigureAwait(false); - //Fill all connection1's stream slots + // Fill all connection1's stream slots AcquireAllStreamSlots(server, client, sendTasks, MaxConcurrentStreams); + streamIds1 = await AcceptRequests(connection1, MaxConcurrentStreams).ConfigureAwait(false); Http2LoopbackConnection connection3 = await PrepareConnection(server, client, MaxConcurrentStreams).ConfigureAwait(false); AcquireAllStreamSlots(server, client, sendTasks, MaxConcurrentStreams); + int[] streamIds3 = await AcceptRequests(connection3, MaxConcurrentStreams).ConfigureAwait(false); + Http2LoopbackConnection connection4 = await PrepareConnection(server, client, MaxConcurrentStreams).ConfigureAwait(false); AcquireAllStreamSlots(server, client, sendTasks, MaxConcurrentStreams); + int[] streamIds4 = await AcceptRequests(connection4, MaxConcurrentStreams).ConfigureAwait(false); - Task[] finalHandleTasks = new[] { - HandleAllPendingRequests(connection1, MaxConcurrentStreams), - HandleAllPendingRequests(connection3, MaxConcurrentStreams), - HandleAllPendingRequests(connection4, MaxConcurrentStreams) - }; - - await TestHelper.WhenAllCompletedOrAnyFailed(finalHandleTasks).ConfigureAwait(false); + await TestHelper.WhenAllCompletedOrAnyFailed( + SendResponses(connection1, streamIds1), + SendResponses(connection3, streamIds3), + SendResponses(connection4, streamIds4)) + .ConfigureAwait(false); await VerifySendTasks(sendTasks).ConfigureAwait(false); } @@ -2778,24 +2772,36 @@ public async Task Http2_MultipleConnectionsEnabled_IdleConnectionTimeoutExpired_ { const int MaxConcurrentStreams = 2; using Http2LoopbackServer server = Http2LoopbackServer.CreateServer(); + server.AllowMultipleConnections = true; + + SemaphoreSlim connectCallbackSemaphore = new(initialCount: 2); + using SocketsHttpHandler handler = CreateHandler(); handler.PooledConnectionIdleTimeout = TimeSpan.FromSeconds(20); + + handler.ConnectCallback = async (context, ct) => + { + await connectCallbackSemaphore.WaitAsync(ct); + + return await DefaultConnectCallback(context.DnsEndPoint, ct); + }; + using (HttpClient client = CreateHttpClient(handler)) { - server.AllowMultipleConnections = true; - List> sendTasks = new List>(); + List> sendTasks0 = new(); + List> sendTasks1 = new(); + List> sendTasks2 = new(); + Http2LoopbackConnection connection0 = await PrepareConnection(server, client, MaxConcurrentStreams).ConfigureAwait(false); - AcquireAllStreamSlots(server, client, sendTasks, MaxConcurrentStreams); - int[] acceptedStreamIds = await AcceptRequests(connection0, MaxConcurrentStreams).ConfigureAwait(false); - Assert.Equal(MaxConcurrentStreams, acceptedStreamIds.Length); + AcquireAllStreamSlots(server, client, sendTasks0, MaxConcurrentStreams); + int[] streamIds0 = await AcceptRequests(connection0, MaxConcurrentStreams).ConfigureAwait(false); - List> connection1SendTasks = new List>(); Http2LoopbackConnection connection1 = await PrepareConnection(server, client, MaxConcurrentStreams).ConfigureAwait(false); - AcquireAllStreamSlots(server, client, connection1SendTasks, MaxConcurrentStreams); - await HandleAllPendingRequests(connection1, MaxConcurrentStreams).ConfigureAwait(false); + AcquireAllStreamSlots(server, client, sendTasks1, MaxConcurrentStreams); + await SendResponses(connection1, await AcceptRequests(connection1, MaxConcurrentStreams).ConfigureAwait(false)); - // Complete all the requests. - await VerifySendTasks(connection1SendTasks).ConfigureAwait(false); + // Complete all the requests on connection1. + await VerifySendTasks(sendTasks1).ConfigureAwait(false); // Wait until the idle connection timeout expires. await connection1.WaitForClientDisconnectAsync(false).WaitAsync(TestHelper.PassingTestTimeout).ConfigureAwait(false); @@ -2803,17 +2809,20 @@ public async Task Http2_MultipleConnectionsEnabled_IdleConnectionTimeoutExpired_ Assert.True(connection1.IsInvalid); Assert.False(connection0.IsInvalid); - Http2LoopbackConnection connection2 = await PrepareConnection(server, client, MaxConcurrentStreams).ConfigureAwait(false); - - AcquireAllStreamSlots(server, client, sendTasks, MaxConcurrentStreams); + // Due to a race condition in how a new Http2 connection is returned to the pool, we may have started a third connection attempt in the background. + // We were blocking such attempts from going through to the Socket layer until now to avoid having to deal with the extra connect when accepting connection2 below. + // Allow the third connection through the ConnectCallback now. + connectCallbackSemaphore.Release(); - await HandleAllPendingRequests(connection2, MaxConcurrentStreams).ConfigureAwait(false); + Http2LoopbackConnection connection2 = await PrepareConnection(server, client, MaxConcurrentStreams).ConfigureAwait(false); + AcquireAllStreamSlots(server, client, sendTasks2, MaxConcurrentStreams); + await SendResponses(connection2, await AcceptRequests(connection2, MaxConcurrentStreams).ConfigureAwait(false)); - //Make sure connection0 is still alive. - int handledRequests0 = await SendResponses(connection0, acceptedStreamIds).ConfigureAwait(false); - Assert.Equal(MaxConcurrentStreams, handledRequests0); + // Make sure connection0 is still alive. + await SendResponses(connection0, streamIds0).ConfigureAwait(false); - await VerifySendTasks(sendTasks).ConfigureAwait(false); + await VerifySendTasks(sendTasks0).ConfigureAwait(false); + await VerifySendTasks(sendTasks2).ConfigureAwait(false); } } @@ -2842,7 +2851,10 @@ private async Task PrepareConnection(Http2LoopbackServe Task warmUpTask = client.GetAsync(server.Address); - Http2LoopbackConnection connection = await GetConnection(server, maxConcurrentStreams).WaitAsync(TestHelper.PassingTestTimeout).ConfigureAwait(false); + var concurrentStreamsSetting = new SettingsEntry { SettingId = SettingId.MaxConcurrentStreams, Value = maxConcurrentStreams }; + + Http2LoopbackConnection connection = await server.EstablishConnectionAsync(timeout: null, ackTimeout: TimeSpan.FromSeconds(10), concurrentStreamsSetting) + .WaitAsync(TestHelper.PassingTestTimeout).ConfigureAwait(false); (int streamId, _) = await connection.ReadAndParseRequestHeaderAsync().WaitAsync(TestHelper.PassingTestTimeout).ConfigureAwait(false); await connection.SendDefaultResponseAsync(streamId).WaitAsync(TestHelper.PassingTestTimeout).ConfigureAwait(false); @@ -2862,49 +2874,25 @@ private static void AcquireAllStreamSlots(Http2LoopbackServer server, HttpClient } } - private static async Task GetConnection(Http2LoopbackServer server, uint maxConcurrentStreams) - { - var concurrentStreamsSetting = new SettingsEntry { SettingId = SettingId.MaxConcurrentStreams, Value = maxConcurrentStreams }; - - return await server.EstablishConnectionAsync(timeout: null, ackTimeout: TimeSpan.FromSeconds(10), concurrentStreamsSetting).ConfigureAwait(false); - } - - private async Task HandleAllPendingRequests(Http2LoopbackConnection connection, int totalRequestCount) - { - int lastStreamId = -1; - for (int i = 0; i < totalRequestCount; i++) - { - (int streamId, _) = await connection.ReadAndParseRequestHeaderAsync().ConfigureAwait(false); - await connection.SendDefaultResponseAsync(streamId).ConfigureAwait(false); - lastStreamId = streamId; - } - - return lastStreamId; - } - private async Task AcceptRequests(Http2LoopbackConnection connection, int requestCount) { int[] streamIds = new int[requestCount]; for (int i = 0; i < streamIds.Length; i++) { - (int streamId, _) = await connection.ReadAndParseRequestHeaderAsync().ConfigureAwait(false); + (int streamId, _) = await connection.ReadAndParseRequestHeaderAsync().WaitAsync(TestHelper.PassingTestTimeout).ConfigureAwait(false); streamIds[i] = streamId; } return streamIds; } - private async Task SendResponses(Http2LoopbackConnection connection, IEnumerable streamIds) + private async Task SendResponses(Http2LoopbackConnection connection, IEnumerable streamIds) { - int count = 0; foreach (int streamId in streamIds) { - count++; - await connection.SendDefaultResponseAsync(streamId).ConfigureAwait(false); + await connection.SendDefaultResponseAsync(streamId).WaitAsync(TestHelper.PassingTestTimeout).ConfigureAwait(false); } - - return count; } } @@ -3108,10 +3096,7 @@ public async Task ConnectCallback_ConnectionPrefix_Success(bool useSsl) var socketsHandler = (SocketsHttpHandler)GetUnderlyingSocketsHttpHandler(handler); socketsHandler.ConnectCallback = async (context, token) => { - Socket clientSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); - await clientSocket.ConnectAsync(listenSocket.LocalEndPoint); - - Stream clientStream = new NetworkStream(clientSocket, ownsSocket: true); + Stream clientStream = await DefaultConnectCallback(listenSocket.LocalEndPoint, token); await clientStream.WriteAsync(RequestPrefix); diff --git a/src/libraries/System.Net.NameResolution/src/System/Net/Dns.cs b/src/libraries/System.Net.NameResolution/src/System/Net/Dns.cs index f4b09f346b3d9a..8227d58ebb61b6 100644 --- a/src/libraries/System.Net.NameResolution/src/System/Net/Dns.cs +++ b/src/libraries/System.Net.NameResolution/src/System/Net/Dns.cs @@ -23,13 +23,13 @@ public static string GetHostName() { name = NameResolutionPal.GetHostName(); } - catch when (LogFailure(string.Empty, startingTimestamp)) + catch (Exception ex) when (LogFailure(string.Empty, startingTimestamp, ex)) { Debug.Fail("LogFailure should return false"); throw; } - NameResolutionTelemetry.Log.AfterResolution(string.Empty, startingTimestamp, successful: true); + NameResolutionTelemetry.Log.AfterResolution(string.Empty, startingTimestamp); if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(null, name); return name; @@ -394,13 +394,13 @@ private static object GetHostEntryOrAddressesCore(string hostName, bool justAddr Aliases = aliases }; } - catch when (LogFailure(hostName, startingTimestamp)) + catch (Exception ex) when (LogFailure(hostName, startingTimestamp, ex)) { Debug.Fail("LogFailure should return false"); throw; } - NameResolutionTelemetry.Log.AfterResolution(hostName, startingTimestamp, successful: true); + NameResolutionTelemetry.Log.AfterResolution(hostName, startingTimestamp); return result; } @@ -434,13 +434,13 @@ private static object GetHostEntryOrAddressesCore(IPAddress address, bool justAd } Debug.Assert(name != null); } - catch when (LogFailure(address, startingTimestamp)) + catch (Exception ex) when (LogFailure(address, startingTimestamp, ex)) { Debug.Fail("LogFailure should return false"); throw; } - NameResolutionTelemetry.Log.AfterResolution(address, startingTimestamp, successful: true); + NameResolutionTelemetry.Log.AfterResolution(address, startingTimestamp); // Do the forward lookup to get the IPs for that host name startingTimestamp = NameResolutionTelemetry.Log.BeforeResolution(name); @@ -464,13 +464,13 @@ private static object GetHostEntryOrAddressesCore(IPAddress address, bool justAd AddressList = addresses }; } - catch when (LogFailure(name, startingTimestamp)) + catch (Exception ex) when (LogFailure(name, startingTimestamp, ex)) { Debug.Fail("LogFailure should return false"); throw; } - NameResolutionTelemetry.Log.AfterResolution(name, startingTimestamp, successful: true); + NameResolutionTelemetry.Log.AfterResolution(name, startingTimestamp); // One of three things happened: // 1. Success. @@ -577,7 +577,7 @@ private static Task GetHostEntryOrAddressesCoreAsync(string hostName, bool justR } private static Task? GetAddrInfoWithTelemetryAsync(string hostName, bool justAddresses, AddressFamily addressFamily, CancellationToken cancellationToken) - where T : class + where T : class { long startingTimestamp = Stopwatch.GetTimestamp(); Task? task = NameResolutionPal.GetAddrInfoAsync(hostName, justAddresses, addressFamily, cancellationToken); @@ -594,15 +594,19 @@ private static Task GetHostEntryOrAddressesCoreAsync(string hostName, bool justR static async Task CompleteAsync(Task task, string hostName, long startingTimestamp) { _ = NameResolutionTelemetry.Log.BeforeResolution(hostName); - T? result = null; + Exception? exception = null; try { - result = await ((Task)task).ConfigureAwait(false); - return result; + return await ((Task)task).ConfigureAwait(false); + } + catch (Exception ex) + { + exception = ex; + throw; } finally { - NameResolutionTelemetry.Log.AfterResolution(hostName, startingTimestamp, successful: result is not null); + NameResolutionTelemetry.Log.AfterResolution(hostName, startingTimestamp, exception); } } } @@ -627,9 +631,9 @@ private static void ValidateHostName(string hostName) } } - private static bool LogFailure(object hostNameOrAddress, long? startingTimestamp) + private static bool LogFailure(object hostNameOrAddress, long? startingTimestamp, Exception exception) { - NameResolutionTelemetry.Log.AfterResolution(hostNameOrAddress, startingTimestamp, successful: false); + NameResolutionTelemetry.Log.AfterResolution(hostNameOrAddress, startingTimestamp, exception); return false; } diff --git a/src/libraries/System.Net.NameResolution/src/System/Net/NameResolutionMetrics.cs b/src/libraries/System.Net.NameResolution/src/System/Net/NameResolutionMetrics.cs index 180f492b3408e2..fe1048e90b22de 100644 --- a/src/libraries/System.Net.NameResolution/src/System/Net/NameResolutionMetrics.cs +++ b/src/libraries/System.Net.NameResolution/src/System/Net/NameResolutionMetrics.cs @@ -13,15 +13,35 @@ internal static class NameResolutionMetrics private static readonly Meter s_meter = new("System.Net.NameResolution"); private static readonly Histogram s_lookupDuration = s_meter.CreateHistogram( - name: "dns.lookups.duration", + name: "dns.lookup.duration", unit: "s", description: "Measures the time taken to perform a DNS lookup."); public static bool IsEnabled() => s_lookupDuration.Enabled; - public static void AfterResolution(TimeSpan duration, string hostName) + public static void AfterResolution(TimeSpan duration, string hostName, Exception? exception) { - s_lookupDuration.Record(duration.TotalSeconds, KeyValuePair.Create("dns.question.name", (object?)hostName)); + var hostNameTag = KeyValuePair.Create("dns.question.name", (object?)hostName); + + if (exception is null) + { + s_lookupDuration.Record(duration.TotalSeconds, hostNameTag); + } + else + { + var errorTypeTag = KeyValuePair.Create("error.type", (object?)GetErrorType(exception)); + s_lookupDuration.Record(duration.TotalSeconds, hostNameTag, errorTypeTag); + } } + + private static string GetErrorType(Exception exception) => (exception as SocketException)?.SocketErrorCode switch + { + SocketError.HostNotFound => "host_not_found", + SocketError.TryAgain => "try_again", + SocketError.AddressFamilyNotSupported => "address_family_not_supported", + SocketError.NoRecovery => "no_recovery", + + _ => exception.GetType().FullName! + }; } } diff --git a/src/libraries/System.Net.NameResolution/src/System/Net/NameResolutionTelemetry.cs b/src/libraries/System.Net.NameResolution/src/System/Net/NameResolutionTelemetry.cs index ef43b59d15a139..73ed325712ac52 100644 --- a/src/libraries/System.Net.NameResolution/src/System/Net/NameResolutionTelemetry.cs +++ b/src/libraries/System.Net.NameResolution/src/System/Net/NameResolutionTelemetry.cs @@ -81,7 +81,7 @@ public long BeforeResolution(object hostNameOrAddress) } [NonEvent] - public void AfterResolution(object hostNameOrAddress, long? startingTimestamp, bool successful) + public void AfterResolution(object hostNameOrAddress, long? startingTimestamp, Exception? exception = null) { Debug.Assert(startingTimestamp.HasValue); if (startingTimestamp == 0) @@ -99,7 +99,7 @@ public void AfterResolution(object hostNameOrAddress, long? startingTimestamp, b if (IsEnabled(EventLevel.Informational, EventKeywords.None)) { - if (!successful) + if (exception is not null) { ResolutionFailed(); } @@ -110,7 +110,7 @@ public void AfterResolution(object hostNameOrAddress, long? startingTimestamp, b if (NameResolutionMetrics.IsEnabled()) { - NameResolutionMetrics.AfterResolution(duration, GetHostnameFromStateObject(hostNameOrAddress)); + NameResolutionMetrics.AfterResolution(duration, GetHostnameFromStateObject(hostNameOrAddress), exception); } } diff --git a/src/libraries/System.Net.NameResolution/tests/FunctionalTests/MetricsTest.cs b/src/libraries/System.Net.NameResolution/tests/FunctionalTests/MetricsTest.cs index d3c0990dbf9c4c..a19d9edc476abc 100644 --- a/src/libraries/System.Net.NameResolution/tests/FunctionalTests/MetricsTest.cs +++ b/src/libraries/System.Net.NameResolution/tests/FunctionalTests/MetricsTest.cs @@ -14,7 +14,7 @@ namespace System.Net.NameResolution.Tests { public class MetricsTest { - private const string DnsLookupDuration = "dns.lookups.duration"; + private const string DnsLookupDuration = "dns.lookup.duration"; [ConditionalFact(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))] public static void ResolveValidHostName_MetricsRecorded() @@ -57,17 +57,26 @@ public static async Task ResolveInvalidHostName_MetricsRecorded() Assert.ThrowsAny(() => Dns.EndGetHostEntry(Dns.BeginGetHostEntry(InvalidHostName, null, null))); Assert.ThrowsAny(() => Dns.EndGetHostAddresses(Dns.BeginGetHostAddresses(InvalidHostName, null, null))); - double[] measurements = GetMeasurementsForHostname(recorder, InvalidHostName); + double[] measurements = GetMeasurementsForHostname(recorder, InvalidHostName, "host_not_found"); Assert.Equal(6, measurements.Length); Assert.All(measurements, m => Assert.True(m > double.Epsilon)); } - private static double[] GetMeasurementsForHostname(InstrumentRecorder recorder, string hostname) + private static double[] GetMeasurementsForHostname(InstrumentRecorder recorder, string hostname, string? expectedErrorType = null) { return recorder .GetMeasurements() - .Where(m => m.Tags.ToArray().Any(t => t.Key == "dns.question.name" && t.Value is string hostnameTag && hostnameTag == hostname)) + .Where(m => + { + KeyValuePair[] tags = m.Tags.ToArray(); + if (!tags.Any(t => t.Key == "dns.question.name" && t.Value is string hostnameTag && hostnameTag == hostname)) + { + return false; + } + string? actualErrorType = tags.FirstOrDefault(t => t.Key == "error.type").Value as string; + return expectedErrorType == actualErrorType; + }) .Select(m => m.Value) .ToArray(); } diff --git a/src/libraries/System.Net.Primitives/src/Resources/Strings.resx b/src/libraries/System.Net.Primitives/src/Resources/Strings.resx index 958a0e2e269f99..65d4809398b3b2 100644 --- a/src/libraries/System.Net.Primitives/src/Resources/Strings.resx +++ b/src/libraries/System.Net.Primitives/src/Resources/Strings.resx @@ -64,7 +64,7 @@ This property is not implemented by this class. - The AddressFamily {0} is not valid for the {1} end point, use {2} instead. + The AddressFamily {0} is not valid for the {1} end point. The supplied {0} is an invalid size for the {1} end point. diff --git a/src/libraries/System.Net.Primitives/src/System/Net/IPEndPoint.cs b/src/libraries/System.Net.Primitives/src/System/Net/IPEndPoint.cs index 3531f266e6c504..ff47d2fbc515ef 100644 --- a/src/libraries/System.Net.Primitives/src/System/Net/IPEndPoint.cs +++ b/src/libraries/System.Net.Primitives/src/System/Net/IPEndPoint.cs @@ -155,9 +155,9 @@ public override EndPoint Create(SocketAddress socketAddress) { ArgumentNullException.ThrowIfNull(socketAddress); - if (socketAddress.Family != AddressFamily) - { - throw new ArgumentException(SR.Format(SR.net_InvalidAddressFamily, socketAddress.Family.ToString(), GetType().FullName, AddressFamily.ToString()), nameof(socketAddress)); + if (socketAddress.Family is not (AddressFamily.InterNetwork or AddressFamily.InterNetworkV6)) + { + throw new ArgumentException(SR.Format(SR.net_InvalidAddressFamily, socketAddress.Family.ToString(), GetType().FullName), nameof(socketAddress)); } int minSize = AddressFamily == AddressFamily.InterNetworkV6 ? SocketAddress.IPv6AddressSize : SocketAddress.IPv4AddressSize; diff --git a/src/libraries/System.Net.Primitives/tests/FunctionalTests/IPEndPointTest.cs b/src/libraries/System.Net.Primitives/tests/FunctionalTests/IPEndPointTest.cs index bb9b95d438e99f..c233dee628dfeb 100644 --- a/src/libraries/System.Net.Primitives/tests/FunctionalTests/IPEndPointTest.cs +++ b/src/libraries/System.Net.Primitives/tests/FunctionalTests/IPEndPointTest.cs @@ -143,6 +143,19 @@ public static void ToString_Invoke_ReturnsExpected(IPEndPoint endPoint, string e Assert.Equal(expected, endPoint.ToString()); } + [Fact] + public static void Create_DifferentAF_Success() + { + SocketAddress sa = new SocketAddress(AddressFamily.InterNetwork, SocketAddress.GetMaximumAddressSize(AddressFamily.InterNetworkV6)); + var ep = new IPEndPoint(IPAddress.IPv6Any, 0); + Assert.NotNull(ep.Create(sa)); + + sa = new SocketAddress(AddressFamily.InterNetworkV6); + ep = new IPEndPoint(IPAddress.Any, 0); + + Assert.NotNull(ep.Create(sa)); + } + public static IEnumerable Serialize_TestData() { yield return new object[] { new IPAddress(2), 16 }; @@ -195,8 +208,7 @@ public static void Create_NullSocketAddress_ThrowsArgumentNullException() public static IEnumerable Create_InvalidAddressFamily_TestData() { - yield return new object[] { new IPEndPoint(2, 500), new SocketAddress(Sockets.AddressFamily.InterNetworkV6) }; - yield return new object[] { new IPEndPoint(IPAddress.Parse("192.169.0.9"), 500), new SocketAddress(Sockets.AddressFamily.InterNetworkV6) }; + yield return new object[] { new IPEndPoint(2, 500), new SocketAddress(Sockets.AddressFamily.Unknown) }; yield return new object[] { new IPEndPoint(IPAddress.Parse("0:0:0:0:0:0:0:1"), 500), new SocketAddress(Sockets.AddressFamily.InterNetwork) }; } diff --git a/src/libraries/System.Net.Requests/src/Resources/Strings.resx b/src/libraries/System.Net.Requests/src/Resources/Strings.resx index f8ac54eb739cdd..05b94042590c8c 100644 --- a/src/libraries/System.Net.Requests/src/Resources/Strings.resx +++ b/src/libraries/System.Net.Requests/src/Resources/Strings.resx @@ -195,6 +195,9 @@ The underlying connection was closed: An unexpected error occurred on a receive + + CRLF character pair is not allowed in FtpWebRequest inputs. + The remote name could not be resolved diff --git a/src/libraries/System.Net.Requests/src/System/Net/FtpControlStream.cs b/src/libraries/System.Net.Requests/src/System/Net/FtpControlStream.cs index 35eefb9c38cbba..ffba72c9315d28 100644 --- a/src/libraries/System.Net.Requests/src/System/Net/FtpControlStream.cs +++ b/src/libraries/System.Net.Requests/src/System/Net/FtpControlStream.cs @@ -1118,6 +1118,11 @@ private string GetPortCommandLine() /// private static string FormatFtpCommand(string command, string? parameter) { + if (parameter is not null && parameter.Contains("\r\n", StringComparison.Ordinal)) + { + throw new FormatException(SR.net_ftp_no_newlines); + } + return string.IsNullOrEmpty(parameter) ? command + "\r\n" : command + " " + parameter + "\r\n"; diff --git a/src/libraries/System.Net.Requests/src/System/Net/FtpWebRequest.cs b/src/libraries/System.Net.Requests/src/System/Net/FtpWebRequest.cs index 5a6009240fd921..b27873e9edb233 100644 --- a/src/libraries/System.Net.Requests/src/System/Net/FtpWebRequest.cs +++ b/src/libraries/System.Net.Requests/src/System/Net/FtpWebRequest.cs @@ -486,6 +486,9 @@ internal FtpWebRequest(Uri uri) if ((object)uri.Scheme != (object)Uri.UriSchemeFtp) throw new ArgumentOutOfRangeException(nameof(uri)); + if (uri.OriginalString.Contains("\r\n", StringComparison.Ordinal)) + throw new FormatException(SR.net_ftp_no_newlines); + _timerCallback = new TimerThread.Callback(TimerCallback); _syncObject = new object(); diff --git a/src/libraries/System.Net.Requests/tests/FtpWebRequestTest.cs b/src/libraries/System.Net.Requests/tests/FtpWebRequestTest.cs index d1f5e58b3492b5..c2b7ad12524003 100644 --- a/src/libraries/System.Net.Requests/tests/FtpWebRequestTest.cs +++ b/src/libraries/System.Net.Requests/tests/FtpWebRequestTest.cs @@ -203,6 +203,27 @@ public void Ftp_RenameFileSubDir_Success(FtpExecutionMode mode) Assert.False(DirExists(mode, dir)); } + [Fact] + public void Ftp_Ignore_NewLine_Constructor_Throws_FormatException() + { + string uri = absoluteUri + Guid.NewGuid().ToString(); + + Assert.Throws(() => WebRequest.Create($"{uri}\r\n{WebRequestMethods.Ftp.AppendFile} {Guid.NewGuid().ToString()}")); + } + + [ConditionalFact(nameof(LocalServerAvailable))] + public void Ftp_Ignore_NewLine_GetRequestStream_And_GetResponse_Throws_FormatException_As_InnerException() + { + FtpWebRequest ftpWebRequest = (FtpWebRequest)WebRequest.Create(absoluteUri + Guid.NewGuid().ToString()); + ftpWebRequest.Method = "APPE"; + ftpWebRequest.Credentials = new NetworkCredential("test\r\ntest2", "test\r\ntest2"); + var requestException = Assert.Throws(() => ftpWebRequest.GetRequestStream()); + Assert.True(requestException.InnerException is FormatException); + + var responseException = Assert.Throws(() => ftpWebRequest.GetResponse()); + Assert.True(responseException.InnerException is FormatException); + } + private static async Task DoAsync(FtpWebRequest request, MemoryStream requestBody) { if (requestBody != null) diff --git a/src/libraries/System.Net.Security/src/System/Net/NegotiateAuthenticationPal.Managed.cs b/src/libraries/System.Net.Security/src/System/Net/NegotiateAuthenticationPal.Managed.cs index 4e5e8906b795cc..fff331646b73c6 100644 --- a/src/libraries/System.Net.Security/src/System/Net/NegotiateAuthenticationPal.Managed.cs +++ b/src/libraries/System.Net.Security/src/System/Net/NegotiateAuthenticationPal.Managed.cs @@ -9,16 +9,23 @@ internal abstract partial class NegotiateAuthenticationPal { public static NegotiateAuthenticationPal Create(NegotiateAuthenticationClientOptions clientOptions) { - switch (clientOptions.Package) + try { - case NegotiationInfoClass.NTLM: - return new ManagedNtlmNegotiateAuthenticationPal(clientOptions); + switch (clientOptions.Package) + { + case NegotiationInfoClass.NTLM: + return new ManagedNtlmNegotiateAuthenticationPal(clientOptions); - case NegotiationInfoClass.Negotiate: - return new ManagedSpnegoNegotiateAuthenticationPal(clientOptions); + case NegotiationInfoClass.Negotiate: + return new ManagedSpnegoNegotiateAuthenticationPal(clientOptions); - default: - return new UnsupportedNegotiateAuthenticationPal(clientOptions); + default: + return new UnsupportedNegotiateAuthenticationPal(clientOptions); + } + } + catch (PlatformNotSupportedException) + { + return new UnsupportedNegotiateAuthenticationPal(clientOptions); } } diff --git a/src/libraries/System.Net.Security/src/System/Net/NegotiateAuthenticationPal.Unix.cs b/src/libraries/System.Net.Security/src/System/Net/NegotiateAuthenticationPal.Unix.cs index 900d66c05bfc7a..ed1fe4e2e91937 100644 --- a/src/libraries/System.Net.Security/src/System/Net/NegotiateAuthenticationPal.Unix.cs +++ b/src/libraries/System.Net.Security/src/System/Net/NegotiateAuthenticationPal.Unix.cs @@ -23,20 +23,20 @@ internal partial class NegotiateAuthenticationPal public static NegotiateAuthenticationPal Create(NegotiateAuthenticationClientOptions clientOptions) { - if (UseManagedNtlm) + try { - switch (clientOptions.Package) + if (UseManagedNtlm) { - case NegotiationInfoClass.NTLM: - return new ManagedNtlmNegotiateAuthenticationPal(clientOptions); + switch (clientOptions.Package) + { + case NegotiationInfoClass.NTLM: + return new ManagedNtlmNegotiateAuthenticationPal(clientOptions); - case NegotiationInfoClass.Negotiate: - return new ManagedSpnegoNegotiateAuthenticationPal(clientOptions, supportKerberos: true); + case NegotiationInfoClass.Negotiate: + return new ManagedSpnegoNegotiateAuthenticationPal(clientOptions, supportKerberos: true); + } } - } - try - { return new UnixNegotiateAuthenticationPal(clientOptions); } catch (Win32Exception) diff --git a/src/libraries/System.Net.Security/src/System/Net/NegotiateAuthenticationPal.Windows.cs b/src/libraries/System.Net.Security/src/System/Net/NegotiateAuthenticationPal.Windows.cs index 3dcb03bfd08f74..07e8dea22baa9c 100644 --- a/src/libraries/System.Net.Security/src/System/Net/NegotiateAuthenticationPal.Windows.cs +++ b/src/libraries/System.Net.Security/src/System/Net/NegotiateAuthenticationPal.Windows.cs @@ -421,28 +421,32 @@ public override unsafe NegotiateAuthenticationStatusCode Wrap(ReadOnlySpan Debug.Assert(success); // alloc new output buffer if not supplied or too small - int resultSize = input.Length + sizes.cbMaxSignature; + int resultSize = input.Length + sizes.cbSecurityTrailer + sizes.cbBlockSize; Span outputBuffer = outputWriter.GetSpan(resultSize); // make a copy of user data for in-place encryption - input.CopyTo(outputBuffer.Slice(sizes.cbMaxSignature, input.Length)); + input.CopyTo(outputBuffer.Slice(sizes.cbSecurityTrailer, input.Length)); isEncrypted = requestEncryption; fixed (byte* outputPtr = outputBuffer) { // Prepare buffers TOKEN(signature), DATA and Padding. - Interop.SspiCli.SecBuffer* unmanagedBuffer = stackalloc Interop.SspiCli.SecBuffer[2]; + Interop.SspiCli.SecBuffer* unmanagedBuffer = stackalloc Interop.SspiCli.SecBuffer[3]; Interop.SspiCli.SecBuffer* tokenBuffer = &unmanagedBuffer[0]; Interop.SspiCli.SecBuffer* dataBuffer = &unmanagedBuffer[1]; + Interop.SspiCli.SecBuffer* paddingBuffer = &unmanagedBuffer[2]; tokenBuffer->BufferType = SecurityBufferType.SECBUFFER_TOKEN; tokenBuffer->pvBuffer = (IntPtr)(outputPtr); - tokenBuffer->cbBuffer = sizes.cbMaxSignature; + tokenBuffer->cbBuffer = sizes.cbSecurityTrailer; dataBuffer->BufferType = SecurityBufferType.SECBUFFER_DATA; - dataBuffer->pvBuffer = (IntPtr)(outputPtr + sizes.cbMaxSignature); + dataBuffer->pvBuffer = (IntPtr)(outputPtr + sizes.cbSecurityTrailer); dataBuffer->cbBuffer = input.Length; + paddingBuffer->BufferType = SecurityBufferType.SECBUFFER_PADDING; + paddingBuffer->pvBuffer = (IntPtr)(outputPtr + sizes.cbSecurityTrailer + input.Length); + paddingBuffer->cbBuffer = sizes.cbBlockSize; - Interop.SspiCli.SecBufferDesc sdcInOut = new Interop.SspiCli.SecBufferDesc(2) + Interop.SspiCli.SecBufferDesc sdcInOut = new Interop.SspiCli.SecBufferDesc(3) { pBuffers = unmanagedBuffer }; @@ -460,7 +464,20 @@ public override unsafe NegotiateAuthenticationStatusCode Wrap(ReadOnlySpan }; } - outputWriter.Advance(tokenBuffer->cbBuffer + dataBuffer->cbBuffer); + // Compact the result + if (tokenBuffer->cbBuffer != sizes.cbSecurityTrailer) + { + outputBuffer.Slice(sizes.cbSecurityTrailer, dataBuffer->cbBuffer).CopyTo( + outputBuffer.Slice(tokenBuffer->cbBuffer, dataBuffer->cbBuffer)); + } + if (tokenBuffer->cbBuffer != sizes.cbSecurityTrailer || + paddingBuffer->cbBuffer != sizes.cbBlockSize) + { + outputBuffer.Slice(sizes.cbSecurityTrailer + input.Length, paddingBuffer->cbBuffer).CopyTo( + outputBuffer.Slice(tokenBuffer->cbBuffer + dataBuffer->cbBuffer, paddingBuffer->cbBuffer)); + } + + outputWriter.Advance(tokenBuffer->cbBuffer + dataBuffer->cbBuffer + paddingBuffer->cbBuffer); return NegotiateAuthenticationStatusCode.Completed; } } diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.cs index e04739d5fe7a6e..e94d862571a0f8 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.cs @@ -927,13 +927,13 @@ internal void FinishOperationSyncSuccess(int bytesTransferred, SocketFlags flags { try { - if (_remoteEndPoint!.AddressFamily == _socketAddress!.Family) + if (_remoteEndPoint!.AddressFamily == AddressFamily.InterNetworkV6 && _socketAddress!.Family == AddressFamily.InterNetwork) { - _remoteEndPoint = _remoteEndPoint!.Create(_socketAddress); + _remoteEndPoint = new IPEndPoint(_socketAddress.GetIPAddress().MapToIPv6(), _socketAddress.GetPort()); } - else if (_remoteEndPoint!.AddressFamily == AddressFamily.InterNetworkV6 && _socketAddress.Family == AddressFamily.InterNetwork) + else { - _remoteEndPoint = new IPEndPoint(_socketAddress.GetIPAddress().MapToIPv6(), _socketAddress.GetPort()); + _remoteEndPoint = _remoteEndPoint!.Create(_socketAddress!); } } catch @@ -949,7 +949,14 @@ internal void FinishOperationSyncSuccess(int bytesTransferred, SocketFlags flags { try { - _remoteEndPoint = _remoteEndPoint!.Create(_socketAddress!); + if (_remoteEndPoint!.AddressFamily == AddressFamily.InterNetworkV6 && _socketAddress!.Family == AddressFamily.InterNetwork) + { + _remoteEndPoint = new IPEndPoint(_socketAddress.GetIPAddress().MapToIPv6(), _socketAddress.GetPort()); + } + else + { + _remoteEndPoint = _remoteEndPoint!.Create(_socketAddress!); + } } catch { diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/ReceiveFrom.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/ReceiveFrom.cs index 1a5ec7d05d28e9..1ec2adeadcf517 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/ReceiveFrom.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/ReceiveFrom.cs @@ -168,6 +168,52 @@ public async Task ReceiveSent_UDP_Success(bool ipv4) } } + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task ReceiveSent_DualMode_Success(bool ipv4) + { + const int Offset = 10; + const int DatagramSize = 256; + const int DatagramsToSend = 16; + + IPAddress address = ipv4 ? IPAddress.Loopback : IPAddress.IPv6Loopback; + using Socket receiver = new Socket(SocketType.Dgram, ProtocolType.Udp); + using Socket sender = new Socket(SocketType.Dgram, ProtocolType.Udp); + if (receiver.DualMode != true || sender.DualMode != true) + { + throw new SkipException("DualMode not available"); + } + + ConfigureNonBlocking(sender); + ConfigureNonBlocking(receiver); + + receiver.BindToAnonymousPort(address); + sender.BindToAnonymousPort(address); + + byte[] sendBuffer = new byte[DatagramSize]; + var receiveInternalBuffer = new byte[DatagramSize + Offset]; + var emptyBuffer = new byte[Offset]; + ArraySegment receiveBuffer = new ArraySegment(receiveInternalBuffer, Offset, DatagramSize); + + Random rnd = new Random(0); + + for (int i = 0; i < DatagramsToSend; i++) + { + rnd.NextBytes(sendBuffer); + sender.SendTo(sendBuffer, receiver.LocalEndPoint); + + IPEndPoint remoteEp = new IPEndPoint(ipv4 ? IPAddress.Any : IPAddress.IPv6Any, 0); + + SocketReceiveFromResult result = await ReceiveFromAsync(receiver, receiveBuffer, remoteEp); + + Assert.Equal(DatagramSize, result.ReceivedBytes); + AssertExtensions.SequenceEqual(emptyBuffer, new ReadOnlySpan(receiveInternalBuffer, 0, Offset)); + AssertExtensions.SequenceEqual(sendBuffer, new ReadOnlySpan(receiveInternalBuffer, Offset, DatagramSize)); + Assert.Equal(sender.LocalEndPoint, result.RemoteEndPoint); + } + } + [Theory] [InlineData(false)] [InlineData(true)] diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/SendReceive/SendReceive.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/SendReceive/SendReceive.cs index d5fbf0b1c83981..32a7bdcfbb6acf 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/SendReceive/SendReceive.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/SendReceive/SendReceive.cs @@ -1019,11 +1019,12 @@ public async Task TcpReceiveSendGetsCanceledByDispose(bool receiveOrSend, bool i return; } - // RHEL7 kernel has a bug preventing close(AF_UNKNOWN) to succeed with IPv6 sockets. - // In this case Dispose will trigger a graceful shutdown, which means that receive will succeed on socket2. - // This bug is fixed in kernel 3.10.0-1160.25+. - // TODO: Remove this, once CI machines are updated to a newer kernel. - bool mayShutdownGraceful = UsesSync && PlatformDetection.IsRedHatFamily7 && receiveOrSend && (ipv6Server || dualModeClient); + // .NET uses connect(AF_UNSPEC) to abort on-going operations on Linux. + // Linux 6.4+ introduced a change (4faeee0cf8a5d88d63cdbc3bab124fb0e6aed08c) which disallows + // this operation while operations are on-going. + // When the connect fails, .NET falls back to use shutdown(SHUT_RDWR). + // This causes the receive on socket2 to succeed instead of failing with ConnectionReset. + bool mayShutdownGraceful = UsesSync && PlatformDetection.IsLinux && receiveOrSend; // We try this a couple of times to deal with a timing race: if the Dispose happens // before the operation is started, the peer won't see a ConnectionReset SocketException and we won't diff --git a/src/libraries/System.Numerics.Tensors/Directory.Build.props b/src/libraries/System.Numerics.Tensors/Directory.Build.props deleted file mode 100644 index 36078bccbf7aa8..00000000000000 --- a/src/libraries/System.Numerics.Tensors/Directory.Build.props +++ /dev/null @@ -1,8 +0,0 @@ - - - - - true - false - - \ No newline at end of file diff --git a/src/libraries/System.Numerics.Tensors/README.md b/src/libraries/System.Numerics.Tensors/README.md index 026563de81306d..6190da60e77c51 100644 --- a/src/libraries/System.Numerics.Tensors/README.md +++ b/src/libraries/System.Numerics.Tensors/README.md @@ -1,2 +1,3 @@ # System.Numerics.Tensors -This library has not been shipped publicly and is not accepting contributions at this time. \ No newline at end of file + +Provides APIs for performing primitive operations over tensors represented by spans of memory. diff --git a/src/libraries/System.Numerics.Tensors/System.Numerics.Tensors.sln b/src/libraries/System.Numerics.Tensors/System.Numerics.Tensors.sln index cc3000d60ef88f..015b65250931ac 100644 --- a/src/libraries/System.Numerics.Tensors/System.Numerics.Tensors.sln +++ b/src/libraries/System.Numerics.Tensors/System.Numerics.Tensors.sln @@ -1,18 +1,34 @@ -Microsoft Visual Studio Solution File, Format Version 12.00 + +Microsoft Visual Studio Solution File, Format Version 12.00 +# Visual Studio Version 17 +VisualStudioVersion = 17.8.34205.153 +MinimumVisualStudioVersion = 10.0.40219.1 Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TestUtilities", "..\Common\tests\TestUtilities\TestUtilities.csproj", "{9F20CEA1-2216-4432-BBBD-F01E05D17F23}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.Bcl.Numerics", "..\Microsoft.Bcl.Numerics\ref\Microsoft.Bcl.Numerics.csproj", "{D311ABE4-10A9-4BB1-89CE-6358C55501A8}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.Bcl.Numerics", "..\Microsoft.Bcl.Numerics\src\Microsoft.Bcl.Numerics.csproj", "{1578185F-C4FA-4866-936B-E62AAEDD03B7}" +EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "System.Numerics.Tensors", "ref\System.Numerics.Tensors.csproj", "{21CB448A-3882-4337-B416-D1A3E0BCFFC5}" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "System.Numerics.Tensors", "src\System.Numerics.Tensors.csproj", "{848DD000-3D22-4A25-A9D9-05AFF857A116}" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "System.Numerics.Tensors.Tests", "tests\System.Numerics.Tensors.Tests.csproj", "{4AF6A02D-82C8-4898-9EDF-01F107C25061}" EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "ComInterfaceGenerator", "..\System.Runtime.InteropServices\gen\ComInterfaceGenerator\ComInterfaceGenerator.csproj", "{8CA7C982-3EE4-4BCE-9493-7A63556736D3}" -EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LibraryImportGenerator", "..\System.Runtime.InteropServices\gen\LibraryImportGenerator\LibraryImportGenerator.csproj", "{4588351F-4233-4957-B84C-7F8E22B8888A}" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.Interop.SourceGeneration", "..\System.Runtime.InteropServices\gen\Microsoft.Interop.SourceGeneration\Microsoft.Interop.SourceGeneration.csproj", "{DB954E01-898A-4FE2-A3AA-180D041AB08F}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "ILLink.CodeFixProvider", "..\..\tools\illink\src\ILLink.CodeFix\ILLink.CodeFixProvider.csproj", "{04FC0651-B9D0-448A-A28B-11B1D4A897F4}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "ILLink.RoslynAnalyzer", "..\..\tools\illink\src\ILLink.RoslynAnalyzer\ILLink.RoslynAnalyzer.csproj", "{683A7D28-CC55-4375-848D-E659075ECEE4}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "ILLink.Tasks", "..\..\tools\illink\src\ILLink.Tasks\ILLink.Tasks.csproj", "{1CBEAEA8-2CA1-4B07-9930-35A785205852}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Mono.Linker", "..\..\tools\illink\src\linker\Mono.Linker.csproj", "{BA7828B1-7953-47A0-AE5A-E22B501C4BD0}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Mono.Linker", "..\..\tools\illink\src\linker\ref\Mono.Linker.csproj", "{57E57290-3A6A-43F8-8764-D4DC8151F89C}" +EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "tests", "tests", "{DE94CA7D-BB10-4865-85A6-6B694631247F}" EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "ref", "ref", "{6BC42E6D-848C-4533-B715-F116E7DB3610}" @@ -21,6 +37,14 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{AB415F5A-75E EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "gen", "gen", "{083161E5-6049-4D84-9739-9D7797D7117D}" EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "gen", "gen", "{841A2FA4-A95F-4612-A8B9-AD2EF769BC71}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{DF0561A1-3AB8-4B51-AFB4-392EE1DD6247}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "ref", "ref", "{7AC4B2C7-A55C-4C4F-9B02-77F5CBFFF4AB}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "tools", "tools", "{F9C2AAB1-C7B0-4E43-BB18-4FB16F6E272B}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -31,6 +55,14 @@ Global {9F20CEA1-2216-4432-BBBD-F01E05D17F23}.Debug|Any CPU.Build.0 = Debug|Any CPU {9F20CEA1-2216-4432-BBBD-F01E05D17F23}.Release|Any CPU.ActiveCfg = Release|Any CPU {9F20CEA1-2216-4432-BBBD-F01E05D17F23}.Release|Any CPU.Build.0 = Release|Any CPU + {D311ABE4-10A9-4BB1-89CE-6358C55501A8}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {D311ABE4-10A9-4BB1-89CE-6358C55501A8}.Debug|Any CPU.Build.0 = Debug|Any CPU + {D311ABE4-10A9-4BB1-89CE-6358C55501A8}.Release|Any CPU.ActiveCfg = Release|Any CPU + {D311ABE4-10A9-4BB1-89CE-6358C55501A8}.Release|Any CPU.Build.0 = Release|Any CPU + {1578185F-C4FA-4866-936B-E62AAEDD03B7}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {1578185F-C4FA-4866-936B-E62AAEDD03B7}.Debug|Any CPU.Build.0 = Debug|Any CPU + {1578185F-C4FA-4866-936B-E62AAEDD03B7}.Release|Any CPU.ActiveCfg = Release|Any CPU + {1578185F-C4FA-4866-936B-E62AAEDD03B7}.Release|Any CPU.Build.0 = Release|Any CPU {21CB448A-3882-4337-B416-D1A3E0BCFFC5}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {21CB448A-3882-4337-B416-D1A3E0BCFFC5}.Debug|Any CPU.Build.0 = Debug|Any CPU {21CB448A-3882-4337-B416-D1A3E0BCFFC5}.Release|Any CPU.ActiveCfg = Release|Any CPU @@ -43,10 +75,6 @@ Global {4AF6A02D-82C8-4898-9EDF-01F107C25061}.Debug|Any CPU.Build.0 = Debug|Any CPU {4AF6A02D-82C8-4898-9EDF-01F107C25061}.Release|Any CPU.ActiveCfg = Release|Any CPU {4AF6A02D-82C8-4898-9EDF-01F107C25061}.Release|Any CPU.Build.0 = Release|Any CPU - {8CA7C982-3EE4-4BCE-9493-7A63556736D3}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {8CA7C982-3EE4-4BCE-9493-7A63556736D3}.Debug|Any CPU.Build.0 = Debug|Any CPU - {8CA7C982-3EE4-4BCE-9493-7A63556736D3}.Release|Any CPU.ActiveCfg = Release|Any CPU - {8CA7C982-3EE4-4BCE-9493-7A63556736D3}.Release|Any CPU.Build.0 = Release|Any CPU {4588351F-4233-4957-B84C-7F8E22B8888A}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {4588351F-4233-4957-B84C-7F8E22B8888A}.Debug|Any CPU.Build.0 = Debug|Any CPU {4588351F-4233-4957-B84C-7F8E22B8888A}.Release|Any CPU.ActiveCfg = Release|Any CPU @@ -55,20 +83,53 @@ Global {DB954E01-898A-4FE2-A3AA-180D041AB08F}.Debug|Any CPU.Build.0 = Debug|Any CPU {DB954E01-898A-4FE2-A3AA-180D041AB08F}.Release|Any CPU.ActiveCfg = Release|Any CPU {DB954E01-898A-4FE2-A3AA-180D041AB08F}.Release|Any CPU.Build.0 = Release|Any CPU + {04FC0651-B9D0-448A-A28B-11B1D4A897F4}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {04FC0651-B9D0-448A-A28B-11B1D4A897F4}.Debug|Any CPU.Build.0 = Debug|Any CPU + {04FC0651-B9D0-448A-A28B-11B1D4A897F4}.Release|Any CPU.ActiveCfg = Release|Any CPU + {04FC0651-B9D0-448A-A28B-11B1D4A897F4}.Release|Any CPU.Build.0 = Release|Any CPU + {683A7D28-CC55-4375-848D-E659075ECEE4}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {683A7D28-CC55-4375-848D-E659075ECEE4}.Debug|Any CPU.Build.0 = Debug|Any CPU + {683A7D28-CC55-4375-848D-E659075ECEE4}.Release|Any CPU.ActiveCfg = Release|Any CPU + {683A7D28-CC55-4375-848D-E659075ECEE4}.Release|Any CPU.Build.0 = Release|Any CPU + {1CBEAEA8-2CA1-4B07-9930-35A785205852}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {1CBEAEA8-2CA1-4B07-9930-35A785205852}.Debug|Any CPU.Build.0 = Debug|Any CPU + {1CBEAEA8-2CA1-4B07-9930-35A785205852}.Release|Any CPU.ActiveCfg = Release|Any CPU + {1CBEAEA8-2CA1-4B07-9930-35A785205852}.Release|Any CPU.Build.0 = Release|Any CPU + {BA7828B1-7953-47A0-AE5A-E22B501C4BD0}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {BA7828B1-7953-47A0-AE5A-E22B501C4BD0}.Debug|Any CPU.Build.0 = Debug|Any CPU + {BA7828B1-7953-47A0-AE5A-E22B501C4BD0}.Release|Any CPU.ActiveCfg = Release|Any CPU + {BA7828B1-7953-47A0-AE5A-E22B501C4BD0}.Release|Any CPU.Build.0 = Release|Any CPU + {57E57290-3A6A-43F8-8764-D4DC8151F89C}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {57E57290-3A6A-43F8-8764-D4DC8151F89C}.Debug|Any CPU.Build.0 = Debug|Any CPU + {57E57290-3A6A-43F8-8764-D4DC8151F89C}.Release|Any CPU.ActiveCfg = Release|Any CPU + {57E57290-3A6A-43F8-8764-D4DC8151F89C}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE EndGlobalSection GlobalSection(NestedProjects) = preSolution {9F20CEA1-2216-4432-BBBD-F01E05D17F23} = {DE94CA7D-BB10-4865-85A6-6B694631247F} - {4AF6A02D-82C8-4898-9EDF-01F107C25061} = {DE94CA7D-BB10-4865-85A6-6B694631247F} + {D311ABE4-10A9-4BB1-89CE-6358C55501A8} = {6BC42E6D-848C-4533-B715-F116E7DB3610} + {1578185F-C4FA-4866-936B-E62AAEDD03B7} = {AB415F5A-75E5-4E03-8A92-15CEDEC4CD3A} {21CB448A-3882-4337-B416-D1A3E0BCFFC5} = {6BC42E6D-848C-4533-B715-F116E7DB3610} {848DD000-3D22-4A25-A9D9-05AFF857A116} = {AB415F5A-75E5-4E03-8A92-15CEDEC4CD3A} - {8CA7C982-3EE4-4BCE-9493-7A63556736D3} = {083161E5-6049-4D84-9739-9D7797D7117D} + {4AF6A02D-82C8-4898-9EDF-01F107C25061} = {DE94CA7D-BB10-4865-85A6-6B694631247F} {4588351F-4233-4957-B84C-7F8E22B8888A} = {083161E5-6049-4D84-9739-9D7797D7117D} {DB954E01-898A-4FE2-A3AA-180D041AB08F} = {083161E5-6049-4D84-9739-9D7797D7117D} + {04FC0651-B9D0-448A-A28B-11B1D4A897F4} = {841A2FA4-A95F-4612-A8B9-AD2EF769BC71} + {683A7D28-CC55-4375-848D-E659075ECEE4} = {841A2FA4-A95F-4612-A8B9-AD2EF769BC71} + {1CBEAEA8-2CA1-4B07-9930-35A785205852} = {DF0561A1-3AB8-4B51-AFB4-392EE1DD6247} + {BA7828B1-7953-47A0-AE5A-E22B501C4BD0} = {DF0561A1-3AB8-4B51-AFB4-392EE1DD6247} + {57E57290-3A6A-43F8-8764-D4DC8151F89C} = {7AC4B2C7-A55C-4C4F-9B02-77F5CBFFF4AB} + {841A2FA4-A95F-4612-A8B9-AD2EF769BC71} = {F9C2AAB1-C7B0-4E43-BB18-4FB16F6E272B} + {DF0561A1-3AB8-4B51-AFB4-392EE1DD6247} = {F9C2AAB1-C7B0-4E43-BB18-4FB16F6E272B} + {7AC4B2C7-A55C-4C4F-9B02-77F5CBFFF4AB} = {F9C2AAB1-C7B0-4E43-BB18-4FB16F6E272B} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {10A5F2C3-5230-4916-9D4D-BBDB94851037} EndGlobalSection -EndGlobal + GlobalSection(SharedMSBuildProjectFiles) = preSolution + ..\..\tools\illink\src\ILLink.Shared\ILLink.Shared.projitems*{683a7d28-cc55-4375-848d-e659075ecee4}*SharedItemsImports = 5 + ..\..\tools\illink\src\ILLink.Shared\ILLink.Shared.projitems*{ba7828b1-7953-47a0-ae5a-e22b501c4bd0}*SharedItemsImports = 5 + EndGlobalSection +EndGlobal \ No newline at end of file diff --git a/src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.cs b/src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.cs index 3161a4c7e780ce..99bd4703574e55 100644 --- a/src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.cs +++ b/src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.cs @@ -6,151 +6,53 @@ namespace System.Numerics.Tensors { - public static partial class ArrayTensorExtensions + public static partial class TensorPrimitives { - public static System.Numerics.Tensors.CompressedSparseTensor ToCompressedSparseTensor(this System.Array array, bool reverseStride = false) { throw null; } - public static System.Numerics.Tensors.CompressedSparseTensor ToCompressedSparseTensor(this T[,,] array, bool reverseStride = false) { throw null; } - public static System.Numerics.Tensors.CompressedSparseTensor ToCompressedSparseTensor(this T[,] array, bool reverseStride = false) { throw null; } - public static System.Numerics.Tensors.CompressedSparseTensor ToCompressedSparseTensor(this T[] array) { throw null; } - public static System.Numerics.Tensors.SparseTensor ToSparseTensor(this System.Array array, bool reverseStride = false) { throw null; } - public static System.Numerics.Tensors.SparseTensor ToSparseTensor(this T[,,] array, bool reverseStride = false) { throw null; } - public static System.Numerics.Tensors.SparseTensor ToSparseTensor(this T[,] array, bool reverseStride = false) { throw null; } - public static System.Numerics.Tensors.SparseTensor ToSparseTensor(this T[] array) { throw null; } - public static System.Numerics.Tensors.DenseTensor ToTensor(this System.Array array, bool reverseStride = false) { throw null; } - public static System.Numerics.Tensors.DenseTensor ToTensor(this T[,,] array, bool reverseStride = false) { throw null; } - public static System.Numerics.Tensors.DenseTensor ToTensor(this T[,] array, bool reverseStride = false) { throw null; } - public static System.Numerics.Tensors.DenseTensor ToTensor(this T[] array) { throw null; } - } - public partial class CompressedSparseTensor : System.Numerics.Tensors.Tensor - { - public CompressedSparseTensor(System.Memory values, System.Memory compressedCounts, System.Memory indices, int nonZeroCount, System.ReadOnlySpan dimensions, bool reverseStride = false) : base (default(System.Array), default(bool)) { } - public CompressedSparseTensor(System.ReadOnlySpan dimensions, bool reverseStride = false) : base (default(System.Array), default(bool)) { } - public CompressedSparseTensor(System.ReadOnlySpan dimensions, int capacity, bool reverseStride = false) : base (default(System.Array), default(bool)) { } - public int Capacity { get { throw null; } } - public System.Memory CompressedCounts { get { throw null; } } - public System.Memory Indices { get { throw null; } } - public override T this[System.ReadOnlySpan indices] { get { throw null; } set { } } - public int NonZeroCount { get { throw null; } } - public System.Memory Values { get { throw null; } } - public override System.Numerics.Tensors.Tensor Clone() { throw null; } - public override System.Numerics.Tensors.Tensor CloneEmpty(System.ReadOnlySpan dimensions) { throw null; } - public override T GetValue(int index) { throw null; } - public override System.Numerics.Tensors.Tensor Reshape(System.ReadOnlySpan dimensions) { throw null; } - public override void SetValue(int index, T value) { } - public override System.Numerics.Tensors.CompressedSparseTensor ToCompressedSparseTensor() { throw null; } - public override System.Numerics.Tensors.DenseTensor ToDenseTensor() { throw null; } - public override System.Numerics.Tensors.SparseTensor ToSparseTensor() { throw null; } - } - public partial class DenseTensor : System.Numerics.Tensors.Tensor - { - public DenseTensor(int length) : base (default(System.Array), default(bool)) { } - public DenseTensor(System.Memory memory, System.ReadOnlySpan dimensions, bool reverseStride = false) : base (default(System.Array), default(bool)) { } - public DenseTensor(System.ReadOnlySpan dimensions, bool reverseStride = false) : base (default(System.Array), default(bool)) { } - public System.Memory Buffer { get { throw null; } } - public override System.Numerics.Tensors.Tensor Clone() { throw null; } - public override System.Numerics.Tensors.Tensor CloneEmpty(System.ReadOnlySpan dimensions) { throw null; } - protected override void CopyTo(T[] array, int arrayIndex) { } - public override T GetValue(int index) { throw null; } - protected override int IndexOf(T item) { throw null; } - public override System.Numerics.Tensors.Tensor Reshape(System.ReadOnlySpan dimensions) { throw null; } - public override void SetValue(int index, T value) { } - } - public partial class SparseTensor : System.Numerics.Tensors.Tensor - { - public SparseTensor(System.ReadOnlySpan dimensions, bool reverseStride = false, int capacity = 0) : base (default(System.Array), default(bool)) { } - public int NonZeroCount { get { throw null; } } - public override System.Numerics.Tensors.Tensor Clone() { throw null; } - public override System.Numerics.Tensors.Tensor CloneEmpty(System.ReadOnlySpan dimensions) { throw null; } - public override T GetValue(int index) { throw null; } - public override System.Numerics.Tensors.Tensor Reshape(System.ReadOnlySpan dimensions) { throw null; } - public override void SetValue(int index, T value) { } - public override System.Numerics.Tensors.CompressedSparseTensor ToCompressedSparseTensor() { throw null; } - public override System.Numerics.Tensors.DenseTensor ToDenseTensor() { throw null; } - public override System.Numerics.Tensors.SparseTensor ToSparseTensor() { throw null; } - } - public static partial class Tensor - { - public static System.Numerics.Tensors.Tensor CreateFromDiagonal(System.Numerics.Tensors.Tensor diagonal) { throw null; } - public static System.Numerics.Tensors.Tensor CreateFromDiagonal(System.Numerics.Tensors.Tensor diagonal, int offset) { throw null; } - public static System.Numerics.Tensors.Tensor CreateIdentity(int size) { throw null; } - public static System.Numerics.Tensors.Tensor CreateIdentity(int size, bool columMajor) { throw null; } - public static System.Numerics.Tensors.Tensor CreateIdentity(int size, bool columMajor, T oneValue) { throw null; } - } - public abstract partial class Tensor : System.Collections.Generic.ICollection, System.Collections.Generic.IEnumerable, System.Collections.Generic.IList, System.Collections.Generic.IReadOnlyCollection, System.Collections.Generic.IReadOnlyList, System.Collections.ICollection, System.Collections.IEnumerable, System.Collections.IList, System.Collections.IStructuralComparable, System.Collections.IStructuralEquatable - { - protected Tensor(System.Array fromArray, bool reverseStride) { } - protected Tensor(int length) { } - protected Tensor(System.ReadOnlySpan dimensions, bool reverseStride) { } - public System.ReadOnlySpan Dimensions { get { throw null; } } - public bool IsFixedSize { get { throw null; } } - public bool IsReadOnly { get { throw null; } } - public bool IsReversedStride { get { throw null; } } - public virtual T this[params int[] indices] { get { throw null; } set { } } - public virtual T this[System.ReadOnlySpan indices] { get { throw null; } set { } } - public long Length { get { throw null; } } - public int Rank { get { throw null; } } - public System.ReadOnlySpan Strides { get { throw null; } } - int System.Collections.Generic.ICollection.Count { get { throw null; } } - T System.Collections.Generic.IList.this[int index] { get { throw null; } set { } } - int System.Collections.Generic.IReadOnlyCollection.Count { get { throw null; } } - T System.Collections.Generic.IReadOnlyList.this[int index] { get { throw null; } } - int System.Collections.ICollection.Count { get { throw null; } } - bool System.Collections.ICollection.IsSynchronized { get { throw null; } } - object System.Collections.ICollection.SyncRoot { get { throw null; } } - object? System.Collections.IList.this[int index] { get { throw null; } set { } } - public abstract System.Numerics.Tensors.Tensor Clone(); - public virtual System.Numerics.Tensors.Tensor CloneEmpty() { throw null; } - public virtual System.Numerics.Tensors.Tensor CloneEmpty(System.ReadOnlySpan dimensions) { throw null; } - public virtual System.Numerics.Tensors.Tensor CloneEmpty() { throw null; } - public abstract System.Numerics.Tensors.Tensor CloneEmpty(System.ReadOnlySpan dimensions); - public static int Compare(System.Numerics.Tensors.Tensor left, System.Numerics.Tensors.Tensor right) { throw null; } - protected virtual bool Contains(T item) { throw null; } - protected virtual void CopyTo(T[] array, int arrayIndex) { } - public static bool Equals(System.Numerics.Tensors.Tensor left, System.Numerics.Tensors.Tensor right) { throw null; } - public virtual void Fill(T value) { } - public string GetArrayString(bool includeWhitespace = true) { throw null; } - public System.Numerics.Tensors.Tensor GetDiagonal() { throw null; } - public System.Numerics.Tensors.Tensor GetDiagonal(int offset) { throw null; } - public System.Numerics.Tensors.Tensor GetTriangle() { throw null; } - public System.Numerics.Tensors.Tensor GetTriangle(int offset) { throw null; } - public System.Numerics.Tensors.Tensor GetUpperTriangle() { throw null; } - public System.Numerics.Tensors.Tensor GetUpperTriangle(int offset) { throw null; } - public abstract T GetValue(int index); - protected virtual int IndexOf(T item) { throw null; } - public abstract System.Numerics.Tensors.Tensor Reshape(System.ReadOnlySpan dimensions); - public abstract void SetValue(int index, T value); - public struct Enumerator : System.Collections.Generic.IEnumerator - { - public T Current { get; private set; } - object? System.Collections.IEnumerator.Current => throw null; - public bool MoveNext() => throw null; - public void Reset() { } - public void Dispose() { } - } - public Enumerator GetEnumerator() => throw null; - void System.Collections.Generic.ICollection.Add(T item) { } - void System.Collections.Generic.ICollection.Clear() { } - bool System.Collections.Generic.ICollection.Contains(T item) { throw null; } - void System.Collections.Generic.ICollection.CopyTo(T[] array, int arrayIndex) { } - bool System.Collections.Generic.ICollection.Remove(T item) { throw null; } - System.Collections.Generic.IEnumerator System.Collections.Generic.IEnumerable.GetEnumerator() { throw null; } - int System.Collections.Generic.IList.IndexOf(T item) { throw null; } - void System.Collections.Generic.IList.Insert(int index, T item) { } - void System.Collections.Generic.IList.RemoveAt(int index) { } - void System.Collections.ICollection.CopyTo(System.Array array, int index) { } - System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() { throw null; } - int System.Collections.IList.Add(object? value) { throw null; } - void System.Collections.IList.Clear() { } - bool System.Collections.IList.Contains(object? value) { throw null; } - int System.Collections.IList.IndexOf(object? value) { throw null; } - void System.Collections.IList.Insert(int index, object? value) { } - void System.Collections.IList.Remove(object? value) { } - void System.Collections.IList.RemoveAt(int index) { } - int System.Collections.IStructuralComparable.CompareTo(object? other, System.Collections.IComparer comparer) { throw null; } - bool System.Collections.IStructuralEquatable.Equals(object? other, System.Collections.IEqualityComparer comparer) { throw null; } - int System.Collections.IStructuralEquatable.GetHashCode(System.Collections.IEqualityComparer comparer) { throw null; } - public virtual System.Numerics.Tensors.CompressedSparseTensor ToCompressedSparseTensor() { throw null; } - public virtual System.Numerics.Tensors.DenseTensor ToDenseTensor() { throw null; } - public virtual System.Numerics.Tensors.SparseTensor ToSparseTensor() { throw null; } + public static void Abs(System.ReadOnlySpan x, System.Span destination) { } + public static void Add(System.ReadOnlySpan x, System.ReadOnlySpan y, System.Span destination) { } + public static void Add(System.ReadOnlySpan x, float y, System.Span destination) { } + public static void AddMultiply(System.ReadOnlySpan x, System.ReadOnlySpan y, System.ReadOnlySpan multiplier, System.Span destination) { } + public static void AddMultiply(System.ReadOnlySpan x, System.ReadOnlySpan y, float multiplier, System.Span destination) { } + public static void AddMultiply(System.ReadOnlySpan x, float y, System.ReadOnlySpan multiplier, System.Span destination) { } + public static void Cosh(System.ReadOnlySpan x, System.Span destination) { } + public static float CosineSimilarity(System.ReadOnlySpan x, System.ReadOnlySpan y) { throw null; } + public static float Distance(System.ReadOnlySpan x, System.ReadOnlySpan y) { throw null; } + public static void Divide(System.ReadOnlySpan x, System.ReadOnlySpan y, System.Span destination) { } + public static void Divide(System.ReadOnlySpan x, float y, System.Span destination) { } + public static float Dot(System.ReadOnlySpan x, System.ReadOnlySpan y) { throw null; } + public static void Exp(System.ReadOnlySpan x, System.Span destination) { } + public static int IndexOfMax(System.ReadOnlySpan x) { throw null; } + public static int IndexOfMaxMagnitude(System.ReadOnlySpan x) { throw null; } + public static int IndexOfMin(System.ReadOnlySpan x) { throw null; } + public static int IndexOfMinMagnitude(System.ReadOnlySpan x) { throw null; } + public static float Norm(System.ReadOnlySpan x) { throw null; } + public static void Log(System.ReadOnlySpan x, System.Span destination) { } + public static void Log2(System.ReadOnlySpan x, System.Span destination) { } + public static float Max(System.ReadOnlySpan x) { throw null; } + public static void Max(System.ReadOnlySpan x, System.ReadOnlySpan y, System.Span destination) { throw null; } + public static float MaxMagnitude(System.ReadOnlySpan x) { throw null; } + public static void MaxMagnitude(System.ReadOnlySpan x, System.ReadOnlySpan y, System.Span destination) { throw null; } + public static float Min(System.ReadOnlySpan x) { throw null; } + public static void Min(System.ReadOnlySpan x, System.ReadOnlySpan y, System.Span destination) { throw null; } + public static float MinMagnitude(System.ReadOnlySpan x) { throw null; } + public static void MinMagnitude(System.ReadOnlySpan x, System.ReadOnlySpan y, System.Span destination) { throw null; } + public static void Multiply(System.ReadOnlySpan x, System.ReadOnlySpan y, System.Span destination) { } + public static void Multiply(System.ReadOnlySpan x, float y, System.Span destination) { } + public static void MultiplyAdd(System.ReadOnlySpan x, System.ReadOnlySpan y, System.ReadOnlySpan addend, System.Span destination) { } + public static void MultiplyAdd(System.ReadOnlySpan x, System.ReadOnlySpan y, float addend, System.Span destination) { } + public static void MultiplyAdd(System.ReadOnlySpan x, float y, System.ReadOnlySpan addend, System.Span destination) { } + public static void Negate(System.ReadOnlySpan x, System.Span destination) { } + public static float Product(System.ReadOnlySpan x) { throw null; } + public static float ProductOfDifferences(System.ReadOnlySpan x, System.ReadOnlySpan y) { throw null; } + public static float ProductOfSums(System.ReadOnlySpan x, System.ReadOnlySpan y) { throw null; } + public static void Sigmoid(System.ReadOnlySpan x, System.Span destination) { } + public static void Sinh(System.ReadOnlySpan x, System.Span destination) { } + public static void SoftMax(System.ReadOnlySpan x, System.Span destination) { } + public static void Subtract(System.ReadOnlySpan x, System.ReadOnlySpan y, System.Span destination) { } + public static void Subtract(System.ReadOnlySpan x, float y, System.Span destination) { } + public static float Sum(System.ReadOnlySpan x) { throw null; } + public static float SumOfMagnitudes(System.ReadOnlySpan x) { throw null; } + public static float SumOfSquares(System.ReadOnlySpan x) { throw null; } + public static void Tanh(System.ReadOnlySpan x, System.Span destination) { } } } diff --git a/src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.csproj b/src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.csproj index cabfe50e267cfb..8d28b0e077d9cd 100644 --- a/src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.csproj +++ b/src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.csproj @@ -1,4 +1,5 @@ + $(NetCoreAppCurrent);$(NetCoreAppPrevious);$(NetCoreAppMinimum);netstandard2.0;$(NetFrameworkMinimum) @@ -7,7 +8,12 @@ + + + + + \ No newline at end of file diff --git a/src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.netcore.cs b/src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.netcore.cs new file mode 100644 index 00000000000000..1cde4351546b26 --- /dev/null +++ b/src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.netcore.cs @@ -0,0 +1,14 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// ------------------------------------------------------------------------------ +// Changes to this file must follow the https://aka.ms/api-review process. +// ------------------------------------------------------------------------------ + +namespace System.Numerics.Tensors +{ + public static partial class TensorPrimitives + { + public static void ConvertToHalf(System.ReadOnlySpan source, System.Span destination) { throw null; } + public static void ConvertToSingle(System.ReadOnlySpan source, System.Span destination) { throw null; } + } +} diff --git a/src/libraries/System.Numerics.Tensors/src/PACKAGE.md b/src/libraries/System.Numerics.Tensors/src/PACKAGE.md new file mode 100644 index 00000000000000..c5670c1c0f9893 --- /dev/null +++ b/src/libraries/System.Numerics.Tensors/src/PACKAGE.md @@ -0,0 +1,53 @@ +## About + +Provides methods for performing mathematical operations over _tensors_ represented as spans. These methods are accelerated to use SIMD (Single instruction, multiple data) operations supported by the CPU where available. + +## Key Features + +* Numerical operations on tensors represented as `ReadOnlySpan` +* Element-wise arithmetic: Add, Subtract, Multiply, Divide, Exp, Log, Cosh, Tanh, etc. +* Tensor arithmetic: CosineSimilarity, Distance, Dot, Normalize, Softmax, Sigmoid, etc. + +## How to Use + +```C# +using System.Numerics.Tensors; + +var movies = new[] { + new { Title="The Lion King", Embedding= new [] { 0.10022575f, -0.23998135f } }, + new { Title="Inception", Embedding= new [] { 0.10327095f, 0.2563685f } }, + new { Title="Toy Story", Embedding= new [] { 0.095857024f, -0.201278f } }, + new { Title="Pulp Function", Embedding= new [] { 0.106827796f, 0.21676421f } }, + new { Title="Shrek", Embedding= new [] { 0.09568083f, -0.21177962f } } +}; +var queryEmbedding = new[] { 0.12217915f, -0.034832448f }; + +var top3MoviesTensorPrimitives = + movies + .Select(movie => + ( + movie.Title, + Similarity: TensorPrimitives.CosineSimilarity(queryEmbedding, movie.Embedding) + )) + .OrderByDescending(movies => movies.Similarity) + .Take(3); + +foreach (var movie in top3MoviesTensorPrimitives) +{ + Console.WriteLine(movie); +} +``` + +## Main Types + +The main types provided by this library are: + +* `System.Numerics.Tensors.TensorPrimitives` + +## Additional Documentation + +* [API documentation](https://learn.microsoft.com/en-us/dotnet/api/system.numerics.tensors) + +## Feedback & Contributing + +System.Numerics.Tensors is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). diff --git a/src/libraries/System.Numerics.Tensors/src/Properties/InternalsVisibleTo.cs b/src/libraries/System.Numerics.Tensors/src/Properties/InternalsVisibleTo.cs deleted file mode 100644 index 2b044e474d570a..00000000000000 --- a/src/libraries/System.Numerics.Tensors/src/Properties/InternalsVisibleTo.cs +++ /dev/null @@ -1,6 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.Runtime.CompilerServices; - -[assembly: InternalsVisibleTo("System.Numerics.Tensors.Tests, PublicKey=00240000048000009400000006020000002400005253413100040000010001004b86c4cb78549b34bab61a3b1800e23bfeb5b3ec390074041536a7e3cbd97f5f04cf0f857155a8928eaa29ebfd11cfbbad3ba70efea7bda3226c6a8d370a4cd303f714486b6ebc225985a638471e6ef571cc92a4613c00b8fa65d61ccee0cbe5f36330c9a01f4183559f1bef24cc2917c6d913e3a541333a1d05d9bed22b38cb")] diff --git a/src/libraries/System.Numerics.Tensors/src/ReferenceAssemblyExclusions.txt b/src/libraries/System.Numerics.Tensors/src/ReferenceAssemblyExclusions.txt new file mode 100644 index 00000000000000..a8f2d0192cfec9 --- /dev/null +++ b/src/libraries/System.Numerics.Tensors/src/ReferenceAssemblyExclusions.txt @@ -0,0 +1,2 @@ +M:System.Numerics.Tensors.TensorPrimitives.ConvertToHalf(System.ReadOnlySpan{System.Single},System.Span{System.Half}) +M:System.Numerics.Tensors.TensorPrimitives.ConvertToSingle(System.ReadOnlySpan{System.Half},System.Span{System.Single}) \ No newline at end of file diff --git a/src/libraries/System.Numerics.Tensors/src/Resources/Strings.resx b/src/libraries/System.Numerics.Tensors/src/Resources/Strings.resx index 57f792dde51dc2..86b9f4d82b1f61 100644 --- a/src/libraries/System.Numerics.Tensors/src/Resources/Strings.resx +++ b/src/libraries/System.Numerics.Tensors/src/Resources/Strings.resx @@ -117,52 +117,16 @@ System.Resources.ResXResourceWriter, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 - - Array must contain elements. + + Destination is too short. - - Cannot compare {0} to {1}. + + Input span arguments must not be empty. - - Cannot compare {0} to {1} with different dimension {2}, {3} != {4}. + + Input span arguments must all have the same length. - - Cannot compare {0} with different dimension {1}, {2} != {3}. + + The destination span may only overlap with an input span if the two spans start at the same memory location. - - Cannot compare {0} with Rank {1} to {2} with Rank {3}. - - - Cannot compute diagonal of {0} with Rank less than 2. - - - Cannot compute diagonal with offset {0}. - - - Tensor {0} must have at least one dimension. - - - Cannot compute triangle of {0} with Rank less than 2. - - - Cannot reshape array due to mismatch in lengths, currently {0} would become {1}. - - - Dimensions must be positive and non-zero. - - - Dimensions must contain elements. - - - Length of {0} ({1}) must match product of {2} ({3}). - - - The number of elements in the Tensor is greater than the available space from index to the end of the destination array. - - - Only single dimensional arrays are supported for the requested action. - - - The value "{0}" is not of type "{1}" and cannot be used in this generic collection. - - \ No newline at end of file + diff --git a/src/libraries/System.Numerics.Tensors/src/System.Numerics.Tensors.csproj b/src/libraries/System.Numerics.Tensors/src/System.Numerics.Tensors.csproj index 521f8055a4fa61..52c6cb65811e68 100644 --- a/src/libraries/System.Numerics.Tensors/src/System.Numerics.Tensors.csproj +++ b/src/libraries/System.Numerics.Tensors/src/System.Numerics.Tensors.csproj @@ -1,30 +1,37 @@ + $(NetCoreAppCurrent);$(NetCoreAppPrevious);$(NetCoreAppMinimum);netstandard2.0;$(NetFrameworkMinimum) true - - $(NoWarn);1591 true - Tensor class which represents and extends multi-dimensional arrays. - -Commonly Used Types: -System.Numerics.Tensors.Tensor<T> -System.Numerics.Tensors.CompressedSparseTensor<T> -System.Numerics.Tensors.DenseTensor<T> -System.Numerics.Tensors.SparseTensor<T> + Provides support for operating over tensors. + + true + ReferenceAssemblyExclusions.txt - - - - - - - + + + + + + + + + + + + + + + + diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/ArrayTensorExtensions.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/ArrayTensorExtensions.cs deleted file mode 100644 index 05e12a62c990e4..00000000000000 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/ArrayTensorExtensions.cs +++ /dev/null @@ -1,149 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -namespace System.Numerics.Tensors -{ - public static class ArrayTensorExtensions - { - /// - /// Creates a copy of this single-dimensional array as a DenseTensor<T> - /// - /// Type contained in the array to copy to the DenseTensor<T>. - /// The array to create a DenseTensor<T> from. - /// A 1-dimensional DenseTensor<T> with the same length and content as . - public static DenseTensor ToTensor(this T[] array) - { - return new DenseTensor(array); - } - - /// - /// Creates a copy of this two-dimensional array as a DenseTensor<T> - /// - /// Type contained in the array to copy to the DenseTensor<T>. - /// The array to create a DenseTensor<T> from. - /// False (default) to indicate that the first dimension is most major (farthest apart) and the last dimension is most minor (closest together): row-major. True to indicate that the last dimension is most major (farthest apart) and the first dimension is most minor (closest together): column-major. - /// A 2-dimensional DenseTensor<T> with the same dimensions and content as . - public static DenseTensor ToTensor(this T[,] array, bool reverseStride = false) - { - return new DenseTensor(array, reverseStride); - } - - /// - /// Creates a copy of this three-dimensional array as a DenseTensor<T> - /// - /// Type contained in the array to copy to the DenseTensor<T>. - /// The array to create a DenseTensor<T> from. - /// False (default) to indicate that the first dimension is most major (farthest apart) and the last dimension is most minor (closest together): akin to row-major in a rank-2 tensor. True to indicate that the last dimension is most major (farthest apart) and the first dimension is most minor (closest together): akin to column-major in a rank-2 tensor. - /// A 3-dimensional DenseTensor<T> with the same dimensions and content as . - public static DenseTensor ToTensor(this T[,,] array, bool reverseStride = false) - { - return new DenseTensor(array, reverseStride); - } - - /// - /// Creates a copy of this n-dimensional array as a DenseTensor<T> - /// - /// Type contained in the array to copy to the DenseTensor<T>. - /// The array to create a DenseTensor<T> from. - /// False (default) to indicate that the first dimension is most major (farthest apart) and the last dimension is most minor (closest together): akin to row-major in a rank-2 tensor. True to indicate that the last dimension is most major (farthest apart) and the first dimension is most minor (closest together): akin to column-major in a rank-2 tensor. - /// A n-dimensional DenseTensor<T> with the same dimensions and content as . - public static DenseTensor ToTensor(this Array array, bool reverseStride = false) - { - return new DenseTensor(array, reverseStride); - } - - /// - /// Creates a copy of this single-dimensional array as a SparseTensor<T> - /// - /// Type contained in the array to copy to the SparseTensor<T>. - /// The array to create a SparseTensor<T> from. - /// A 1-dimensional SparseTensor<T> with the same length and content as . - public static SparseTensor ToSparseTensor(this T[] array) - { - return new SparseTensor(array); - } - - /// - /// Creates a copy of this two-dimensional array as a SparseTensor<T> - /// - /// Type contained in the array to copy to the SparseTensor<T>. - /// The array to create a SparseTensor<T> from. - /// False (default) to indicate that the first dimension is most major (farthest apart) and the last dimension is most minor (closest together): row-major. True to indicate that the last dimension is most major (farthest apart) and the first dimension is most minor (closest together): column-major. - /// A 2-dimensional SparseTensor<T> with the same dimensions and content as . - public static SparseTensor ToSparseTensor(this T[,] array, bool reverseStride = false) - { - return new SparseTensor(array, reverseStride); - } - - /// - /// Creates a copy of this three-dimensional array as a SparseTensor<T> - /// - /// Type contained in the array to copy to the SparseTensor<T>. - /// The array to create a SparseTensor<T> from. - /// False (default) to indicate that the first dimension is most major (farthest apart) and the last dimension is most minor (closest together): akin to row-major in a rank-2 tensor. True to indicate that the last dimension is most major (farthest apart) and the first dimension is most minor (closest together): akin to column-major in a rank-2 tensor. - /// A 3-dimensional SparseTensor<T> with the same dimensions and content as . - public static SparseTensor ToSparseTensor(this T[,,] array, bool reverseStride = false) - { - return new SparseTensor(array, reverseStride); - } - - /// - /// Creates a copy of this n-dimensional array as a SparseTensor<T> - /// - /// Type contained in the array to copy to the SparseTensor<T>. - /// The array to create a SparseTensor<T> from. - /// False (default) to indicate that the first dimension is most major (farthest apart) and the last dimension is most minor (closest together): akin to row-major in a rank-2 tensor. True to indicate that the last dimension is most major (farthest apart) and the first dimension is most minor (closest together): akin to column-major in a rank-2 tensor. - /// A n-dimensional SparseTensor<T> with the same dimensions and content as . - public static SparseTensor ToSparseTensor(this Array array, bool reverseStride = false) - { - return new SparseTensor(array, reverseStride); - } - - /// - /// Creates a copy of this single-dimensional array as a CompressedSparseTensor<T> - /// - /// Type contained in the array to copy to the CompressedSparseTensor<T>. - /// The array to create a CompressedSparseTensor<T> from. - /// A 1-dimensional CompressedSparseTensor<T> with the same length and content as . - public static CompressedSparseTensor ToCompressedSparseTensor(this T[] array) - { - return new CompressedSparseTensor(array); - } - - /// - /// Creates a copy of this two-dimensional array as a CompressedSparseTensor<T> - /// - /// Type contained in the array to copy to the CompressedSparseTensor<T>. - /// The array to create a CompressedSparseTensor<T> from. - /// False (default) to indicate that the first dimension is most major (farthest apart) and the last dimension is most minor (closest together): row-major. True to indicate that the last dimension is most major (farthest apart) and the first dimension is most minor (closest together): column-major. - /// A 2-dimensional CompressedSparseTensor<T> with the same dimensions and content as . - public static CompressedSparseTensor ToCompressedSparseTensor(this T[,] array, bool reverseStride = false) - { - return new CompressedSparseTensor(array, reverseStride); - } - - /// - /// Creates a copy of this three-dimensional array as a CompressedSparseTensor<T> - /// - /// Type contained in the array to copy to the CompressedSparseTensor<T>. - /// The array to create a CompressedSparseTensor<T> from. - /// False (default) to indicate that the first dimension is most major (farthest apart) and the last dimension is most minor (closest together): akin to row-major in a rank-2 tensor. True to indicate that the last dimension is most major (farthest apart) and the first dimension is most minor (closest together): akin to column-major in a rank-2 tensor. - /// A 3-dimensional CompressedSparseTensor<T> with the same dimensions and content as . - public static CompressedSparseTensor ToCompressedSparseTensor(this T[,,] array, bool reverseStride = false) - { - return new CompressedSparseTensor(array, reverseStride); - } - - /// - /// Creates a copy of this n-dimensional array as a CompressedSparseTensor<T> - /// - /// Type contained in the array to copy to the CompressedSparseTensor<T>. - /// The array to create a CompressedSparseTensor<T> from. - /// False (default) to indicate that the first dimension is most major (farthest apart) and the last dimension is most minor (closest together): akin to row-major in a rank-2 tensor. True to indicate that the last dimension is most major (farthest apart) and the first dimension is most minor (closest together): akin to column-major in a rank-2 tensor. - /// A n-dimensional CompressedSparseTensor<T> with the same dimensions and content as . - public static CompressedSparseTensor ToCompressedSparseTensor(this Array array, bool reverseStride = false) - { - return new CompressedSparseTensor(array, reverseStride); - } - } -} diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/ArrayUtilities.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/ArrayUtilities.cs deleted file mode 100644 index 97152090dfea06..00000000000000 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/ArrayUtilities.cs +++ /dev/null @@ -1,216 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.Diagnostics; - -namespace System.Numerics.Tensors -{ - internal static class ArrayUtilities - { - public const int StackallocMax = 16; - - public static long GetProduct(ReadOnlySpan dimensions, int startIndex = 0) - { - if (dimensions.Length == 0) - { - return 0; - } - - long product = 1; - for (int i = startIndex; i < dimensions.Length; i++) - { - if (dimensions[i] < 0) - { - throw new ArgumentOutOfRangeException($"{nameof(dimensions)}[{i}]"); - } - - // we use a long which should be much larger than is ever used here, - // but still force checked - checked - { - product *= dimensions[i]; - } - } - - return product; - } - - public static bool IsAscending(ReadOnlySpan values) - { - for (int i = 1; i < values.Length; i++) - { - if (values[i] < values[i - 1]) - { - return false; - } - } - - return true; - } - - public static bool IsDescending(ReadOnlySpan values) - { - for (int i = 1; i < values.Length; i++) - { - if (values[i] > values[i - 1]) - { - return false; - } - } - - return true; - } - - /// - /// Gets the set of strides that can be used to calculate the offset of n-dimensions in a 1-dimensional layout - /// - /// - /// - /// - public static int[] GetStrides(ReadOnlySpan dimensions, bool reverseStride = false) - { - int[] strides = new int[dimensions.Length]; - - int stride = 1; - if (reverseStride) - { - for (int i = 0; i < strides.Length; i++) - { - strides[i] = stride; - stride *= dimensions[i]; - } - } - else - { - for (int i = strides.Length - 1; i >= 0; i--) - { - strides[i] = stride; - stride *= dimensions[i]; - } - } - - return strides; - } - - public static void SplitStrides(int[] strides, int[] splitAxes, int[] newStrides, int stridesOffset, int[] splitStrides, int splitStridesOffset) - { - int newStrideIndex = 0; - for (int i = 0; i < strides.Length; i++) - { - int stride = strides[i]; - bool isSplit = false; - for (int j = 0; j < splitAxes.Length; j++) - { - if (splitAxes[j] == i) - { - splitStrides[splitStridesOffset + j] = stride; - isSplit = true; - break; - } - } - - if (!isSplit) - { - newStrides[stridesOffset + newStrideIndex++] = stride; - } - } - } - - /// - /// Calculates the 1-d index for n-d indices in layout specified by strides. - /// - /// - /// - /// - /// - public static int GetIndex(int[] strides, ReadOnlySpan indices, int startFromDimension = 0) - { - Debug.Assert(strides.Length == indices.Length); - - int index = 0; - for (int i = startFromDimension; i < indices.Length; i++) - { - index += strides[i] * indices[i]; - } - - return index; - } - - /// - /// Calculates the n-d indices from the 1-d index in a layout specified by strides - /// - /// - /// - /// - /// - /// - public static void GetIndices(ReadOnlySpan strides, bool reverseStride, int index, int[] indices, int startFromDimension = 0) - { - Debug.Assert(reverseStride ? IsAscending(strides) : IsDescending(strides), "Index decomposition requires ordered strides"); - Debug.Assert(strides.Length == indices.Length); - - int remainder = index; - for (int i = startFromDimension; i < strides.Length; i++) - { - // reverse the index for reverseStride so that we divide by largest stride first - var nIndex = reverseStride ? strides.Length - 1 - i : i; - - var stride = strides[nIndex]; - indices[nIndex] = remainder / stride; - remainder %= stride; - } - } - - /// - /// Calculates the n-d indices from the 1-d index in a layout specified by strides - /// - /// - /// - /// - /// - /// - public static void GetIndices(ReadOnlySpan strides, bool reverseStride, int index, Span indices, int startFromDimension = 0) - { - Debug.Assert(reverseStride ? IsAscending(strides) : IsDescending(strides), "Index decomposition requires ordered strides"); - Debug.Assert(strides.Length == indices.Length); - - int remainder = index; - for (int i = startFromDimension; i < strides.Length; i++) - { - // reverse the index for reverseStride so that we divide by largest stride first - var nIndex = reverseStride ? strides.Length - 1 - i : i; - - var stride = strides[nIndex]; - indices[nIndex] = remainder / stride; - remainder %= stride; - } - } - - /// - /// Takes an 1-d index over n-d sourceStrides and recalculates it assuming same n-d coordinates over a different n-d strides - /// - public static int TransformIndexByStrides(int index, int[] sourceStrides, bool sourceReverseStride, int[] transformStrides) - { - Debug.Assert(index >= 0); - Debug.Assert(sourceReverseStride ? IsAscending(sourceStrides) : IsDescending(sourceStrides), "Index decomposition requires ordered strides"); - Debug.Assert(sourceStrides.Length == transformStrides.Length); - - int transformIndex = 0; - int remainder = index; - - for (int i = 0; i < sourceStrides.Length; i++) - { - // reverse the index for reverseStride so that we divide by largest stride first - var nIndex = sourceReverseStride ? sourceStrides.Length - 1 - i : i; - - var sourceStride = sourceStrides[nIndex]; - var transformStride = transformStrides[nIndex]; - - transformIndex += transformStride * (remainder / sourceStride); - remainder %= sourceStride; - } - - return transformIndex; - } - } -} diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/CompressedSparseTensor.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/CompressedSparseTensor.cs deleted file mode 100644 index b41915acddd5ab..00000000000000 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/CompressedSparseTensor.cs +++ /dev/null @@ -1,517 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.Diagnostics; -using System.Linq; - -namespace System.Numerics.Tensors -{ - /// - /// Represents a tensor using compressed sparse format - /// For a two dimensional tensor this is referred to as compressed sparse row (CSR, CRS, Yale), compressed sparse column (CSC, CCS) - /// - /// In this format, data that is in the same value for the compressed dimension has locality - /// - /// In standard layout of a dense tensor, data with the same value for first dimensions has locality. - /// As such we'll use reverseStride = false (default) to mean that the first dimension is compressed (CSR) - /// and reverseStride = true to mean that the last dimension is compressed (CSC) - /// - /// - /// - public class CompressedSparseTensor : Tensor - { - private Memory values; - private readonly Memory compressedCounts; - private Memory indices; - - private int nonZeroCount; - - private readonly int[] nonCompressedStrides; - private readonly int compressedDimension; - - private const int defaultCapacity = 64; - - /// - /// Constructs a new CompressedSparseTensor of the specified dimensions and stride ordering. - /// - /// An span of integers that represent the size of each dimension of the CompressedSparseTensor to create. - /// False (default) to indicate that the first dimension is most major (farthest apart) and the last dimension is most minor (closest together): akin to row-major in a rank-2 tensor. True to indicate that the last dimension is most major (farthest apart) and the first dimension is most minor (closest together): akin to column-major in a rank-2 tensor. - public CompressedSparseTensor(ReadOnlySpan dimensions, bool reverseStride = false) : this(dimensions, defaultCapacity, reverseStride) - { } - - /// - /// Constructs a new CompressedSparseTensor of the specified dimensions, initial capacity, and stride ordering. - /// - /// An span of integers that represent the size of each dimension of the CompressedSparseTensor to create. - /// The number of non-zero values this tensor can store without resizing. - /// False (default) to indicate that the first dimension is most major (farthest apart) and the last dimension is most minor (closest together): akin to row-major in a rank-2 tensor. True to indicate that the last dimension is most major (farthest apart) and the first dimension is most minor (closest together): akin to column-major in a rank-2 tensor. - public CompressedSparseTensor(ReadOnlySpan dimensions, int capacity, bool reverseStride = false) : base(dimensions, reverseStride) - { - nonZeroCount = 0; - compressedDimension = reverseStride ? Rank - 1 : 0; - nonCompressedStrides = (int[])strides.Clone(); - nonCompressedStrides[compressedDimension] = 0; - var compressedDimensionLength = dimensions[compressedDimension]; - compressedCounts = new int[compressedDimensionLength + 1]; - values = new T[capacity]; - indices = new int[capacity]; - } - - /// - /// Constructs a new CompressedSparseTensor of the specified dimensions, wrapping existing backing memory for the contents. - /// Growing this CompressedSparseTensor will re-allocate the backing memory. - /// - /// Memory storing non-zero values to construct this tensor with. - /// Memory storing the counts of non-zero elements at each index of the compressed dimension. - /// Memory storing the linearized index (excluding the compressed dimension) of non-zero elements. - /// The number of valid entries (eg: non-zero values) in and . - /// An span of integers that represent the size of each dimension of the CompressedSparseTensor to create. - /// False (default) to indicate that the first dimension is most major (farthest apart) and the last dimension is most minor (closest together): akin to row-major in a rank-2 tensor. True to indicate that the last dimension is most major (farthest apart) and the first dimension is most minor (closest together): akin to column-major in a rank-2 tensor. - public CompressedSparseTensor(Memory values, Memory compressedCounts, Memory indices, int nonZeroCount, ReadOnlySpan dimensions, bool reverseStride = false) : base(dimensions, reverseStride) - { - compressedDimension = reverseStride ? Rank - 1 : 0; - nonCompressedStrides = (int[])strides.Clone(); - nonCompressedStrides[compressedDimension] = 0; - this.values = values; - this.compressedCounts = compressedCounts; - this.indices = indices; - this.nonZeroCount = nonZeroCount; - } - - internal CompressedSparseTensor(Array fromArray, bool reverseStride = false) : base(fromArray, reverseStride) - { - nonZeroCount = 0; - compressedDimension = reverseStride ? Rank - 1 : 0; - nonCompressedStrides = (int[])strides.Clone(); - nonCompressedStrides[compressedDimension] = 0; - var compressedDimensionLength = dimensions[compressedDimension]; - compressedCounts = new int[compressedDimensionLength + 1]; - - int index = 0; - if (reverseStride) - { - // Array is always row-major - var sourceStrides = ArrayUtilities.GetStrides(dimensions); - - foreach (T item in fromArray) - { - if (!item!.Equals(Zero)) - { - var destIndex = ArrayUtilities.TransformIndexByStrides(index, sourceStrides, false, strides); - var compressedIndex = destIndex / strides[compressedDimension]; - var nonCompressedIndex = destIndex % strides[compressedDimension]; - - SetAt(item, compressedIndex, nonCompressedIndex); - } - - index++; - } - } - else - { - foreach (T item in fromArray) - { - if (!item!.Equals(Zero)) - { - var compressedIndex = index / strides[compressedDimension]; - var nonCompressedIndex = index % strides[compressedDimension]; - - SetAt(item, compressedIndex, nonCompressedIndex); - } - - index++; - } - } - } - - /// - /// Obtains the value at the specified indices - /// - /// A span of integers that represent the indices specifying the position of the element to get. - /// The value at the specified position in this Tensor. - public override T this[ReadOnlySpan indices] - { - get - { - var compressedIndex = indices[compressedDimension]; - var nonCompressedIndex = ArrayUtilities.GetIndex(nonCompressedStrides, indices); - - - if (TryFindIndex(compressedIndex, nonCompressedIndex, out int valueIndex)) - { - return values.Span[valueIndex]; - } - - return Zero; - } - - set - { - var compressedIndex = indices[compressedDimension]; - var nonCompressedIndex = ArrayUtilities.GetIndex(nonCompressedStrides, indices); - - SetAt(value, compressedIndex, nonCompressedIndex); - } - } - - /// - /// Gets the value at the specified index, where index is lineraized as a dot product between indices and strides. - /// - /// An integer index computed as a dot-product of indices. - /// The value at the specified position in this Tensor. - public override T GetValue(int index) - { - var compressedDimensionStride = strides[compressedDimension]; - Debug.Assert(compressedDimensionStride == strides.Max()); - - var compressedIndex = index / compressedDimensionStride; - var nonCompressedIndex = index % compressedDimensionStride; - - - if (TryFindIndex(compressedIndex, nonCompressedIndex, out int valueIndex)) - { - return values.Span[valueIndex]; - } - - return Zero; - } - - /// - /// Sets the value at the specified index, where index is a linearized version of n-dimension indices using strides. - /// - /// An integer index computed as a dot-product of indices. - /// The new value to set at the specified position in this Tensor. - public override void SetValue(int index, T value) - { - var compressedDimensionStride = strides[compressedDimension]; - Debug.Assert(compressedDimensionStride == strides.Max()); - - var compressedIndex = index / compressedDimensionStride; - var nonCompressedIndex = index % compressedDimensionStride; - - SetAt(value, compressedIndex, nonCompressedIndex); - - } - - /// - /// Gets the number of non-zero values this tensor can store without resizing. - /// - public int Capacity => values.Length; - - /// - /// Get's the number on non-zero values currently being stored in this tensor. - /// - public int NonZeroCount => nonZeroCount; - - /// - /// Memory storing non-zero values. - /// - public Memory Values => values; - - /// - /// Memory storing the counts of non-zero elements at each index of the compressed dimension. - /// - public Memory CompressedCounts => compressedCounts; - - /// - /// Memory storing the linearized index (excluding the compressed dimension) of non-zero elements. - /// - public Memory Indices => indices; - - private void EnsureCapacity(int min, int allocateIndex = -1) - { - if (values.Length < min) - { - var newCapacity = values.Length == 0 ? defaultCapacity : values.Length * 2; - - if (newCapacity > Length) - { - newCapacity = (int)Length; - } - - if (newCapacity < min) - { - newCapacity = min; - } - - Memory newValues = new T[newCapacity]; - Memory newIndices = new int[newCapacity]; - - if (nonZeroCount > 0) - { - if (allocateIndex == -1) - { - var valuesSpan = values.Span.Slice(0, nonZeroCount); - var indicesSpan = indices.Span.Slice(0, nonZeroCount); - - valuesSpan.CopyTo(newValues.Span); - indicesSpan.CopyTo(newIndices.Span); - } - else - { - Debug.Assert(allocateIndex <= nonZeroCount); - // leave a gap at allocateIndex - - // copy range before allocateIndex - if (allocateIndex > 0) - { - var valuesSpan = values.Span.Slice(0, allocateIndex); - var indicesSpan = indices.Span.Slice(0, allocateIndex); - - valuesSpan.CopyTo(newValues.Span); - indicesSpan.CopyTo(newIndices.Span); - } - - if (allocateIndex < nonZeroCount) - { - var valuesSpan = values.Span.Slice(allocateIndex, nonZeroCount - allocateIndex); - var indicesSpan = indices.Span.Slice(allocateIndex, nonZeroCount - allocateIndex); - - var newValuesSpan = newValues.Span.Slice(allocateIndex + 1, nonZeroCount - allocateIndex); - var newIndicesSpan = newIndices.Span.Slice(allocateIndex + 1, nonZeroCount - allocateIndex); - - valuesSpan.CopyTo(newValuesSpan); - indicesSpan.CopyTo(newIndicesSpan); - } - } - } - - values = newValues; - indices = newIndices; - } - } - - private void InsertAt(int valueIndex, T value, int compressedIndex, int nonCompressedIndex) - { - Debug.Assert(valueIndex <= nonZeroCount); - Debug.Assert(compressedIndex < compressedCounts.Length - 1); - - if (values.Length <= valueIndex) - { - // allocate a new array, leaving a gap - EnsureCapacity(valueIndex + 1, valueIndex); - } - else if (nonZeroCount != valueIndex) - { - // shift values to make a gap - values.Span.Slice(valueIndex, nonZeroCount - valueIndex).CopyTo(values.Span.Slice(valueIndex + 1)); - indices.Span.Slice(valueIndex, nonZeroCount - valueIndex).CopyTo(indices.Span.Slice(valueIndex + 1)); - } - - values.Span[valueIndex] = value; - indices.Span[valueIndex] = nonCompressedIndex; - - var compressedCountsSpan = compressedCounts.Span.Slice(compressedIndex + 1); - for (int i = 0; i < compressedCountsSpan.Length; i++) - { - compressedCountsSpan[i]++; - } - nonZeroCount++; - } - - private void RemoveAt(int valueIndex, int compressedIndex) - { - Debug.Assert(valueIndex < nonZeroCount); - Debug.Assert(compressedIndex < compressedCounts.Length - 1); - - // shift values to close the gap - values.Span.Slice(valueIndex + 1, nonZeroCount - valueIndex - 1).CopyTo(values.Span.Slice(valueIndex)); - indices.Span.Slice(valueIndex + 1, nonZeroCount - valueIndex - 1).CopyTo(indices.Span.Slice(valueIndex)); - - var compressedCountsSpan = compressedCounts.Span.Slice(compressedIndex + 1); - for (int i = 0; i < compressedCountsSpan.Length; i++) - { - compressedCountsSpan[i]--; - } - nonZeroCount--; - } - - private void SetAt(T value, int compressedIndex, int nonCompressedIndex) - { - bool isZero = value!.Equals(Zero); - - if (TryFindIndex(compressedIndex, nonCompressedIndex, out int valueIndex)) - { - if (isZero) - { - RemoveAt(valueIndex, compressedIndex); - } - else - { - values.Span[valueIndex] = value; - indices.Span[valueIndex] = nonCompressedIndex; - } - } - else if (!isZero) - { - InsertAt(valueIndex, value, compressedIndex, nonCompressedIndex); - } - } - - /// - /// Trys to find the place to store a value - /// - /// - /// - /// - /// True if element is found at specific index, false if no specific index is found and insertion point is returned - private bool TryFindIndex(int compressedIndex, int nonCompressedIndex, out int valueIndex) - { - if (nonZeroCount == 0) - { - valueIndex = 0; - return false; - } - - Debug.Assert(compressedIndex < compressedCounts.Length - 1); - - var compressedCountsSpan = compressedCounts.Span; - var lowerValueIndex = compressedCountsSpan[compressedIndex]; - var upperValueIndex = compressedCountsSpan[compressedIndex + 1]; - var indicesSpan = indices.Span; - - // could be a faster search - for (valueIndex = lowerValueIndex; valueIndex < upperValueIndex; valueIndex++) - { - if (indicesSpan[valueIndex] == nonCompressedIndex) - { - return true; - } - } - - return false; - } - - /// - /// Creates a shallow copy of this tensor, with new backing storage. - /// - /// A shallow copy of this tensor. - public override Tensor Clone() - { - return new CompressedSparseTensor(values.ToArray(), compressedCounts.ToArray(), indices.ToArray(), nonZeroCount, dimensions, IsReversedStride); - } - - /// - /// Creates a new Tensor of a different type with the specified dimensions and the same layout as this tensor with elements initialized to their default value. - /// - /// Type contained in the returned Tensor. - /// An span of integers that represent the size of each dimension of the CompressedSparseTensor to create. - /// A new tensor with the same layout as this tensor but different type and dimensions. - public override Tensor CloneEmpty(ReadOnlySpan dimensions) - { - return new CompressedSparseTensor(dimensions, IsReversedStride); - } - - /// - /// Reshapes the current tensor to new dimensions. Unlike other Tensor implementations, CompressedSparseTensor<T> must allocate new backing storage to represent a reshaped Tensor. - /// - /// An span of integers that represent the size of each dimension of the CompressedSparseTensor to create. - /// A new tensor that reinterprets the content of this tensor to new dimensions (assuming the same linear index for each element). - public override Tensor Reshape(ReadOnlySpan dimensions) - { - // reshape currently has shallow semantics which are not compatible with the backing storage for CompressedSparseTensor - // which bakes in information about dimensions (compressedCounts and indices) - - var newCompressedDimension = IsReversedStride ? dimensions.Length - 1 : 0; - var newCompressedDimensionLength = dimensions[newCompressedDimension]; - var newCompressedDimensionStride = (int)(Length / newCompressedDimensionLength); - - var newValues = (T[])values.ToArray(); - var newCompressedCounts = new int[newCompressedDimensionLength + 1]; - var newIndices = new int[indices.Length]; - - var compressedIndex = 0; - - var compressedCountsSpan = compressedCounts.Span; - var indicesSpan = indices.Span.Slice(0, nonZeroCount); - for (int valueIndex = 0; valueIndex < indicesSpan.Length; valueIndex++) - { - while (valueIndex >= compressedCountsSpan[compressedIndex + 1]) - { - compressedIndex++; - Debug.Assert(compressedIndex < compressedCounts.Length); - } - - var currentIndex = indicesSpan[valueIndex] + compressedIndex * strides[compressedDimension]; - - newIndices[valueIndex] = currentIndex % newCompressedDimensionStride; - - var newCompressedIndex = currentIndex / newCompressedDimensionStride; - newCompressedCounts[newCompressedIndex + 1] = valueIndex + 1; - } - - return new CompressedSparseTensor(newValues, newCompressedCounts, newIndices, nonZeroCount, dimensions, IsReversedStride); - } - - /// - /// Creates a copy of this tensor as a DenseTensor<T>. - /// - /// A copy of this tensor as a DenseTensor<T> - public override DenseTensor ToDenseTensor() - { - var denseTensor = new DenseTensor(Dimensions, reverseStride: IsReversedStride); - - var compressedIndex = 0; - - var compressedCountsSpan = compressedCounts.Span; - var indicesSpan = indices.Span.Slice(0, nonZeroCount); - var valuesSpan = values.Span.Slice(0, nonZeroCount); - for (int valueIndex = 0; valueIndex < valuesSpan.Length; valueIndex++) - { - while (valueIndex >= compressedCountsSpan[compressedIndex + 1]) - { - compressedIndex++; - Debug.Assert(compressedIndex < compressedCounts.Length); - } - - var index = indicesSpan[valueIndex] + compressedIndex * strides[compressedDimension]; - - denseTensor.SetValue(index, valuesSpan[valueIndex]); - } - - return denseTensor; - } - - /// - /// Creates a copy of this tensor as a new CompressedSparseTensor<T> eliminating any unused space in the backing storage. - /// - /// A copy of this tensor as a CompressedSparseTensor<T>. - public override CompressedSparseTensor ToCompressedSparseTensor() - { - // Create a copy of the backing storage, eliminating any unused space. - var newValues = values.Slice(0, nonZeroCount).ToArray(); - var newIndices = indices.Slice(0, nonZeroCount).ToArray(); - - return new CompressedSparseTensor(newValues, compressedCounts.ToArray(), newIndices, nonZeroCount, dimensions, IsReversedStride); - } - - /// - /// Creates a copy of this tensor as a SparseTensor<T>. - /// - /// A copy of this tensor as a SparseTensor<T>. - public override SparseTensor ToSparseTensor() - { - var sparseTensor = new SparseTensor(dimensions, capacity: NonZeroCount, reverseStride: IsReversedStride); - - var compressedIndex = 0; - - var compressedCountsSpan = compressedCounts.Span; - var indicesSpan = indices.Span.Slice(0, nonZeroCount); - var valuesSpan = values.Span.Slice(0, nonZeroCount); - for (int valueIndex = 0; valueIndex < valuesSpan.Length; valueIndex++) - { - while (valueIndex >= compressedCountsSpan[compressedIndex + 1]) - { - compressedIndex++; - Debug.Assert(compressedIndex < compressedCounts.Length); - } - - var index = indicesSpan[valueIndex] + compressedIndex * strides[compressedDimension]; - - sparseTensor.SetValue(index, valuesSpan[valueIndex]); - } - - return sparseTensor; - } - } -} diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/DenseTensor.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/DenseTensor.cs deleted file mode 100644 index 5f8715be3d18c8..00000000000000 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/DenseTensor.cs +++ /dev/null @@ -1,178 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.Runtime.InteropServices; - -namespace System.Numerics.Tensors -{ - /// - /// Represents a multi-dimensional collection of objects of type T that can be accessed by indices. DenseTensor stores values in a contiguous sequential block of memory where all values are represented. - /// - /// type contained within the Tensor. Typically a value type such as int, double, float, etc. - public class DenseTensor : Tensor - { - private readonly Memory memory; - - internal DenseTensor(Array fromArray, bool reverseStride = false) : base(fromArray, reverseStride) - { - // copy initial array - var backingArray = new T[fromArray.Length]; - - int index = 0; - if (reverseStride) - { - // Array is always row-major - var sourceStrides = ArrayUtilities.GetStrides(dimensions); - - foreach (var item in fromArray) - { - var destIndex = ArrayUtilities.TransformIndexByStrides(index++, sourceStrides, false, strides); - backingArray[destIndex] = (T)item!; - } - } - else - { - foreach (var item in fromArray) - { - backingArray[index++] = (T)item!; - } - } - memory = backingArray; - } - - /// - /// Initializes a rank-1 Tensor using the specified . - /// - /// Size of the 1-dimensional tensor - public DenseTensor(int length) : base(length) - { - memory = new T[length]; - } - - /// - /// Initializes a rank-n Tensor using the dimensions specified in . - /// - /// An span of integers that represent the size of each dimension of the DenseTensor to create. - /// False (default) to indicate that the first dimension is most major (farthest apart) and the last dimension is most minor (closest together): akin to row-major in a rank-2 tensor. True to indicate that the last dimension is most major (farthest apart) and the first dimension is most minor (closest together): akin to column-major in a rank-2 tensor. - public DenseTensor(ReadOnlySpan dimensions, bool reverseStride = false) : base(dimensions, reverseStride) - { - memory = new T[Length]; - } - - /// - /// Constructs a new DenseTensor of the specified dimensions, wrapping existing backing memory for the contents. - /// - /// - /// An span of integers that represent the size of each dimension of the DenseTensor to create. - /// False (default) to indicate that the first dimension is most major (farthest apart) and the last dimension is most minor (closest together): akin to row-major in a rank-2 tensor. True to indicate that the last dimension is most major (farthest apart) and the first dimension is most minor (closest together): akin to column-major in a rank-2 tensor. - public DenseTensor(Memory memory, ReadOnlySpan dimensions, bool reverseStride = false) : base(dimensions, reverseStride) - { - this.memory = memory; - - if (Length != memory.Length) - { - throw new ArgumentException(SR.Format(SR.LengthMustMatch, nameof(memory), memory.Length, nameof(dimensions), Length)); - } - } - - /// - /// Memory storing backing values of this tensor. - /// - public Memory Buffer => memory; - - /// - /// Gets the value at the specified index, where index is a linearized version of n-dimension indices using strides. - /// - /// An integer index computed as a dot-product of indices. - /// The value at the specified position in this Tensor. - public override T GetValue(int index) - { - return Buffer.Span[index]; - } - - /// - /// Sets the value at the specified index, where index is a linearized version of n-dimension indices using strides. - /// - /// An integer index computed as a dot-product of indices. - /// The new value to set at the specified position in this Tensor. - public override void SetValue(int index, T value) - { - Buffer.Span[index] = value; - } - - protected override void CopyTo(T[] array, int arrayIndex) - { - if (array is null) - { - throw new ArgumentNullException(nameof(array)); - } - - if (array.Length < arrayIndex + Length) - { - throw new ArgumentException(SR.NumberGreaterThenAvailableSpace, nameof(array)); - } - - Buffer.Span.CopyTo(array.AsSpan(arrayIndex)); - } - - protected override int IndexOf(T item) - { - // TODO: use Span.IndexOf when/if it removes the IEquatable type constraint - if (MemoryMarshal.TryGetArray(Buffer, out var arraySegment)) - { - var result = Array.IndexOf(arraySegment.Array!, item, arraySegment.Offset, arraySegment.Count); - if (result != -1) - { - result -= arraySegment.Offset; - } - return result; - } - else - { - return base.IndexOf(item); - } - } - - /// - /// Creates a shallow copy of this tensor, with new backing storage. - /// - /// A shallow copy of this tensor. - public override Tensor Clone() - { - return new DenseTensor(Buffer.ToArray(), dimensions, IsReversedStride); - } - - /// - /// Creates a new Tensor of a different type with the specified dimensions and the same layout as this tensor with elements initialized to their default value. - /// - /// Type contained in the returned Tensor. - /// An span of integers that represent the size of each dimension of the DenseTensor to create. - /// A new tensor with the same layout as this tensor but different type and dimensions. - public override Tensor CloneEmpty(ReadOnlySpan dimensions) - { - return new DenseTensor(dimensions, IsReversedStride); - } - - /// - /// Reshapes the current tensor to new dimensions, using the same backing storage. - /// - /// An span of integers that represent the size of each dimension of the DenseTensor to create. - /// A new tensor that reinterprets backing Buffer of this tensor with different dimensions. - public override Tensor Reshape(ReadOnlySpan dimensions) - { - if (dimensions.Length == 0) - { - throw new ArgumentException(SR.DimensionsMustContainElements, nameof(dimensions)); - } - - var newSize = ArrayUtilities.GetProduct(dimensions); - - if (newSize != Length) - { - throw new ArgumentException(SR.Format(SR.CannotReshapeArrayDueToMismatchInLengths, Length, newSize), nameof(dimensions)); - } - - return new DenseTensor(Buffer, dimensions, IsReversedStride); - } - } -} diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/SparseTensor.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/SparseTensor.cs deleted file mode 100644 index 83948a0b918b7d..00000000000000 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/SparseTensor.cs +++ /dev/null @@ -1,177 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.Collections.Generic; - -namespace System.Numerics.Tensors -{ - /// - /// Represents a multi-dimensional collection of objects of type T that can be accessed by indices. Unlike other Tensor<T> implementations SparseTensor<T> does not expose its backing storage. It is meant as an intermediate to be used to build other Tensors, such as CompressedSparseTensor. Unlike CompressedSparseTensor where insertions are O(n), insertions to SparseTensor<T> are nominally O(1). - /// - /// type contained within the Tensor. Typically a value type such as int, double, float, etc. - public class SparseTensor : Tensor - { - private readonly Dictionary values; - /// - /// Constructs a new SparseTensor of the specified dimensions, initial capacity, and stride ordering. - /// - /// An span of integers that represent the size of each dimension of the SparseTensor to create. - /// False (default) to indicate that the first dimension is most major (farthest apart) and the last dimension is most minor (closest together): akin to row-major in a rank-2 tensor. True to indicate that the last dimension is most major (farthest apart) and the first dimension is most minor (closest together): akin to column-major in a rank-2 tensor. - /// The number of non-zero values this tensor can store without resizing. - public SparseTensor(ReadOnlySpan dimensions, bool reverseStride = false, int capacity = 0) : base(dimensions, reverseStride) - { - values = new Dictionary(capacity); - } - - internal SparseTensor(Dictionary values, ReadOnlySpan dimensions, bool reverseStride = false) : base(dimensions, reverseStride) - { - this.values = values; - } - - internal SparseTensor(Array fromArray, bool reverseStride = false) : base(fromArray, reverseStride) - { - values = new Dictionary(fromArray.Length); - - int index = 0; - if (reverseStride) - { - // Array is always row-major - var sourceStrides = ArrayUtilities.GetStrides(dimensions); - - foreach (T item in fromArray) - { - if (!item!.Equals(Zero)) - { - var destIndex = ArrayUtilities.TransformIndexByStrides(index, sourceStrides, false, strides); - values[destIndex] = item; - } - - index++; - } - } - else - { - foreach (T item in fromArray) - { - if (!item!.Equals(Zero)) - { - values[index] = item; - } - - index++; - } - } - } - - /// - /// Gets the value at the specified index, where index is a linearized version of n-dimension indices using strides. - /// - /// An integer index computed as a dot-product of indices. - /// The value at the specified position in this Tensor. - public override T GetValue(int index) - { - - if (!values.TryGetValue(index, out T? value)) - { - value = Zero; - } - - return value; - } - - /// - /// Sets the value at the specified index, where index is a linearized version of n-dimension indices using strides. - /// - /// An integer index computed as a dot-product of indices. - /// The new value to set at the specified position in this Tensor. - public override void SetValue(int index, T value) - { - if (value!.Equals(Zero)) - { - values.Remove(index); - } - else - { - values[index] = value; - } - } - - /// - /// Get's the number on non-zero values currently being stored in this tensor. - /// - public int NonZeroCount => values.Count; - - /// - /// Creates a shallow copy of this tensor, with new backing storage. - /// - /// A shallow copy of this tensor. - public override Tensor Clone() - { - var valueCopy = new Dictionary(values); - return new SparseTensor(valueCopy, dimensions, IsReversedStride); - } - - /// - /// Creates a new Tensor of a different type with the specified dimensions and the same layout as this tensor with elements initialized to their default value. - /// - /// Type contained in the returned Tensor. - /// An span of integers that represent the size of each dimension of the SparseTensor to create. - /// A new tensor with the same layout as this tensor but different type and dimensions. - public override Tensor CloneEmpty(ReadOnlySpan dimensions) - { - return new SparseTensor(dimensions, IsReversedStride); - } - - /// - /// Reshapes the current tensor to new dimensions, using the same backing storage. - /// - /// An span of integers that represent the size of each dimension of the SparseTensor to create. - /// A new tensor that reinterprets backing storage of this tensor with different dimensions. - public override Tensor Reshape(ReadOnlySpan dimensions) - { - return new SparseTensor(values, dimensions, IsReversedStride); - } - - /// - /// Creates a copy of this tensor as a DenseTensor<T>. - /// - /// A copy of this tensor as a DenseTensor<T> - public override DenseTensor ToDenseTensor() - { - var denseTensor = new DenseTensor(Dimensions, reverseStride: IsReversedStride); - - // only set non-zero values - foreach (var pair in values) - { - denseTensor.SetValue(pair.Key, pair.Value); - } - - return denseTensor; - } - - /// - /// Creates a copy of this tensor as a new SparseTensor<T> eliminating any unused space in the backing storage. - /// - /// A copy of this tensor as a SparseTensor<T> eliminated any usused space in the backing storage. - public override SparseTensor ToSparseTensor() - { - var valueCopy = new Dictionary(values); - return new SparseTensor(valueCopy, dimensions, IsReversedStride); - } - - /// - /// Creates a copy of this tensor as a CompressedSparseTensor<T>. - /// - /// A copy of this tensor as a CompressedSparseTensor<T>. - public override CompressedSparseTensor ToCompressedSparseTensor() - { - var compressedSparseTensor = new CompressedSparseTensor(dimensions, capacity: NonZeroCount, reverseStride: IsReversedStride); - - foreach (var pair in values) - { - compressedSparseTensor.SetValue(pair.Key, pair.Value); - } - return compressedSparseTensor; - } - } -} diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/Tensor.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/Tensor.cs deleted file mode 100644 index f63547682dd24b..00000000000000 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/Tensor.cs +++ /dev/null @@ -1,1365 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.Collections; -using System.Collections.Generic; -using System.Diagnostics; -using System.Text; - -namespace System.Numerics.Tensors -{ - /// - /// Various methods for creating and manipulating Tensor<T> - /// - public static partial class Tensor - { - /// - /// Creates an identity tensor of the specified size. An identity tensor is a two dimensional tensor with 1s in the diagonal. - /// - /// type contained within the Tensor. Typically a value type such as int, double, float, etc. - /// Width and height of the identity tensor to create. - /// a by with 1s along the diagonal and zeros elsewhere. - public static Tensor CreateIdentity(int size) - { - return CreateIdentity(size, false, Tensor.One); - } - - /// - /// Creates an identity tensor of the specified size and layout (row vs column major). An identity tensor is a two dimensional tensor with 1s in the diagonal. - /// - /// type contained within the Tensor. Typically a value type such as int, double, float, etc. - /// Width and height of the identity tensor to create. - /// >False to indicate that the first dimension is most minor (closest) and the last dimension is most major (farthest): row-major. True to indicate that the last dimension is most minor (closest together) and the first dimension is most major (farthest apart): column-major. - /// a by with 1s along the diagonal and zeros elsewhere. - public static Tensor CreateIdentity(int size, bool columMajor) - { - return CreateIdentity(size, columMajor, Tensor.One); - } - - /// - /// Creates an identity tensor of the specified size and layout (row vs column major) using the specified one value. An identity tensor is a two dimensional tensor with 1s in the diagonal. This may be used in case T is a type that doesn't have a known 1 value. - /// - /// type contained within the Tensor. Typically a value type such as int, double, float, etc. - /// Width and height of the identity tensor to create. - /// >False to indicate that the first dimension is most minor (closest) and the last dimension is most major (farthest): row-major. True to indicate that the last dimension is most minor (closest together) and the first dimension is most major (farthest apart): column-major. - /// Value of that is used along the diagonal. - /// a by with 1s along the diagonal and zeros elsewhere. - public static Tensor CreateIdentity(int size, bool columMajor, T oneValue) - { - Span dimensions = stackalloc int[2]; - dimensions[0] = dimensions[1] = size; - - var result = new DenseTensor(dimensions, columMajor); - - for (int i = 0; i < size; i++) - { - result.SetValue(i * size + i, oneValue); - } - - return result; - } - - /// - /// Creates a n+1-rank tensor using the specified n-rank diagonal. Values not on the diagonal will be filled with zeros. - /// - /// type contained within the Tensor. Typically a value type such as int, double, float, etc. - /// Tensor representing the diagonal to build the new tensor from. - /// A new tensor of the same layout and order as of one higher rank, with the values of along the diagonal and zeros elsewhere. - public static Tensor CreateFromDiagonal(Tensor diagonal) - { - return CreateFromDiagonal(diagonal, 0); - } - - /// - /// Creates a n+1-dimension tensor using the specified n-dimension diagonal at the specified offset from the center. Values not on the diagonal will be filled with zeros. - /// - /// type contained within the Tensor. Typically a value type such as int, double, float, etc. - /// Tensor representing the diagonal to build the new tensor from. - /// Offset of diagonal to set in returned tensor. 0 for the main diagonal, less than zero for diagonals below, greater than zero from diagonals above. - /// A new tensor of the same layout and order as of one higher rank, with the values of along the specified diagonal and zeros elsewhere. - public static Tensor CreateFromDiagonal(Tensor diagonal, int offset) - { - if (diagonal.Rank < 1) - { - throw new ArgumentException(SR.Format(SR.MustHaveAtLeastOneDimension, nameof(diagonal)), nameof(diagonal)); - } - - int diagonalLength = diagonal.dimensions[0]; - - // TODO: allow specification of axis1 and axis2? - var rank = diagonal.dimensions.Length + 1; - Span dimensions = rank < ArrayUtilities.StackallocMax ? stackalloc int[rank] : new int[rank]; - - // assume square - var axisLength = diagonalLength + Math.Abs(offset); - dimensions[0] = dimensions[1] = axisLength; - - for (int i = 1; i < diagonal.dimensions.Length; i++) - { - dimensions[i + 1] = diagonal.dimensions[i]; - } - - var result = diagonal.CloneEmpty(dimensions); - - var sizePerDiagonal = diagonal.Length / diagonalLength; - - var diagProjectionStride = diagonal.IsReversedStride && diagonal.Rank > 1 ? diagonal.strides[1] : 1; - var resultProjectionStride = result.IsReversedStride && result.Rank > 2 ? result.strides[2] : 1; - - for (int diagIndex = 0; diagIndex < diagonalLength; diagIndex++) - { - var resultIndex0 = offset < 0 ? diagIndex - offset : diagIndex; - var resultIndex1 = offset > 0 ? diagIndex + offset : diagIndex; - - var resultBase = resultIndex0 * result.strides[0] + resultIndex1 * result.strides[1]; - var diagBase = diagIndex * diagonal.strides[0]; - - for (int diagProjectionOffset = 0; diagProjectionOffset < sizePerDiagonal; diagProjectionOffset++) - { - result.SetValue(resultBase + diagProjectionOffset * resultProjectionStride, - diagonal.GetValue(diagBase + diagProjectionOffset * diagProjectionStride)); - } - } - - return result; - } - } - - /// - /// Represents a multi-dimensional collection of objects of type T that can be accessed by indices. - /// - /// type contained within the Tensor. Typically a value type such as int, double, float, etc. - [DebuggerDisplay("{GetArrayString(false)}")] - // When we cross-compile for frameworks that expose ICloneable this must implement ICloneable as well. - public abstract class Tensor : IList, IList, IReadOnlyList, IStructuralComparable, IStructuralEquatable - { - internal static T Zero - { - get - { - if (typeof(T) == typeof(bool)) - { - return (T)(object)(false); - } - else if (typeof(T) == typeof(byte)) - { - return (T)(object)(byte)(0); - } - else if (typeof(T) == typeof(char)) - { - return (T)(object)(char)(0); - } - else if (typeof(T) == typeof(decimal)) - { - return (T)(object)(decimal)(0); - } - else if (typeof(T) == typeof(double)) - { - return (T)(object)(double)(0); - } - else if (typeof(T) == typeof(float)) - { - return (T)(object)(float)(0); - } - else if (typeof(T) == typeof(int)) - { - return (T)(object)(int)(0); - } - else if (typeof(T) == typeof(long)) - { - return (T)(object)(long)(0); - } - else if (typeof(T) == typeof(sbyte)) - { - return (T)(object)(sbyte)(0); - } - else if (typeof(T) == typeof(short)) - { - return (T)(object)(short)(0); - } - else if (typeof(T) == typeof(uint)) - { - return (T)(object)(uint)(0); - } - else if (typeof(T) == typeof(ulong)) - { - return (T)(object)(ulong)(0); - } - else if (typeof(T) == typeof(ushort)) - { - return (T)(object)(ushort)(0); - } - - throw new NotSupportedException(); - } - } - - internal static T One - { - get - { - if (typeof(T) == typeof(bool)) - { - return (T)(object)(true); - } - else if (typeof(T) == typeof(byte)) - { - return (T)(object)(byte)(1); - } - else if (typeof(T) == typeof(char)) - { - return (T)(object)(char)(1); - } - else if (typeof(T) == typeof(decimal)) - { - return (T)(object)(decimal)(1); - } - else if (typeof(T) == typeof(double)) - { - return (T)(object)(double)(1); - } - else if (typeof(T) == typeof(float)) - { - return (T)(object)(float)(1); - } - else if (typeof(T) == typeof(int)) - { - return (T)(object)(int)(1); - } - else if (typeof(T) == typeof(long)) - { - return (T)(object)(long)(1); - } - else if (typeof(T) == typeof(sbyte)) - { - return (T)(object)(sbyte)(1); - } - else if (typeof(T) == typeof(short)) - { - return (T)(object)(short)(1); - } - else if (typeof(T) == typeof(uint)) - { - return (T)(object)(uint)(1); - } - else if (typeof(T) == typeof(ulong)) - { - return (T)(object)(ulong)(1); - } - else if (typeof(T) == typeof(ushort)) - { - return (T)(object)(ushort)(1); - } - - throw new NotSupportedException(); - } - } - - internal readonly int[] dimensions; - internal readonly int[] strides; - private readonly bool isReversedStride; - - private readonly long length; - - /// - /// Initialize a 1-dimensional tensor of the specified length - /// - /// Size of the 1-dimensional tensor - protected Tensor(int length) - { - dimensions = new[] { length }; - strides = new[] { 1 }; - isReversedStride = false; - this.length = length; - } - - /// - /// Initialize an n-dimensional tensor with the specified dimensions and layout. ReverseStride=true gives a stride of 1-element witdth to the first dimension (0). ReverseStride=false gives a stride of 1-element width to the last dimension (n-1). - /// - /// An span of integers that represent the size of each dimension of the Tensor to create. - /// False (default) to indicate that the first dimension is most major (farthest apart) and the last dimension is most minor (closest together): akin to row-major in a rank-2 tensor. True to indicate that the last dimension is most major (farthest apart) and the first dimension is most minor (closest together): akin to column-major in a rank-2 tensor. - protected Tensor(ReadOnlySpan dimensions, bool reverseStride) - { - if (dimensions.Length == 0) - { - throw new ArgumentException(SR.DimensionsMustContainElements, nameof(dimensions)); - } - - this.dimensions = new int[dimensions.Length]; - long size = 1; - for (int i = 0; i < dimensions.Length; i++) - { - if (dimensions[i] < 1) - { - throw new ArgumentOutOfRangeException(nameof(dimensions), SR.DimensionsMustBePositiveAndNonZero); - } - this.dimensions[i] = dimensions[i]; - size *= dimensions[i]; - } - - strides = ArrayUtilities.GetStrides(dimensions, reverseStride); - isReversedStride = reverseStride; - - length = size; - } - - /// - /// Initializes tensor with same dimensions as array, content of array is ignored. ReverseStride=true gives a stride of 1-element witdth to the first dimension (0). ReverseStride=false gives a stride of 1-element width to the last dimension (n-1). - /// - /// Array from which to derive dimensions. - /// False (default) to indicate that the first dimension is most major (farthest apart) and the last dimension is most minor (closest together): akin to row-major in a rank-2 tensor. True to indicate that the last dimension is most major (farthest apart) and the first dimension is most minor (closest together): akin to column-major in a rank-2 tensor. - protected Tensor(Array fromArray, bool reverseStride) - { - if (fromArray is null) - { - throw new ArgumentNullException(nameof(fromArray)); - } - - if (fromArray.Rank == 0) - { - throw new ArgumentException(SR.ArrayMustContainElements, nameof(fromArray)); - } - - dimensions = new int[fromArray.Rank]; - long size = 1; - for (int i = 0; i < dimensions.Length; i++) - { - dimensions[i] = fromArray.GetLength(i); - size *= dimensions[i]; - } - - strides = ArrayUtilities.GetStrides(dimensions, reverseStride); - isReversedStride = reverseStride; - - length = size; - } - - /// - /// Total length of the Tensor. - /// - public long Length => length; - - /// - /// Rank of the tensor: number of dimensions. - /// - public int Rank => dimensions.Length; - - /// - /// True if strides are reversed (AKA Column-major) - /// - public bool IsReversedStride => isReversedStride; - - /// - /// Returns a readonly view of the dimensions of this tensor. - /// - public ReadOnlySpan Dimensions => dimensions; - - /// - /// Returns a readonly view of the strides of this tensor. - /// - public ReadOnlySpan Strides => strides; - - /// - /// Sets all elements in Tensor to . - /// - /// Value to fill - public virtual void Fill(T value) - { - for (int i = 0; i < Length; i++) - { - SetValue(i, value); - } - } - - /// - /// Creates a shallow copy of this tensor, with new backing storage. - /// - /// A shallow copy of this tensor. - public abstract Tensor Clone(); - - /// - /// Creates a new Tensor with the same layout and dimensions as this tensor with elements initialized to their default value. - /// - /// A new Tensor with the same layout and dimensions as this tensor with elements initialized to their default value. - public virtual Tensor CloneEmpty() - { - return CloneEmpty(dimensions); - } - - /// - /// Creates a new Tensor with the specified dimensions and the same layout as this tensor with elements initialized to their default value. - /// - /// An span of integers that represent the size of each dimension of the DenseTensor to create. - /// A new Tensor with the same layout as this tensor and specified with elements initialized to their default value. - public virtual Tensor CloneEmpty(ReadOnlySpan dimensions) - { - return CloneEmpty(dimensions); - } - - /// - /// Creates a new Tensor of a different type with the same layout and size as this tensor with elements initialized to their default value. - /// - /// Type contained within the new Tensor. Typically a value type such as int, double, float, etc. - /// A new Tensor with the same layout and dimensions as this tensor with elements of type initialized to their default value. - public virtual Tensor CloneEmpty() - { - return CloneEmpty(dimensions); - } - - /// - /// Creates a new Tensor of a different type with the specified dimensions and the same layout as this tensor with elements initialized to their default value. - /// - /// Type contained within the new Tensor. Typically a value type such as int, double, float, etc. - /// An span of integers that represent the size of each dimension of the DenseTensor to create. - /// A new Tensor with the same layout as this tensor of specified with elements of type initialized to their default value. - public abstract Tensor CloneEmpty(ReadOnlySpan dimensions); - - /// - /// Gets the n-1 dimension diagonal from the n dimension tensor. - /// - /// An n-1 dimension tensor with the values from the main diagonal of this tensor. - public Tensor GetDiagonal() - { - return GetDiagonal(0); - } - - /// - /// Gets the n-1 dimension diagonal from the n dimension tensor at the specified offset from center. - /// - /// Offset of diagonal to set in returned tensor. 0 for the main diagonal, less than zero for diagonals below, greater than zero from diagonals above. - /// An n-1 dimension tensor with the values from the specified diagonal of this tensor. - public Tensor GetDiagonal(int offset) - { - // Get diagonal of first two dimensions for all remaining dimensions - - // diagnonal is as follows: - // { 1, 2, 4 } - // { 8, 3, 9 } - // { 0, 7, 5 } - // The diagonal at offset 0 is { 1, 3, 5 } - // The diagonal at offset 1 is { 2, 9 } - // The diagonal at offset -1 is { 8, 7 } - - if (Rank < 2) - { - throw new InvalidOperationException(SR.Format(SR.CannotComputeDiagonal, nameof(Tensor))); - } - - // TODO: allow specification of axis1 and axis2? - var axisLength0 = dimensions[0]; - var axisLength1 = dimensions[1]; - - // the diagonal will be the length of the smaller axis - // if offset it positive, the length will shift along the second axis - // if the offset is negative, the length will shift along the first axis - // In that way the length of the diagonal will be - // Min(offset < 0 ? axisLength0 + offset : axisLength0, offset > 0 ? axisLength1 - offset : axisLength1) - // To illustrate, consider the following - // { 1, 2, 4, 3, 7 } - // { 8, 3, 9, 2, 6 } - // { 0, 7, 5, 2, 9 } - // The diagonal at offset 0 is { 1, 3, 5 }, Min(3, 5) = 3 - // The diagonal at offset 1 is { 2, 9, 2 }, Min(3, 5 - 1) = 3 - // The diagonal at offset 3 is { 3, 6 }, Min(3, 5 - 3) = 2 - // The diagonal at offset -1 is { 8, 7 }, Min(3 - 1, 5) = 2 - var offsetAxisLength0 = offset < 0 ? axisLength0 + offset : axisLength0; - var offsetAxisLength1 = offset > 0 ? axisLength1 - offset : axisLength1; - - var diagonalLength = Math.Min(offsetAxisLength0, offsetAxisLength1); - - if (diagonalLength <= 0) - { - throw new ArgumentException(SR.Format(SR.CannotComputeDiagonalWithOffset, offset), nameof(offset)); - } - - var newTensorRank = Rank - 1; - var newTensorDimensions = newTensorRank < ArrayUtilities.StackallocMax ? stackalloc int[newTensorRank] : new int[newTensorRank]; - newTensorDimensions[0] = diagonalLength; - - for (int i = 2; i < dimensions.Length; i++) - { - newTensorDimensions[i - 1] = dimensions[i]; - } - - var diagonalTensor = CloneEmpty(newTensorDimensions); - var sizePerDiagonal = diagonalTensor.Length / diagonalTensor.Dimensions[0]; - - var diagProjectionStride = diagonalTensor.IsReversedStride && diagonalTensor.Rank > 1 ? diagonalTensor.strides[1] : 1; - var sourceProjectionStride = IsReversedStride && Rank > 2 ? strides[2] : 1; - - for (int diagIndex = 0; diagIndex < diagonalLength; diagIndex++) - { - var sourceIndex0 = offset < 0 ? diagIndex - offset : diagIndex; - var sourceIndex1 = offset > 0 ? diagIndex + offset : diagIndex; - - var sourceBase = sourceIndex0 * strides[0] + sourceIndex1 * strides[1]; - var diagBase = diagIndex * diagonalTensor.strides[0]; - - for (int diagProjectionIndex = 0; diagProjectionIndex < sizePerDiagonal; diagProjectionIndex++) - { - diagonalTensor.SetValue(diagBase + diagProjectionIndex * diagProjectionStride, - GetValue(sourceBase + diagProjectionIndex * sourceProjectionStride)); - } - } - - return diagonalTensor; - } - - /// - /// Gets a tensor representing the elements below and including the diagonal, with the rest of the elements zero-ed. - /// - /// A tensor with the values from this tensor at and below the main diagonal and zeros elsewhere. - public Tensor GetTriangle() - { - return GetTriangle(0, upper: false); - } - - /// - /// Gets a tensor representing the elements below and including the specified diagonal, with the rest of the elements zero-ed. - /// - /// Offset of diagonal to set in returned tensor. 0 for the main diagonal, less than zero for diagonals below, greater than zero from diagonals above. - /// A tensor with the values from this tensor at and below the specified diagonal and zeros elsewhere. - public Tensor GetTriangle(int offset) - { - return GetTriangle(offset, upper: false); - } - - /// - /// Gets a tensor representing the elements above and including the diagonal, with the rest of the elements zero-ed. - /// - /// A tensor with the values from this tensor at and above the main diagonal and zeros elsewhere. - public Tensor GetUpperTriangle() - { - return GetTriangle(0, upper: true); - } - - /// - /// Gets a tensor representing the elements above and including the specified diagonal, with the rest of the elements zero-ed. - /// - /// Offset of diagonal to set in returned tensor. 0 for the main diagonal, less than zero for diagonals below, greater than zero from diagonals above. - /// A tensor with the values from this tensor at and above the specified diagonal and zeros elsewhere. - public Tensor GetUpperTriangle(int offset) - { - return GetTriangle(offset, upper: true); - } - - private Tensor GetTriangle(int offset, bool upper) - { - if (Rank < 2) - { - throw new InvalidOperationException(SR.Format(SR.CannotComputeTriangle, nameof(Tensor))); - } - - // Similar to get diagonal except it gets every element below and including the diagonal. - - // TODO: allow specification of axis1 and axis2? - var axisLength0 = dimensions[0]; - var axisLength1 = dimensions[1]; - var diagonalLength = Math.Max(axisLength0, axisLength1); - - var result = CloneEmpty(); - - var projectionSize = Length / (axisLength0 * axisLength1); - var projectionStride = IsReversedStride && Rank > 2 ? strides[2] : 1; - - for (int diagIndex = 0; diagIndex < diagonalLength; diagIndex++) - { - // starting point for the tri - var triIndex0 = offset > 0 ? diagIndex - offset : diagIndex; - var triIndex1 = offset > 0 ? diagIndex : diagIndex + offset; - - // for lower triangle, iterate index0 keeping same index1 - // for upper triangle, iterate index1 keeping same index0 - - if (triIndex0 < 0) - { - if (upper) - { - // out of bounds, ignore this diagIndex. - continue; - } - else - { - // set index to 0 so that we can iterate on the remaining index0 values. - triIndex0 = 0; - } - } - - if (triIndex1 < 0) - { - if (upper) - { - // set index to 0 so that we can iterate on the remaining index1 values. - triIndex1 = 0; - } - else - { - // out of bounds, ignore this diagIndex. - continue; - } - } - - while ((triIndex1 < axisLength1) && (triIndex0 < axisLength0)) - { - var baseIndex = triIndex0 * strides[0] + triIndex1 * result.strides[1]; - - for (int projectionIndex = 0; projectionIndex < projectionSize; projectionIndex++) - { - var index = baseIndex + projectionIndex * projectionStride; - - result.SetValue(index, GetValue(index)); - } - - if (upper) - { - triIndex1++; - } - else - { - triIndex0++; - } - } - } - - return result; - } - - /// - /// Reshapes the current tensor to new dimensions, using the same backing storage if possible. - /// - /// An span of integers that represent the size of each dimension of the Tensor to create. - /// A new tensor that reinterprets this tensor with different dimensions. - public abstract Tensor Reshape(ReadOnlySpan dimensions); - - /// - /// Obtains the value at the specified indices - /// - /// A one-dimensional array of integers that represent the indices specifying the position of the element to get. - /// The value at the specified position in this Tensor. - public virtual T this[params int[] indices] - { - get - { - if (indices is null) - { - throw new ArgumentNullException(nameof(indices)); - } - - var span = new ReadOnlySpan(indices); - return this[span]; - } - - set - { - if (indices is null) - { - throw new ArgumentNullException(nameof(indices)); - } - - var span = new ReadOnlySpan(indices); - this[span] = value; - } - } - - /// - /// Obtains the value at the specified indices - /// - /// A span integers that represent the indices specifying the position of the element to get. - /// The value at the specified position in this Tensor. - public virtual T this[ReadOnlySpan indices] - { - get - { - return GetValue(ArrayUtilities.GetIndex(strides, indices)); - } - - set - { - SetValue(ArrayUtilities.GetIndex(strides, indices), value); - } - } - - /// - /// Gets the value at the specified index, where index is a linearized version of n-dimension indices using strides. - /// - /// An integer index computed as a dot-product of indices. - /// The value at the specified position in this Tensor. - public abstract T GetValue(int index); - - /// - /// Sets the value at the specified index, where index is a linearized version of n-dimension indices using strides. - /// - /// An integer index computed as a dot-product of indices. - /// The new value to set at the specified position in this Tensor. - public abstract void SetValue(int index, T value); - - /// - /// The type that implements enumerators for instances. - /// - public struct Enumerator : IEnumerator - { - private readonly Tensor _tensor; - private int _index; - - internal Enumerator(Tensor tensor) - { - Debug.Assert(tensor != null); - - _tensor = tensor; - _index = 0; - Current = default!; - } - - public T Current { get; private set; } - - object? IEnumerator.Current => Current; - - public bool MoveNext() - { - if (_index < _tensor.Length) - { - Current = _tensor.GetValue(_index); - ++_index; - return true; - } - else - { - Current = default!; - return false; - } - } - - /// - /// Resets the enumerator to the beginning. - /// - public void Reset() - { - _index = 0; - Current = default!; - } - - /// - /// Disposes the enumerator. - /// - public void Dispose() { } - } - - /// - /// Gets an enumerator that enumerates the elements of the . - /// - /// An enumerator for the current . - public Enumerator GetEnumerator() => new Enumerator(this); - - #region statics - /// - /// Performs a value comparison of the content and shape of two tensors. Two tensors are equal if they have the same shape and same value at every set of indices. If not equal a tensor is greater or less than another tensor based on the first non-equal element when enumerating in linear order. - /// - /// - /// - /// - public static int Compare(Tensor left, Tensor right) - { - return StructuralComparisons.StructuralComparer.Compare(left, right); - } - - /// - /// Performs a value equality comparison of the content of two tensors. Two tensors are equal if they have the same shape and same value at every set of indices. - /// - /// - /// - /// - public static bool Equals(Tensor left, Tensor right) - { - return StructuralComparisons.StructuralEqualityComparer.Equals(left, right); - } - #endregion - - #region IEnumerable members - IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); - #endregion - - #region ICollection members - int ICollection.Count => (int)Length; - - bool ICollection.IsSynchronized => false; - - object ICollection.SyncRoot => this; // backingArray.this? - - void ICollection.CopyTo(Array array, int index) - { - if (array is T[] destinationArray) - { - CopyTo(destinationArray, index); - } - else - { - if (array == null) - { - throw new ArgumentNullException(nameof(array)); - } - if (array.Rank != 1) - { - throw new ArgumentException(SR.OnlySingleDimensionalArraysSupported, nameof(array)); - } - if (array.Length < index + Length) - { - throw new ArgumentException(SR.NumberGreaterThenAvailableSpace, nameof(array)); - } - - for (int i = 0; i < length; i++) - { - array.SetValue(GetValue(i), index + i); - } - } - } - #endregion - - #region IList members - object? IList.this[int index] - { - get - { - return GetValue(index); - } - set - { - try - { - SetValue(index, (T)value!); - } - catch (InvalidCastException) - { - throw new ArgumentException(SR.Format(SR.ValueIsNotOfType, value, typeof(T))); - } - } - } - - public bool IsFixedSize => true; - - public bool IsReadOnly => false; - - int IList.Add(object? value) - { - throw new InvalidOperationException(); - } - - void IList.Clear() - { - Fill(default!); - } - - bool IList.Contains(object? value) - { - if (IsCompatibleObject(value!)) - { - return Contains((T)value!); - } - return false; - } - - int IList.IndexOf(object? value) - { - if (IsCompatibleObject(value!)) - { - return IndexOf((T)value!); - } - return -1; - } - - void IList.Insert(int index, object? value) - { - throw new InvalidOperationException(); - } - - void IList.Remove(object? value) - { - throw new InvalidOperationException(); - } - - void IList.RemoveAt(int index) - { - throw new InvalidOperationException(); - } - #endregion - - #region IEnumerable members - IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); - #endregion - - #region ICollection members - int ICollection.Count => (int)Length; - - void ICollection.Add(T item) - { - throw new InvalidOperationException(); - } - - void ICollection.Clear() - { - Fill(default!); - } - - bool ICollection.Contains(T item) - { - return Contains(item); - } - - /// - /// Determines whether an element is in the Tensor<T>. - /// - /// - /// The object to locate in the Tensor<T>. The value can be null for reference types. - /// - /// - /// true if item is found in the Tensor<T>; otherwise, false. - /// - protected virtual bool Contains(T item) - { - return Length != 0 && IndexOf(item) != -1; - } - - void ICollection.CopyTo(T[] array, int arrayIndex) - { - CopyTo(array, arrayIndex); - } - - /// - /// Copies the elements of the Tensor<T> to an Array, starting at a particular Array index. - /// - /// - /// The one-dimensional Array that is the destination of the elements copied from Tensor<T>. The Array must have zero-based indexing. - /// - /// - /// The zero-based index in array at which copying begins. - /// - protected virtual void CopyTo(T[] array, int arrayIndex) - { - if (array is null) - { - throw new ArgumentNullException(nameof(array)); - } - - if (array.Length < arrayIndex + Length) - { - throw new ArgumentException(SR.NumberGreaterThenAvailableSpace, nameof(array)); - } - - for (int i = 0; i < length; i++) - { - array[arrayIndex + i] = GetValue(i); - } - } - - bool ICollection.Remove(T item) - { - throw new InvalidOperationException(); - } - #endregion - - #region IReadOnlyCollection members - - int IReadOnlyCollection.Count => (int)Length; - - #endregion - - #region IList members - T IList.this[int index] - { - get { return GetValue(index); } - set { SetValue(index, value); } - } - - int IList.IndexOf(T item) - { - return IndexOf(item); - } - - /// - /// Determines the index of a specific item in the Tensor<T>. - /// - /// The object to locate in the Tensor<T>. - /// The index of item if found in the tensor; otherwise, -1. - protected virtual int IndexOf(T item) - { - for (int i = 0; i < Length; i++) - { - if (GetValue(i)!.Equals(item)) - { - return i; - } - } - - return -1; - } - - void IList.Insert(int index, T item) - { - throw new InvalidOperationException(); - } - - void IList.RemoveAt(int index) - { - throw new InvalidOperationException(); - } - #endregion - - #region IReadOnlyList members - - T IReadOnlyList.this[int index] => GetValue(index); - - #endregion - - #region IStructuralComparable members - int IStructuralComparable.CompareTo(object? other, IComparer comparer) - { - if (other == null) - { - return 1; - } - - if (other is Tensor) - { - return CompareTo((Tensor)other, comparer); - } - - var otherArray = other as Array; - - if (otherArray != null) - { - return CompareTo(otherArray, comparer); - } - - throw new ArgumentException(SR.Format(SR.CannotCompare, nameof(Tensor), other.GetType()), nameof(other)); - } - - private int CompareTo(Tensor other, IComparer comparer) - { - if (Rank != other.Rank) - { - throw new ArgumentException(SR.Format(SR.CannotCompareWithRank, nameof(Tensor), Rank, nameof(other), other.Rank), nameof(other)); - } - - for (int i = 0; i < dimensions.Length; i++) - { - if (dimensions[i] != other.dimensions[i]) - { - throw new ArgumentException(SR.Format(SR.CannotCompareWithDifferentDimension, nameof(Tensor), i, dimensions[i], other.dimensions[i]), nameof(other)); - } - } - - int result = 0; - - if (IsReversedStride == other.IsReversedStride) - { - for (int i = 0; i < Length; i++) - { - result = comparer.Compare(GetValue(i), other.GetValue(i)); - if (result != 0) - { - break; - } - } - } - else - { - var indices = Rank < ArrayUtilities.StackallocMax ? stackalloc int[Rank] : new int[Rank]; - for (int i = 0; i < Length; i++) - { - ArrayUtilities.GetIndices(strides, IsReversedStride, i, indices); - result = comparer.Compare(this[indices], other[indices]); - if (result != 0) - { - break; - } - } - } - - return result; - } - - private int CompareTo(Array other, IComparer comparer) - { - if (Rank != other.Rank) - { - throw new ArgumentException(SR.Format(SR.CannotCompareWithRank, nameof(Tensor), Rank, nameof(Array), other.Rank), nameof(other)); - } - - for (int i = 0; i < dimensions.Length; i++) - { - var otherDimension = other.GetLength(i); - if (dimensions[i] != otherDimension) - { - throw new ArgumentException(SR.Format(SR.CannotCompareToWithDifferentDimension, nameof(Tensor), nameof(Array), i, dimensions[i], otherDimension), nameof(other)); - } - } - - int result = 0; - var indices = new int[Rank]; - for (int i = 0; i < Length; i++) - { - ArrayUtilities.GetIndices(strides, IsReversedStride, i, indices); - - result = comparer.Compare(GetValue(i), other.GetValue(indices)); - - if (result != 0) - { - break; - } - } - - return result; - } - #endregion - - #region IStructuralEquatable members - bool IStructuralEquatable.Equals(object? other, IEqualityComparer comparer) - { - if (other == null) - { - return false; - } - - if (other is Tensor) - { - return Equals((Tensor)other, comparer); - } - - var otherArray = other as Array; - - if (otherArray != null) - { - return Equals(otherArray, comparer); - } - - throw new ArgumentException(SR.Format(SR.CannotCompare, nameof(Tensor), other.GetType()), nameof(other)); - } - - private bool Equals(Tensor other, IEqualityComparer comparer) - { - if (Rank != other.Rank) - { - throw new ArgumentException(SR.Format(SR.CannotCompareWithRank, nameof(Tensor), Rank, nameof(other), other.Rank), nameof(other)); - } - - for (int i = 0; i < dimensions.Length; i++) - { - if (dimensions[i] != other.dimensions[i]) - { - throw new ArgumentException(SR.Format(SR.CannotCompareWithDifferentDimension, nameof(Tensor), i, dimensions[i], other.dimensions[i]), nameof(other)); - } - } - - if (IsReversedStride == other.IsReversedStride) - { - for (int i = 0; i < Length; i++) - { - if (!comparer.Equals(GetValue(i), other.GetValue(i))) - { - return false; - } - } - } - else - { - var indices = Rank < ArrayUtilities.StackallocMax ? stackalloc int[Rank] : new int[Rank]; - for (int i = 0; i < Length; i++) - { - ArrayUtilities.GetIndices(strides, IsReversedStride, i, indices); - - if (!comparer.Equals(this[indices], other[indices])) - { - return false; - } - } - } - - return true; - } - - private bool Equals(Array other, IEqualityComparer comparer) - { - if (Rank != other.Rank) - { - throw new ArgumentException(SR.Format(SR.CannotCompareWithRank, nameof(Tensor), Rank, nameof(Array), other.Rank), nameof(other)); - } - - for (int i = 0; i < dimensions.Length; i++) - { - var otherDimension = other.GetLength(i); - if (dimensions[i] != otherDimension) - { - throw new ArgumentException(SR.Format(SR.CannotCompareToWithDifferentDimension, nameof(Tensor), nameof(Array), i, dimensions[i], otherDimension), nameof(other)); - } - } - - var indices = new int[Rank]; - for (int i = 0; i < Length; i++) - { - ArrayUtilities.GetIndices(strides, IsReversedStride, i, indices); - - if (!comparer.Equals(GetValue(i), other.GetValue(indices))) - { - return false; - } - } - - return true; - } - int IStructuralEquatable.GetHashCode(IEqualityComparer comparer) - { - int hashCode = 0; - // this ignores shape, which is fine it just means we'll have hash collisions for things - // with the same content and different shape. - for (int i = 0; i < Length; i++) - { - hashCode ^= comparer.GetHashCode(GetValue(i)!); - } - - return hashCode; - } - #endregion - - #region Translations - - /// - /// Creates a copy of this tensor as a DenseTensor<T>. If this tensor is already a DenseTensor<T> calling this method is equivalent to calling Clone(). - /// - /// - public virtual DenseTensor ToDenseTensor() - { - var denseTensor = new DenseTensor(Dimensions, IsReversedStride); - for (int i = 0; i < Length; i++) - { - denseTensor.SetValue(i, GetValue(i)); - } - return denseTensor; - } - - - /// - /// Creates a copy of this tensor as a SparseTensor<T>. If this tensor is already a SparseTensor<T> calling this method is equivalent to calling Clone(). - /// - /// - public virtual SparseTensor ToSparseTensor() - { - var sparseTensor = new SparseTensor(Dimensions, IsReversedStride); - for (int i = 0; i < Length; i++) - { - sparseTensor.SetValue(i, GetValue(i)); - } - return sparseTensor; - } - - /// - /// Creates a copy of this tensor as a CompressedSparseTensor<T>. If this tensor is already a CompressedSparseTensor<T> calling this method is equivalent to calling Clone(). - /// - /// - public virtual CompressedSparseTensor ToCompressedSparseTensor() - { - var compressedSparseTensor = new CompressedSparseTensor(Dimensions, IsReversedStride); - for (int i = 0; i < Length; i++) - { - compressedSparseTensor.SetValue(i, GetValue(i)); - } - return compressedSparseTensor; - } - - #endregion - - public string GetArrayString(bool includeWhitespace = true) - { - var builder = new StringBuilder(); - - var strides = ArrayUtilities.GetStrides(dimensions); - var indices = new int[Rank]; - var innerDimension = Rank - 1; - var innerLength = dimensions[innerDimension]; - - int indent = 0; - for (int outerIndex = 0; outerIndex < Length; outerIndex += innerLength) - { - ArrayUtilities.GetIndices(strides, false, outerIndex, indices); - - while ((indent < innerDimension) && (indices[indent] == 0)) - { - // start up - if (includeWhitespace) - { - Indent(builder, indent); - } - indent++; - builder.Append('{'); - if (includeWhitespace) - { - builder.AppendLine(); - } - } - - for (int innerIndex = 0; innerIndex < innerLength; innerIndex++) - { - indices[innerDimension] = innerIndex; - - if ((innerIndex == 0)) - { - if (includeWhitespace) - { - Indent(builder, indent); - } - builder.Append('{'); - } - else - { - builder.Append(','); - } - builder.Append(this[indices]); - } - builder.Append('}'); - - for (int i = Rank - 2; i >= 0; i--) - { - var lastIndex = dimensions[i] - 1; - if (indices[i] == lastIndex) - { - // close out - --indent; - if (includeWhitespace) - { - builder.AppendLine(); - Indent(builder, indent); - } - builder.Append('}'); - } - else - { - builder.Append(','); - if (includeWhitespace) - { - builder.AppendLine(); - } - break; - } - } - } - - return builder.ToString(); - } - - private static void Indent(StringBuilder builder, int tabs, int spacesPerTab = 4) - { - for (int tab = 0; tab < tabs; tab++) - { - for (int space = 0; space < spacesPerTab; space++) - { - builder.Append(' '); - } - } - } - - private static bool IsCompatibleObject(object value) - { - // Non-null values are fine. Only accept nulls if T is a class or Nullable. - // Note that default(T) is not equal to null for value types except when T is Nullable. - return ((value is T) || (value == null && default(T) == null)); - } - } -} diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs new file mode 100644 index 00000000000000..03db1abb7f858a --- /dev/null +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs @@ -0,0 +1,1097 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +namespace System.Numerics.Tensors +{ + /// Performs primitive tensor operations over spans of memory. + public static partial class TensorPrimitives + { + /// Computes the element-wise absolute value of each single-precision floating-point number in the specified tensor. + /// The tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = MathF.Abs([i]). + /// + /// + /// The absolute value of a is its numeric value without its sign. For example, the absolute value of both 1.2e-03 and -1.2e03 is 1.2e03. + /// + /// + /// If a value is equal to or , the result stored into the corresponding destination location is set to . + /// If a value is equal to , the result stored into the corresponding destination location is the original NaN value with the sign bit removed. + /// + /// + public static void Abs(ReadOnlySpan x, Span destination) => + InvokeSpanIntoSpan(x, destination); + + /// Computes the element-wise addition of single-precision floating-point numbers in the specified tensors. + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Length of must be same as length of . + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = [i] + [i]. + /// + /// + /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// + public static void Add(ReadOnlySpan x, ReadOnlySpan y, Span destination) => + InvokeSpanSpanIntoSpan(x, y, destination); + + /// Computes the element-wise addition of single-precision floating-point numbers in the specified tensors. + /// The first tensor, represented as a span. + /// The second tensor, represented as a scalar. + /// The destination tensor, represented as a span. + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = [i] + . + /// + /// + /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// + public static void Add(ReadOnlySpan x, float y, Span destination) => + InvokeSpanScalarIntoSpan(x, y, destination); + + /// Computes the element-wise result of ( + ) * for the specified tensors. + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The third tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Length of must be same as length of and the length of . + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = ([i] + [i]) * [i]. + /// + /// + /// If any of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// + public static void AddMultiply(ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan multiplier, Span destination) => + InvokeSpanSpanSpanIntoSpan(x, y, multiplier, destination); + + /// Computes the element-wise result of ( + ) * for the specified tensors. + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The third tensor, represented as a scalar. + /// The destination tensor, represented as a span. + /// Length of must be same as length of . + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = ([i] + [i]) * . + /// + /// + /// If any of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// + public static void AddMultiply(ReadOnlySpan x, ReadOnlySpan y, float multiplier, Span destination) => + InvokeSpanSpanScalarIntoSpan(x, y, multiplier, destination); + + /// Computes the element-wise result of ( + ) * for the specified tensors. + /// The first tensor, represented as a span. + /// The second tensor, represented as a scalar. + /// The third tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Length of must be same as length of . + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = ([i] + ) * [i]. + /// + /// + /// If any of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// + public static void AddMultiply(ReadOnlySpan x, float y, ReadOnlySpan multiplier, Span destination) => + InvokeSpanScalarSpanIntoSpan(x, y, multiplier, destination); + + /// Computes the element-wise hyperbolic cosine of each single-precision floating-point radian angle in the specified tensor. + /// The tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = .Cosh([i]). + /// + /// + /// If a value is equal to or , the result stored into the corresponding destination location is set to . + /// If a value is equal to , the result stored into the corresponding destination location is also NaN. + /// + /// + /// The angles in x must be in radians. Use or multiply by /180 to convert degrees to radians. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static void Cosh(ReadOnlySpan x, Span destination) => + InvokeSpanIntoSpan(x, destination); + + /// Computes the cosine similarity between the two specified non-empty, equal-length tensors of single-precision floating-point numbers. + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The cosine similarity of the two tensors. + /// Length of must be same as length of . + /// and must not be empty. + /// + /// + /// This method effectively computes TensorPrimitives.Dot(x, y) / (MathF.Sqrt(TensorPrimitives.SumOfSquares(x)) * MathF.Sqrt(TensorPrimitives.SumOfSquares(y)). + /// + /// + /// If any element in either input tensor is equal to , , or , + /// NaN is returned. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static float CosineSimilarity(ReadOnlySpan x, ReadOnlySpan y) + { + if (x.IsEmpty) + { + ThrowHelper.ThrowArgument_SpansMustBeNonEmpty(); + } + + if (x.Length != y.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + } + + return CosineSimilarityCore(x, y); + } + + /// Computes the distance between two points, specified as non-empty, equal-length tensors of single-precision floating-point numbers, in Euclidean space. + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The Euclidean distance. + /// Length of must be same as length of . + /// and must not be empty. + /// + /// + /// This method effectively computes the equivalent of: + /// + /// Span<float> difference = ...; + /// TensorPrimitives.Subtract(x, y, difference); + /// float result = MathF.Sqrt(TensorPrimitives.SumOfSquares(difference)); + /// + /// but without requiring additional temporary storage for the intermediate differences. + /// + /// + /// If any element in either input tensor is equal to , NaN is returned. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static float Distance(ReadOnlySpan x, ReadOnlySpan y) + { + if (x.IsEmpty) + { + ThrowHelper.ThrowArgument_SpansMustBeNonEmpty(); + } + + return MathF.Sqrt(Aggregate(x, y)); + } + + /// Computes the element-wise division of single-precision floating-point numbers in the specified tensors. + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Length of must be same as length of . + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = [i] / [i]. + /// + /// + /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// + public static void Divide(ReadOnlySpan x, ReadOnlySpan y, Span destination) => + InvokeSpanSpanIntoSpan(x, y, destination); + + /// Computes the element-wise division of single-precision floating-point numbers in the specified tensors. + /// The first tensor, represented as a span. + /// The second tensor, represented as a scalar. + /// The destination tensor, represented as a span. + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = [i] / . + /// + /// + /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// + public static void Divide(ReadOnlySpan x, float y, Span destination) => + InvokeSpanScalarIntoSpan(x, y, destination); + + /// Computes the dot product of two tensors containing single-precision floating-point numbers. + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The dot product. + /// Length of must be same as length of . + /// + /// + /// This method effectively computes the equivalent of: + /// + /// Span<float> products = ...; + /// TensorPrimitives.Multiply(x, y, products); + /// float result = TensorPrimitives.Sum(products); + /// + /// but without requiring additional temporary storage for the intermediate products. It corresponds to the dot method defined by BLAS1. + /// + /// + /// If any of the input elements is equal to , the resulting value is also NaN. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static float Dot(ReadOnlySpan x, ReadOnlySpan y) => + Aggregate(x, y); + + /// Computes the element-wise result of raising e to the single-precision floating-point number powers in the specified tensor. + /// The tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = .Exp([i]). + /// + /// + /// If a value equals or , the result stored into the corresponding destination location is set to NaN. + /// If a value equals , the result stored into the corresponding destination location is set to 0. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static void Exp(ReadOnlySpan x, Span destination) => + InvokeSpanIntoSpan(x, destination); + + /// Searches for the index of the largest single-precision floating-point number in the specified tensor. + /// The tensor, represented as a span. + /// The index of the maximum element in , or -1 if is empty. + /// + /// + /// The determination of the maximum element matches the IEEE 754:2019 `maximum` function. If any value equal to + /// is present, the index of the first is returned. Positive 0 is considered greater than negative 0. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static int IndexOfMax(ReadOnlySpan x) + { + if (x.IsEmpty) + { + return -1; + } + + return IndexOfMinMaxCore(x); + } + + /// Searches for the index of the single-precision floating-point number with the largest magnitude in the specified tensor. + /// The tensor, represented as a span. + /// The index of the element in with the largest magnitude (absolute value), or -1 if is empty. + /// + /// + /// The determination of the maximum magnitude matches the IEEE 754:2019 `maximumMagnitude` function. If any value equal to + /// is present, the index of the first is returned. If two values have the same magnitude and one is positive and the other is negative, + /// the positive value is considered to have the larger magnitude. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static int IndexOfMaxMagnitude(ReadOnlySpan x) + { + if (x.IsEmpty) + { + return -1; + } + + return IndexOfMinMaxCore(x); + } + + /// Searches for the index of the smallest single-precision floating-point number in the specified tensor. + /// The tensor, represented as a span. + /// The index of the minimum element in , or -1 if is empty. + /// + /// + /// The determination of the minimum element matches the IEEE 754:2019 `minimum` function. If any value equal to + /// is present, the index of the first is returned. Negative 0 is considered smaller than positive 0. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static int IndexOfMin(ReadOnlySpan x) + { + if (x.IsEmpty) + { + return -1; + } + + return IndexOfMinMaxCore(x); + } + + /// Searches for the index of the single-precision floating-point number with the smallest magnitude in the specified tensor. + /// The tensor, represented as a span. + /// The index of the element in with the smallest magnitude (absolute value), or -1 if is empty. + /// + /// + /// The determination of the minimum magnitude matches the IEEE 754:2019 `minimumMagnitude` function. If any value equal to + /// is present, the index of the first is returned. If two values have the same magnitude and one is positive and the other is negative, + /// the negative value is considered to have the smaller magnitude. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static int IndexOfMinMagnitude(ReadOnlySpan x) + { + if (x.IsEmpty) + { + return -1; + } + + return IndexOfMinMaxCore(x); + } + + /// Computes the element-wise natural (base e) logarithm of single-precision floating-point numbers in the specified tensor. + /// The tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = .Log([i]). + /// + /// + /// If a value equals 0, the result stored into the corresponding destination location is set to . + /// If a value is negative or equal to , the result stored into the corresponding destination location is set to NaN. + /// If a value is positive infinity, the result stored into the corresponding destination location is set to . + /// Otherwise, if a value is positive, its natural logarithm is stored into the corresponding destination location. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static void Log(ReadOnlySpan x, Span destination) => + InvokeSpanIntoSpan(x, destination); + + /// Computes the element-wise base 2 logarithm of single-precision floating-point numbers in the specified tensor. + /// The tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = .Log2([i]). + /// + /// + /// If a value equals 0, the result stored into the corresponding destination location is set to . + /// If a value is negative or equal to , the result stored into the corresponding destination location is set to NaN. + /// If a value is positive infinity, the result stored into the corresponding destination location is set to . + /// Otherwise, if a value is positive, its natural logarithm is stored into the corresponding destination location. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static void Log2(ReadOnlySpan x, Span destination) => + InvokeSpanIntoSpan(x, destination); + + /// Searches for the largest single-precision floating-point number in the specified tensor. + /// The tensor, represented as a span. + /// The maximum element in . + /// Length of must be greater than zero. + /// + /// + /// The determination of the maximum element matches the IEEE 754:2019 `maximum` function. If any value equal to + /// is present, the first is returned. Positive 0 is considered greater than negative 0. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static float Max(ReadOnlySpan x) => + MinMaxCore(x); + + /// Computes the element-wise maximum of the single-precision floating-point numbers in the specified tensors. + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Length of must be same as length of . + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = MathF.Max([i], [i]). + /// + /// + /// The determination of the maximum element matches the IEEE 754:2019 `maximum` function. If either value is equal to , + /// that value is stored as the result. Positive 0 is considered greater than negative 0. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static void Max(ReadOnlySpan x, ReadOnlySpan y, Span destination) => + InvokeSpanSpanIntoSpan(x, y, destination); + + /// Searches for the single-precision floating-point number with the largest magnitude in the specified tensor. + /// The tensor, represented as a span. + /// The element in with the largest magnitude (absolute value). + /// Length of must be greater than zero. + /// + /// + /// The determination of the maximum magnitude matches the IEEE 754:2019 `maximumMagnitude` function. If any value equal to + /// is present, the first is returned. If two values have the same magnitude and one is positive and the other is negative, + /// the positive value is considered to have the larger magnitude. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static float MaxMagnitude(ReadOnlySpan x) => + MinMaxCore(x); + + /// Computes the element-wise single-precision floating-point number with the largest magnitude in the specified tensors. + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Length of must be same as length of . + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// This method effectively computes [i] = MathF.MaxMagnitude([i], [i]). + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static void MaxMagnitude(ReadOnlySpan x, ReadOnlySpan y, Span destination) => + InvokeSpanSpanIntoSpan(x, y, destination); + + /// Searches for the smallest single-precision floating-point number in the specified tensor. + /// The tensor, represented as a span. + /// The minimum element in . + /// Length of must be greater than zero. + /// + /// + /// The determination of the minimum element matches the IEEE 754:2019 `minimum` function. If any value is equal to + /// is present, the first is returned. Negative 0 is considered smaller than positive 0. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static float Min(ReadOnlySpan x) => + MinMaxCore(x); + + /// Computes the element-wise minimum of the single-precision floating-point numbers in the specified tensors. + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Length of must be same as length of . + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = MathF.Max([i], [i]). + /// + /// + /// The determination of the maximum element matches the IEEE 754:2019 `maximum` function. If either value is equal to , + /// that value is stored as the result. Positive 0 is considered greater than negative 0. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static void Min(ReadOnlySpan x, ReadOnlySpan y, Span destination) => + InvokeSpanSpanIntoSpan(x, y, destination); + + /// Searches for the single-precision floating-point number with the smallest magnitude in the specified tensor. + /// The tensor, represented as a span. + /// The element in with the smallest magnitude (absolute value). + /// Length of must be greater than zero. + /// + /// + /// The determination of the minimum magnitude matches the IEEE 754:2019 `minimumMagnitude` function. If any value equal to + /// is present, the first is returned. If two values have the same magnitude and one is positive and the other is negative, + /// the negative value is considered to have the smaller magnitude. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static float MinMagnitude(ReadOnlySpan x) => + MinMaxCore(x); + + /// Computes the element-wise single-precision floating-point number with the smallest magnitude in the specified tensors. + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Length of must be same as length of . + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// This method effectively computes [i] = MathF.MinMagnitude([i], [i]). + /// + /// + /// The determination of the maximum magnitude matches the IEEE 754:2019 `minimumMagnitude` function. If either value is equal to , + /// that value is stored as the result. If the two values have the same magnitude and one is positive and the other is negative, + /// the negative value is considered to have the smaller magnitude. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static void MinMagnitude(ReadOnlySpan x, ReadOnlySpan y, Span destination) => + InvokeSpanSpanIntoSpan(x, y, destination); + + /// Computes the element-wise product of single-precision floating-point numbers in the specified tensors. + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Length of must be same as length of . + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = [i] * [i]. + /// + /// + /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// + public static void Multiply(ReadOnlySpan x, ReadOnlySpan y, Span destination) => + InvokeSpanSpanIntoSpan(x, y, destination); + + /// Computes the element-wise product of single-precision floating-point numbers in the specified tensors. + /// The first tensor, represented as a span. + /// The second tensor, represented as a scalar. + /// The destination tensor, represented as a span. + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = [i] * . + /// It corresponds to the scal method defined by BLAS1. + /// + /// + /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// + public static void Multiply(ReadOnlySpan x, float y, Span destination) => + InvokeSpanScalarIntoSpan(x, y, destination); + + /// Computes the element-wise result of ( * ) * for the specified tensors of single-precision floating-point numbers. + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The third tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Length of must be same as length of and length of . + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = ([i] * [i]) + [i]. + /// + /// + /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// + public static void MultiplyAdd(ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan addend, Span destination) => + InvokeSpanSpanSpanIntoSpan(x, y, addend, destination); + + /// Computes the element-wise result of ( * ) * for the specified tensors of single-precision floating-point numbers. + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The third tensor, represented as a scalar. + /// The destination tensor, represented as a span. + /// Length of must be same as length of . + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = ([i] * [i]) + . + /// It corresponds to the axpy method defined by BLAS1. + /// + /// + /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// + public static void MultiplyAdd(ReadOnlySpan x, ReadOnlySpan y, float addend, Span destination) => + InvokeSpanSpanScalarIntoSpan(x, y, addend, destination); + + /// Computes the element-wise result of ( * ) * for the specified tensors of single-precision floating-point numbers. + /// The first tensor, represented as a span. + /// The second tensor, represented as a scalar. + /// The third tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Length of must be same as length of . + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = ([i] * ) + [i]. + /// + /// + /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// + public static void MultiplyAdd(ReadOnlySpan x, float y, ReadOnlySpan addend, Span destination) => + InvokeSpanScalarSpanIntoSpan(x, y, addend, destination); + + /// Computes the element-wise negation of each single-precision floating-point number in the specified tensor. + /// The tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = -[i]. + /// + /// + /// If any of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// + public static void Negate(ReadOnlySpan x, Span destination) => + InvokeSpanIntoSpan(x, destination); + + /// Computes the Euclidean norm of the specified tensor of single-precision floating-point numbers. + /// The first tensor, represented as a span. + /// The norm. + /// + /// + /// This method effectively computes MathF.Sqrt(TensorPrimitives.SumOfSquares(x)). + /// This is often referred to as the Euclidean norm or L2 norm. + /// It corresponds to the nrm2 method defined by BLAS1. + /// + /// + /// If any of the input values is equal to , the result value is also NaN. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static float Norm(ReadOnlySpan x) => + MathF.Sqrt(SumOfSquares(x)); + + /// Computes the product of all elements in the specified non-empty tensor of single-precision floating-point numbers. + /// The tensor, represented as a span. + /// The result of multiplying all elements in . + /// Length of must be greater than zero. + /// + /// + /// If any of the input values is equal to , the result value is also NaN. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static float Product(ReadOnlySpan x) + { + if (x.IsEmpty) + { + ThrowHelper.ThrowArgument_SpansMustBeNonEmpty(); + } + + return Aggregate(x); + } + + /// Computes the product of the element-wise differences of the single-precision floating-point numbers in the specified non-empty tensors. + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The result of multiplying the element-wise subtraction of the elements in the second tensor from the first tensor. + /// Length of both input spans must be greater than zero. + /// and must have the same length. + /// + /// + /// This method effectively computes: + /// + /// Span<float> differences = ...; + /// TensorPrimitives.Subtract(x, y, differences); + /// float result = TensorPrimitives.Product(differences); + /// + /// but without requiring additional temporary storage for the intermediate differences. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static float ProductOfDifferences(ReadOnlySpan x, ReadOnlySpan y) + { + if (x.IsEmpty) + { + ThrowHelper.ThrowArgument_SpansMustBeNonEmpty(); + } + + return Aggregate(x, y); + } + + /// Computes the product of the element-wise sums of the single-precision floating-point numbers in the specified non-empty tensors. + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The result of multiplying the element-wise additions of the elements in each tensor. + /// Length of both input spans must be greater than zero. + /// and must have the same length. + /// + /// + /// This method effectively computes: + /// + /// Span<float> sums = ...; + /// TensorPrimitives.Add(x, y, sums); + /// float result = TensorPrimitives.Product(sums); + /// + /// but without requiring additional temporary storage for the intermediate sums. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static float ProductOfSums(ReadOnlySpan x, ReadOnlySpan y) + { + if (x.IsEmpty) + { + ThrowHelper.ThrowArgument_SpansMustBeNonEmpty(); + } + + return Aggregate(x, y); + } + + /// Computes the element-wise sigmoid function on the specified non-empty tensor of single-precision floating-point numbers. + /// The tensor, represented as a span. + /// The destination tensor. + /// Destination is too short. + /// must not be empty. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = 1f / (1f + .Exp(-[i])). + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static void Sigmoid(ReadOnlySpan x, Span destination) + { + if (x.IsEmpty) + { + ThrowHelper.ThrowArgument_SpansMustBeNonEmpty(); + } + + InvokeSpanIntoSpan(x, destination); + } + + /// Computes the element-wise hyperbolic sine of each single-precision floating-point radian angle in the specified tensor. + /// The tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = .Sinh([i]). + /// + /// + /// If a value is equal to , , or , + /// the corresponding destination location is set to that value. + /// + /// + /// The angles in x must be in radians. Use or multiply by /180 to convert degrees to radians. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static void Sinh(ReadOnlySpan x, Span destination) => + InvokeSpanIntoSpan(x, destination); + + /// Computes the softmax function over the specified non-empty tensor of single-precision floating-point numbers. + /// The tensor, represented as a span. + /// The destination tensor. + /// Destination is too short. + /// must not be empty. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes a sum of MathF.Exp(x[i]) for all elements in . + /// It then effectively computes [i] = MathF.Exp([i]) / sum. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static void SoftMax(ReadOnlySpan x, Span destination) + { + if (x.IsEmpty) + { + ThrowHelper.ThrowArgument_SpansMustBeNonEmpty(); + } + + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ValidateInputOutputSpanNonOverlapping(x, destination); + + float expSum = Aggregate(x); + + InvokeSpanScalarIntoSpan(x, expSum, destination); + } + + /// Computes the element-wise difference between single-precision floating-point numbers in the specified tensors. + /// The first tensor, represented as a span. + /// The second tensor, represented as a scalar. + /// The destination tensor, represented as a span. + /// Length of must be same as length of . + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = [i] - [i]. + /// + /// + /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// + public static void Subtract(ReadOnlySpan x, ReadOnlySpan y, Span destination) => + InvokeSpanSpanIntoSpan(x, y, destination); + + /// Computes the element-wise difference between single-precision floating-point numbers in the specified tensors. + /// The first tensor, represented as a span. + /// The second tensor, represented as a scalar. + /// The destination tensor, represented as a span. + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = [i] - . + /// + /// + /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// + public static void Subtract(ReadOnlySpan x, float y, Span destination) => + InvokeSpanScalarIntoSpan(x, y, destination); + + /// Computes the sum of all elements in the specified tensor of single-precision floating-point numbers. + /// The tensor, represented as a span. + /// The result of adding all elements in , or zero if is empty. + /// + /// + /// If any of the values in the input is equal to , the result is also NaN. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static float Sum(ReadOnlySpan x) => + Aggregate(x); + + /// Computes the sum of the absolute values of every element in the specified tensor of single-precision floating-point numbers. + /// The tensor, represented as a span. + /// The result of adding the absolute value of every element in , or zero if is empty. + /// + /// + /// This method effectively computes: + /// + /// Span<float> absoluteValues = ...; + /// TensorPrimitives.Abs(x, absoluteValues); + /// float result = TensorPrimitives.Sum(absoluteValues); + /// + /// but without requiring intermediate storage for the absolute values. It corresponds to the asum method defined by BLAS1. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static float SumOfMagnitudes(ReadOnlySpan x) => + Aggregate(x); + + /// Computes the sum of the square of every element in the specified tensor of single-precision floating-point numbers. + /// The tensor, represented as a span. + /// The result of adding the square of every element in , or zero if is empty. + /// + /// + /// This method effectively computes: + /// + /// Span<float> squaredValues = ...; + /// TensorPrimitives.Multiply(x, x, squaredValues); + /// float result = TensorPrimitives.Sum(squaredValues); + /// + /// but without requiring intermediate storage for the squared values. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static float SumOfSquares(ReadOnlySpan x) => + Aggregate(x); + + /// Computes the element-wise hyperbolic tangent of each single-precision floating-point radian angle in the specified tensor. + /// The tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = .Tanh([i]). + /// + /// + /// If a value is equal to , the corresponding destination location is set to -1. + /// If a value is equal to , the corresponding destination location is set to 1. + /// If a value is equal to , the corresponding destination location is set to NaN. + /// + /// + /// The angles in x must be in radians. Use or multiply by /180 to convert degrees to radians. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static void Tanh(ReadOnlySpan x, Span destination) => + InvokeSpanIntoSpan(x, destination); + + /// Throws an exception if the and spans overlap and don't begin at the same memory location. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static void ValidateInputOutputSpanNonOverlapping(ReadOnlySpan input, Span output) + { + if (!Unsafe.AreSame(ref MemoryMarshal.GetReference(input), ref MemoryMarshal.GetReference(output)) && + input.Overlaps(output)) + { + ThrowHelper.ThrowArgument_InputAndDestinationSpanMustNotOverlap(); + } + } + + /// Mask used to handle alignment elements before vectorized handling of the input. + /// + /// Logically 16 rows of 16 uints. The Nth row should be used to handle N alignment elements at the + /// beginning of the input, where elements in the vector after that will be zero'd. + /// + /// There actually exists 17 rows in the table with the last row being a repeat of the first. This is + /// done because it allows the main algorithms to use a simplified algorithm when computing the amount + /// of misalignment where we always skip the first 16 elements, even if already aligned, so we don't + /// double process them. This allows us to avoid an additional branch. + /// + private static ReadOnlySpan AlignmentUInt32Mask_16x16 => + [ + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + ]; + + /// Mask used to handle remaining elements after vectorized handling of the input. + /// + /// Logically 16 rows of 16 uints. The Nth row should be used to handle N remaining elements at the + /// end of the input, where elements in the vector prior to that will be zero'd. + /// + /// Much as with the AlignmentMask table, we actually have 17 rows where the last row is a repeat of + /// the first. Doing this allows us to avoid an additional branch and instead to always process the + /// last 16 elements via a conditional select instead. + /// + private static ReadOnlySpan RemainderUInt32Mask_16x16 => + [ + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + ]; + } +} diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs new file mode 100644 index 00000000000000..498e4b58da77ca --- /dev/null +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs @@ -0,0 +1,11607 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Runtime.Intrinsics; +using System.Runtime.Intrinsics.Arm; +using System.Runtime.Intrinsics.X86; + +namespace System.Numerics.Tensors +{ + public static unsafe partial class TensorPrimitives + { + /// Defines the threshold, in bytes, at which non-temporal stores will be used. + /// + /// A non-temporal store is one that allows the CPU to bypass the cache when writing to memory. + /// + /// This can be beneficial when working with large amounts of memory where the writes would otherwise + /// cause large amounts of repeated updates and evictions. The hardware optimization manuals recommend + /// the threshold to be roughly half the size of the last level of on-die cache -- that is, if you have approximately + /// 4MB of L3 cache per core, you'd want this to be approx. 1-2MB, depending on if hyperthreading was enabled. + /// + /// However, actually computing the amount of L3 cache per core can be tricky or error prone. Native memcpy + /// algorithms use a constant threshold that is typically around 256KB and we match that here for simplicity. This + /// threshold accounts for most processors in the last 10-15 years that had approx. 1MB L3 per core and support + /// hyperthreading, giving a per core last level cache of approx. 512KB. + /// + private const nuint NonTemporalByteThreshold = 256 * 1024; + + /// + /// Copies to , converting each + /// value to its nearest representable half-precision floating-point value. + /// + /// The source span from which to copy values. + /// The destination span into which the converted values should be written. + /// Destination is too short. + /// + /// + /// This method effectively computes [i] = (Half)[i]. + /// + /// + /// and must not overlap. If they do, behavior is undefined. + /// + /// + public static void ConvertToHalf(ReadOnlySpan source, Span destination) + { + if (source.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ref float sourceRef = ref MemoryMarshal.GetReference(source); + ref ushort destinationRef = ref Unsafe.As(ref MemoryMarshal.GetReference(destination)); + int i = 0, twoVectorsFromEnd; + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated) + { + twoVectorsFromEnd = source.Length - (Vector512.Count * 2); + if (i <= twoVectorsFromEnd) + { + // Loop handling two input vectors / one output vector at a time. + do + { + Vector512 lower = SingleToHalfAsWidenedUInt32_Vector512(Vector512.LoadUnsafe(ref sourceRef, (uint)i)); + Vector512 upper = SingleToHalfAsWidenedUInt32_Vector512(Vector512.LoadUnsafe(ref sourceRef, (uint)(i + Vector512.Count))); + Vector512.Narrow(lower, upper).StoreUnsafe(ref destinationRef, (uint)i); + + i += Vector512.Count * 2; + } + while (i <= twoVectorsFromEnd); + + // Handle any remaining elements with final vectors. + if (i != source.Length) + { + i = source.Length - (Vector512.Count * 2); + + Vector512 lower = SingleToHalfAsWidenedUInt32_Vector512(Vector512.LoadUnsafe(ref sourceRef, (uint)i)); + Vector512 upper = SingleToHalfAsWidenedUInt32_Vector512(Vector512.LoadUnsafe(ref sourceRef, (uint)(i + Vector512.Count))); + Vector512.Narrow(lower, upper).StoreUnsafe(ref destinationRef, (uint)i); + } + + return; + } + } +#endif + + if (Vector256.IsHardwareAccelerated) + { + twoVectorsFromEnd = source.Length - (Vector256.Count * 2); + if (i <= twoVectorsFromEnd) + { + // Loop handling two input vectors / one output vector at a time. + do + { + Vector256 lower = SingleToHalfAsWidenedUInt32_Vector256(Vector256.LoadUnsafe(ref sourceRef, (uint)i)); + Vector256 upper = SingleToHalfAsWidenedUInt32_Vector256(Vector256.LoadUnsafe(ref sourceRef, (uint)(i + Vector256.Count))); + Vector256 halfs = Vector256.Narrow(lower, upper); + halfs.StoreUnsafe(ref destinationRef, (uint)i); + + i += Vector256.Count * 2; + } + while (i <= twoVectorsFromEnd); + + // Handle any remaining elements with final vectors. + if (i != source.Length) + { + i = source.Length - (Vector256.Count * 2); + + Vector256 lower = SingleToHalfAsWidenedUInt32_Vector256(Vector256.LoadUnsafe(ref sourceRef, (uint)i)); + Vector256 upper = SingleToHalfAsWidenedUInt32_Vector256(Vector256.LoadUnsafe(ref sourceRef, (uint)(i + Vector256.Count))); + Vector256.Narrow(lower, upper).StoreUnsafe(ref destinationRef, (uint)i); + } + + return; + } + } + + if (Vector128.IsHardwareAccelerated) + { + twoVectorsFromEnd = source.Length - (Vector128.Count * 2); + if (i <= twoVectorsFromEnd) + { + // Loop handling two input vectors / one output vector at a time. + do + { + Vector128 lower = SingleToHalfAsWidenedUInt32_Vector128(Vector128.LoadUnsafe(ref sourceRef, (uint)i)); + Vector128 upper = SingleToHalfAsWidenedUInt32_Vector128(Vector128.LoadUnsafe(ref sourceRef, (uint)(i + Vector128.Count))); + Vector128.Narrow(lower, upper).StoreUnsafe(ref destinationRef, (uint)i); + + i += Vector128.Count * 2; + } + while (i <= twoVectorsFromEnd); + + // Handle any remaining elements with final vectors. + if (i != source.Length) + { + i = source.Length - (Vector128.Count * 2); + + Vector128 lower = SingleToHalfAsWidenedUInt32_Vector128(Vector128.LoadUnsafe(ref sourceRef, (uint)i)); + Vector128 upper = SingleToHalfAsWidenedUInt32_Vector128(Vector128.LoadUnsafe(ref sourceRef, (uint)(i + Vector128.Count))); + Vector128.Narrow(lower, upper).StoreUnsafe(ref destinationRef, (uint)i); + } + + return; + } + } + + while (i < source.Length) + { + Unsafe.Add(ref destinationRef, i) = BitConverter.HalfToUInt16Bits((Half)Unsafe.Add(ref sourceRef, i)); + i++; + } + + // This implements a vectorized version of the `explicit operator Half(float value) operator`. + // See detailed description of the algorithm used here: + // https://github.com/dotnet/runtime/blob/ca8d6f0420096831766ec11c7d400e4f7ccc7a34/src/libraries/System.Private.CoreLib/src/System/Half.cs#L606-L714 + // The cast operator converts a float to a Half represented as a UInt32, then narrows to a UInt16, and reinterpret casts to Half. + // This does the same, with an input VectorXx and an output VectorXx. + // Loop handling two input vectors at a time; each input float is double the size of each output Half, + // so we need two vectors of floats to produce one vector of Halfs. Half isn't supported in VectorXx, + // so we convert the VectorXx to a VectorXx, and the caller then uses this twice, narrows the combination + // into a VectorXx, and then saves that out to the destination `ref Half` reinterpreted as `ref ushort`. + +#pragma warning disable IDE0059 // https://github.com/dotnet/roslyn/issues/44948 + const uint MinExp = 0x3880_0000u; // Minimum exponent for rounding + const uint Exponent126 = 0x3f00_0000u; // Exponent displacement #1 + const uint SingleBiasedExponentMask = 0x7F80_0000; // float.BiasedExponentMask; // Exponent mask + const uint Exponent13 = 0x0680_0000u; // Exponent displacement #2 + const float MaxHalfValueBelowInfinity = 65520.0f; // Maximum value that is not Infinity in Half + const uint ExponentMask = 0x7C00; // Mask for exponent bits in Half + const uint SingleSignMask = 0x8000_0000u; // float.SignMask; // Mask for sign bit in float +#pragma warning restore IDE0059 + + static Vector128 SingleToHalfAsWidenedUInt32_Vector128(Vector128 value) + { + Vector128 bitValue = value.AsUInt32(); + + // Extract sign bit + Vector128 sign = Vector128.ShiftRightLogical(bitValue & Vector128.Create(SingleSignMask), 16); + + // Detecting NaN (0u if value is NaN; otherwise, ~0u) + Vector128 realMask = Vector128.Equals(value, value).AsUInt32(); + + // Clear sign bit + value = Vector128.Abs(value); + + // Rectify values that are Infinity in Half. + value = Vector128.Min(Vector128.Create(MaxHalfValueBelowInfinity), value); + + // Rectify lower exponent + Vector128 exponentOffset0 = Vector128.Max(value, Vector128.Create(MinExp).AsSingle()).AsUInt32(); + + // Extract exponent + exponentOffset0 &= Vector128.Create(SingleBiasedExponentMask); + + // Add exponent by 13 + exponentOffset0 += Vector128.Create(Exponent13); + + // Round Single into Half's precision (NaN also gets modified here, just setting the MSB of fraction) + value += exponentOffset0.AsSingle(); + bitValue = value.AsUInt32(); + + // Only exponent bits will be modified if NaN + Vector128 maskedHalfExponentForNaN = ~realMask & Vector128.Create(ExponentMask); + + // Subtract exponent by 126 + bitValue -= Vector128.Create(Exponent126); + + // Shift bitValue right by 13 bits to match the boundary of exponent part and fraction part. + Vector128 newExponent = Vector128.ShiftRightLogical(bitValue, 13); + + // Clear the fraction parts if the value was NaN. + bitValue &= realMask; + + // Merge the exponent part with fraction part, and add the exponent part and fraction part's overflow. + bitValue += newExponent; + + // Clear exponents if value is NaN + bitValue &= ~maskedHalfExponentForNaN; + + // Merge sign bit with possible NaN exponent + Vector128 signAndMaskedExponent = maskedHalfExponentForNaN | sign; + + // Merge sign bit and possible NaN exponent + bitValue |= signAndMaskedExponent; + + // The final result + return bitValue; + } + + static Vector256 SingleToHalfAsWidenedUInt32_Vector256(Vector256 value) + { + Vector256 bitValue = value.AsUInt32(); + + // Extract sign bit + Vector256 sign = Vector256.ShiftRightLogical(bitValue & Vector256.Create(SingleSignMask), 16); + + // Detecting NaN (0u if value is NaN; otherwise, ~0u) + Vector256 realMask = Vector256.Equals(value, value).AsUInt32(); + + // Clear sign bit + value = Vector256.Abs(value); + + // Rectify values that are Infinity in Half. + value = Vector256.Min(Vector256.Create(MaxHalfValueBelowInfinity), value); + + // Rectify lower exponent + Vector256 exponentOffset0 = Vector256.Max(value, Vector256.Create(MinExp).AsSingle()).AsUInt32(); + + // Extract exponent + exponentOffset0 &= Vector256.Create(SingleBiasedExponentMask); + + // Add exponent by 13 + exponentOffset0 += Vector256.Create(Exponent13); + + // Round Single into Half's precision (NaN also gets modified here, just setting the MSB of fraction) + value += exponentOffset0.AsSingle(); + bitValue = value.AsUInt32(); + + // Only exponent bits will be modified if NaN + Vector256 maskedHalfExponentForNaN = ~realMask & Vector256.Create(ExponentMask); + + // Subtract exponent by 126 + bitValue -= Vector256.Create(Exponent126); + + // Shift bitValue right by 13 bits to match the boundary of exponent part and fraction part. + Vector256 newExponent = Vector256.ShiftRightLogical(bitValue, 13); + + // Clear the fraction parts if the value was NaN. + bitValue &= realMask; + + // Merge the exponent part with fraction part, and add the exponent part and fraction part's overflow. + bitValue += newExponent; + + // Clear exponents if value is NaN + bitValue &= ~maskedHalfExponentForNaN; + + // Merge sign bit with possible NaN exponent + Vector256 signAndMaskedExponent = maskedHalfExponentForNaN | sign; + + // Merge sign bit and possible NaN exponent + bitValue |= signAndMaskedExponent; + + // The final result + return bitValue; + } + +#if NET8_0_OR_GREATER + static Vector512 SingleToHalfAsWidenedUInt32_Vector512(Vector512 value) + { + Vector512 bitValue = value.AsUInt32(); + + // Extract sign bit + Vector512 sign = Vector512.ShiftRightLogical(bitValue & Vector512.Create(SingleSignMask), 16); + + // Detecting NaN (0u if value is NaN; otherwise, ~0u) + Vector512 realMask = Vector512.Equals(value, value).AsUInt32(); + + // Clear sign bit + value = Vector512.Abs(value); + + // Rectify values that are Infinity in Half. + value = Vector512.Min(Vector512.Create(MaxHalfValueBelowInfinity), value); + + // Rectify lower exponent + Vector512 exponentOffset0 = Vector512.Max(value, Vector512.Create(MinExp).AsSingle()).AsUInt32(); + + // Extract exponent + exponentOffset0 &= Vector512.Create(SingleBiasedExponentMask); + + // Add exponent by 13 + exponentOffset0 += Vector512.Create(Exponent13); + + // Round Single into Half's precision (NaN also gets modified here, just setting the MSB of fraction) + value += exponentOffset0.AsSingle(); + bitValue = value.AsUInt32(); + + // Only exponent bits will be modified if NaN + Vector512 maskedHalfExponentForNaN = ~realMask & Vector512.Create(ExponentMask); + + // Subtract exponent by 126 + bitValue -= Vector512.Create(Exponent126); + + // Shift bitValue right by 13 bits to match the boundary of exponent part and fraction part. + Vector512 newExponent = Vector512.ShiftRightLogical(bitValue, 13); + + // Clear the fraction parts if the value was NaN. + bitValue &= realMask; + + // Merge the exponent part with fraction part, and add the exponent part and fraction part's overflow. + bitValue += newExponent; + + // Clear exponents if value is NaN + bitValue &= ~maskedHalfExponentForNaN; + + // Merge sign bit with possible NaN exponent + Vector512 signAndMaskedExponent = maskedHalfExponentForNaN | sign; + + // Merge sign bit and possible NaN exponent + bitValue |= signAndMaskedExponent; + + // The final result + return bitValue; + } +#endif + } + + /// + /// Copies to , converting each half-precision + /// floating-point value to its nearest representable value. + /// + /// The source span from which to copy values. + /// The destination span into which the converted values should be written. + /// Destination is too short. + /// + /// + /// This method effectively computes [i] = (float)[i]. + /// + /// + /// and must not overlap. If they do, behavior is undefined. + /// + /// + public static void ConvertToSingle(ReadOnlySpan source, Span destination) + { + if (source.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ref short sourceRef = ref Unsafe.As(ref MemoryMarshal.GetReference(source)); + ref float destinationRef = ref MemoryMarshal.GetReference(destination); + int i = 0, oneVectorFromEnd; + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated) + { + oneVectorFromEnd = source.Length - Vector512.Count; + if (i <= oneVectorFromEnd) + { + // Loop handling one input vector / two output vectors at a time. + do + { + (Vector512 lower, Vector512 upper) = Vector512.Widen(Vector512.LoadUnsafe(ref sourceRef, (uint)i)); + HalfAsWidenedUInt32ToSingle_Vector512(lower.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)i); + HalfAsWidenedUInt32ToSingle_Vector512(upper.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)(i + Vector512.Count)); + + i += Vector512.Count; + } + while (i <= oneVectorFromEnd); + + // Handle any remaining elements with a final input vector. + if (i != source.Length) + { + i = source.Length - Vector512.Count; + + (Vector512 lower, Vector512 upper) = Vector512.Widen(Vector512.LoadUnsafe(ref sourceRef, (uint)i)); + HalfAsWidenedUInt32ToSingle_Vector512(lower.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)i); + HalfAsWidenedUInt32ToSingle_Vector512(upper.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)(i + Vector512.Count)); + } + + return; + } + } +#endif + + if (Vector256.IsHardwareAccelerated) + { + oneVectorFromEnd = source.Length - Vector256.Count; + if (i <= oneVectorFromEnd) + { + // Loop handling one input vector / two output vectors at a time. + do + { + (Vector256 lower, Vector256 upper) = Vector256.Widen(Vector256.LoadUnsafe(ref sourceRef, (uint)i)); + HalfAsWidenedUInt32ToSingle_Vector256(lower.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)i); + HalfAsWidenedUInt32ToSingle_Vector256(upper.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)(i + Vector256.Count)); + + i += Vector256.Count; + } + while (i <= oneVectorFromEnd); + + // Handle any remaining elements with a final input vector. + if (i != source.Length) + { + i = source.Length - Vector256.Count; + + (Vector256 lower, Vector256 upper) = Vector256.Widen(Vector256.LoadUnsafe(ref sourceRef, (uint)i)); + HalfAsWidenedUInt32ToSingle_Vector256(lower.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)i); + HalfAsWidenedUInt32ToSingle_Vector256(upper.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)(i + Vector256.Count)); + } + + return; + } + } + + if (Vector128.IsHardwareAccelerated) + { + oneVectorFromEnd = source.Length - Vector128.Count; + if (i <= oneVectorFromEnd) + { + // Loop handling one input vector / two output vectors at a time. + do + { + (Vector128 lower, Vector128 upper) = Vector128.Widen(Vector128.LoadUnsafe(ref sourceRef, (uint)i)); + HalfAsWidenedUInt32ToSingle_Vector128(lower.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)i); + HalfAsWidenedUInt32ToSingle_Vector128(upper.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)(i + Vector128.Count)); + + i += Vector128.Count; + } + while (i <= oneVectorFromEnd); + + // Handle any remaining elements with a final input vector. + if (i != source.Length) + { + i = source.Length - Vector128.Count; + + (Vector128 lower, Vector128 upper) = Vector128.Widen(Vector128.LoadUnsafe(ref sourceRef, (uint)i)); + HalfAsWidenedUInt32ToSingle_Vector128(lower.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)i); + HalfAsWidenedUInt32ToSingle_Vector128(upper.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)(i + Vector128.Count)); + } + + return; + } + } + + while (i < source.Length) + { + Unsafe.Add(ref destinationRef, i) = (float)Unsafe.As(ref Unsafe.Add(ref sourceRef, i)); + i++; + } + + // This implements a vectorized version of the `explicit operator float(Half value) operator`. + // See detailed description of the algorithm used here: + // https://github.com/dotnet/runtime/blob/3bf40a378f00cb5bf18ff62796bc7097719b974c/src/libraries/System.Private.CoreLib/src/System/Half.cs#L1010-L1040 + // The cast operator converts a Half represented as uint to a float. This does the same, with an input VectorXx and an output VectorXx. + // The VectorXx is created by reading a vector of Halfs as a VectorXx then widened to two VectorXxs and cast to VectorXxs. + // We loop handling one input vector at a time, producing two output float vectors. + +#pragma warning disable IDE0059 // https://github.com/dotnet/roslyn/issues/44948 + const uint ExponentLowerBound = 0x3880_0000u; // The smallest positive normal number in Half, converted to Single + const uint ExponentOffset = 0x3800_0000u; // BitConverter.SingleToUInt32Bits(1.0f) - ((uint)BitConverter.HalfToUInt16Bits((Half)1.0f) << 13) + const uint SingleSignMask = 0x8000_0000; // float.SignMask; // Mask for sign bit in Single + const uint HalfExponentMask = 0x7C00; // Mask for exponent bits in Half + const uint HalfToSingleBitsMask = 0x0FFF_E000; // Mask for bits in Single converted from Half +#pragma warning restore IDE0059 + + static Vector128 HalfAsWidenedUInt32ToSingle_Vector128(Vector128 value) + { + // Extract sign bit of value + Vector128 sign = value & Vector128.Create(SingleSignMask); + + // Copy sign bit to upper bits + Vector128 bitValueInProcess = value; + + // Extract exponent bits of value (BiasedExponent is not for here as it performs unnecessary shift) + Vector128 offsetExponent = bitValueInProcess & Vector128.Create(HalfExponentMask); + + // ~0u when value is subnormal, 0 otherwise + Vector128 subnormalMask = Vector128.Equals(offsetExponent, Vector128.Zero); + + // ~0u when value is either Infinity or NaN, 0 otherwise + Vector128 infinityOrNaNMask = Vector128.Equals(offsetExponent, Vector128.Create(HalfExponentMask)); + + // 0x3880_0000u if value is subnormal, 0 otherwise + Vector128 maskedExponentLowerBound = subnormalMask & Vector128.Create(ExponentLowerBound); + + // 0x3880_0000u if value is subnormal, 0x3800_0000u otherwise + Vector128 offsetMaskedExponentLowerBound = Vector128.Create(ExponentOffset) | maskedExponentLowerBound; + + // Match the position of the boundary of exponent bits and fraction bits with IEEE 754 Binary32(Single) + bitValueInProcess = Vector128.ShiftLeft(bitValueInProcess, 13); + + // Double the offsetMaskedExponentLowerBound if value is either Infinity or NaN + offsetMaskedExponentLowerBound = Vector128.ConditionalSelect(Vector128.Equals(infinityOrNaNMask, Vector128.Zero), + offsetMaskedExponentLowerBound, + Vector128.ShiftLeft(offsetMaskedExponentLowerBound, 1)); + + // Extract exponent bits and fraction bits of value + bitValueInProcess &= Vector128.Create(HalfToSingleBitsMask); + + // Adjust exponent to match the range of exponent + bitValueInProcess += offsetMaskedExponentLowerBound; + + // If value is subnormal, remove unnecessary 1 on top of fraction bits. + Vector128 absoluteValue = (bitValueInProcess.AsSingle() - maskedExponentLowerBound.AsSingle()).AsUInt32(); + + // Merge sign bit with rest + return (absoluteValue | sign).AsSingle(); + } + + static Vector256 HalfAsWidenedUInt32ToSingle_Vector256(Vector256 value) + { + // Extract sign bit of value + Vector256 sign = value & Vector256.Create(SingleSignMask); + + // Copy sign bit to upper bits + Vector256 bitValueInProcess = value; + + // Extract exponent bits of value (BiasedExponent is not for here as it performs unnecessary shift) + Vector256 offsetExponent = bitValueInProcess & Vector256.Create(HalfExponentMask); + + // ~0u when value is subnormal, 0 otherwise + Vector256 subnormalMask = Vector256.Equals(offsetExponent, Vector256.Zero); + + // ~0u when value is either Infinity or NaN, 0 otherwise + Vector256 infinityOrNaNMask = Vector256.Equals(offsetExponent, Vector256.Create(HalfExponentMask)); + + // 0x3880_0000u if value is subnormal, 0 otherwise + Vector256 maskedExponentLowerBound = subnormalMask & Vector256.Create(ExponentLowerBound); + + // 0x3880_0000u if value is subnormal, 0x3800_0000u otherwise + Vector256 offsetMaskedExponentLowerBound = Vector256.Create(ExponentOffset) | maskedExponentLowerBound; + + // Match the position of the boundary of exponent bits and fraction bits with IEEE 754 Binary32(Single) + bitValueInProcess = Vector256.ShiftLeft(bitValueInProcess, 13); + + // Double the offsetMaskedExponentLowerBound if value is either Infinity or NaN + offsetMaskedExponentLowerBound = Vector256.ConditionalSelect(Vector256.Equals(infinityOrNaNMask, Vector256.Zero), + offsetMaskedExponentLowerBound, + Vector256.ShiftLeft(offsetMaskedExponentLowerBound, 1)); + + // Extract exponent bits and fraction bits of value + bitValueInProcess &= Vector256.Create(HalfToSingleBitsMask); + + // Adjust exponent to match the range of exponent + bitValueInProcess += offsetMaskedExponentLowerBound; + + // If value is subnormal, remove unnecessary 1 on top of fraction bits. + Vector256 absoluteValue = (bitValueInProcess.AsSingle() - maskedExponentLowerBound.AsSingle()).AsUInt32(); + + // Merge sign bit with rest + return (absoluteValue | sign).AsSingle(); + } + +#if NET8_0_OR_GREATER + static Vector512 HalfAsWidenedUInt32ToSingle_Vector512(Vector512 value) + { + // Extract sign bit of value + Vector512 sign = value & Vector512.Create(SingleSignMask); + + // Copy sign bit to upper bits + Vector512 bitValueInProcess = value; + + // Extract exponent bits of value (BiasedExponent is not for here as it performs unnecessary shift) + Vector512 offsetExponent = bitValueInProcess & Vector512.Create(HalfExponentMask); + + // ~0u when value is subnormal, 0 otherwise + Vector512 subnormalMask = Vector512.Equals(offsetExponent, Vector512.Zero); + + // ~0u when value is either Infinity or NaN, 0 otherwise + Vector512 infinityOrNaNMask = Vector512.Equals(offsetExponent, Vector512.Create(HalfExponentMask)); + + // 0x3880_0000u if value is subnormal, 0 otherwise + Vector512 maskedExponentLowerBound = subnormalMask & Vector512.Create(ExponentLowerBound); + + // 0x3880_0000u if value is subnormal, 0x3800_0000u otherwise + Vector512 offsetMaskedExponentLowerBound = Vector512.Create(ExponentOffset) | maskedExponentLowerBound; + + // Match the position of the boundary of exponent bits and fraction bits with IEEE 754 Binary32(Single) + bitValueInProcess = Vector512.ShiftLeft(bitValueInProcess, 13); + + // Double the offsetMaskedExponentLowerBound if value is either Infinity or NaN + offsetMaskedExponentLowerBound = Vector512.ConditionalSelect(Vector512.Equals(infinityOrNaNMask, Vector512.Zero), + offsetMaskedExponentLowerBound, + Vector512.ShiftLeft(offsetMaskedExponentLowerBound, 1)); + + // Extract exponent bits and fraction bits of value + bitValueInProcess &= Vector512.Create(HalfToSingleBitsMask); + + // Adjust exponent to match the range of exponent + bitValueInProcess += offsetMaskedExponentLowerBound; + + // If value is subnormal, remove unnecessary 1 on top of fraction bits. + Vector512 absoluteValue = (bitValueInProcess.AsSingle() - maskedExponentLowerBound.AsSingle()).AsUInt32(); + + // Merge sign bit with rest + return (absoluteValue | sign).AsSingle(); + } +#endif + } + + /// Computes the cosine similarity between the two specified non-empty, equal-length tensors of single-precision floating-point numbers. + /// Assumes arguments have already been validated to be non-empty and equal length. + private static float CosineSimilarityCore(ReadOnlySpan x, ReadOnlySpan y) + { + // Compute the same as: + // TensorPrimitives.Dot(x, y) / (Math.Sqrt(TensorPrimitives.SumOfSquares(x)) * Math.Sqrt(TensorPrimitives.SumOfSquares(y))) + // but only looping over each span once. + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated && x.Length >= Vector512.Count) + { + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float yRef = ref MemoryMarshal.GetReference(y); + + Vector512 dotProductVector = Vector512.Zero; + Vector512 xSumOfSquaresVector = Vector512.Zero; + Vector512 ySumOfSquaresVector = Vector512.Zero; + + // Process vectors, summing their dot products and squares, as long as there's a vector's worth remaining. + int oneVectorFromEnd = x.Length - Vector512.Count; + int i = 0; + do + { + Vector512 xVec = Vector512.LoadUnsafe(ref xRef, (uint)i); + Vector512 yVec = Vector512.LoadUnsafe(ref yRef, (uint)i); + + dotProductVector = FusedMultiplyAdd(xVec, yVec, dotProductVector); + xSumOfSquaresVector = FusedMultiplyAdd(xVec, xVec, xSumOfSquaresVector); + ySumOfSquaresVector = FusedMultiplyAdd(yVec, yVec, ySumOfSquaresVector); + + i += Vector512.Count; + } + while (i <= oneVectorFromEnd); + + // Process the last vector in the span, masking off elements already processed. + if (i != x.Length) + { + Vector512 xVec = Vector512.LoadUnsafe(ref xRef, (uint)(x.Length - Vector512.Count)); + Vector512 yVec = Vector512.LoadUnsafe(ref yRef, (uint)(x.Length - Vector512.Count)); + + Vector512 remainderMask = CreateRemainderMaskSingleVector512(x.Length - i); + xVec &= remainderMask; + yVec &= remainderMask; + + dotProductVector = FusedMultiplyAdd(xVec, yVec, dotProductVector); + xSumOfSquaresVector = FusedMultiplyAdd(xVec, xVec, xSumOfSquaresVector); + ySumOfSquaresVector = FusedMultiplyAdd(yVec, yVec, ySumOfSquaresVector); + } + + // Sum(X * Y) / (|X| * |Y|) + return + Vector512.Sum(dotProductVector) / + (MathF.Sqrt(Vector512.Sum(xSumOfSquaresVector)) * MathF.Sqrt(Vector512.Sum(ySumOfSquaresVector))); + } +#endif + + if (Vector256.IsHardwareAccelerated && x.Length >= Vector256.Count) + { + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float yRef = ref MemoryMarshal.GetReference(y); + + Vector256 dotProductVector = Vector256.Zero; + Vector256 xSumOfSquaresVector = Vector256.Zero; + Vector256 ySumOfSquaresVector = Vector256.Zero; + + // Process vectors, summing their dot products and squares, as long as there's a vector's worth remaining. + int oneVectorFromEnd = x.Length - Vector256.Count; + int i = 0; + do + { + Vector256 xVec = Vector256.LoadUnsafe(ref xRef, (uint)i); + Vector256 yVec = Vector256.LoadUnsafe(ref yRef, (uint)i); + + dotProductVector = FusedMultiplyAdd(xVec, yVec, dotProductVector); + xSumOfSquaresVector = FusedMultiplyAdd(xVec, xVec, xSumOfSquaresVector); + ySumOfSquaresVector = FusedMultiplyAdd(yVec, yVec, ySumOfSquaresVector); + + i += Vector256.Count; + } + while (i <= oneVectorFromEnd); + + // Process the last vector in the span, masking off elements already processed. + if (i != x.Length) + { + Vector256 xVec = Vector256.LoadUnsafe(ref xRef, (uint)(x.Length - Vector256.Count)); + Vector256 yVec = Vector256.LoadUnsafe(ref yRef, (uint)(x.Length - Vector256.Count)); + + Vector256 remainderMask = CreateRemainderMaskSingleVector256(x.Length - i); + xVec &= remainderMask; + yVec &= remainderMask; + + dotProductVector = FusedMultiplyAdd(xVec, yVec, dotProductVector); + xSumOfSquaresVector = FusedMultiplyAdd(xVec, xVec, xSumOfSquaresVector); + ySumOfSquaresVector = FusedMultiplyAdd(yVec, yVec, ySumOfSquaresVector); + } + + // Sum(X * Y) / (|X| * |Y|) + return + Vector256.Sum(dotProductVector) / + (MathF.Sqrt(Vector256.Sum(xSumOfSquaresVector)) * MathF.Sqrt(Vector256.Sum(ySumOfSquaresVector))); + } + + if (Vector128.IsHardwareAccelerated && x.Length >= Vector128.Count) + { + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float yRef = ref MemoryMarshal.GetReference(y); + + Vector128 dotProductVector = Vector128.Zero; + Vector128 xSumOfSquaresVector = Vector128.Zero; + Vector128 ySumOfSquaresVector = Vector128.Zero; + + // Process vectors, summing their dot products and squares, as long as there's a vector's worth remaining. + int oneVectorFromEnd = x.Length - Vector128.Count; + int i = 0; + do + { + Vector128 xVec = Vector128.LoadUnsafe(ref xRef, (uint)i); + Vector128 yVec = Vector128.LoadUnsafe(ref yRef, (uint)i); + + dotProductVector = FusedMultiplyAdd(xVec, yVec, dotProductVector); + xSumOfSquaresVector = FusedMultiplyAdd(xVec, xVec, xSumOfSquaresVector); + ySumOfSquaresVector = FusedMultiplyAdd(yVec, yVec, ySumOfSquaresVector); + + i += Vector128.Count; + } + while (i <= oneVectorFromEnd); + + // Process the last vector in the span, masking off elements already processed. + if (i != x.Length) + { + Vector128 xVec = Vector128.LoadUnsafe(ref xRef, (uint)(x.Length - Vector128.Count)); + Vector128 yVec = Vector128.LoadUnsafe(ref yRef, (uint)(x.Length - Vector128.Count)); + + Vector128 remainderMask = CreateRemainderMaskSingleVector128(x.Length - i); + xVec &= remainderMask; + yVec &= remainderMask; + + dotProductVector = FusedMultiplyAdd(xVec, yVec, dotProductVector); + xSumOfSquaresVector = FusedMultiplyAdd(xVec, xVec, xSumOfSquaresVector); + ySumOfSquaresVector = FusedMultiplyAdd(yVec, yVec, ySumOfSquaresVector); + } + + // Sum(X * Y) / (|X| * |Y|) + return + Vector128.Sum(dotProductVector) / + (MathF.Sqrt(Vector128.Sum(xSumOfSquaresVector)) * MathF.Sqrt(Vector128.Sum(ySumOfSquaresVector))); + } + + // Vectorization isn't supported or there are too few elements to vectorize. + // Use a scalar implementation. + float dotProduct = 0f, xSumOfSquares = 0f, ySumOfSquares = 0f; + for (int i = 0; i < x.Length; i++) + { + dotProduct = MathF.FusedMultiplyAdd(x[i], y[i], dotProduct); + xSumOfSquares = MathF.FusedMultiplyAdd(x[i], x[i], xSumOfSquares); + ySumOfSquares = MathF.FusedMultiplyAdd(y[i], y[i], ySumOfSquares); + } + + // Sum(X * Y) / (|X| * |Y|) + return + dotProduct / + (MathF.Sqrt(xSumOfSquares) * MathF.Sqrt(ySumOfSquares)); + } + + /// Performs an aggregation over all elements in to produce a single-precision floating-point value. + /// Specifies the transform operation that should be applied to each element loaded from . + /// + /// Specifies the aggregation binary operation that should be applied to multiple values to aggregate them into a single value. + /// The aggregation is applied after the transform is applied to each element. + /// + private static float Aggregate( + ReadOnlySpan x) + where TTransformOperator : struct, IUnaryOperator + where TAggregationOperator : struct, IAggregationOperator + { + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes + + ref float xRef = ref MemoryMarshal.GetReference(x); + + nuint remainder = (uint)(x.Length); + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated) + { + float result; + + if (remainder >= (uint)(Vector512.Count)) + { + result = Vectorized512(ref xRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + result = Vectorized512Small(ref xRef, remainder); + } + + return result; + } +#endif + + if (Vector256.IsHardwareAccelerated) + { + float result; + + if (remainder >= (uint)(Vector256.Count)) + { + result = Vectorized256(ref xRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + result = Vectorized256Small(ref xRef, remainder); + } + + return result; + } + + if (Vector128.IsHardwareAccelerated) + { + float result; + + if (remainder >= (uint)(Vector128.Count)) + { + result = Vectorized128(ref xRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + result = Vectorized128Small(ref xRef, remainder); + } + + return result; + } + + // This is the software fallback when no acceleration is available. + // It requires no branches to hit. + + return SoftwareFallback(ref xRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static float SoftwareFallback(ref float xRef, nuint length) + { + float result = TAggregationOperator.IdentityValue; + + for (nuint i = 0; i < length; i++) + { + result = TAggregationOperator.Invoke(result, TTransformOperator.Invoke(Unsafe.Add(ref xRef, i))); + } + + return result; + } + + static float Vectorized128(ref float xRef, nuint remainder) + { + Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector128 beg = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); + Vector128 end = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count))); + + nuint misalignment = 0; + + if (remainder > (uint)(Vector128.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + { + float* xPtr = px; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(xPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + misalignment = ((uint)(sizeof(Vector128)) - ((nuint)(xPtr) % (uint)(sizeof(Vector128)))) / sizeof(float); + + xPtr += misalignment; + + Debug.Assert(((nuint)(xPtr) % (uint)(sizeof(Vector128))) == 0); + + remainder -= misalignment; + } + + Vector128 vector1; + Vector128 vector2; + Vector128 vector3; + Vector128 vector4; + + // We only need to load, so there isn't a lot of benefit to doing non-temporal operations + + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0))); + vector2 = TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1))); + vector3 = TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2))); + vector4 = TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3))); + + vresult = TAggregationOperator.Invoke(vresult, vector1); + vresult = TAggregationOperator.Invoke(vresult, vector2); + vresult = TAggregationOperator.Invoke(vresult, vector3); + vresult = TAggregationOperator.Invoke(vresult, vector4); + + // We load, process, and store the next four vectors + + vector1 = TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4))); + vector2 = TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5))); + vector3 = TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6))); + vector4 = TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7))); + + vresult = TAggregationOperator.Invoke(vresult, vector1); + vresult = TAggregationOperator.Invoke(vresult, vector2); + vresult = TAggregationOperator.Invoke(vresult, vector3); + vresult = TAggregationOperator.Invoke(vresult, vector4); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + } + } + + // Store the first block. Handling this separately simplifies the latter code as we know + // they come after and so we can relegate it to full blocks or the trailing elements + + beg = Vector128.ConditionalSelect(CreateAlignmentMaskSingleVector128((int)(misalignment)), beg, Vector128.Create(TAggregationOperator.IdentityValue)); + vresult = TAggregationOperator.Invoke(vresult, beg); + + // Process the remaining [0, Count * 7] elements via a jump table + // + // We end up handling any trailing elements in case 0 and in the + // worst case end up just doing the identity operation here if there + // were no trailing elements. + + (nuint blocks, nuint trailing) = Math.DivRem(remainder, (nuint)(Vector128.Count)); + blocks -= (misalignment == 0) ? 1u : 0u; + remainder -= trailing; + + switch (blocks) + { + case 7: + { + Vector128 vector = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 6; + } + + case 6: + { + Vector128 vector = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 5; + } + + case 5: + { + Vector128 vector = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 4; + } + + case 4: + { + Vector128 vector = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 3; + } + + case 3: + { + Vector128 vector = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 2; + } + + case 2: + { + Vector128 vector = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 1; + } + + case 1: + { + Vector128 vector = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 1))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 0; + } + + case 0: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end = Vector128.ConditionalSelect(CreateRemainderMaskSingleVector128((int)(trailing)), end, Vector128.Create(TAggregationOperator.IdentityValue)); + vresult = TAggregationOperator.Invoke(vresult, end); + break; + } + } + + return TAggregationOperator.Invoke(vresult); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static float Vectorized128Small(ref float xRef, nuint remainder) + { + float result = TAggregationOperator.IdentityValue; + + switch (remainder) + { + case 3: + { + result = TAggregationOperator.Invoke(result, TTransformOperator.Invoke(Unsafe.Add(ref xRef, 2))); + goto case 2; + } + + case 2: + { + result = TAggregationOperator.Invoke(result, TTransformOperator.Invoke(Unsafe.Add(ref xRef, 1))); + goto case 1; + } + + case 1: + { + result = TAggregationOperator.Invoke(result, TTransformOperator.Invoke(xRef)); + goto case 0; + } + + case 0: + { + break; + } + } + + return result; + } + + static float Vectorized256(ref float xRef, nuint remainder) + { + Vector256 vresult = Vector256.Create(TAggregationOperator.IdentityValue); + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector256 beg = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef)); + Vector256 end = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count))); + + nuint misalignment = 0; + + if (remainder > (uint)(Vector256.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + { + float* xPtr = px; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(xPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + misalignment = ((uint)(sizeof(Vector256)) - ((nuint)(xPtr) % (uint)(sizeof(Vector256)))) / sizeof(float); + + xPtr += misalignment; + + Debug.Assert(((nuint)(xPtr) % (uint)(sizeof(Vector256))) == 0); + + remainder -= misalignment; + } + + Vector256 vector1; + Vector256 vector2; + Vector256 vector3; + Vector256 vector4; + + // We only need to load, so there isn't a lot of benefit to doing non-temporal operations + + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0))); + vector2 = TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1))); + vector3 = TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2))); + vector4 = TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3))); + + vresult = TAggregationOperator.Invoke(vresult, vector1); + vresult = TAggregationOperator.Invoke(vresult, vector2); + vresult = TAggregationOperator.Invoke(vresult, vector3); + vresult = TAggregationOperator.Invoke(vresult, vector4); + + // We load, process, and store the next four vectors + + vector1 = TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4))); + vector2 = TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5))); + vector3 = TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6))); + vector4 = TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7))); + + vresult = TAggregationOperator.Invoke(vresult, vector1); + vresult = TAggregationOperator.Invoke(vresult, vector2); + vresult = TAggregationOperator.Invoke(vresult, vector3); + vresult = TAggregationOperator.Invoke(vresult, vector4); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + } + } + + // Store the first block. Handling this separately simplifies the latter code as we know + // they come after and so we can relegate it to full blocks or the trailing elements + + beg = Vector256.ConditionalSelect(CreateAlignmentMaskSingleVector256((int)(misalignment)), beg, Vector256.Create(TAggregationOperator.IdentityValue)); + vresult = TAggregationOperator.Invoke(vresult, beg); + + // Process the remaining [0, Count * 7] elements via a jump table + // + // We end up handling any trailing elements in case 0 and in the + // worst case end up just doing the identity operation here if there + // were no trailing elements. + + (nuint blocks, nuint trailing) = Math.DivRem(remainder, (nuint)(Vector256.Count)); + blocks -= (misalignment == 0) ? 1u : 0u; + remainder -= trailing; + + switch (blocks) + { + case 7: + { + Vector256 vector = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 6; + } + + case 6: + { + Vector256 vector = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 5; + } + + case 5: + { + Vector256 vector = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 4; + } + + case 4: + { + Vector256 vector = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 3; + } + + case 3: + { + Vector256 vector = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 2; + } + + case 2: + { + Vector256 vector = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 1; + } + + case 1: + { + Vector256 vector = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 1))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 0; + } + + case 0: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end = Vector256.ConditionalSelect(CreateRemainderMaskSingleVector256((int)(trailing)), end, Vector256.Create(TAggregationOperator.IdentityValue)); + vresult = TAggregationOperator.Invoke(vresult, end); + break; + } + } + + return TAggregationOperator.Invoke(vresult); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static float Vectorized256Small(ref float xRef, nuint remainder) + { + float result = TAggregationOperator.IdentityValue; + + switch (remainder) + { + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); + + Vector128 beg = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); + Vector128 end = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count))); + + end = Vector128.ConditionalSelect(CreateRemainderMaskSingleVector128((int)(remainder % (uint)(Vector128.Count))), end, Vector128.Create(TAggregationOperator.IdentityValue)); + + vresult = TAggregationOperator.Invoke(vresult, beg); + vresult = TAggregationOperator.Invoke(vresult, end); + + result = TAggregationOperator.Invoke(vresult); + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); + + Vector128 beg = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); + vresult = TAggregationOperator.Invoke(vresult, beg); + + result = TAggregationOperator.Invoke(vresult); + break; + } + + case 3: + { + result = TAggregationOperator.Invoke(result, TTransformOperator.Invoke(Unsafe.Add(ref xRef, 2))); + goto case 2; + } + + case 2: + { + result = TAggregationOperator.Invoke(result, TTransformOperator.Invoke(Unsafe.Add(ref xRef, 1))); + goto case 1; + } + + case 1: + { + result = TAggregationOperator.Invoke(result, TTransformOperator.Invoke(xRef)); + goto case 0; + } + + case 0: + { + break; + } + } + + return result; + } + +#if NET8_0_OR_GREATER + static float Vectorized512(ref float xRef, nuint remainder) + { + Vector512 vresult = Vector512.Create(TAggregationOperator.IdentityValue); + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector512 beg = TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef)); + Vector512 end = TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count))); + + nuint misalignment = 0; + + if (remainder > (uint)(Vector512.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + { + float* xPtr = px; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(xPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + misalignment = ((uint)(sizeof(Vector512)) - ((nuint)(xPtr) % (uint)(sizeof(Vector512)))) / sizeof(float); + + xPtr += misalignment; + + Debug.Assert(((nuint)(xPtr) % (uint)(sizeof(Vector512))) == 0); + + remainder -= misalignment; + } + + Vector512 vector1; + Vector512 vector2; + Vector512 vector3; + Vector512 vector4; + + // We only need to load, so there isn't a lot of benefit to doing non-temporal operations + + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0))); + vector2 = TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1))); + vector3 = TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2))); + vector4 = TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3))); + + vresult = TAggregationOperator.Invoke(vresult, vector1); + vresult = TAggregationOperator.Invoke(vresult, vector2); + vresult = TAggregationOperator.Invoke(vresult, vector3); + vresult = TAggregationOperator.Invoke(vresult, vector4); + + // We load, process, and store the next four vectors + + vector1 = TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4))); + vector2 = TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5))); + vector3 = TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6))); + vector4 = TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7))); + + vresult = TAggregationOperator.Invoke(vresult, vector1); + vresult = TAggregationOperator.Invoke(vresult, vector2); + vresult = TAggregationOperator.Invoke(vresult, vector3); + vresult = TAggregationOperator.Invoke(vresult, vector4); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + } + } + + // Store the first block. Handling this separately simplifies the latter code as we know + // they come after and so we can relegate it to full blocks or the trailing elements + + beg = Vector512.ConditionalSelect(CreateAlignmentMaskSingleVector512((int)(misalignment)), beg, Vector512.Create(TAggregationOperator.IdentityValue)); + vresult = TAggregationOperator.Invoke(vresult, beg); + + // Process the remaining [0, Count * 7] elements via a jump table + // + // We end up handling any trailing elements in case 0 and in the + // worst case end up just doing the identity operation here if there + // were no trailing elements. + + (nuint blocks, nuint trailing) = Math.DivRem(remainder, (nuint)(Vector512.Count)); + blocks -= (misalignment == 0) ? 1u : 0u; + remainder -= trailing; + + switch (blocks) + { + case 7: + { + Vector512 vector = TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 6; + } + + case 6: + { + Vector512 vector = TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 5; + } + + case 5: + { + Vector512 vector = TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 4; + } + + case 4: + { + Vector512 vector = TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 3; + } + + case 3: + { + Vector512 vector = TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 2; + } + + case 2: + { + Vector512 vector = TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 1; + } + + case 1: + { + Vector512 vector = TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 1))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 0; + } + + case 0: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end = Vector512.ConditionalSelect(CreateRemainderMaskSingleVector512((int)(trailing)), end, Vector512.Create(TAggregationOperator.IdentityValue)); + vresult = TAggregationOperator.Invoke(vresult, end); + break; + } + } + + return TAggregationOperator.Invoke(vresult); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static float Vectorized512Small(ref float xRef, nuint remainder) + { + float result = TAggregationOperator.IdentityValue; + + switch (remainder) + { + case 15: + case 14: + case 13: + case 12: + case 11: + case 10: + case 9: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + Vector256 vresult = Vector256.Create(TAggregationOperator.IdentityValue); + + Vector256 beg = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef)); + Vector256 end = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count))); + + end = Vector256.ConditionalSelect(CreateRemainderMaskSingleVector256((int)(remainder % (uint)(Vector256.Count))), end, Vector256.Create(TAggregationOperator.IdentityValue)); + + vresult = TAggregationOperator.Invoke(vresult, beg); + vresult = TAggregationOperator.Invoke(vresult, end); + + result = TAggregationOperator.Invoke(vresult); + break; + } + + case 8: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + Vector256 vresult = Vector256.Create(TAggregationOperator.IdentityValue); + + Vector256 beg = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef)); + vresult = TAggregationOperator.Invoke(vresult, beg); + + result = TAggregationOperator.Invoke(vresult); + break; + } + + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); + + Vector128 beg = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); + Vector128 end = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count))); + + end = Vector128.ConditionalSelect(CreateRemainderMaskSingleVector128((int)(remainder % (uint)(Vector128.Count))), end, Vector128.Create(TAggregationOperator.IdentityValue)); + + vresult = TAggregationOperator.Invoke(vresult, beg); + vresult = TAggregationOperator.Invoke(vresult, end); + + result = TAggregationOperator.Invoke(vresult); + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); + + Vector128 beg = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); + vresult = TAggregationOperator.Invoke(vresult, beg); + + result = TAggregationOperator.Invoke(vresult); + break; + } + + case 3: + { + result = TAggregationOperator.Invoke(result, TTransformOperator.Invoke(Unsafe.Add(ref xRef, 2))); + goto case 2; + } + + case 2: + { + result = TAggregationOperator.Invoke(result, TTransformOperator.Invoke(Unsafe.Add(ref xRef, 1))); + goto case 1; + } + + case 1: + { + result = TAggregationOperator.Invoke(result, TTransformOperator.Invoke(xRef)); + goto case 0; + } + + case 0: + { + break; + } + } + + return result; + } +#endif + } + + /// Performs an aggregation over all pair-wise elements in and to produce a single-precision floating-point value. + /// Specifies the binary operation that should be applied to the pair-wise elements loaded from and . + /// + /// Specifies the aggregation binary operation that should be applied to multiple values to aggregate them into a single value. + /// The aggregation is applied to the results of the binary operations on the pair-wise values. + /// + private static float Aggregate( + ReadOnlySpan x, ReadOnlySpan y) + where TBinaryOperator : struct, IBinaryOperator + where TAggregationOperator : struct, IAggregationOperator + { + if (x.Length != y.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + } + + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float yRef = ref MemoryMarshal.GetReference(y); + + nuint remainder = (uint)(x.Length); + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated) + { + float result; + + if (remainder >= (uint)(Vector512.Count)) + { + result = Vectorized512(ref xRef, ref yRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + result = Vectorized512Small(ref xRef, ref yRef, remainder); + } + + return result; + } +#endif + + if (Vector256.IsHardwareAccelerated) + { + float result; + + if (remainder >= (uint)(Vector256.Count)) + { + result = Vectorized256(ref xRef, ref yRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + result = Vectorized256Small(ref xRef, ref yRef, remainder); + } + + return result; + } + + if (Vector128.IsHardwareAccelerated) + { + float result; + + if (remainder >= (uint)(Vector128.Count)) + { + result = Vectorized128(ref xRef, ref yRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + result = Vectorized128Small(ref xRef, ref yRef, remainder); + } + + return result; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + return SoftwareFallback(ref xRef, ref yRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static float SoftwareFallback(ref float xRef, ref float yRef, nuint length) + { + float result = TAggregationOperator.IdentityValue; + + for (nuint i = 0; i < length; i++) + { + result = TAggregationOperator.Invoke(result, TBinaryOperator.Invoke(Unsafe.Add(ref xRef, i), + Unsafe.Add(ref yRef, i))); + } + + return result; + } + + static float Vectorized128(ref float xRef, ref float yRef, nuint remainder) + { + Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef)); + Vector128 end = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count))); + + nuint misalignment = 0; + + if (remainder > (uint)(Vector128.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + { + float* xPtr = px; + float* yPtr = py; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(xPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + misalignment = ((uint)(sizeof(Vector128)) - ((nuint)(xPtr) % (uint)(sizeof(Vector128)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + + Debug.Assert(((nuint)(xPtr) % (uint)(sizeof(Vector128))) == 0); + + remainder -= misalignment; + } + + Vector128 vector1; + Vector128 vector2; + Vector128 vector3; + Vector128 vector4; + + // We only need to load, so there isn't a lot of benefit to doing non-temporal operations + + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 0))); + vector2 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 1))); + vector3 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 2))); + vector4 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 3))); + + vresult = TAggregationOperator.Invoke(vresult, vector1); + vresult = TAggregationOperator.Invoke(vresult, vector2); + vresult = TAggregationOperator.Invoke(vresult, vector3); + vresult = TAggregationOperator.Invoke(vresult, vector4); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 4))); + vector2 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 5))); + vector3 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 6))); + vector4 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 7))); + + vresult = TAggregationOperator.Invoke(vresult, vector1); + vresult = TAggregationOperator.Invoke(vresult, vector2); + vresult = TAggregationOperator.Invoke(vresult, vector3); + vresult = TAggregationOperator.Invoke(vresult, vector4); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + yPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + } + } + + // Store the first block. Handling this separately simplifies the latter code as we know + // they come after and so we can relegate it to full blocks or the trailing elements + + beg = Vector128.ConditionalSelect(CreateAlignmentMaskSingleVector128((int)(misalignment)), beg, Vector128.Create(TAggregationOperator.IdentityValue)); + vresult = TAggregationOperator.Invoke(vresult, beg); + + // Process the remaining [0, Count * 7] elements via a jump table + // + // We end up handling any trailing elements in case 0 and in the + // worst case end up just doing the identity operation here if there + // were no trailing elements. + + (nuint blocks, nuint trailing) = Math.DivRem(remainder, (nuint)(Vector128.Count)); + blocks -= (misalignment == 0) ? 1u : 0u; + remainder -= trailing; + + switch (blocks) + { + case 7: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 7))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 6; + } + + case 6: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 6))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 5; + } + + case 5: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 5))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 4; + } + + case 4: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 4))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 3; + } + + case 3: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 3))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 2; + } + + case 2: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 2))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 1; + } + + case 1: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 1)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 1))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 0; + } + + case 0: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end = Vector128.ConditionalSelect(CreateRemainderMaskSingleVector128((int)(trailing)), end, Vector128.Create(TAggregationOperator.IdentityValue)); + vresult = TAggregationOperator.Invoke(vresult, end); + break; + } + } + + return TAggregationOperator.Invoke(vresult); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static float Vectorized128Small(ref float xRef, ref float yRef, nuint remainder) + { + float result = TAggregationOperator.IdentityValue; + + switch (remainder) + { + case 3: + { + result = TAggregationOperator.Invoke(result, TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2))); + goto case 2; + } + + case 2: + { + result = TAggregationOperator.Invoke(result, TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1))); + goto case 1; + } + + case 1: + { + result = TAggregationOperator.Invoke(result, TBinaryOperator.Invoke(xRef, yRef)); + goto case 0; + } + + case 0: + { + break; + } + } + + return result; + } + + static float Vectorized256(ref float xRef, ref float yRef, nuint remainder) + { + Vector256 vresult = Vector256.Create(TAggregationOperator.IdentityValue); + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector256 beg = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef)); + Vector256 end = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count))); + + nuint misalignment = 0; + + if (remainder > (uint)(Vector256.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + { + float* xPtr = px; + float* yPtr = py; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(xPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + misalignment = ((uint)(sizeof(Vector256)) - ((nuint)(xPtr) % (uint)(sizeof(Vector256)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + + Debug.Assert(((nuint)(xPtr) % (uint)(sizeof(Vector256))) == 0); + + remainder -= misalignment; + } + + Vector256 vector1; + Vector256 vector2; + Vector256 vector3; + Vector256 vector4; + + // We only need to load, so there isn't a lot of benefit to doing non-temporal operations + + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 0))); + vector2 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 1))); + vector3 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 2))); + vector4 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 3))); + + vresult = TAggregationOperator.Invoke(vresult, vector1); + vresult = TAggregationOperator.Invoke(vresult, vector2); + vresult = TAggregationOperator.Invoke(vresult, vector3); + vresult = TAggregationOperator.Invoke(vresult, vector4); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 4))); + vector2 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 5))); + vector3 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 6))); + vector4 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 7))); + + vresult = TAggregationOperator.Invoke(vresult, vector1); + vresult = TAggregationOperator.Invoke(vresult, vector2); + vresult = TAggregationOperator.Invoke(vresult, vector3); + vresult = TAggregationOperator.Invoke(vresult, vector4); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + yPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + } + } + + // Store the first block. Handling this separately simplifies the latter code as we know + // they come after and so we can relegate it to full blocks or the trailing elements + + beg = Vector256.ConditionalSelect(CreateAlignmentMaskSingleVector256((int)(misalignment)), beg, Vector256.Create(TAggregationOperator.IdentityValue)); + vresult = TAggregationOperator.Invoke(vresult, beg); + + // Process the remaining [0, Count * 7] elements via a jump table + // + // We end up handling any trailing elements in case 0 and in the + // worst case end up just doing the identity operation here if there + // were no trailing elements. + + (nuint blocks, nuint trailing) = Math.DivRem(remainder, (nuint)(Vector256.Count)); + blocks -= (misalignment == 0) ? 1u : 0u; + remainder -= trailing; + + switch (blocks) + { + case 7: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 7))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 6; + } + + case 6: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 6))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 5; + } + + case 5: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 5))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 4; + } + + case 4: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 4))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 3; + } + + case 3: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 3))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 2; + } + + case 2: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 2))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 1; + } + + case 1: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 1)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 1))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 0; + } + + case 0: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end = Vector256.ConditionalSelect(CreateRemainderMaskSingleVector256((int)(trailing)), end, Vector256.Create(TAggregationOperator.IdentityValue)); + vresult = TAggregationOperator.Invoke(vresult, end); + break; + } + } + + return TAggregationOperator.Invoke(vresult); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static float Vectorized256Small(ref float xRef, ref float yRef, nuint remainder) + { + float result = TAggregationOperator.IdentityValue; + + switch (remainder) + { + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); + + Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef)); + Vector128 end = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count))); + + end = Vector128.ConditionalSelect(CreateRemainderMaskSingleVector128((int)(remainder % (uint)(Vector128.Count))), end, Vector128.Create(TAggregationOperator.IdentityValue)); + + vresult = TAggregationOperator.Invoke(vresult, beg); + vresult = TAggregationOperator.Invoke(vresult, end); + + result = TAggregationOperator.Invoke(vresult); + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); + + Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef)); + vresult = TAggregationOperator.Invoke(vresult, beg); + + result = TAggregationOperator.Invoke(vresult); + break; + } + + case 3: + { + result = TAggregationOperator.Invoke(result, TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2))); + goto case 2; + } + + case 2: + { + result = TAggregationOperator.Invoke(result, TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1))); + goto case 1; + } + + case 1: + { + result = TAggregationOperator.Invoke(result, TBinaryOperator.Invoke(xRef, yRef)); + goto case 0; + } + + case 0: + { + break; + } + } + + return result; + } + +#if NET8_0_OR_GREATER + static float Vectorized512(ref float xRef, ref float yRef, nuint remainder) + { + Vector512 vresult = Vector512.Create(TAggregationOperator.IdentityValue); + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector512 beg = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef), + Vector512.LoadUnsafe(ref yRef)); + Vector512 end = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count))); + + nuint misalignment = 0; + + if (remainder > (uint)(Vector512.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + { + float* xPtr = px; + float* yPtr = py; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(xPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + misalignment = ((uint)(sizeof(Vector512)) - ((nuint)(xPtr) % (uint)(sizeof(Vector512)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + + Debug.Assert(((nuint)(xPtr) % (uint)(sizeof(Vector512))) == 0); + + remainder -= misalignment; + } + + Vector512 vector1; + Vector512 vector2; + Vector512 vector3; + Vector512 vector4; + + // We only need to load, so there isn't a lot of benefit to doing non-temporal operations + + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 0))); + vector2 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 1))); + vector3 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 2))); + vector4 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 3))); + + vresult = TAggregationOperator.Invoke(vresult, vector1); + vresult = TAggregationOperator.Invoke(vresult, vector2); + vresult = TAggregationOperator.Invoke(vresult, vector3); + vresult = TAggregationOperator.Invoke(vresult, vector4); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 4))); + vector2 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 5))); + vector3 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 6))); + vector4 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 7))); + + vresult = TAggregationOperator.Invoke(vresult, vector1); + vresult = TAggregationOperator.Invoke(vresult, vector2); + vresult = TAggregationOperator.Invoke(vresult, vector3); + vresult = TAggregationOperator.Invoke(vresult, vector4); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + yPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + } + } + + // Store the first block. Handling this separately simplifies the latter code as we know + // they come after and so we can relegate it to full blocks or the trailing elements + + beg = Vector512.ConditionalSelect(CreateAlignmentMaskSingleVector512((int)(misalignment)), beg, Vector512.Create(TAggregationOperator.IdentityValue)); + vresult = TAggregationOperator.Invoke(vresult, beg); + + // Process the remaining [0, Count * 7] elements via a jump table + // + // We end up handling any trailing elements in case 0 and in the + // worst case end up just doing the identity operation here if there + // were no trailing elements. + + (nuint blocks, nuint trailing) = Math.DivRem(remainder, (nuint)(Vector512.Count)); + blocks -= (misalignment == 0) ? 1u : 0u; + remainder -= trailing; + + switch (blocks) + { + case 7: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 7))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 6; + } + + case 6: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 6))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 5; + } + + case 5: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 5))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 4; + } + + case 4: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 4))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 3; + } + + case 3: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 3))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 2; + } + + case 2: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 2))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 1; + } + + case 1: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 1)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 1))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 0; + } + + case 0: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end = Vector512.ConditionalSelect(CreateRemainderMaskSingleVector512((int)(trailing)), end, Vector512.Create(TAggregationOperator.IdentityValue)); + vresult = TAggregationOperator.Invoke(vresult, end); + break; + } + } + + return TAggregationOperator.Invoke(vresult); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static float Vectorized512Small(ref float xRef, ref float yRef, nuint remainder) + { + float result = TAggregationOperator.IdentityValue; + + switch (remainder) + { + case 15: + case 14: + case 13: + case 12: + case 11: + case 10: + case 9: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + Vector256 vresult = Vector256.Create(TAggregationOperator.IdentityValue); + + Vector256 beg = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef)); + Vector256 end = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count))); + + end = Vector256.ConditionalSelect(CreateRemainderMaskSingleVector256((int)(remainder % (uint)(Vector256.Count))), end, Vector256.Create(TAggregationOperator.IdentityValue)); + + vresult = TAggregationOperator.Invoke(vresult, beg); + vresult = TAggregationOperator.Invoke(vresult, end); + + result = TAggregationOperator.Invoke(vresult); + break; + } + + case 8: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + Vector256 vresult = Vector256.Create(TAggregationOperator.IdentityValue); + + Vector256 beg = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef)); + vresult = TAggregationOperator.Invoke(vresult, beg); + + result = TAggregationOperator.Invoke(vresult); + break; + } + + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); + + Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef)); + Vector128 end = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count))); + + end = Vector128.ConditionalSelect(CreateRemainderMaskSingleVector128((int)(remainder % (uint)(Vector128.Count))), end, Vector128.Create(TAggregationOperator.IdentityValue)); + + vresult = TAggregationOperator.Invoke(vresult, beg); + vresult = TAggregationOperator.Invoke(vresult, end); + + result = TAggregationOperator.Invoke(vresult); + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); + + Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef)); + vresult = TAggregationOperator.Invoke(vresult, beg); + + result = TAggregationOperator.Invoke(vresult); + break; + } + + case 3: + { + result = TAggregationOperator.Invoke(result, TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2))); + goto case 2; + } + + case 2: + { + result = TAggregationOperator.Invoke(result, TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1))); + goto case 1; + } + + case 1: + { + result = TAggregationOperator.Invoke(result, TBinaryOperator.Invoke(xRef, yRef)); + goto case 0; + } + + case 0: + { + break; + } + } + + return result; + } +#endif + } + + /// + /// This is the same as + /// with an identity transform, except it early exits on NaN. + /// + private static float MinMaxCore(ReadOnlySpan x) + where TMinMaxOperator : struct, IAggregationOperator + { + if (x.IsEmpty) + { + ThrowHelper.ThrowArgument_SpansMustBeNonEmpty(); + } + + // This matches the IEEE 754:2019 `maximum`/`minimum` functions. + // It propagates NaN inputs back to the caller and + // otherwise returns the greater of the inputs. + // It treats +0 as greater than -0 as per the specification. + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated && x.Length >= Vector512.Count) + { + ref float xRef = ref MemoryMarshal.GetReference(x); + // Load the first vector as the initial set of results, and bail immediately + // to scalar handling if it contains any NaNs (which don't compare equally to themselves). + Vector512 result = Vector512.LoadUnsafe(ref xRef, 0), current; + if (!Vector512.EqualsAll(result, result)) + { + return GetFirstNaN(result); + } + + int oneVectorFromEnd = x.Length - Vector512.Count; + int i = Vector512.Count; + + // Aggregate additional vectors into the result as long as there's at least one full vector left to process. + while (i <= oneVectorFromEnd) + { + // Load the next vector, and early exit on NaN. + current = Vector512.LoadUnsafe(ref xRef, (uint)i); + if (!Vector512.EqualsAll(current, current)) + { + return GetFirstNaN(current); + } + + result = TMinMaxOperator.Invoke(result, current); + i += Vector512.Count; + } + + // If any elements remain, handle them in one final vector. + if (i != x.Length) + { + current = Vector512.LoadUnsafe(ref xRef, (uint)(x.Length - Vector512.Count)); + if (!Vector512.EqualsAll(current, current)) + { + return GetFirstNaN(current); + } + + result = Vector512.ConditionalSelect( + Vector512.Equals(CreateRemainderMaskSingleVector512(x.Length - i), Vector512.Zero), + result, + TMinMaxOperator.Invoke(result, current)); + } + + // Aggregate the lanes in the vector to create the final scalar result. + return TMinMaxOperator.Invoke(result); + } +#endif + + if (Vector256.IsHardwareAccelerated && x.Length >= Vector256.Count) + { + ref float xRef = ref MemoryMarshal.GetReference(x); + + // Load the first vector as the initial set of results, and bail immediately + // to scalar handling if it contains any NaNs (which don't compare equally to themselves). + Vector256 result = Vector256.LoadUnsafe(ref xRef, 0), current; + if (!Vector256.EqualsAll(result, result)) + { + return GetFirstNaN(result); + } + + int oneVectorFromEnd = x.Length - Vector256.Count; + int i = Vector256.Count; + + // Aggregate additional vectors into the result as long as there's at least one full vector left to process. + while (i <= oneVectorFromEnd) + { + // Load the next vector, and early exit on NaN. + current = Vector256.LoadUnsafe(ref xRef, (uint)i); + if (!Vector256.EqualsAll(current, current)) + { + return GetFirstNaN(current); + } + + result = TMinMaxOperator.Invoke(result, current); + i += Vector256.Count; + } + + // If any elements remain, handle them in one final vector. + if (i != x.Length) + { + current = Vector256.LoadUnsafe(ref xRef, (uint)(x.Length - Vector256.Count)); + if (!Vector256.EqualsAll(current, current)) + { + return GetFirstNaN(current); + } + + result = Vector256.ConditionalSelect( + Vector256.Equals(CreateRemainderMaskSingleVector256(x.Length - i), Vector256.Zero), + result, + TMinMaxOperator.Invoke(result, current)); + } + + // Aggregate the lanes in the vector to create the final scalar result. + return TMinMaxOperator.Invoke(result); + } + + if (Vector128.IsHardwareAccelerated && x.Length >= Vector128.Count) + { + ref float xRef = ref MemoryMarshal.GetReference(x); + + // Load the first vector as the initial set of results, and bail immediately + // to scalar handling if it contains any NaNs (which don't compare equally to themselves). + Vector128 result = Vector128.LoadUnsafe(ref xRef, 0), current; + if (!Vector128.EqualsAll(result, result)) + { + return GetFirstNaN(result); + } + + int oneVectorFromEnd = x.Length - Vector128.Count; + int i = Vector128.Count; + + // Aggregate additional vectors into the result as long as there's at least one full vector left to process. + while (i <= oneVectorFromEnd) + { + // Load the next vector, and early exit on NaN. + current = Vector128.LoadUnsafe(ref xRef, (uint)i); + if (!Vector128.EqualsAll(current, current)) + { + return GetFirstNaN(current); + } + + result = TMinMaxOperator.Invoke(result, current); + i += Vector128.Count; + } + + // If any elements remain, handle them in one final vector. + if (i != x.Length) + { + current = Vector128.LoadUnsafe(ref xRef, (uint)(x.Length - Vector128.Count)); + if (!Vector128.EqualsAll(current, current)) + { + return GetFirstNaN(current); + } + + result = Vector128.ConditionalSelect( + Vector128.Equals(CreateRemainderMaskSingleVector128(x.Length - i), Vector128.Zero), + result, + TMinMaxOperator.Invoke(result, current)); + } + + // Aggregate the lanes in the vector to create the final scalar result. + return TMinMaxOperator.Invoke(result); + } + + // Scalar path used when either vectorization is not supported or the input is too small to vectorize. + { + float result = x[0]; + if (float.IsNaN(result)) + { + return result; + } + + for (int i = 1; i < x.Length; i++) + { + float current = x[i]; + if (float.IsNaN(current)) + { + return current; + } + + result = TMinMaxOperator.Invoke(result, current); + } + + return result; + } + } + + private static int IndexOfMinMaxCore(ReadOnlySpan x) where TIndexOfMinMax : struct, IIndexOfOperator + { + // This matches the IEEE 754:2019 `maximum`/`minimum` functions. + // It propagates NaN inputs back to the caller and + // otherwise returns the index of the greater of the inputs. + // It treats +0 as greater than -0 as per the specification. + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated && x.Length >= Vector512.Count) + { + ref float xRef = ref MemoryMarshal.GetReference(x); + Vector512 resultIndex = Vector512.Create(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + Vector512 curIndex = resultIndex; + Vector512 increment = Vector512.Create(Vector512.Count); + + // Load the first vector as the initial set of results, and bail immediately + // to scalar handling if it contains any NaNs (which don't compare equally to themselves). + Vector512 result = Vector512.LoadUnsafe(ref xRef); + Vector512 current; + + Vector512 nanMask = ~Vector512.Equals(result, result); + if (nanMask != Vector512.Zero) + { + return IndexOfFirstMatch(nanMask); + } + + int oneVectorFromEnd = x.Length - Vector512.Count; + int i = Vector512.Count; + + // Aggregate additional vectors into the result as long as there's at least one full vector left to process. + while (i <= oneVectorFromEnd) + { + // Load the next vector, and early exit on NaN. + current = Vector512.LoadUnsafe(ref xRef, (uint)i); + curIndex += increment; + + nanMask = ~Vector512.Equals(current, current); + if (nanMask != Vector512.Zero) + { + return i + IndexOfFirstMatch(nanMask); + } + + TIndexOfMinMax.Invoke(ref result, current, ref resultIndex, curIndex); + + i += Vector512.Count; + } + + // If any elements remain, handle them in one final vector. + if (i != x.Length) + { + current = Vector512.LoadUnsafe(ref xRef, (uint)(x.Length - Vector512.Count)); + curIndex += Vector512.Create(x.Length - i); + + nanMask = ~Vector512.Equals(current, current); + if (nanMask != Vector512.Zero) + { + return curIndex[IndexOfFirstMatch(nanMask)]; + } + + TIndexOfMinMax.Invoke(ref result, current, ref resultIndex, curIndex); + } + + // Aggregate the lanes in the vector to create the final scalar result. + return TIndexOfMinMax.Invoke(result, resultIndex); + } +#endif + + if (Vector256.IsHardwareAccelerated && x.Length >= Vector256.Count) + { + ref float xRef = ref MemoryMarshal.GetReference(x); + Vector256 resultIndex = Vector256.Create(0, 1, 2, 3, 4, 5, 6, 7); + Vector256 curIndex = resultIndex; + Vector256 increment = Vector256.Create(Vector256.Count); + + // Load the first vector as the initial set of results, and bail immediately + // to scalar handling if it contains any NaNs (which don't compare equally to themselves). + Vector256 result = Vector256.LoadUnsafe(ref xRef); + Vector256 current; + + Vector256 nanMask = ~Vector256.Equals(result, result); + if (nanMask != Vector256.Zero) + { + return IndexOfFirstMatch(nanMask); + } + + int oneVectorFromEnd = x.Length - Vector256.Count; + int i = Vector256.Count; + + // Aggregate additional vectors into the result as long as there's at least one full vector left to process. + while (i <= oneVectorFromEnd) + { + // Load the next vector, and early exit on NaN. + current = Vector256.LoadUnsafe(ref xRef, (uint)i); + curIndex += increment; + + nanMask = ~Vector256.Equals(current, current); + if (nanMask != Vector256.Zero) + { + return i + IndexOfFirstMatch(nanMask); + } + + TIndexOfMinMax.Invoke(ref result, current, ref resultIndex, curIndex); + + i += Vector256.Count; + } + + // If any elements remain, handle them in one final vector. + if (i != x.Length) + { + current = Vector256.LoadUnsafe(ref xRef, (uint)(x.Length - Vector256.Count)); + curIndex += Vector256.Create(x.Length - i); + + nanMask = ~Vector256.Equals(current, current); + if (nanMask != Vector256.Zero) + { + return curIndex[IndexOfFirstMatch(nanMask)]; + } + + TIndexOfMinMax.Invoke(ref result, current, ref resultIndex, curIndex); + } + + // Aggregate the lanes in the vector to create the final scalar result. + return TIndexOfMinMax.Invoke(result, resultIndex); + } + + if (Vector128.IsHardwareAccelerated && x.Length >= Vector128.Count) + { + ref float xRef = ref MemoryMarshal.GetReference(x); + Vector128 resultIndex = Vector128.Create(0, 1, 2, 3); + Vector128 curIndex = resultIndex; + Vector128 increment = Vector128.Create(Vector128.Count); + + // Load the first vector as the initial set of results, and bail immediately + // to scalar handling if it contains any NaNs (which don't compare equally to themselves). + Vector128 result = Vector128.LoadUnsafe(ref xRef); + Vector128 current; + + Vector128 nanMask = ~Vector128.Equals(result, result); + if (nanMask != Vector128.Zero) + { + return IndexOfFirstMatch(nanMask); + } + + int oneVectorFromEnd = x.Length - Vector128.Count; + int i = Vector128.Count; + + // Aggregate additional vectors into the result as long as there's at least one full vector left to process. + while (i <= oneVectorFromEnd) + { + // Load the next vector, and early exit on NaN. + current = Vector128.LoadUnsafe(ref xRef, (uint)i); + curIndex += increment; + + nanMask = ~Vector128.Equals(current, current); + if (nanMask != Vector128.Zero) + { + return i + IndexOfFirstMatch(nanMask); + } + + TIndexOfMinMax.Invoke(ref result, current, ref resultIndex, curIndex); + + i += Vector128.Count; + } + + // If any elements remain, handle them in one final vector. + if (i != x.Length) + { + curIndex += Vector128.Create(x.Length - i); + + current = Vector128.LoadUnsafe(ref xRef, (uint)(x.Length - Vector128.Count)); + + nanMask = ~Vector128.Equals(current, current); + if (nanMask != Vector128.Zero) + { + return curIndex[IndexOfFirstMatch(nanMask)]; + } + + TIndexOfMinMax.Invoke(ref result, current, ref resultIndex, curIndex); + } + + // Aggregate the lanes in the vector to create the final scalar result. + return TIndexOfMinMax.Invoke(result, resultIndex); + } + + // Scalar path used when either vectorization is not supported or the input is too small to vectorize. + float curResult = x[0]; + int curIn = 0; + if (float.IsNaN(curResult)) + { + return curIn; + } + + for (int i = 1; i < x.Length; i++) + { + float current = x[i]; + if (float.IsNaN(current)) + { + return i; + } + + curIn = TIndexOfMinMax.Invoke(ref curResult, current, curIn, i); + } + + return curIn; + } + + private static int IndexOfFirstMatch(Vector128 mask) + { + return BitOperations.TrailingZeroCount(mask.ExtractMostSignificantBits()); + } + + private static int IndexOfFirstMatch(Vector256 mask) + { + return BitOperations.TrailingZeroCount(mask.ExtractMostSignificantBits()); + } + +#if NET8_0_OR_GREATER + private static int IndexOfFirstMatch(Vector512 mask) + { + return BitOperations.TrailingZeroCount(mask.ExtractMostSignificantBits()); + } +#endif + + /// Performs an element-wise operation on and writes the results to . + /// Specifies the operation to perform on each element loaded from . + private static void InvokeSpanIntoSpan( + ReadOnlySpan x, Span destination) + where TUnaryOperator : struct, IUnaryOperator + { + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ValidateInputOutputSpanNonOverlapping(x, destination); + + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float dRef = ref MemoryMarshal.GetReference(destination); + + nuint remainder = (uint)(x.Length); + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector512.Count)) + { + Vectorized512(ref xRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized512Small(ref xRef, ref dRef, remainder); + } + + return; + } +#endif + + if (Vector256.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector256.Count)) + { + Vectorized256(ref xRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized256Small(ref xRef, ref dRef, remainder); + } + + return; + } + + if (Vector128.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector128.Count)) + { + Vectorized128(ref xRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized128Small(ref xRef, ref dRef, remainder); + } + + return; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, ref float dRef, nuint length) + { + for (nuint i = 0; i < length; i++) + { + Unsafe.Add(ref dRef, i) = TUnaryOperator.Invoke(Unsafe.Add(ref xRef, i)); + } + } + + static void Vectorized128(ref float xRef, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector128 beg = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); + Vector128 end = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count))); + + if (remainder > (uint)(Vector128.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector128)) - ((nuint)(dPtr) % (uint)(sizeof(Vector128)))) / sizeof(float); + + xPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector128))) == 0); + + remainder -= misalignment; + } + + Vector128 vector1; + Vector128 vector2; + Vector128 vector3; + Vector128 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0))); + vector2 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1))); + vector3 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2))); + vector4 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4))); + vector2 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5))); + vector3 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6))); + vector4 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0))); + vector2 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1))); + vector3 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2))); + vector4 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector128.Count * 0)); + vector2.Store(dPtr + (uint)(Vector128.Count * 1)); + vector3.Store(dPtr + (uint)(Vector128.Count * 2)); + vector4.Store(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4))); + vector2 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5))); + vector3 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6))); + vector4 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector128.Count * 4)); + vector2.Store(dPtr + (uint)(Vector128.Count * 5)); + vector3.Store(dPtr + (uint)(Vector128.Count * 6)); + vector4.Store(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector128.Count - 1)) & (nuint)(-Vector128.Count); + + switch (remainder / (uint)(Vector128.Count)) + { + case 8: + { + Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 8)); + goto case 7; + } + + case 7: + { + Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 7)); + goto case 6; + } + + case 6: + { + Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 6)); + goto case 5; + } + + case 5: + { + Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 5)); + goto case 4; + } + + case 4: + { + Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 4)); + goto case 3; + } + + case 3: + { + Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 3)); + goto case 2; + } + + case 2: + { + Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector128.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized128Small(ref float xRef, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 3: + { + Unsafe.Add(ref dRef, 2) = TUnaryOperator.Invoke(Unsafe.Add(ref xRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TUnaryOperator.Invoke(Unsafe.Add(ref xRef, 1)); + goto case 1; + } + + case 1: + { + dRef = TUnaryOperator.Invoke(xRef); + goto case 0; + } + + case 0: + { + break; + } + } + } + + static void Vectorized256(ref float xRef, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector256 beg = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef)); + Vector256 end = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count))); + + if (remainder > (uint)(Vector256.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector256)) - ((nuint)(dPtr) % (uint)(sizeof(Vector256)))) / sizeof(float); + + xPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector256))) == 0); + + remainder -= misalignment; + } + + Vector256 vector1; + Vector256 vector2; + Vector256 vector3; + Vector256 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0))); + vector2 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1))); + vector3 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2))); + vector4 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4))); + vector2 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5))); + vector3 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6))); + vector4 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0))); + vector2 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1))); + vector3 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2))); + vector4 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector256.Count * 0)); + vector2.Store(dPtr + (uint)(Vector256.Count * 1)); + vector3.Store(dPtr + (uint)(Vector256.Count * 2)); + vector4.Store(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4))); + vector2 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5))); + vector3 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6))); + vector4 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector256.Count * 4)); + vector2.Store(dPtr + (uint)(Vector256.Count * 5)); + vector3.Store(dPtr + (uint)(Vector256.Count * 6)); + vector4.Store(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector256.Count - 1)) & (nuint)(-Vector256.Count); + + switch (remainder / (uint)(Vector256.Count)) + { + case 8: + { + Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 8)); + goto case 7; + } + + case 7: + { + Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 7)); + goto case 6; + } + + case 6: + { + Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 6)); + goto case 5; + } + + case 5: + { + Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 5)); + goto case 4; + } + + case 4: + { + Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 4)); + goto case 3; + } + + case 3: + { + Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 3)); + goto case 2; + } + + case 2: + { + Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector256.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized256Small(ref float xRef, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); + Vector128 end = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count))); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TUnaryOperator.Invoke(Unsafe.Add(ref xRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TUnaryOperator.Invoke(Unsafe.Add(ref xRef, 1)); + goto case 1; + } + + case 1: + { + dRef = TUnaryOperator.Invoke(xRef); + goto case 0; + } + + case 0: + { + break; + } + } + } + +#if NET8_0_OR_GREATER + static void Vectorized512(ref float xRef, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector512 beg = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef)); + Vector512 end = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count))); + + if (remainder > (uint)(Vector512.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector512)) - ((nuint)(dPtr) % (uint)(sizeof(Vector512)))) / sizeof(float); + + xPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector512))) == 0); + + remainder -= misalignment; + } + + Vector512 vector1; + Vector512 vector2; + Vector512 vector3; + Vector512 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0))); + vector2 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1))); + vector3 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2))); + vector4 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4))); + vector2 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5))); + vector3 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6))); + vector4 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0))); + vector2 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1))); + vector3 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2))); + vector4 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector512.Count * 0)); + vector2.Store(dPtr + (uint)(Vector512.Count * 1)); + vector3.Store(dPtr + (uint)(Vector512.Count * 2)); + vector4.Store(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4))); + vector2 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5))); + vector3 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6))); + vector4 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector512.Count * 4)); + vector2.Store(dPtr + (uint)(Vector512.Count * 5)); + vector3.Store(dPtr + (uint)(Vector512.Count * 6)); + vector4.Store(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector512.Count - 1)) & (nuint)(-Vector512.Count); + + switch (remainder / (uint)(Vector512.Count)) + { + case 8: + { + Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 8)); + goto case 7; + } + + case 7: + { + Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 7)); + goto case 6; + } + + case 6: + { + Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 6)); + goto case 5; + } + + case 5: + { + Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 5)); + goto case 4; + } + + case 4: + { + Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 4)); + goto case 3; + } + + case 3: + { + Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 3)); + goto case 2; + } + + case 2: + { + Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector512.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized512Small(ref float xRef, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 15: + case 14: + case 13: + case 12: + case 11: + case 10: + case 9: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 beg = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef)); + Vector256 end = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count))); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count)); + + break; + } + + case 8: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 beg = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); + Vector128 end = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count))); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TUnaryOperator.Invoke(Unsafe.Add(ref xRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TUnaryOperator.Invoke(Unsafe.Add(ref xRef, 1)); + goto case 1; + } + + case 1: + { + dRef = TUnaryOperator.Invoke(xRef); + goto case 0; + } + + case 0: + { + break; + } + } + } +#endif + } + + /// + /// Performs an element-wise operation on and , + /// and writes the results to . + /// + /// + /// Specifies the operation to perform on the pair-wise elements loaded from and . + /// + private static void InvokeSpanSpanIntoSpan( + ReadOnlySpan x, ReadOnlySpan y, Span destination) + where TBinaryOperator : struct, IBinaryOperator + { + if (x.Length != y.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + } + + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ValidateInputOutputSpanNonOverlapping(x, destination); + ValidateInputOutputSpanNonOverlapping(y, destination); + + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float yRef = ref MemoryMarshal.GetReference(y); + ref float dRef = ref MemoryMarshal.GetReference(destination); + + nuint remainder = (uint)(x.Length); + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector512.Count)) + { + Vectorized512(ref xRef, ref yRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized512Small(ref xRef, ref yRef, ref dRef, remainder); + } + + return; + } +#endif + + if (Vector256.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector256.Count)) + { + Vectorized256(ref xRef, ref yRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized256Small(ref xRef, ref yRef, ref dRef, remainder); + } + + return; + } + + if (Vector128.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector128.Count)) + { + Vectorized128(ref xRef, ref yRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized128Small(ref xRef, ref yRef, ref dRef, remainder); + } + + return; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, ref yRef, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, ref float yRef, ref float dRef, nuint length) + { + for (nuint i = 0; i < length; i++) + { + Unsafe.Add(ref dRef, i) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, i), + Unsafe.Add(ref yRef, i)); + } + } + + static void Vectorized128(ref float xRef, ref float yRef, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef)); + Vector128 end = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count))); + + if (remainder > (uint)(Vector128.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* yPtr = py; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector128)) - ((nuint)(dPtr) % (uint)(sizeof(Vector128)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector128))) == 0); + + remainder -= misalignment; + } + + Vector128 vector1; + Vector128 vector2; + Vector128 vector3; + Vector128 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 0))); + vector2 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 1))); + vector3 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 2))); + vector4 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 4))); + vector2 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 5))); + vector3 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 6))); + vector4 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + yPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 0))); + vector2 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 1))); + vector3 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 2))); + vector4 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector128.Count * 0)); + vector2.Store(dPtr + (uint)(Vector128.Count * 1)); + vector3.Store(dPtr + (uint)(Vector128.Count * 2)); + vector4.Store(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 4))); + vector2 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 5))); + vector3 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 6))); + vector4 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector128.Count * 4)); + vector2.Store(dPtr + (uint)(Vector128.Count * 5)); + vector3.Store(dPtr + (uint)(Vector128.Count * 6)); + vector4.Store(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + yPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector128.Count - 1)) & (nuint)(-Vector128.Count); + + switch (remainder / (uint)(Vector128.Count)) + { + case 8: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 8)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 8)); + goto case 7; + } + + case 7: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 7)); + goto case 6; + } + + case 6: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 6)); + goto case 5; + } + + case 5: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 5)); + goto case 4; + } + + case 4: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 4)); + goto case 3; + } + + case 3: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 3)); + goto case 2; + } + + case 2: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector128.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized128Small(ref float xRef, ref float yRef, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 3: + { + Unsafe.Add(ref dRef, 2) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1)); + goto case 1; + } + + case 1: + { + dRef = TBinaryOperator.Invoke(xRef, yRef); + goto case 0; + } + + case 0: + { + break; + } + } + } + + static void Vectorized256(ref float xRef, ref float yRef, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector256 beg = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef)); + Vector256 end = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count))); + + if (remainder > (uint)(Vector256.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* yPtr = py; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector256)) - ((nuint)(dPtr) % (uint)(sizeof(Vector256)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector256))) == 0); + + remainder -= misalignment; + } + + Vector256 vector1; + Vector256 vector2; + Vector256 vector3; + Vector256 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 0))); + vector2 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 1))); + vector3 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 2))); + vector4 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 4))); + vector2 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 5))); + vector3 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 6))); + vector4 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + yPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 0))); + vector2 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 1))); + vector3 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 2))); + vector4 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector256.Count * 0)); + vector2.Store(dPtr + (uint)(Vector256.Count * 1)); + vector3.Store(dPtr + (uint)(Vector256.Count * 2)); + vector4.Store(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 4))); + vector2 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 5))); + vector3 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 6))); + vector4 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector256.Count * 4)); + vector2.Store(dPtr + (uint)(Vector256.Count * 5)); + vector3.Store(dPtr + (uint)(Vector256.Count * 6)); + vector4.Store(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + yPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector256.Count - 1)) & (nuint)(-Vector256.Count); + + switch (remainder / (uint)(Vector256.Count)) + { + case 8: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 8)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 8)); + goto case 7; + } + + case 7: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 7)); + goto case 6; + } + + case 6: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 6)); + goto case 5; + } + + case 5: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 5)); + goto case 4; + } + + case 4: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 4)); + goto case 3; + } + + case 3: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 3)); + goto case 2; + } + + case 2: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector256.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized256Small(ref float xRef, ref float yRef, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef)); + Vector128 end = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count))); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1)); + goto case 1; + } + + case 1: + { + dRef = TBinaryOperator.Invoke(xRef, yRef); + goto case 0; + } + + case 0: + { + break; + } + } + } + +#if NET8_0_OR_GREATER + static void Vectorized512(ref float xRef, ref float yRef, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector512 beg = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef), + Vector512.LoadUnsafe(ref yRef)); + Vector512 end = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count))); + + if (remainder > (uint)(Vector512.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* yPtr = py; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector512)) - ((nuint)(dPtr) % (uint)(sizeof(Vector512)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector512))) == 0); + + remainder -= misalignment; + } + + Vector512 vector1; + Vector512 vector2; + Vector512 vector3; + Vector512 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 0))); + vector2 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 1))); + vector3 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 2))); + vector4 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 4))); + vector2 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 5))); + vector3 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 6))); + vector4 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + yPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 0))); + vector2 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 1))); + vector3 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 2))); + vector4 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector512.Count * 0)); + vector2.Store(dPtr + (uint)(Vector512.Count * 1)); + vector3.Store(dPtr + (uint)(Vector512.Count * 2)); + vector4.Store(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 4))); + vector2 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 5))); + vector3 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 6))); + vector4 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector512.Count * 4)); + vector2.Store(dPtr + (uint)(Vector512.Count * 5)); + vector3.Store(dPtr + (uint)(Vector512.Count * 6)); + vector4.Store(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + yPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector512.Count - 1)) & (nuint)(-Vector512.Count); + + switch (remainder / (uint)(Vector512.Count)) + { + case 8: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 8)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 8)); + goto case 7; + } + + case 7: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 7)); + goto case 6; + } + + case 6: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 6)); + goto case 5; + } + + case 5: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 5)); + goto case 4; + } + + case 4: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 4)); + goto case 3; + } + + case 3: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 3)); + goto case 2; + } + + case 2: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector512.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized512Small(ref float xRef, ref float yRef, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 15: + case 14: + case 13: + case 12: + case 11: + case 10: + case 9: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 beg = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef)); + Vector256 end = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count))); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count)); + + break; + } + + case 8: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 beg = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef)); + Vector128 end = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count))); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1)); + goto case 1; + } + + case 1: + { + dRef = TBinaryOperator.Invoke(xRef, yRef); + goto case 0; + } + + case 0: + { + break; + } + } + } +#endif + } + + /// + /// Performs an element-wise operation on and , + /// and writes the results to . + /// + /// + /// Specifies the operation to perform on each element loaded from with . + /// + private static void InvokeSpanScalarIntoSpan( + ReadOnlySpan x, float y, Span destination) + where TBinaryOperator : struct, IBinaryOperator => + InvokeSpanScalarIntoSpan(x, y, destination); + + /// + /// Performs an element-wise operation on and , + /// and writes the results to . + /// + /// + /// Specifies the operation to perform on each element loaded from . + /// It is not used with . + /// + /// + /// Specifies the operation to perform on the transformed value from with . + /// + private static void InvokeSpanScalarIntoSpan( + ReadOnlySpan x, float y, Span destination) + where TTransformOperator : struct, IUnaryOperator + where TBinaryOperator : struct, IBinaryOperator + { + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ValidateInputOutputSpanNonOverlapping(x, destination); + + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float dRef = ref MemoryMarshal.GetReference(destination); + + nuint remainder = (uint)(x.Length); + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector512.Count)) + { + Vectorized512(ref xRef, y, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized512Small(ref xRef, y, ref dRef, remainder); + } + + return; + } +#endif + + if (Vector256.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector256.Count)) + { + Vectorized256(ref xRef, y, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized256Small(ref xRef, y, ref dRef, remainder); + } + + return; + } + + if (Vector128.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector128.Count)) + { + Vectorized128(ref xRef, y, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized128Small(ref xRef, y, ref dRef, remainder); + } + + return; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, y, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, float y, ref float dRef, nuint length) + { + for (nuint i = 0; i < length; i++) + { + Unsafe.Add(ref dRef, i) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, i)), + y); + } + } + + static void Vectorized128(ref float xRef, float y, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector128 yVec = Vector128.Create(y); + + Vector128 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)), + yVec); + Vector128 end = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count))), + yVec); + + if (remainder > (uint)(Vector128.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector128)) - ((nuint)(dPtr) % (uint)(sizeof(Vector128)))) / sizeof(float); + + xPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector128))) == 0); + + remainder -= misalignment; + } + + Vector128 vector1; + Vector128 vector2; + Vector128 vector3; + Vector128 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3))), + yVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7))), + yVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3))), + yVec); + + vector1.Store(dPtr + (uint)(Vector128.Count * 0)); + vector2.Store(dPtr + (uint)(Vector128.Count * 1)); + vector3.Store(dPtr + (uint)(Vector128.Count * 2)); + vector4.Store(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7))), + yVec); + + vector1.Store(dPtr + (uint)(Vector128.Count * 4)); + vector2.Store(dPtr + (uint)(Vector128.Count * 5)); + vector3.Store(dPtr + (uint)(Vector128.Count * 6)); + vector4.Store(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector128.Count - 1)) & (nuint)(-Vector128.Count); + + switch (remainder / (uint)(Vector128.Count)) + { + case 8: + { + Vector128 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 8))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 8)); + goto case 7; + } + + case 7: + { + Vector128 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 7)); + goto case 6; + } + + case 6: + { + Vector128 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 6)); + goto case 5; + } + + case 5: + { + Vector128 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 5)); + goto case 4; + } + + case 4: + { + Vector128 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 4)); + goto case 3; + } + + case 3: + { + Vector128 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 3)); + goto case 2; + } + + case 2: + { + Vector128 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector128.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized128Small(ref float xRef, float y, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 3: + { + Unsafe.Add(ref dRef, 2) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, 2)), + y); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, 1)), + y); + goto case 1; + } + + case 1: + { + dRef = TBinaryOperator.Invoke(TTransformOperator.Invoke(xRef), y); + goto case 0; + } + + case 0: + { + break; + } + } + } + + static void Vectorized256(ref float xRef, float y, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector256 yVec = Vector256.Create(y); + + Vector256 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef)), + yVec); + Vector256 end = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count))), + yVec); + + if (remainder > (uint)(Vector256.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector256)) - ((nuint)(dPtr) % (uint)(sizeof(Vector256)))) / sizeof(float); + + xPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector256))) == 0); + + remainder -= misalignment; + } + + Vector256 vector1; + Vector256 vector2; + Vector256 vector3; + Vector256 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3))), + yVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7))), + yVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3))), + yVec); + + vector1.Store(dPtr + (uint)(Vector256.Count * 0)); + vector2.Store(dPtr + (uint)(Vector256.Count * 1)); + vector3.Store(dPtr + (uint)(Vector256.Count * 2)); + vector4.Store(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7))), + yVec); + + vector1.Store(dPtr + (uint)(Vector256.Count * 4)); + vector2.Store(dPtr + (uint)(Vector256.Count * 5)); + vector3.Store(dPtr + (uint)(Vector256.Count * 6)); + vector4.Store(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector256.Count - 1)) & (nuint)(-Vector256.Count); + + switch (remainder / (uint)(Vector256.Count)) + { + case 8: + { + Vector256 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 8))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 8)); + goto case 7; + } + + case 7: + { + Vector256 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 7)); + goto case 6; + } + + case 6: + { + Vector256 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 6)); + goto case 5; + } + + case 5: + { + Vector256 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 5)); + goto case 4; + } + + case 4: + { + Vector256 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 4)); + goto case 3; + } + + case 3: + { + Vector256 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 3)); + goto case 2; + } + + case 2: + { + Vector256 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector256.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized256Small(ref float xRef, float y, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 yVec = Vector128.Create(y); + + Vector128 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)), + yVec); + Vector128 end = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count))), + yVec); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)), + Vector128.Create(y)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, 2)), + y); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, 1)), + y); + goto case 1; + } + + case 1: + { + dRef = TBinaryOperator.Invoke(TTransformOperator.Invoke(xRef), y); + goto case 0; + } + + case 0: + { + break; + } + } + } + +#if NET8_0_OR_GREATER + static void Vectorized512(ref float xRef, float y, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector512 yVec = Vector512.Create(y); + + Vector512 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef)), + yVec); + Vector512 end = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count))), + yVec); + + if (remainder > (uint)(Vector512.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector512)) - ((nuint)(dPtr) % (uint)(sizeof(Vector512)))) / sizeof(float); + + xPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector512))) == 0); + + remainder -= misalignment; + } + + Vector512 vector1; + Vector512 vector2; + Vector512 vector3; + Vector512 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3))), + yVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7))), + yVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3))), + yVec); + + vector1.Store(dPtr + (uint)(Vector512.Count * 0)); + vector2.Store(dPtr + (uint)(Vector512.Count * 1)); + vector3.Store(dPtr + (uint)(Vector512.Count * 2)); + vector4.Store(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7))), + yVec); + + vector1.Store(dPtr + (uint)(Vector512.Count * 4)); + vector2.Store(dPtr + (uint)(Vector512.Count * 5)); + vector3.Store(dPtr + (uint)(Vector512.Count * 6)); + vector4.Store(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector512.Count - 1)) & (nuint)(-Vector512.Count); + + switch (remainder / (uint)(Vector512.Count)) + { + case 8: + { + Vector512 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 8))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 8)); + goto case 7; + } + + case 7: + { + Vector512 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 7)); + goto case 6; + } + + case 6: + { + Vector512 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 6)); + goto case 5; + } + + case 5: + { + Vector512 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 5)); + goto case 4; + } + + case 4: + { + Vector512 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 4)); + goto case 3; + } + + case 3: + { + Vector512 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 3)); + goto case 2; + } + + case 2: + { + Vector512 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector512.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized512Small(ref float xRef, float y, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 15: + case 14: + case 13: + case 12: + case 11: + case 10: + case 9: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 yVec = Vector256.Create(y); + + Vector256 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef)), + yVec); + Vector256 end = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count))), + yVec); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count)); + + break; + } + + case 8: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef)), + Vector256.Create(y)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 yVec = Vector128.Create(y); + + Vector128 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)), + yVec); + Vector128 end = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count))), + yVec); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)), + Vector128.Create(y)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, 2)), + y); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, 1)), + y); + goto case 1; + } + + case 1: + { + dRef = TBinaryOperator.Invoke(TTransformOperator.Invoke(xRef), y); + goto case 0; + } + + case 0: + { + break; + } + } + } +#endif + } + + /// + /// Performs an element-wise operation on , , and , + /// and writes the results to . + /// + /// + /// Specifies the operation to perform on the pair-wise elements loaded from , , + /// and . + /// + private static void InvokeSpanSpanSpanIntoSpan( + ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan z, Span destination) + where TTernaryOperator : struct, ITernaryOperator + { + if (x.Length != y.Length || x.Length != z.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + } + + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ValidateInputOutputSpanNonOverlapping(x, destination); + ValidateInputOutputSpanNonOverlapping(y, destination); + ValidateInputOutputSpanNonOverlapping(z, destination); + + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float yRef = ref MemoryMarshal.GetReference(y); + ref float zRef = ref MemoryMarshal.GetReference(z); + ref float dRef = ref MemoryMarshal.GetReference(destination); + + nuint remainder = (uint)(x.Length); + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector512.Count)) + { + Vectorized512(ref xRef, ref yRef, ref zRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized512Small(ref xRef, ref yRef, ref zRef, ref dRef, remainder); + } + + return; + } +#endif + + if (Vector256.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector256.Count)) + { + Vectorized256(ref xRef, ref yRef, ref zRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized256Small(ref xRef, ref yRef, ref zRef, ref dRef, remainder); + } + + return; + } + + if (Vector128.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector128.Count)) + { + Vectorized128(ref xRef, ref yRef, ref zRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized128Small(ref xRef, ref yRef, ref zRef, ref dRef, remainder); + } + + return; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, ref yRef, ref zRef, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint length) + { + for (nuint i = 0; i < length; i++) + { + Unsafe.Add(ref dRef, i) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, i), + Unsafe.Add(ref yRef, i), + Unsafe.Add(ref zRef, i)); + } + } + + static void Vectorized128(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef), + Vector128.LoadUnsafe(ref zRef)); + Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count))); + + if (remainder > (uint)(Vector128.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pz = &zRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* yPtr = py; + float* zPtr = pz; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector128)) - ((nuint)(dPtr) % (uint)(sizeof(Vector128)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + zPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector128))) == 0); + + remainder -= misalignment; + } + + Vector128 vector1; + Vector128 vector2; + Vector128 vector3; + Vector128 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 0)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 4)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + yPtr += (uint)(Vector128.Count * 8); + zPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 0)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector128.Count * 0)); + vector2.Store(dPtr + (uint)(Vector128.Count * 1)); + vector3.Store(dPtr + (uint)(Vector128.Count * 2)); + vector4.Store(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 4)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector128.Count * 4)); + vector2.Store(dPtr + (uint)(Vector128.Count * 5)); + vector3.Store(dPtr + (uint)(Vector128.Count * 6)); + vector4.Store(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + yPtr += (uint)(Vector128.Count * 8); + zPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + zRef = ref *zPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector128.Count - 1)) & (nuint)(-Vector128.Count); + + switch (remainder / (uint)(Vector128.Count)) + { + case 8: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 8)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 8)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 8)); + goto case 7; + } + + case 7: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 7)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 7)); + goto case 6; + } + + case 6: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 6)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 6)); + goto case 5; + } + + case 5: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 5)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 5)); + goto case 4; + } + + case 4: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 4)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 4)); + goto case 3; + } + + case 3: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 3)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 3)); + goto case 2; + } + + case 2: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 2)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector128.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized128Small(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 3: + { + Unsafe.Add(ref dRef, 2) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2), + Unsafe.Add(ref zRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1), + Unsafe.Add(ref zRef, 1)); + goto case 1; + } + + case 1: + { + dRef = TTernaryOperator.Invoke(xRef, yRef, zRef); + goto case 0; + } + + case 0: + { + break; + } + } + } + + static void Vectorized256(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef), + Vector256.LoadUnsafe(ref zRef)); + Vector256 end = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count)), + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count))); + + if (remainder > (uint)(Vector256.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pz = &zRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* yPtr = py; + float* zPtr = pz; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector256)) - ((nuint)(dPtr) % (uint)(sizeof(Vector256)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + zPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector256))) == 0); + + remainder -= misalignment; + } + + Vector256 vector1; + Vector256 vector2; + Vector256 vector3; + Vector256 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 0)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 4)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + yPtr += (uint)(Vector256.Count * 8); + zPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 0)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector256.Count * 0)); + vector2.Store(dPtr + (uint)(Vector256.Count * 1)); + vector3.Store(dPtr + (uint)(Vector256.Count * 2)); + vector4.Store(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 4)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector256.Count * 4)); + vector2.Store(dPtr + (uint)(Vector256.Count * 5)); + vector3.Store(dPtr + (uint)(Vector256.Count * 6)); + vector4.Store(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + yPtr += (uint)(Vector256.Count * 8); + zPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + zRef = ref *zPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector256.Count - 1)) & (nuint)(-Vector256.Count); + + switch (remainder / (uint)(Vector256.Count)) + { + case 8: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 8)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 8)), + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 8)); + goto case 7; + } + + case 7: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 7)), + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 7)); + goto case 6; + } + + case 6: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 6)), + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 6)); + goto case 5; + } + + case 5: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 5)), + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 5)); + goto case 4; + } + + case 4: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 4)), + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 4)); + goto case 3; + } + + case 3: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 3)), + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 3)); + goto case 2; + } + + case 2: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 2)), + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector256.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized256Small(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef), + Vector128.LoadUnsafe(ref zRef)); + Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count))); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef), + Vector128.LoadUnsafe(ref zRef)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2), + Unsafe.Add(ref zRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1), + Unsafe.Add(ref zRef, 1)); + goto case 1; + } + + case 1: + { + dRef = TTernaryOperator.Invoke(xRef, yRef, zRef); + goto case 0; + } + + case 0: + { + break; + } + } + } + +#if NET8_0_OR_GREATER + static void Vectorized512(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector512 beg = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef), + Vector512.LoadUnsafe(ref yRef), + Vector512.LoadUnsafe(ref zRef)); + Vector512 end = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count)), + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count))); + + if (remainder > (uint)(Vector512.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pz = &zRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* yPtr = py; + float* zPtr = pz; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector512)) - ((nuint)(dPtr) % (uint)(sizeof(Vector512)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + zPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector512))) == 0); + + remainder -= misalignment; + } + + Vector512 vector1; + Vector512 vector2; + Vector512 vector3; + Vector512 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 0)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 4)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + yPtr += (uint)(Vector512.Count * 8); + zPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 0)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector512.Count * 0)); + vector2.Store(dPtr + (uint)(Vector512.Count * 1)); + vector3.Store(dPtr + (uint)(Vector512.Count * 2)); + vector4.Store(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 4)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector512.Count * 4)); + vector2.Store(dPtr + (uint)(Vector512.Count * 5)); + vector3.Store(dPtr + (uint)(Vector512.Count * 6)); + vector4.Store(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + yPtr += (uint)(Vector512.Count * 8); + zPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + zRef = ref *zPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector512.Count - 1)) & (nuint)(-Vector512.Count); + + switch (remainder / (uint)(Vector512.Count)) + { + case 8: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 8)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 8)), + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 8)); + goto case 7; + } + + case 7: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 7)), + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 7)); + goto case 6; + } + + case 6: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 6)), + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 6)); + goto case 5; + } + + case 5: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 5)), + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 5)); + goto case 4; + } + + case 4: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 4)), + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 4)); + goto case 3; + } + + case 3: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 3)), + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 3)); + goto case 2; + } + + case 2: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 2)), + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector512.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized512Small(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 15: + case 14: + case 13: + case 12: + case 11: + case 10: + case 9: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef), + Vector256.LoadUnsafe(ref zRef)); + Vector256 end = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count)), + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count))); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count)); + + break; + } + + case 8: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef), + Vector256.LoadUnsafe(ref zRef)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef), + Vector128.LoadUnsafe(ref zRef)); + Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count))); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef), + Vector128.LoadUnsafe(ref zRef)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2), + Unsafe.Add(ref zRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1), + Unsafe.Add(ref zRef, 1)); + goto case 1; + } + + case 1: + { + dRef = TTernaryOperator.Invoke(xRef, yRef, zRef); + goto case 0; + } + + case 0: + { + break; + } + } + } +#endif + } + + /// + /// Performs an element-wise operation on , , and , + /// and writes the results to . + /// + /// + /// Specifies the operation to perform on the pair-wise elements loaded from and + /// with . + /// + private static void InvokeSpanSpanScalarIntoSpan( + ReadOnlySpan x, ReadOnlySpan y, float z, Span destination) + where TTernaryOperator : struct, ITernaryOperator + { + if (x.Length != y.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + } + + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ValidateInputOutputSpanNonOverlapping(x, destination); + ValidateInputOutputSpanNonOverlapping(y, destination); + + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float yRef = ref MemoryMarshal.GetReference(y); + ref float dRef = ref MemoryMarshal.GetReference(destination); + + nuint remainder = (uint)(x.Length); + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector512.Count)) + { + Vectorized512(ref xRef, ref yRef, z, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized512Small(ref xRef, ref yRef, z, ref dRef, remainder); + } + + return; + } +#endif + + if (Vector256.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector256.Count)) + { + Vectorized256(ref xRef, ref yRef, z, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized256Small(ref xRef, ref yRef, z, ref dRef, remainder); + } + + return; + } + + if (Vector128.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector128.Count)) + { + Vectorized128(ref xRef, ref yRef, z, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized128Small(ref xRef, ref yRef, z, ref dRef, remainder); + } + + return; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, ref yRef, z, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, ref float yRef, float z, ref float dRef, nuint length) + { + for (nuint i = 0; i < length; i++) + { + Unsafe.Add(ref dRef, i) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, i), + Unsafe.Add(ref yRef, i), + z); + } + } + + static void Vectorized128(ref float xRef, ref float yRef, float z, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector128 zVec = Vector128.Create(z); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef), + zVec); + Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count)), + zVec); + + if (remainder > (uint)(Vector128.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* yPtr = py; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector128)) - ((nuint)(dPtr) % (uint)(sizeof(Vector128)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector128))) == 0); + + remainder -= misalignment; + } + + Vector128 vector1; + Vector128 vector2; + Vector128 vector3; + Vector128 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 0)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 1)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 2)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 3)), + zVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 4)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 5)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 6)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 7)), + zVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + yPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 0)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 1)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 2)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 3)), + zVec); + + vector1.Store(dPtr + (uint)(Vector128.Count * 0)); + vector2.Store(dPtr + (uint)(Vector128.Count * 1)); + vector3.Store(dPtr + (uint)(Vector128.Count * 2)); + vector4.Store(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 4)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 5)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 6)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 7)), + zVec); + + vector1.Store(dPtr + (uint)(Vector128.Count * 4)); + vector2.Store(dPtr + (uint)(Vector128.Count * 5)); + vector3.Store(dPtr + (uint)(Vector128.Count * 6)); + vector4.Store(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + yPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector128.Count - 1)) & (nuint)(-Vector128.Count); + + switch (remainder / (uint)(Vector128.Count)) + { + case 8: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 8)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 8)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 8)); + goto case 7; + } + + case 7: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 7)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 7)); + goto case 6; + } + + case 6: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 6)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 6)); + goto case 5; + } + + case 5: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 5)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 5)); + goto case 4; + } + + case 4: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 4)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 4)); + goto case 3; + } + + case 3: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 3)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 3)); + goto case 2; + } + + case 2: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 2)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector128.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized128Small(ref float xRef, ref float yRef, float z, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 3: + { + Unsafe.Add(ref dRef, 2) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2), + z); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1), + z); + goto case 1; + } + + case 1: + { + dRef = TTernaryOperator.Invoke(xRef, yRef, z); + goto case 0; + } + + case 0: + { + break; + } + } + } + + static void Vectorized256(ref float xRef, ref float yRef, float z, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector256 zVec = Vector256.Create(z); + + Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef), + zVec); + Vector256 end = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count)), + zVec); + + if (remainder > (uint)(Vector256.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* yPtr = py; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector256)) - ((nuint)(dPtr) % (uint)(sizeof(Vector256)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector256))) == 0); + + remainder -= misalignment; + } + + Vector256 vector1; + Vector256 vector2; + Vector256 vector3; + Vector256 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 0)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 1)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 2)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 3)), + zVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 4)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 5)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 6)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 7)), + zVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + yPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 0)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 1)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 2)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 3)), + zVec); + + vector1.Store(dPtr + (uint)(Vector256.Count * 0)); + vector2.Store(dPtr + (uint)(Vector256.Count * 1)); + vector3.Store(dPtr + (uint)(Vector256.Count * 2)); + vector4.Store(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 4)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 5)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 6)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 7)), + zVec); + + vector1.Store(dPtr + (uint)(Vector256.Count * 4)); + vector2.Store(dPtr + (uint)(Vector256.Count * 5)); + vector3.Store(dPtr + (uint)(Vector256.Count * 6)); + vector4.Store(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + yPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector256.Count - 1)) & (nuint)(-Vector256.Count); + + switch (remainder / (uint)(Vector256.Count)) + { + case 8: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 8)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 8)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 8)); + goto case 7; + } + + case 7: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 7)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 7)); + goto case 6; + } + + case 6: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 6)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 6)); + goto case 5; + } + + case 5: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 5)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 5)); + goto case 4; + } + + case 4: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 4)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 4)); + goto case 3; + } + + case 3: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 3)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 3)); + goto case 2; + } + + case 2: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 2)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector256.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized256Small(ref float xRef, ref float yRef, float z, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 zVec = Vector128.Create(z); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef), + zVec); + Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count)), + zVec); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef), + Vector128.Create(z)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2), + z); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1), + z); + goto case 1; + } + + case 1: + { + dRef = TTernaryOperator.Invoke(xRef, yRef, z); + goto case 0; + } + + case 0: + { + break; + } + } + } + +#if NET8_0_OR_GREATER + static void Vectorized512(ref float xRef, ref float yRef, float z, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector512 zVec = Vector512.Create(z); + + Vector512 beg = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef), + Vector512.LoadUnsafe(ref yRef), + zVec); + Vector512 end = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count)), + zVec); + + if (remainder > (uint)(Vector512.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* yPtr = py; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector512)) - ((nuint)(dPtr) % (uint)(sizeof(Vector512)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector512))) == 0); + + remainder -= misalignment; + } + + Vector512 vector1; + Vector512 vector2; + Vector512 vector3; + Vector512 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 0)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 1)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 2)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 3)), + zVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 4)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 5)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 6)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 7)), + zVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + yPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 0)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 1)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 2)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 3)), + zVec); + + vector1.Store(dPtr + (uint)(Vector512.Count * 0)); + vector2.Store(dPtr + (uint)(Vector512.Count * 1)); + vector3.Store(dPtr + (uint)(Vector512.Count * 2)); + vector4.Store(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 4)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 5)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 6)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 7)), + zVec); + + vector1.Store(dPtr + (uint)(Vector512.Count * 4)); + vector2.Store(dPtr + (uint)(Vector512.Count * 5)); + vector3.Store(dPtr + (uint)(Vector512.Count * 6)); + vector4.Store(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + yPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector512.Count - 1)) & (nuint)(-Vector512.Count); + + switch (remainder / (uint)(Vector512.Count)) + { + case 8: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 8)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 8)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 8)); + goto case 7; + } + + case 7: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 7)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 7)); + goto case 6; + } + + case 6: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 6)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 6)); + goto case 5; + } + + case 5: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 5)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 5)); + goto case 4; + } + + case 4: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 4)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 4)); + goto case 3; + } + + case 3: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 3)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 3)); + goto case 2; + } + + case 2: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 2)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector512.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized512Small(ref float xRef, ref float yRef, float z, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 15: + case 14: + case 13: + case 12: + case 11: + case 10: + case 9: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 zVec = Vector256.Create(z); + + Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef), + zVec); + Vector256 end = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count)), + zVec); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count)); + + break; + } + + case 8: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef), + Vector256.Create(z)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 zVec = Vector128.Create(z); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef), + zVec); + Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count)), + zVec); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef), + Vector128.Create(z)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2), + z); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1), + z); + goto case 1; + } + + case 1: + { + dRef = TTernaryOperator.Invoke(xRef, yRef, z); + goto case 0; + } + + case 0: + { + break; + } + } + } +#endif + } + + /// + /// Performs an element-wise operation on , , and , + /// and writes the results to . + /// + /// + /// Specifies the operation to perform on the pair-wise element loaded from , with , + /// and the element loaded from . + /// + private static void InvokeSpanScalarSpanIntoSpan( + ReadOnlySpan x, float y, ReadOnlySpan z, Span destination) + where TTernaryOperator : struct, ITernaryOperator + { + if (x.Length != z.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + } + + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ValidateInputOutputSpanNonOverlapping(x, destination); + ValidateInputOutputSpanNonOverlapping(z, destination); + + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float zRef = ref MemoryMarshal.GetReference(z); + ref float dRef = ref MemoryMarshal.GetReference(destination); + + nuint remainder = (uint)(x.Length); + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector512.Count)) + { + Vectorized512(ref xRef, y, ref zRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized512Small(ref xRef, y, ref zRef, ref dRef, remainder); + } + + return; + } +#endif + + if (Vector256.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector256.Count)) + { + Vectorized256(ref xRef, y, ref zRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized256Small(ref xRef, y, ref zRef, ref dRef, remainder); + } + + return; + } + + if (Vector128.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector128.Count)) + { + Vectorized128(ref xRef, y, ref zRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized128Small(ref xRef, y, ref zRef, ref dRef, remainder); + } + + return; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, y, ref zRef, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, float y, ref float zRef, ref float dRef, nuint length) + { + for (nuint i = 0; i < length; i++) + { + Unsafe.Add(ref dRef, i) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, i), + y, + Unsafe.Add(ref zRef, i)); + } + } + + static void Vectorized128(ref float xRef, float y, ref float zRef, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector128 yVec = Vector128.Create(y); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + yVec, + Vector128.LoadUnsafe(ref zRef)); + Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count))); + + if (remainder > (uint)(Vector128.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* pz = &zRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* zPtr = pz; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector128)) - ((nuint)(dPtr) % (uint)(sizeof(Vector128)))) / sizeof(float); + + xPtr += misalignment; + zPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector128))) == 0); + + remainder -= misalignment; + } + + Vector128 vector1; + Vector128 vector2; + Vector128 vector3; + Vector128 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + zPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector128.Count * 0)); + vector2.Store(dPtr + (uint)(Vector128.Count * 1)); + vector3.Store(dPtr + (uint)(Vector128.Count * 2)); + vector4.Store(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector128.Count * 4)); + vector2.Store(dPtr + (uint)(Vector128.Count * 5)); + vector3.Store(dPtr + (uint)(Vector128.Count * 6)); + vector4.Store(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + zPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + zRef = ref *zPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector128.Count - 1)) & (nuint)(-Vector128.Count); + + switch (remainder / (uint)(Vector128.Count)) + { + case 8: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 8)), + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 8)); + goto case 7; + } + + case 7: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7)), + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 7)); + goto case 6; + } + + case 6: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6)), + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 6)); + goto case 5; + } + + case 5: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5)), + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 5)); + goto case 4; + } + + case 4: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4)), + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 4)); + goto case 3; + } + + case 3: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3)), + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 3)); + goto case 2; + } + + case 2: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2)), + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector128.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized128Small(ref float xRef, float y, ref float zRef, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 3: + { + Unsafe.Add(ref dRef, 2) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + y, + Unsafe.Add(ref zRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + y, + Unsafe.Add(ref zRef, 1)); + goto case 1; + } + + case 1: + { + dRef = TTernaryOperator.Invoke(xRef, y, zRef); + goto case 0; + } + + case 0: + { + break; + } + } + } + + static void Vectorized256(ref float xRef, float y, ref float zRef, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector256 yVec = Vector256.Create(y); + + Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + yVec, + Vector256.LoadUnsafe(ref zRef)); + Vector256 end = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), + yVec, + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count))); + + if (remainder > (uint)(Vector256.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* pz = &zRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* zPtr = pz; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector256)) - ((nuint)(dPtr) % (uint)(sizeof(Vector256)))) / sizeof(float); + + xPtr += misalignment; + zPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector256))) == 0); + + remainder -= misalignment; + } + + Vector256 vector1; + Vector256 vector2; + Vector256 vector3; + Vector256 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + zPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector256.Count * 0)); + vector2.Store(dPtr + (uint)(Vector256.Count * 1)); + vector3.Store(dPtr + (uint)(Vector256.Count * 2)); + vector4.Store(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector256.Count * 4)); + vector2.Store(dPtr + (uint)(Vector256.Count * 5)); + vector3.Store(dPtr + (uint)(Vector256.Count * 6)); + vector4.Store(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + zPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + zRef = ref *zPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector256.Count - 1)) & (nuint)(-Vector256.Count); + + switch (remainder / (uint)(Vector256.Count)) + { + case 8: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 8)), + yVec, + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 8)); + goto case 7; + } + + case 7: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7)), + yVec, + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 7)); + goto case 6; + } + + case 6: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6)), + yVec, + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 6)); + goto case 5; + } + + case 5: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5)), + yVec, + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 5)); + goto case 4; + } + + case 4: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4)), + yVec, + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 4)); + goto case 3; + } + + case 3: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3)), + yVec, + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 3)); + goto case 2; + } + + case 2: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2)), + yVec, + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector256.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized256Small(ref float xRef, float y, ref float zRef, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 yVec = Vector128.Create(y); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + yVec, + Vector128.LoadUnsafe(ref zRef)); + Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count))); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.Create(y), + Vector128.LoadUnsafe(ref zRef)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + y, + Unsafe.Add(ref zRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + y, + Unsafe.Add(ref zRef, 1)); + goto case 1; + } + + case 1: + { + dRef = TTernaryOperator.Invoke(xRef, y, zRef); + goto case 0; + } + + case 0: + { + break; + } + } + } + +#if NET8_0_OR_GREATER + static void Vectorized512(ref float xRef, float y, ref float zRef, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector512 yVec = Vector512.Create(y); + + Vector512 beg = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef), + yVec, + Vector512.LoadUnsafe(ref zRef)); + Vector512 end = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count)), + yVec, + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count))); + + if (remainder > (uint)(Vector512.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* pz = &zRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* zPtr = pz; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector512)) - ((nuint)(dPtr) % (uint)(sizeof(Vector512)))) / sizeof(float); + + xPtr += misalignment; + zPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector512))) == 0); + + remainder -= misalignment; + } + + Vector512 vector1; + Vector512 vector2; + Vector512 vector3; + Vector512 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + zPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector512.Count * 0)); + vector2.Store(dPtr + (uint)(Vector512.Count * 1)); + vector3.Store(dPtr + (uint)(Vector512.Count * 2)); + vector4.Store(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector512.Count * 4)); + vector2.Store(dPtr + (uint)(Vector512.Count * 5)); + vector3.Store(dPtr + (uint)(Vector512.Count * 6)); + vector4.Store(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + zPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + zRef = ref *zPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector512.Count - 1)) & (nuint)(-Vector512.Count); + + switch (remainder / (uint)(Vector512.Count)) + { + case 8: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 8)), + yVec, + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 8)); + goto case 7; + } + + case 7: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7)), + yVec, + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 7)); + goto case 6; + } + + case 6: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6)), + yVec, + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 6)); + goto case 5; + } + + case 5: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5)), + yVec, + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 5)); + goto case 4; + } + + case 4: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4)), + yVec, + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 4)); + goto case 3; + } + + case 3: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3)), + yVec, + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 3)); + goto case 2; + } + + case 2: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2)), + yVec, + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector512.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized512Small(ref float xRef, float y, ref float zRef, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 15: + case 14: + case 13: + case 12: + case 11: + case 10: + case 9: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 yVec = Vector256.Create(y); + + Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + yVec, + Vector256.LoadUnsafe(ref zRef)); + Vector256 end = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), + yVec, + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count))); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count)); + + break; + } + + case 8: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.Create(y), + Vector256.LoadUnsafe(ref zRef)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 yVec = Vector128.Create(y); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + yVec, + Vector128.LoadUnsafe(ref zRef)); + Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count))); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.Create(y), + Vector128.LoadUnsafe(ref zRef)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + y, + Unsafe.Add(ref zRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + y, + Unsafe.Add(ref zRef, 1)); + goto case 1; + } + + case 1: + { + dRef = TTernaryOperator.Invoke(xRef, y, zRef); + goto case 0; + } + + case 0: + { + break; + } + } + } +#endif + } + + /// Performs (x * y) + z. It will be rounded as one ternary operation if such an operation is accelerated on the current hardware. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector128 FusedMultiplyAdd(Vector128 x, Vector128 y, Vector128 addend) + { + if (Fma.IsSupported) + { + return Fma.MultiplyAdd(x, y, addend); + } + + if (AdvSimd.IsSupported) + { + return AdvSimd.FusedMultiplyAdd(addend, x, y); + } + + return (x * y) + addend; + } + + /// Performs (x * y) + z. It will be rounded as one ternary operation if such an operation is accelerated on the current hardware. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector256 FusedMultiplyAdd(Vector256 x, Vector256 y, Vector256 addend) + { + if (Fma.IsSupported) + { + return Fma.MultiplyAdd(x, y, addend); + } + + return (x * y) + addend; + } + +#if NET8_0_OR_GREATER + /// Performs (x * y) + z. It will be rounded as one ternary operation if such an operation is accelerated on the current hardware. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector512 FusedMultiplyAdd(Vector512 x, Vector512 y, Vector512 addend) + { + if (Avx512F.IsSupported) + { + return Avx512F.FusedMultiplyAdd(x, y, addend); + } + + return (x * y) + addend; + } +#endif + + /// Aggregates all of the elements in the into a single value. + /// Specifies the operation to be performed on each pair of values. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static float HorizontalAggregate(Vector128 x) where TAggregate : struct, IBinaryOperator + { + // We need to do log2(count) operations to compute the total sum + + x = TAggregate.Invoke(x, Vector128.Shuffle(x, Vector128.Create(2, 3, 0, 1))); + x = TAggregate.Invoke(x, Vector128.Shuffle(x, Vector128.Create(1, 0, 3, 2))); + + return x.ToScalar(); + } + + /// Aggregates all of the elements in the into a single value. + /// Specifies the operation to be performed on each pair of values. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static float HorizontalAggregate(Vector256 x) where TAggregate : struct, IBinaryOperator => + HorizontalAggregate(TAggregate.Invoke(x.GetLower(), x.GetUpper())); + +#if NET8_0_OR_GREATER + /// Aggregates all of the elements in the into a single value. + /// Specifies the operation to be performed on each pair of values. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static float HorizontalAggregate(Vector512 x) where TAggregate : struct, IBinaryOperator => + HorizontalAggregate(TAggregate.Invoke(x.GetLower(), x.GetUpper())); +#endif + + /// Gets whether the specified is negative. + private static bool IsNegative(float f) => float.IsNegative(f); + + /// Gets whether each specified is negative. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector128 IsNegative(Vector128 vector) => + Vector128.LessThan(vector.AsInt32(), Vector128.Zero).AsSingle(); + + /// Gets whether each specified is negative. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector256 IsNegative(Vector256 vector) => + Vector256.LessThan(vector.AsInt32(), Vector256.Zero).AsSingle(); + +#if NET8_0_OR_GREATER + /// Gets whether each specified is negative. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector512 IsNegative(Vector512 vector) => + Vector512.LessThan(vector.AsInt32(), Vector512.Zero).AsSingle(); +#endif + + /// Gets whether the specified is positive. + private static bool IsPositive(float f) => float.IsPositive(f); + + /// Gets whether each specified is positive. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector128 IsPositive(Vector128 vector) => + Vector128.GreaterThan(vector.AsInt32(), Vector128.AllBitsSet).AsSingle(); + + /// Gets whether each specified is positive. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector256 IsPositive(Vector256 vector) => + Vector256.GreaterThan(vector.AsInt32(), Vector256.AllBitsSet).AsSingle(); + +#if NET8_0_OR_GREATER + /// Gets whether each specified is positive. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector512 IsPositive(Vector512 vector) => + Vector512.GreaterThan(vector.AsInt32(), Vector512.AllBitsSet).AsSingle(); +#endif + + /// Finds and returns the first NaN value in . + /// The vector must have already been validated to contain a NaN. + private static float GetFirstNaN(Vector128 vector) + { + Debug.Assert(!Vector128.EqualsAll(vector, vector), "Expected vector to contain a NaN"); + return vector.GetElement(BitOperations.TrailingZeroCount((~Vector128.Equals(vector, vector)).ExtractMostSignificantBits())); + } + + /// Finds and returns the first NaN index value in . + /// The vector must have already been validated to contain a NaN. + private static int GetFirstNaNIndex(Vector128 vector, Vector128 index) + { + Debug.Assert(!Vector128.EqualsAll(vector, vector), "Expected vector to contain a NaN"); + return index.GetElement(BitOperations.TrailingZeroCount((~Vector128.Equals(vector, vector)).ExtractMostSignificantBits())); + } + + /// Finds and returns the first NaN value in . + /// The vector must have already been validated to contain a NaN. + private static float GetFirstNaN(Vector256 vector) + { + Debug.Assert(!Vector256.EqualsAll(vector, vector), "Expected vector to contain a NaN"); + return vector.GetElement(BitOperations.TrailingZeroCount((~Vector256.Equals(vector, vector)).ExtractMostSignificantBits())); + } + + /// Finds and returns the first NaN index value in . + /// The vector must have already been validated to contain a NaN. + private static int GetFirstNaNIndex(Vector256 vector, Vector256 index) + { + Debug.Assert(!Vector256.EqualsAll(vector, vector), "Expected vector to contain a NaN"); + return index.GetElement(BitOperations.TrailingZeroCount((~Vector256.Equals(vector, vector)).ExtractMostSignificantBits())); + } + +#if NET8_0_OR_GREATER + /// Finds and returns the first NaN value in . + /// The vector must have already been validated to contain a NaN. + private static float GetFirstNaN(Vector512 vector) + { + Debug.Assert(!Vector512.EqualsAll(vector, vector), "Expected vector to contain a NaN"); + return vector.GetElement(BitOperations.TrailingZeroCount((~Vector512.Equals(vector, vector)).ExtractMostSignificantBits())); + } + + /// Finds and returns the first NaN value in . + /// The vector must have already been validated to contain a NaN. + private static int GetFirstNaNIndex(Vector512 vector, Vector512 index) + { + Debug.Assert(!Vector512.EqualsAll(vector, vector), "Expected vector to contain a NaN"); + return index.GetElement(BitOperations.TrailingZeroCount((~Vector512.Equals(vector, vector)).ExtractMostSignificantBits())); + } +#endif + + /// Gets the base 2 logarithm of . + private static float Log2(float x) => MathF.Log2(x); + + /// + /// Gets a vector mask that will be all-ones-set for the last elements + /// and zero for all other elements. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector128 CreateAlignmentMaskSingleVector128(int count) => + Vector128.LoadUnsafe( + ref Unsafe.As(ref MemoryMarshal.GetReference(AlignmentUInt32Mask_16x16)), + (uint)(count * 16)); // first four floats in the row + + /// + /// Gets a vector mask that will be all-ones-set for the last elements + /// and zero for all other elements. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector256 CreateAlignmentMaskSingleVector256(int count) => + Vector256.LoadUnsafe( + ref Unsafe.As(ref MemoryMarshal.GetReference(AlignmentUInt32Mask_16x16)), + (uint)(count * 16)); // first eight floats in the row + +#if NET8_0_OR_GREATER + /// + /// Gets a vector mask that will be all-ones-set for the last elements + /// and zero for all other elements. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector512 CreateAlignmentMaskSingleVector512(int count) => + Vector512.LoadUnsafe( + ref Unsafe.As(ref MemoryMarshal.GetReference(AlignmentUInt32Mask_16x16)), + (uint)(count * 16)); // all sixteen floats in the row +#endif + + /// + /// Gets a vector mask that will be all-ones-set for the last elements + /// and zero for all other elements. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector128 CreateRemainderMaskSingleVector128(int count) => + Vector128.LoadUnsafe( + ref Unsafe.As(ref MemoryMarshal.GetReference(RemainderUInt32Mask_16x16)), + (uint)((count * 16) + 12)); // last four floats in the row + + /// + /// Gets a vector mask that will be all-ones-set for the last elements + /// and zero for all other elements. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector256 CreateRemainderMaskSingleVector256(int count) => + Vector256.LoadUnsafe( + ref Unsafe.As(ref MemoryMarshal.GetReference(RemainderUInt32Mask_16x16)), + (uint)((count * 16) + 8)); // last eight floats in the row + +#if NET8_0_OR_GREATER + /// + /// Gets a vector mask that will be all-ones-set for the last elements + /// and zero for all other elements. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector512 CreateRemainderMaskSingleVector512(int count) => + Vector512.LoadUnsafe( + ref Unsafe.As(ref MemoryMarshal.GetReference(RemainderUInt32Mask_16x16)), + (uint)(count * 16)); // all sixteen floats in the row +#endif + + /// x + y + private readonly struct AddOperator : IAggregationOperator + { + public static float Invoke(float x, float y) => x + y; + public static Vector128 Invoke(Vector128 x, Vector128 y) => x + y; + public static Vector256 Invoke(Vector256 x, Vector256 y) => x + y; +#if NET8_0_OR_GREATER + public static Vector512 Invoke(Vector512 x, Vector512 y) => x + y; +#endif + + public static float Invoke(Vector128 x) => Vector128.Sum(x); + public static float Invoke(Vector256 x) => Vector256.Sum(x); +#if NET8_0_OR_GREATER + public static float Invoke(Vector512 x) => Vector512.Sum(x); +#endif + + public static float IdentityValue => 0; + } + + /// x - y + private readonly struct SubtractOperator : IBinaryOperator + { + public static float Invoke(float x, float y) => x - y; + public static Vector128 Invoke(Vector128 x, Vector128 y) => x - y; + public static Vector256 Invoke(Vector256 x, Vector256 y) => x - y; +#if NET8_0_OR_GREATER + public static Vector512 Invoke(Vector512 x, Vector512 y) => x - y; +#endif + } + + /// (x - y) * (x - y) + private readonly struct SubtractSquaredOperator : IBinaryOperator + { + public static float Invoke(float x, float y) + { + float tmp = x - y; + return tmp * tmp; + } + + public static Vector128 Invoke(Vector128 x, Vector128 y) + { + Vector128 tmp = x - y; + return tmp * tmp; + } + + public static Vector256 Invoke(Vector256 x, Vector256 y) + { + Vector256 tmp = x - y; + return tmp * tmp; + } + +#if NET8_0_OR_GREATER + public static Vector512 Invoke(Vector512 x, Vector512 y) + { + Vector512 tmp = x - y; + return tmp * tmp; + } +#endif + } + + /// x * y + private readonly struct MultiplyOperator : IAggregationOperator + { + public static float Invoke(float x, float y) => x * y; + public static Vector128 Invoke(Vector128 x, Vector128 y) => x * y; + public static Vector256 Invoke(Vector256 x, Vector256 y) => x * y; +#if NET8_0_OR_GREATER + public static Vector512 Invoke(Vector512 x, Vector512 y) => x * y; +#endif + + public static float Invoke(Vector128 x) => HorizontalAggregate(x); + public static float Invoke(Vector256 x) => HorizontalAggregate(x); +#if NET8_0_OR_GREATER + public static float Invoke(Vector512 x) => HorizontalAggregate(x); +#endif + + public static float IdentityValue => 1; + } + + /// x / y + private readonly struct DivideOperator : IBinaryOperator + { + public static float Invoke(float x, float y) => x / y; + public static Vector128 Invoke(Vector128 x, Vector128 y) => x / y; + public static Vector256 Invoke(Vector256 x, Vector256 y) => x / y; +#if NET8_0_OR_GREATER + public static Vector512 Invoke(Vector512 x, Vector512 y) => x / y; +#endif + } + + /// MathF.Max(x, y) (but NaNs may not be propagated) + private readonly struct MaxOperator : IAggregationOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static float Invoke(float x, float y) => + x == y ? + (IsNegative(x) ? y : x) : + (y > x ? y : x); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector128 Invoke(Vector128 x, Vector128 y) + { + if (AdvSimd.IsSupported) + { + return AdvSimd.Max(x, y); + } + + return + Vector128.ConditionalSelect(Vector128.Equals(x, y), + Vector128.ConditionalSelect(IsNegative(x), y, x), + Vector128.Max(x, y)); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector256 Invoke(Vector256 x, Vector256 y) => + Vector256.ConditionalSelect(Vector256.Equals(x, y), + Vector256.ConditionalSelect(IsNegative(x), y, x), + Vector256.Max(x, y)); + +#if NET8_0_OR_GREATER + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector512 Invoke(Vector512 x, Vector512 y) => + Vector512.ConditionalSelect(Vector512.Equals(x, y), + Vector512.ConditionalSelect(IsNegative(x), y, x), + Vector512.Max(x, y)); +#endif + + public static float Invoke(Vector128 x) => HorizontalAggregate(x); + public static float Invoke(Vector256 x) => HorizontalAggregate(x); +#if NET8_0_OR_GREATER + public static float Invoke(Vector512 x) => HorizontalAggregate(x); +#endif + } + + private interface IIndexOfOperator + { + static abstract int Invoke(ref float result, float current, int resultIndex, int curIndex); + static abstract int Invoke(Vector128 result, Vector128 resultIndex); + static abstract void Invoke(ref Vector128 result, Vector128 current, ref Vector128 resultIndex, Vector128 curIndex); + static abstract int Invoke(Vector256 result, Vector256 resultIndex); + static abstract void Invoke(ref Vector256 result, Vector256 current, ref Vector256 resultIndex, Vector256 curIndex); +#if NET8_0_OR_GREATER + static abstract int Invoke(Vector512 result, Vector512 resultIndex); + static abstract void Invoke(ref Vector512 result, Vector512 current, ref Vector512 resultIndex, Vector512 curIndex); +#endif + } + + /// Returns the index of MathF.Max(x, y) + private readonly struct IndexOfMaxOperator : IIndexOfOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(Vector128 result, Vector128 maxIndex) + { + Vector128 tmpResult = Vector128.Shuffle(result, Vector128.Create(2, 3, 0, 1)); + Vector128 tmpIndex = Vector128.Shuffle(maxIndex, Vector128.Create(2, 3, 0, 1)); + + Invoke(ref result, tmpResult, ref maxIndex, tmpIndex); + + tmpResult = Vector128.Shuffle(result, Vector128.Create(1, 0, 3, 2)); + tmpIndex = Vector128.Shuffle(maxIndex, Vector128.Create(1, 0, 3, 2)); + + Invoke(ref result, tmpResult, ref maxIndex, tmpIndex); + return maxIndex.ToScalar(); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Invoke(ref Vector128 max, Vector128 current, ref Vector128 maxIndex, Vector128 curIndex) + { + Vector128 greaterThanMask = Vector128.GreaterThan(max, current); + + Vector128 equalMask = Vector128.Equals(max, current); + if (equalMask.AsInt32() != Vector128.Zero) + { + Vector128 negativeMask = IsNegative(current); + Vector128 lessThanMask = Vector128.LessThan(maxIndex, curIndex); + + greaterThanMask |= (negativeMask & equalMask) | (~IsNegative(max) & equalMask & lessThanMask.AsSingle()); + } + + max = ElementWiseSelect(greaterThanMask, max, current); + + maxIndex = ElementWiseSelect(greaterThanMask.AsInt32(), maxIndex, curIndex); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(Vector256 result, Vector256 maxIndex) + { + // Max the upper/lower halves of the Vector256 + Vector128 resultLower = result.GetLower(); + Vector128 indexLower = maxIndex.GetLower(); + + Invoke(ref resultLower, result.GetUpper(), ref indexLower, maxIndex.GetUpper()); + return Invoke(resultLower, indexLower); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Invoke(ref Vector256 max, Vector256 current, ref Vector256 maxIndex, Vector256 curIndex) + { + Vector256 greaterThanMask = Vector256.GreaterThan(max, current); + + Vector256 equalMask = Vector256.Equals(max, current); + if (equalMask.AsInt32() != Vector256.Zero) + { + Vector256 negativeMask = IsNegative(current); + Vector256 lessThanMask = Vector256.LessThan(maxIndex, curIndex); + + greaterThanMask |= (negativeMask & equalMask) | (~IsNegative(max) & equalMask & lessThanMask.AsSingle()); + } + + max = ElementWiseSelect(greaterThanMask, max, current); + + maxIndex = ElementWiseSelect(greaterThanMask.AsInt32(), maxIndex, curIndex); + } + +#if NET8_0_OR_GREATER + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(Vector512 result, Vector512 resultIndex) + { + // Min the upper/lower halves of the Vector512 + Vector256 resultLower = result.GetLower(); + Vector256 indexLower = resultIndex.GetLower(); + + Invoke(ref resultLower, result.GetUpper(), ref indexLower, resultIndex.GetUpper()); + return Invoke(resultLower, indexLower); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Invoke(ref Vector512 max, Vector512 current, ref Vector512 maxIndex, Vector512 curIndex) + { + Vector512 greaterThanMask = Vector512.GreaterThan(max, current); + + Vector512 equalMask = Vector512.Equals(max, current); + if (equalMask.AsInt32() != Vector512.Zero) + { + Vector512 negativeMask = IsNegative(current); + Vector512 lessThanMask = Vector512.LessThan(maxIndex, curIndex); + + greaterThanMask |= (negativeMask & equalMask) | (~IsNegative(max) & equalMask & lessThanMask.AsSingle()); + } + + max = ElementWiseSelect(greaterThanMask, max, current); + + maxIndex = ElementWiseSelect(greaterThanMask.AsInt32(), maxIndex, curIndex); + } +#endif + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(ref float result, float current, int resultIndex, int curIndex) + { + if (result == current) + { + if (IsNegative(result) && !IsNegative(current)) + { + result = current; + return curIndex; + } + } + else if (current > result) + { + result = current; + return curIndex; + } + + return resultIndex; + } + } + + private readonly struct IndexOfMaxMagnitudeOperator : IIndexOfOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(Vector128 result, Vector128 maxIndex) + { + Vector128 tmpResult = Vector128.Shuffle(result, Vector128.Create(2, 3, 0, 1)); + Vector128 tmpIndex = Vector128.Shuffle(maxIndex, Vector128.Create(2, 3, 0, 1)); + + Invoke(ref result, tmpResult, ref maxIndex, tmpIndex); + + tmpResult = Vector128.Shuffle(result, Vector128.Create(1, 0, 3, 2)); + tmpIndex = Vector128.Shuffle(maxIndex, Vector128.Create(1, 0, 3, 2)); + + Invoke(ref result, tmpResult, ref maxIndex, tmpIndex); + return maxIndex.ToScalar(); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Invoke(ref Vector128 max, Vector128 current, ref Vector128 maxIndex, Vector128 curIndex) + { + Vector128 maxMag = Vector128.Abs(max), currentMag = Vector128.Abs(current); + + Vector128 greaterThanMask = Vector128.GreaterThan(maxMag, currentMag); + + Vector128 equalMask = Vector128.Equals(max, current); + if (equalMask.AsInt32() != Vector128.Zero) + { + Vector128 negativeMask = IsNegative(current); + Vector128 lessThanMask = Vector128.LessThan(maxIndex, curIndex); + + greaterThanMask |= (negativeMask & equalMask) | (~IsNegative(max) & equalMask & lessThanMask.AsSingle()); + } + + max = ElementWiseSelect(greaterThanMask, max, current); + + maxIndex = ElementWiseSelect(greaterThanMask.AsInt32(), maxIndex, curIndex); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(Vector256 result, Vector256 maxIndex) + { + // Max the upper/lower halves of the Vector256 + Vector128 resultLower = result.GetLower(); + Vector128 indexLower = maxIndex.GetLower(); + + Invoke(ref resultLower, result.GetUpper(), ref indexLower, maxIndex.GetUpper()); + return Invoke(resultLower, indexLower); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Invoke(ref Vector256 max, Vector256 current, ref Vector256 maxIndex, Vector256 curIndex) + { + Vector256 maxMag = Vector256.Abs(max), currentMag = Vector256.Abs(current); + + Vector256 greaterThanMask = Vector256.GreaterThan(maxMag, currentMag); + + Vector256 equalMask = Vector256.Equals(max, current); + if (equalMask.AsInt32() != Vector256.Zero) + { + Vector256 negativeMask = IsNegative(current); + Vector256 lessThanMask = Vector256.LessThan(maxIndex, curIndex); + + greaterThanMask |= (negativeMask & equalMask) | (~IsNegative(max) & equalMask & lessThanMask.AsSingle()); + } + + max = ElementWiseSelect(greaterThanMask, max, current); + + maxIndex = ElementWiseSelect(greaterThanMask.AsInt32(), maxIndex, curIndex); + } + +#if NET8_0_OR_GREATER + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(Vector512 result, Vector512 resultIndex) + { + // Min the upper/lower halves of the Vector512 + Vector256 resultLower = result.GetLower(); + Vector256 indexLower = resultIndex.GetLower(); + + Invoke(ref resultLower, result.GetUpper(), ref indexLower, resultIndex.GetUpper()); + return Invoke(resultLower, indexLower); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Invoke(ref Vector512 max, Vector512 current, ref Vector512 maxIndex, Vector512 curIndex) + { + Vector512 maxMag = Vector512.Abs(max), currentMag = Vector512.Abs(current); + Vector512 greaterThanMask = Vector512.GreaterThan(maxMag, currentMag); + + Vector512 equalMask = Vector512.Equals(max, current); + if (equalMask.AsInt32() != Vector512.Zero) + { + Vector512 negativeMask = IsNegative(current); + Vector512 lessThanMask = Vector512.LessThan(maxIndex, curIndex); + + greaterThanMask |= (negativeMask & equalMask) | (~IsNegative(max) & equalMask & lessThanMask.AsSingle()); + } + + max = ElementWiseSelect(greaterThanMask, max, current); + + maxIndex = ElementWiseSelect(greaterThanMask.AsInt32(), maxIndex, curIndex); + } +#endif + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(ref float result, float current, int resultIndex, int curIndex) + { + float curMaxAbs = MathF.Abs(result); + float currentAbs = MathF.Abs(current); + + if (curMaxAbs == currentAbs) + { + if (IsNegative(result) && !IsNegative(current)) + { + result = current; + return curIndex; + } + } + else if (currentAbs > curMaxAbs) + { + result = current; + return curIndex; + } + + return resultIndex; + } + } + + /// Returns the index of MathF.Min(x, y) + private readonly struct IndexOfMinOperator : IIndexOfOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(Vector128 result, Vector128 resultIndex) + { + Vector128 tmpResult = Vector128.Shuffle(result, Vector128.Create(2, 3, 0, 1)); + Vector128 tmpIndex = Vector128.Shuffle(resultIndex, Vector128.Create(2, 3, 0, 1)); + + Invoke(ref result, tmpResult, ref resultIndex, tmpIndex); + + tmpResult = Vector128.Shuffle(result, Vector128.Create(1, 0, 3, 2)); + tmpIndex = Vector128.Shuffle(resultIndex, Vector128.Create(1, 0, 3, 2)); + + Invoke(ref result, tmpResult, ref resultIndex, tmpIndex); + return resultIndex.ToScalar(); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Invoke(ref Vector128 result, Vector128 current, ref Vector128 resultIndex, Vector128 curIndex) + { + Vector128 lessThanMask = Vector128.LessThan(result, current); + + Vector128 equalMask = Vector128.Equals(result, current); + if (equalMask.AsInt32() != Vector128.Zero) + { + Vector128 negativeMask = IsNegative(current); + Vector128 lessThanIndexMask = Vector128.LessThan(resultIndex, curIndex); + + lessThanMask |= (~negativeMask & equalMask) | (IsNegative(result) & equalMask & lessThanIndexMask.AsSingle()); + } + + result = ElementWiseSelect(lessThanMask, result, current); + + resultIndex = ElementWiseSelect(lessThanMask.AsInt32(), resultIndex, curIndex); + } + + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(Vector256 result, Vector256 resultIndex) + { + // Min the upper/lower halves of the Vector256 + Vector128 resultLower = result.GetLower(); + Vector128 indexLower = resultIndex.GetLower(); + + Invoke(ref resultLower, result.GetUpper(), ref indexLower, resultIndex.GetUpper()); + return Invoke(resultLower, indexLower); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Invoke(ref Vector256 result, Vector256 current, ref Vector256 resultIndex, Vector256 curIndex) + { + Vector256 lessThanMask = Vector256.LessThan(result, current); + + Vector256 equalMask = Vector256.Equals(result, current); + if (equalMask.AsInt32() != Vector256.Zero) + { + Vector256 negativeMask = IsNegative(current); + Vector256 lessThanIndexMask = Vector256.LessThan(resultIndex, curIndex); + + lessThanMask |= (~negativeMask & equalMask) | (IsNegative(result) & equalMask & lessThanIndexMask.AsSingle()); + } + + result = ElementWiseSelect(lessThanMask, result, current); + + resultIndex = ElementWiseSelect(lessThanMask.AsInt32(), resultIndex, curIndex); + } + +#if NET8_0_OR_GREATER + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(Vector512 result, Vector512 resultIndex) + { + // Min the upper/lower halves of the Vector512 + Vector256 resultLower = result.GetLower(); + Vector256 indexLower = resultIndex.GetLower(); + + Invoke(ref resultLower, result.GetUpper(), ref indexLower, resultIndex.GetUpper()); + return Invoke(resultLower, indexLower); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Invoke(ref Vector512 result, Vector512 current, ref Vector512 resultIndex, Vector512 curIndex) + { + Vector512 lessThanMask = Vector512.LessThan(result, current); + + Vector512 equalMask = Vector512.Equals(result, current); + if (equalMask.AsInt32() != Vector512.Zero) + { + Vector512 negativeMask = IsNegative(current); + Vector512 lessThanIndexMask = Vector512.LessThan(resultIndex, curIndex); + + lessThanMask |= (~negativeMask & equalMask) | (IsNegative(result) & equalMask & lessThanIndexMask.AsSingle()); + } + + result = ElementWiseSelect(lessThanMask, result, current); + + resultIndex = ElementWiseSelect(lessThanMask.AsInt32(), resultIndex, curIndex); + } +#endif + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(ref float result, float current, int resultIndex, int curIndex) + { + if (result == current) + { + if (IsPositive(result) && !IsPositive(current)) + { + result = current; + return curIndex; + } + } + else if (current < result) + { + result = current; + return curIndex; + } + + return resultIndex; + } + } + + private readonly struct IndexOfMinMagnitudeOperator : IIndexOfOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(Vector128 result, Vector128 resultIndex) + { + Vector128 tmpResult = Vector128.Shuffle(result, Vector128.Create(2, 3, 0, 1)); + Vector128 tmpIndex = Vector128.Shuffle(resultIndex, Vector128.Create(2, 3, 0, 1)); + + Invoke(ref result, tmpResult, ref resultIndex, tmpIndex); + + tmpResult = Vector128.Shuffle(result, Vector128.Create(1, 0, 3, 2)); + tmpIndex = Vector128.Shuffle(resultIndex, Vector128.Create(1, 0, 3, 2)); + + Invoke(ref result, tmpResult, ref resultIndex, tmpIndex); + return resultIndex.ToScalar(); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Invoke(ref Vector128 result, Vector128 current, ref Vector128 resultIndex, Vector128 curIndex) + { + Vector128 minMag = Vector128.Abs(result), currentMag = Vector128.Abs(current); + + Vector128 lessThanMask = Vector128.LessThan(minMag, currentMag); + + Vector128 equalMask = Vector128.Equals(result, current); + if (equalMask.AsInt32() != Vector128.Zero) + { + Vector128 negativeMask = IsNegative(current); + Vector128 lessThanIndexMask = Vector128.LessThan(resultIndex, curIndex); + + lessThanMask |= (~negativeMask & equalMask) | (IsNegative(result) & equalMask & lessThanIndexMask.AsSingle()); + } + + result = ElementWiseSelect(lessThanMask, result, current); + + resultIndex = ElementWiseSelect(lessThanMask.AsInt32(), resultIndex, curIndex); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(Vector256 result, Vector256 resultIndex) + { + // Min the upper/lower halves of the Vector256 + Vector128 resultLower = result.GetLower(); + Vector128 indexLower = resultIndex.GetLower(); + + Invoke(ref resultLower, result.GetUpper(), ref indexLower, resultIndex.GetUpper()); + return Invoke(resultLower, indexLower); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Invoke(ref Vector256 result, Vector256 current, ref Vector256 resultIndex, Vector256 curIndex) + { + Vector256 minMag = Vector256.Abs(result), currentMag = Vector256.Abs(current); + + Vector256 lessThanMask = Vector256.LessThan(minMag, currentMag); + + Vector256 equalMask = Vector256.Equals(result, current); + if (equalMask.AsInt32() != Vector256.Zero) + { + Vector256 negativeMask = IsNegative(current); + Vector256 lessThanIndexMask = Vector256.LessThan(resultIndex, curIndex); + + lessThanMask |= (~negativeMask & equalMask) | (IsNegative(result) & equalMask & lessThanIndexMask.AsSingle()); + } + + result = ElementWiseSelect(lessThanMask, result, current); + + resultIndex = ElementWiseSelect(lessThanMask.AsInt32(), resultIndex, curIndex); + } + +#if NET8_0_OR_GREATER + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(Vector512 result, Vector512 resultIndex) + { + // Min the upper/lower halves of the Vector512 + Vector256 resultLower = result.GetLower(); + Vector256 indexLower = resultIndex.GetLower(); + + Invoke(ref resultLower, result.GetUpper(), ref indexLower, resultIndex.GetUpper()); + return Invoke(resultLower, indexLower); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Invoke(ref Vector512 result, Vector512 current, ref Vector512 resultIndex, Vector512 curIndex) + { + Vector512 minMag = Vector512.Abs(result), currentMag = Vector512.Abs(current); + + Vector512 lessThanMask = Vector512.LessThan(minMag, currentMag); + + Vector512 equalMask = Vector512.Equals(result, current); + if (equalMask.AsInt32() != Vector512.Zero) + { + Vector512 negativeMask = IsNegative(current); + Vector512 lessThanIndexMask = Vector512.LessThan(resultIndex, curIndex); + + lessThanMask |= (~negativeMask & equalMask) | (IsNegative(result) & equalMask & lessThanIndexMask.AsSingle()); + } + + result = ElementWiseSelect(lessThanMask, result, current); + + resultIndex = ElementWiseSelect(lessThanMask.AsInt32(), resultIndex, curIndex); + } +#endif + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(ref float result, float current, int resultIndex, int curIndex) + { + float curMinAbs = MathF.Abs(result); + float currentAbs = MathF.Abs(current); + if (curMinAbs == currentAbs) + { + if (IsPositive(result) && !IsPositive(current)) + { + result = current; + return curIndex; + } + } + else if (currentAbs < curMinAbs) + { + result = current; + return curIndex; + } + + return resultIndex; + } + } + + /// MathF.Max(x, y) + private readonly struct MaxPropagateNaNOperator : IBinaryOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static float Invoke(float x, float y) => MathF.Max(x, y); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector128 Invoke(Vector128 x, Vector128 y) + { + if (AdvSimd.IsSupported) + { + return AdvSimd.Max(x, y); + } + + return + Vector128.ConditionalSelect(Vector128.Equals(x, x), + Vector128.ConditionalSelect(Vector128.Equals(y, y), + Vector128.ConditionalSelect(Vector128.Equals(x, y), + Vector128.ConditionalSelect(IsNegative(x), y, x), + Vector128.Max(x, y)), + y), + x); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector256 Invoke(Vector256 x, Vector256 y) => + Vector256.ConditionalSelect(Vector256.Equals(x, x), + Vector256.ConditionalSelect(Vector256.Equals(y, y), + Vector256.ConditionalSelect(Vector256.Equals(x, y), + Vector256.ConditionalSelect(IsNegative(x), y, x), + Vector256.Max(x, y)), + y), + x); + +#if NET8_0_OR_GREATER + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector512 Invoke(Vector512 x, Vector512 y) => + Vector512.ConditionalSelect(Vector512.Equals(x, x), + Vector512.ConditionalSelect(Vector512.Equals(y, y), + Vector512.ConditionalSelect(Vector512.Equals(x, y), + Vector512.ConditionalSelect(IsNegative(x), y, x), + Vector512.Max(x, y)), + y), + x); +#endif + } + + /// Operator to get x or y based on which has the larger MathF.Abs (but NaNs may not be propagated) + private readonly struct MaxMagnitudeOperator : IAggregationOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static float Invoke(float x, float y) + { + float xMag = MathF.Abs(x), yMag = MathF.Abs(y); + return + xMag == yMag ? + (IsNegative(x) ? y : x) : + (xMag > yMag ? x : y); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector128 Invoke(Vector128 x, Vector128 y) + { + Vector128 xMag = Vector128.Abs(x), yMag = Vector128.Abs(y); + return + Vector128.ConditionalSelect(Vector128.Equals(xMag, yMag), + Vector128.ConditionalSelect(IsNegative(x), y, x), + Vector128.ConditionalSelect(Vector128.GreaterThan(xMag, yMag), x, y)); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector256 Invoke(Vector256 x, Vector256 y) + { + Vector256 xMag = Vector256.Abs(x), yMag = Vector256.Abs(y); + return + Vector256.ConditionalSelect(Vector256.Equals(xMag, yMag), + Vector256.ConditionalSelect(IsNegative(x), y, x), + Vector256.ConditionalSelect(Vector256.GreaterThan(xMag, yMag), x, y)); + } + +#if NET8_0_OR_GREATER + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector512 Invoke(Vector512 x, Vector512 y) + { + Vector512 xMag = Vector512.Abs(x), yMag = Vector512.Abs(y); + return + Vector512.ConditionalSelect(Vector512.Equals(xMag, yMag), + Vector512.ConditionalSelect(IsNegative(x), y, x), + Vector512.ConditionalSelect(Vector512.GreaterThan(xMag, yMag), x, y)); + } +#endif + + public static float Invoke(Vector128 x) => HorizontalAggregate(x); + public static float Invoke(Vector256 x) => HorizontalAggregate(x); +#if NET8_0_OR_GREATER + public static float Invoke(Vector512 x) => HorizontalAggregate(x); +#endif + } + + /// Operator to get x or y based on which has the larger MathF.Abs + private readonly struct MaxMagnitudePropagateNaNOperator : IBinaryOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static float Invoke(float x, float y) => MathF.MaxMagnitude(x, y); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector128 Invoke(Vector128 x, Vector128 y) + { + Vector128 xMag = Vector128.Abs(x), yMag = Vector128.Abs(y); + return + Vector128.ConditionalSelect(Vector128.Equals(x, x), + Vector128.ConditionalSelect(Vector128.Equals(y, y), + Vector128.ConditionalSelect(Vector128.Equals(yMag, xMag), + Vector128.ConditionalSelect(IsNegative(x), y, x), + Vector128.ConditionalSelect(Vector128.GreaterThan(yMag, xMag), y, x)), + y), + x); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector256 Invoke(Vector256 x, Vector256 y) + { + Vector256 xMag = Vector256.Abs(x), yMag = Vector256.Abs(y); + return + Vector256.ConditionalSelect(Vector256.Equals(x, x), + Vector256.ConditionalSelect(Vector256.Equals(y, y), + Vector256.ConditionalSelect(Vector256.Equals(xMag, yMag), + Vector256.ConditionalSelect(IsNegative(x), y, x), + Vector256.ConditionalSelect(Vector256.GreaterThan(xMag, yMag), x, y)), + y), + x); + } + +#if NET8_0_OR_GREATER + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector512 Invoke(Vector512 x, Vector512 y) + { + Vector512 xMag = Vector512.Abs(x), yMag = Vector512.Abs(y); + return + Vector512.ConditionalSelect(Vector512.Equals(x, x), + Vector512.ConditionalSelect(Vector512.Equals(y, y), + Vector512.ConditionalSelect(Vector512.Equals(xMag, yMag), + Vector512.ConditionalSelect(IsNegative(x), y, x), + Vector512.ConditionalSelect(Vector512.GreaterThan(xMag, yMag), x, y)), + y), + x); + } +#endif + } + + /// MathF.Min(x, y) (but NaNs may not be propagated) + private readonly struct MinOperator : IAggregationOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static float Invoke(float x, float y) => + x == y ? + (IsNegative(y) ? y : x) : + (y < x ? y : x); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector128 Invoke(Vector128 x, Vector128 y) + { + if (AdvSimd.IsSupported) + { + return AdvSimd.Min(x, y); + } + + return + Vector128.ConditionalSelect(Vector128.Equals(x, y), + Vector128.ConditionalSelect(IsNegative(y), y, x), + Vector128.Min(x, y)); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector256 Invoke(Vector256 x, Vector256 y) => + Vector256.ConditionalSelect(Vector256.Equals(x, y), + Vector256.ConditionalSelect(IsNegative(y), y, x), + Vector256.Min(x, y)); + +#if NET8_0_OR_GREATER + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector512 Invoke(Vector512 x, Vector512 y) => + Vector512.ConditionalSelect(Vector512.Equals(x, y), + Vector512.ConditionalSelect(IsNegative(y), y, x), + Vector512.Min(x, y)); +#endif + + public static float Invoke(Vector128 x) => HorizontalAggregate(x); + public static float Invoke(Vector256 x) => HorizontalAggregate(x); +#if NET8_0_OR_GREATER + public static float Invoke(Vector512 x) => HorizontalAggregate(x); +#endif + } + + /// MathF.Min(x, y) + private readonly struct MinPropagateNaNOperator : IBinaryOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static float Invoke(float x, float y) => MathF.Min(x, y); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector128 Invoke(Vector128 x, Vector128 y) + { + if (AdvSimd.IsSupported) + { + return AdvSimd.Min(x, y); + } + + return + Vector128.ConditionalSelect(Vector128.Equals(x, x), + Vector128.ConditionalSelect(Vector128.Equals(y, y), + Vector128.ConditionalSelect(Vector128.Equals(x, y), + Vector128.ConditionalSelect(IsNegative(x), x, y), + Vector128.Min(x, y)), + y), + x); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector256 Invoke(Vector256 x, Vector256 y) => + Vector256.ConditionalSelect(Vector256.Equals(x, x), + Vector256.ConditionalSelect(Vector256.Equals(y, y), + Vector256.ConditionalSelect(Vector256.Equals(x, y), + Vector256.ConditionalSelect(IsNegative(x), x, y), + Vector256.Min(x, y)), + y), + x); + +#if NET8_0_OR_GREATER + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector512 Invoke(Vector512 x, Vector512 y) => + Vector512.ConditionalSelect(Vector512.Equals(x, x), + Vector512.ConditionalSelect(Vector512.Equals(y, y), + Vector512.ConditionalSelect(Vector512.Equals(x, y), + Vector512.ConditionalSelect(IsNegative(x), x, y), + Vector512.Min(x, y)), + y), + x); +#endif + } + + /// Operator to get x or y based on which has the smaller MathF.Abs (but NaNs may not be propagated) + private readonly struct MinMagnitudeOperator : IAggregationOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static float Invoke(float x, float y) + { + float xMag = MathF.Abs(x), yMag = MathF.Abs(y); + return xMag == yMag ? + (IsNegative(y) ? y : x) : + (yMag < xMag ? y : x); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector128 Invoke(Vector128 x, Vector128 y) + { + Vector128 xMag = Vector128.Abs(x), yMag = Vector128.Abs(y); + return + Vector128.ConditionalSelect(Vector128.Equals(yMag, xMag), + Vector128.ConditionalSelect(IsNegative(y), y, x), + Vector128.ConditionalSelect(Vector128.LessThan(yMag, xMag), y, x)); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector256 Invoke(Vector256 x, Vector256 y) + { + Vector256 xMag = Vector256.Abs(x), yMag = Vector256.Abs(y); + return + Vector256.ConditionalSelect(Vector256.Equals(yMag, xMag), + Vector256.ConditionalSelect(IsNegative(y), y, x), + Vector256.ConditionalSelect(Vector256.LessThan(yMag, xMag), y, x)); + } + +#if NET8_0_OR_GREATER + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector512 Invoke(Vector512 x, Vector512 y) + { + Vector512 xMag = Vector512.Abs(x), yMag = Vector512.Abs(y); + return + Vector512.ConditionalSelect(Vector512.Equals(yMag, xMag), + Vector512.ConditionalSelect(IsNegative(y), y, x), + Vector512.ConditionalSelect(Vector512.LessThan(yMag, xMag), y, x)); + } +#endif + + public static float Invoke(Vector128 x) => HorizontalAggregate(x); + public static float Invoke(Vector256 x) => HorizontalAggregate(x); +#if NET8_0_OR_GREATER + public static float Invoke(Vector512 x) => HorizontalAggregate(x); +#endif + } + + /// Operator to get x or y based on which has the smaller MathF.Abs + private readonly struct MinMagnitudePropagateNaNOperator : IBinaryOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static float Invoke(float x, float y) => MathF.MinMagnitude(x, y); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector128 Invoke(Vector128 x, Vector128 y) + { + Vector128 xMag = Vector128.Abs(x), yMag = Vector128.Abs(y); + return + Vector128.ConditionalSelect(Vector128.Equals(x, x), + Vector128.ConditionalSelect(Vector128.Equals(y, y), + Vector128.ConditionalSelect(Vector128.Equals(yMag, xMag), + Vector128.ConditionalSelect(IsNegative(x), x, y), + Vector128.ConditionalSelect(Vector128.LessThan(xMag, yMag), x, y)), + y), + x); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector256 Invoke(Vector256 x, Vector256 y) + { + Vector256 xMag = Vector256.Abs(x), yMag = Vector256.Abs(y); + return + Vector256.ConditionalSelect(Vector256.Equals(x, x), + Vector256.ConditionalSelect(Vector256.Equals(y, y), + Vector256.ConditionalSelect(Vector256.Equals(yMag, xMag), + Vector256.ConditionalSelect(IsNegative(x), x, y), + Vector256.ConditionalSelect(Vector256.LessThan(xMag, yMag), x, y)), + y), + x); + } + +#if NET8_0_OR_GREATER + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector512 Invoke(Vector512 x, Vector512 y) + { + Vector512 xMag = Vector512.Abs(x), yMag = Vector512.Abs(y); + return + Vector512.ConditionalSelect(Vector512.Equals(x, x), + Vector512.ConditionalSelect(Vector512.Equals(y, y), + Vector512.ConditionalSelect(Vector512.Equals(yMag, xMag), + Vector512.ConditionalSelect(IsNegative(x), x, y), + Vector512.ConditionalSelect(Vector512.LessThan(xMag, yMag), x, y)), + y), + x); + } +#endif + } + + /// -x + private readonly struct NegateOperator : IUnaryOperator + { + public static float Invoke(float x) => -x; + public static Vector128 Invoke(Vector128 x) => -x; + public static Vector256 Invoke(Vector256 x) => -x; +#if NET8_0_OR_GREATER + public static Vector512 Invoke(Vector512 x) => -x; +#endif + } + + /// (x + y) * z + private readonly struct AddMultiplyOperator : ITernaryOperator + { + public static float Invoke(float x, float y, float z) => (x + y) * z; + public static Vector128 Invoke(Vector128 x, Vector128 y, Vector128 z) => (x + y) * z; + public static Vector256 Invoke(Vector256 x, Vector256 y, Vector256 z) => (x + y) * z; +#if NET8_0_OR_GREATER + public static Vector512 Invoke(Vector512 x, Vector512 y, Vector512 z) => (x + y) * z; +#endif + } + + /// (x * y) + z + private readonly struct MultiplyAddOperator : ITernaryOperator + { + public static float Invoke(float x, float y, float z) => (x * y) + z; + public static Vector128 Invoke(Vector128 x, Vector128 y, Vector128 z) => (x * y) + z; + public static Vector256 Invoke(Vector256 x, Vector256 y, Vector256 z) => (x * y) + z; +#if NET8_0_OR_GREATER + public static Vector512 Invoke(Vector512 x, Vector512 y, Vector512 z) => (x * y) + z; +#endif + } + + /// x + private readonly struct IdentityOperator : IUnaryOperator + { + public static float Invoke(float x) => x; + public static Vector128 Invoke(Vector128 x) => x; + public static Vector256 Invoke(Vector256 x) => x; +#if NET8_0_OR_GREATER + public static Vector512 Invoke(Vector512 x) => x; +#endif + } + + /// x * x + private readonly struct SquaredOperator : IUnaryOperator + { + public static float Invoke(float x) => x * x; + public static Vector128 Invoke(Vector128 x) => x * x; + public static Vector256 Invoke(Vector256 x) => x * x; +#if NET8_0_OR_GREATER + public static Vector512 Invoke(Vector512 x) => x * x; +#endif + } + + /// MathF.Abs(x) + private readonly struct AbsoluteOperator : IUnaryOperator + { + public static float Invoke(float x) => MathF.Abs(x); + public static Vector128 Invoke(Vector128 x) => Vector128.Abs(x); + public static Vector256 Invoke(Vector256 x) => Vector256.Abs(x); +#if NET8_0_OR_GREATER + public static Vector512 Invoke(Vector512 x) => Vector512.Abs(x); +#endif + } + + /// MathF.Exp(x) + private readonly struct ExpOperator : IUnaryOperator + { + // This code is based on `vrs4_expf` from amd/aocl-libm-ose + // Copyright (C) 2019-2022 Advanced Micro Devices, Inc. All rights reserved. + // + // Licensed under the BSD 3-Clause "New" or "Revised" License + // See THIRD-PARTY-NOTICES.TXT for the full license text + + // Implementation Notes: + // 1. Argument Reduction: + // e^x = 2^(x/ln2) --- (1) + // + // Let x/ln(2) = z --- (2) + // + // Let z = n + r , where n is an integer --- (3) + // |r| <= 1/2 + // + // From (1), (2) and (3), + // e^x = 2^z + // = 2^(N+r) + // = (2^N)*(2^r) --- (4) + // + // 2. Polynomial Evaluation + // From (4), + // r = z - N + // 2^r = C1 + C2*r + C3*r^2 + C4*r^3 + C5 *r^4 + C6*r^5 + // + // 4. Reconstruction + // Thus, + // e^x = (2^N) * (2^r) + + private const uint V_ARG_MAX = 0x42AE0000; + private const uint V_MASK = 0x7FFFFFFF; + + private const float V_EXPF_MIN = -103.97208f; + private const float V_EXPF_MAX = 88.72284f; + + private const double V_EXPF_HUGE = 6755399441055744; + private const double V_TBL_LN2 = 1.4426950408889634; + + private const double C1 = 1.0000000754895704; + private const double C2 = 0.6931472254087585; + private const double C3 = 0.2402210737432219; + private const double C4 = 0.05550297297702539; + private const double C5 = 0.009676036358193323; + private const double C6 = 0.001341000536524434; + + public static float Invoke(float x) => MathF.Exp(x); + + public static Vector128 Invoke(Vector128 x) + { + // Convert x to double precision + (Vector128 xl, Vector128 xu) = Vector128.Widen(x); + + // x * (64.0 / ln(2)) + Vector128 v_tbl_ln2 = Vector128.Create(V_TBL_LN2); + + Vector128 zl = xl * v_tbl_ln2; + Vector128 zu = xu * v_tbl_ln2; + + Vector128 v_expf_huge = Vector128.Create(V_EXPF_HUGE); + + Vector128 dnl = zl + v_expf_huge; + Vector128 dnu = zu + v_expf_huge; + + // n = int (z) + Vector128 nl = dnl.AsUInt64(); + Vector128 nu = dnu.AsUInt64(); + + // dn = double(n) + dnl -= v_expf_huge; + dnu -= v_expf_huge; + + // r = z - dn + Vector128 c1 = Vector128.Create(C1); + Vector128 c2 = Vector128.Create(C2); + Vector128 c3 = Vector128.Create(C3); + Vector128 c4 = Vector128.Create(C4); + Vector128 c5 = Vector128.Create(C5); + Vector128 c6 = Vector128.Create(C6); + + Vector128 rl = zl - dnl; + + Vector128 rl2 = rl * rl; + Vector128 rl4 = rl2 * rl2; + + Vector128 polyl = (c4 * rl + c3) * rl2 + + ((c6 * rl + c5) * rl4 + + (c2 * rl + c1)); + + + Vector128 ru = zu - dnu; + + Vector128 ru2 = ru * ru; + Vector128 ru4 = ru2 * ru2; + + Vector128 polyu = (c4 * ru + c3) * ru2 + + ((c6 * ru + c5) * ru4 + + (c2 * ru + c1)); + + // result = (float)[poly + (n << 52)] + Vector128 ret = Vector128.Narrow( + (polyl.AsUInt64() + Vector128.ShiftLeft(nl, 52)).AsDouble(), + (polyu.AsUInt64() + Vector128.ShiftLeft(nu, 52)).AsDouble() + ); + + // Check if -103 < |x| < 88 + if (Vector128.GreaterThanAny(x.AsUInt32() & Vector128.Create(V_MASK), Vector128.Create(V_ARG_MAX))) + { + // (x > V_EXPF_MAX) ? float.PositiveInfinity : x + Vector128 infinityMask = Vector128.GreaterThan(x, Vector128.Create(V_EXPF_MAX)); + + ret = Vector128.ConditionalSelect( + infinityMask, + Vector128.Create(float.PositiveInfinity), + ret + ); + + // (x < V_EXPF_MIN) ? 0 : x + ret = Vector128.AndNot(ret, Vector128.LessThan(x, Vector128.Create(V_EXPF_MIN))); + } + + return ret; + } + + public static Vector256 Invoke(Vector256 x) + { + // Convert x to double precision + (Vector256 xl, Vector256 xu) = Vector256.Widen(x); + + // x * (64.0 / ln(2)) + Vector256 v_tbl_ln2 = Vector256.Create(V_TBL_LN2); + + Vector256 zl = xl * v_tbl_ln2; + Vector256 zu = xu * v_tbl_ln2; + + Vector256 v_expf_huge = Vector256.Create(V_EXPF_HUGE); + + Vector256 dnl = zl + v_expf_huge; + Vector256 dnu = zu + v_expf_huge; + + // n = int (z) + Vector256 nl = dnl.AsUInt64(); + Vector256 nu = dnu.AsUInt64(); + + // dn = double(n) + dnl -= v_expf_huge; + dnu -= v_expf_huge; + + // r = z - dn + Vector256 c1 = Vector256.Create(C1); + Vector256 c2 = Vector256.Create(C2); + Vector256 c3 = Vector256.Create(C3); + Vector256 c4 = Vector256.Create(C4); + Vector256 c5 = Vector256.Create(C5); + Vector256 c6 = Vector256.Create(C6); + + Vector256 rl = zl - dnl; + + Vector256 rl2 = rl * rl; + Vector256 rl4 = rl2 * rl2; + + Vector256 polyl = (c4 * rl + c3) * rl2 + + ((c6 * rl + c5) * rl4 + + (c2 * rl + c1)); + + + Vector256 ru = zu - dnu; + + Vector256 ru2 = ru * ru; + Vector256 ru4 = ru2 * ru2; + + Vector256 polyu = (c4 * ru + c3) * ru2 + + ((c6 * ru + c5) * ru4 + + (c2 * ru + c1)); + + // result = (float)[poly + (n << 52)] + Vector256 ret = Vector256.Narrow( + (polyl.AsUInt64() + Vector256.ShiftLeft(nl, 52)).AsDouble(), + (polyu.AsUInt64() + Vector256.ShiftLeft(nu, 52)).AsDouble() + ); + + // Check if -103 < |x| < 88 + if (Vector256.GreaterThanAny(x.AsUInt32() & Vector256.Create(V_MASK), Vector256.Create(V_ARG_MAX))) + { + // (x > V_EXPF_MAX) ? float.PositiveInfinity : x + Vector256 infinityMask = Vector256.GreaterThan(x, Vector256.Create(V_EXPF_MAX)); + + ret = Vector256.ConditionalSelect( + infinityMask, + Vector256.Create(float.PositiveInfinity), + ret + ); + + // (x < V_EXPF_MIN) ? 0 : x + ret = Vector256.AndNot(ret, Vector256.LessThan(x, Vector256.Create(V_EXPF_MIN))); + } + + return ret; + } + +#if NET8_0_OR_GREATER + public static Vector512 Invoke(Vector512 x) + { + // Convert x to double precision + (Vector512 xl, Vector512 xu) = Vector512.Widen(x); + + // x * (64.0 / ln(2)) + Vector512 v_tbl_ln2 = Vector512.Create(V_TBL_LN2); + + Vector512 zl = xl * v_tbl_ln2; + Vector512 zu = xu * v_tbl_ln2; + + Vector512 v_expf_huge = Vector512.Create(V_EXPF_HUGE); + + Vector512 dnl = zl + v_expf_huge; + Vector512 dnu = zu + v_expf_huge; + + // n = int (z) + Vector512 nl = dnl.AsUInt64(); + Vector512 nu = dnu.AsUInt64(); + + // dn = double(n) + dnl -= v_expf_huge; + dnu -= v_expf_huge; + + // r = z - dn + Vector512 c1 = Vector512.Create(C1); + Vector512 c2 = Vector512.Create(C2); + Vector512 c3 = Vector512.Create(C3); + Vector512 c4 = Vector512.Create(C4); + Vector512 c5 = Vector512.Create(C5); + Vector512 c6 = Vector512.Create(C6); + + Vector512 rl = zl - dnl; + + Vector512 rl2 = rl * rl; + Vector512 rl4 = rl2 * rl2; + + Vector512 polyl = (c4 * rl + c3) * rl2 + + ((c6 * rl + c5) * rl4 + + (c2 * rl + c1)); + + + Vector512 ru = zu - dnu; + + Vector512 ru2 = ru * ru; + Vector512 ru4 = ru2 * ru2; + + Vector512 polyu = (c4 * ru + c3) * ru2 + + ((c6 * ru + c5) * ru4 + + (c2 * ru + c1)); + + // result = (float)[poly + (n << 52)] + Vector512 ret = Vector512.Narrow( + (polyl.AsUInt64() + Vector512.ShiftLeft(nl, 52)).AsDouble(), + (polyu.AsUInt64() + Vector512.ShiftLeft(nu, 52)).AsDouble() + ); + + // Check if -103 < |x| < 88 + if (Vector512.GreaterThanAny(x.AsUInt32() & Vector512.Create(V_MASK), Vector512.Create(V_ARG_MAX))) + { + // (x > V_EXPF_MAX) ? float.PositiveInfinity : x + Vector512 infinityMask = Vector512.GreaterThan(x, Vector512.Create(V_EXPF_MAX)); + + ret = Vector512.ConditionalSelect( + infinityMask, + Vector512.Create(float.PositiveInfinity), + ret + ); + + // (x < V_EXPF_MIN) ? 0 : x + ret = Vector512.AndNot(ret, Vector512.LessThan(x, Vector512.Create(V_EXPF_MIN))); + } + + return ret; + } +#endif + } + + /// MathF.Cosh(x) + private readonly struct CoshOperator : IUnaryOperator + { + // This code is based on `vrs4_coshf` from amd/aocl-libm-ose + // Copyright (C) 2008-2022 Advanced Micro Devices, Inc. All rights reserved. + // + // Licensed under the BSD 3-Clause "New" or "Revised" License + // See THIRD-PARTY-NOTICES.TXT for the full license text + + // Spec: + // coshf(|x| > 89.415985107421875) = Infinity + // coshf(Infinity) = infinity + // coshf(-Infinity) = infinity + // + // cosh(x) = (exp(x) + exp(-x))/2 + // cosh(-x) = +cosh(x) + // + // checks for special cases + // if ( asint(x) > infinity) return x with overflow exception and + // return x. + // if x is NaN then raise invalid FP operation exception and return x. + // + // coshf = v/2 * exp(x - log(v)) where v = 0x1.0000e8p-1 + + private const float LOGV = 0.693161f; + private const float HALFV = 1.0000138f; + private const float INVV2 = 0.24999309f; + + public static float Invoke(float x) => MathF.Cosh(x); + + public static Vector128 Invoke(Vector128 x) + { + Vector128 y = Vector128.Abs(x); + Vector128 z = ExpOperator.Invoke(y - Vector128.Create(LOGV)); + return Vector128.Create(HALFV) * (z + (Vector128.Create(INVV2) / z)); + } + + public static Vector256 Invoke(Vector256 x) + { + Vector256 y = Vector256.Abs(x); + Vector256 z = ExpOperator.Invoke(y - Vector256.Create(LOGV)); + return Vector256.Create(HALFV) * (z + (Vector256.Create(INVV2) / z)); + } + +#if NET8_0_OR_GREATER + public static Vector512 Invoke(Vector512 x) + { + Vector512 y = Vector512.Abs(x); + Vector512 z = ExpOperator.Invoke(y - Vector512.Create(LOGV)); + return Vector512.Create(HALFV) * (z + (Vector512.Create(INVV2) / z)); + } +#endif + } + + /// MathF.Sinh(x) + private readonly struct SinhOperator : IUnaryOperator + { + // Same as cosh, but with `z -` rather than `z +`, and with the sign + // flipped on the result based on the sign of the input. + + private const uint SIGN_MASK = 0x7FFFFFFF; + private const float LOGV = 0.693161f; + private const float HALFV = 1.0000138f; + private const float INVV2 = 0.24999309f; + + public static float Invoke(float x) => MathF.Sinh(x); + + public static Vector128 Invoke(Vector128 x) + { + Vector128 y = Vector128.Abs(x); + Vector128 z = ExpOperator.Invoke(y - Vector128.Create(LOGV)); + Vector128 result = Vector128.Create(HALFV) * (z - (Vector128.Create(INVV2) / z)); + Vector128 sign = x.AsUInt32() & Vector128.Create(~SIGN_MASK); + return (sign ^ result.AsUInt32()).AsSingle(); + } + + public static Vector256 Invoke(Vector256 x) + { + Vector256 y = Vector256.Abs(x); + Vector256 z = ExpOperator.Invoke(y - Vector256.Create(LOGV)); + Vector256 result = Vector256.Create(HALFV) * (z - (Vector256.Create(INVV2) / z)); + Vector256 sign = x.AsUInt32() & Vector256.Create(~SIGN_MASK); + return (sign ^ result.AsUInt32()).AsSingle(); + } + +#if NET8_0_OR_GREATER + public static Vector512 Invoke(Vector512 x) + { + Vector512 y = Vector512.Abs(x); + Vector512 z = ExpOperator.Invoke(y - Vector512.Create(LOGV)); + Vector512 result = Vector512.Create(HALFV) * (z - (Vector512.Create(INVV2) / z)); + Vector512 sign = x.AsUInt32() & Vector512.Create(~SIGN_MASK); + return (sign ^ result.AsUInt32()).AsSingle(); + } +#endif + } + + /// MathF.Tanh(x) + private readonly struct TanhOperator : IUnaryOperator + { + // This code is based on `vrs4_tanhf` from amd/aocl-libm-ose + // Copyright (C) 2008-2022 Advanced Micro Devices, Inc. All rights reserved. + // + // Licensed under the BSD 3-Clause "New" or "Revised" License + // See THIRD-PARTY-NOTICES.TXT for the full license text + + // To compute vrs4_tanhf(v_f32x4_t x) + // Let y = |x| + // If 0 <= y < 0x1.154246p3 + // Let z = e^(-2.0 * y) - 1 -(1) + // + // Using (1), tanhf(y) can be calculated as, + // tanhf(y) = -z / (z + 2.0) + // + // For other cases, call scalar tanhf() + // + // If x < 0, then we use the identity + // tanhf(-x) = -tanhf(x) + + private const uint SIGN_MASK = 0x7FFFFFFF; + + public static float Invoke(float x) => MathF.Tanh(x); + + public static Vector128 Invoke(Vector128 x) + { + Vector128 y = Vector128.Abs(x); + Vector128 z = ExpOperator.Invoke(Vector128.Create(-2f) * y) - Vector128.Create(1f); + Vector128 sign = x.AsUInt32() & Vector128.Create(~SIGN_MASK); + return (sign ^ (-z / (z + Vector128.Create(2f))).AsUInt32()).AsSingle(); + } + + public static Vector256 Invoke(Vector256 x) + { + Vector256 y = Vector256.Abs(x); + Vector256 z = ExpOperator.Invoke(Vector256.Create(-2f) * y) - Vector256.Create(1f); + Vector256 sign = x.AsUInt32() & Vector256.Create(~SIGN_MASK); + return (sign ^ (-z / (z + Vector256.Create(2f))).AsUInt32()).AsSingle(); + } + +#if NET8_0_OR_GREATER + public static Vector512 Invoke(Vector512 x) + { + Vector512 y = Vector512.Abs(x); + Vector512 z = ExpOperator.Invoke(Vector512.Create(-2f) * y) - Vector512.Create(1f); + Vector512 sign = x.AsUInt32() & Vector512.Create(~SIGN_MASK); + return (sign ^ (-z / (z + Vector512.Create(2f))).AsUInt32()).AsSingle(); + } +#endif + } + + /// MathF.Log(x) + private readonly struct LogOperator : IUnaryOperator + { + // This code is based on `vrs4_logf` from amd/aocl-libm-ose + // Copyright (C) 2018-2019 Advanced Micro Devices, Inc. All rights reserved. + // + // Licensed under the BSD 3-Clause "New" or "Revised" License + // See THIRD-PARTY-NOTICES.TXT for the full license text + + // Spec: + // logf(x) + // = logf(x) if x ∈ F and x > 0 + // = x if x = qNaN + // = 0 if x = 1 + // = -inf if x = (-0, 0} + // = NaN otherwise + // + // Assumptions/Expectations + // - ULP is derived to be << 4 (always) + // - Some FPU Exceptions may not be available + // - Performance is at least 3x + // + // Implementation Notes: + // 1. Range Reduction: + // x = 2^n*(1+f) .... (1) + // where n is exponent and is an integer + // (1+f) is mantissa ∈ [1,2). i.e., 1 ≤ 1+f < 2 .... (2) + // + // From (1), taking log on both sides + // log(x) = log(2^n * (1+f)) + // = log(2^n) + log(1+f) + // = n*log(2) + log(1+f) .... (3) + // + // let z = 1 + f + // log(z) = log(k) + log(z) - log(k) + // log(z) = log(kz) - log(k) + // + // From (2), range of z is [1, 2) + // by simply dividing range by 'k', z is in [1/k, 2/k) .... (4) + // Best choice of k is the one which gives equal and opposite values + // at extrema +- -+ + // 1 | 2 | + // --- - 1 = - |--- - 1 | + // k | k | .... (5) + // +- -+ + // + // Solving for k, k = 3/2, + // From (4), using 'k' value, range is therefore [-0.3333, 0.3333] + // + // 2. Polynomial Approximation: + // More information refer to tools/sollya/vrs4_logf.sollya + // + // 7th Deg - Error abs: 0x1.04c4ac98p-22 rel: 0x1.2216e6f8p-19 + // 6th Deg - Error abs: 0x1.179e97d8p-19 rel: 0x1.db676c1p-17 + + private const uint V_MIN = 0x00800000; + private const uint V_MAX = 0x7F800000; + private const uint V_MASK = 0x007FFFFF; + private const uint V_OFF = 0x3F2AAAAB; + + private const float V_LN2 = 0.6931472f; + + private const float C0 = 0.0f; + private const float C1 = 1.0f; + private const float C2 = -0.5000001f; + private const float C3 = 0.33332965f; + private const float C4 = -0.24999046f; + private const float C5 = 0.20018855f; + private const float C6 = -0.16700386f; + private const float C7 = 0.13902695f; + private const float C8 = -0.1197452f; + private const float C9 = 0.14401625f; + private const float C10 = -0.13657966f; + + public static float Invoke(float x) => MathF.Log(x); + + public static Vector128 Invoke(Vector128 x) + { + Vector128 specialResult = x; + + // x is subnormal or infinity or NaN + Vector128 specialMask = Vector128.GreaterThanOrEqual(x.AsUInt32() - Vector128.Create(V_MIN), Vector128.Create(V_MAX - V_MIN)); + + if (specialMask != Vector128.Zero) + { + // float.IsZero(x) ? float.NegativeInfinity : x + Vector128 zeroMask = Vector128.Equals(x, Vector128.Zero); + + specialResult = Vector128.ConditionalSelect( + zeroMask, + Vector128.Create(float.NegativeInfinity), + specialResult + ); + + // (x < 0) ? float.NaN : x + Vector128 lessThanZeroMask = Vector128.LessThan(x, Vector128.Zero); + + specialResult = Vector128.ConditionalSelect( + lessThanZeroMask, + Vector128.Create(float.NaN), + specialResult + ); + + // float.IsZero(x) | (x < 0) | float.IsNaN(x) | float.IsPositiveInfinity(x) + Vector128 temp = zeroMask + | lessThanZeroMask + | ~Vector128.Equals(x, x) + | Vector128.Equals(x, Vector128.Create(float.PositiveInfinity)); + + // subnormal + Vector128 subnormalMask = Vector128.AndNot(specialMask.AsSingle(), temp); + + x = Vector128.ConditionalSelect( + subnormalMask, + ((x * 8388608.0f).AsUInt32() - Vector128.Create(23u << 23)).AsSingle(), + x + ); + + specialMask = temp.AsUInt32(); + } + + Vector128 vx = x.AsUInt32() - Vector128.Create(V_OFF); + Vector128 n = Vector128.ConvertToSingle(Vector128.ShiftRightArithmetic(vx.AsInt32(), 23)); + + vx = (vx & Vector128.Create(V_MASK)) + Vector128.Create(V_OFF); + + Vector128 r = vx.AsSingle() - Vector128.Create(1.0f); + + Vector128 r2 = r * r; + Vector128 r4 = r2 * r2; + Vector128 r8 = r4 * r4; + + Vector128 q = (Vector128.Create(C10) * r2 + (Vector128.Create(C9) * r + Vector128.Create(C8))) + * r8 + (((Vector128.Create(C7) * r + Vector128.Create(C6)) + * r2 + (Vector128.Create(C5) * r + Vector128.Create(C4))) + * r4 + ((Vector128.Create(C3) * r + Vector128.Create(C2)) + * r2 + (Vector128.Create(C1) * r + Vector128.Create(C0)))); + + return Vector128.ConditionalSelect( + specialMask.AsSingle(), + specialResult, + n * Vector128.Create(V_LN2) + q + ); + } + + public static Vector256 Invoke(Vector256 x) + { + Vector256 specialResult = x; + + // x is subnormal or infinity or NaN + Vector256 specialMask = Vector256.GreaterThanOrEqual(x.AsUInt32() - Vector256.Create(V_MIN), Vector256.Create(V_MAX - V_MIN)); + + if (specialMask != Vector256.Zero) + { + // float.IsZero(x) ? float.NegativeInfinity : x + Vector256 zeroMask = Vector256.Equals(x, Vector256.Zero); + + specialResult = Vector256.ConditionalSelect( + zeroMask, + Vector256.Create(float.NegativeInfinity), + specialResult + ); + + // (x < 0) ? float.NaN : x + Vector256 lessThanZeroMask = Vector256.LessThan(x, Vector256.Zero); + + specialResult = Vector256.ConditionalSelect( + lessThanZeroMask, + Vector256.Create(float.NaN), + specialResult + ); + + // float.IsZero(x) | (x < 0) | float.IsNaN(x) | float.IsPositiveInfinity(x) + Vector256 temp = zeroMask + | lessThanZeroMask + | ~Vector256.Equals(x, x) + | Vector256.Equals(x, Vector256.Create(float.PositiveInfinity)); + + // subnormal + Vector256 subnormalMask = Vector256.AndNot(specialMask.AsSingle(), temp); + + x = Vector256.ConditionalSelect( + subnormalMask, + ((x * 8388608.0f).AsUInt32() - Vector256.Create(23u << 23)).AsSingle(), + x + ); + + specialMask = temp.AsUInt32(); + } + + Vector256 vx = x.AsUInt32() - Vector256.Create(V_OFF); + Vector256 n = Vector256.ConvertToSingle(Vector256.ShiftRightArithmetic(vx.AsInt32(), 23)); + + vx = (vx & Vector256.Create(V_MASK)) + Vector256.Create(V_OFF); + + Vector256 r = vx.AsSingle() - Vector256.Create(1.0f); + + Vector256 r2 = r * r; + Vector256 r4 = r2 * r2; + Vector256 r8 = r4 * r4; + + Vector256 q = (Vector256.Create(C10) * r2 + (Vector256.Create(C9) * r + Vector256.Create(C8))) + * r8 + (((Vector256.Create(C7) * r + Vector256.Create(C6)) + * r2 + (Vector256.Create(C5) * r + Vector256.Create(C4))) + * r4 + ((Vector256.Create(C3) * r + Vector256.Create(C2)) + * r2 + (Vector256.Create(C1) * r + Vector256.Create(C0)))); + + return Vector256.ConditionalSelect( + specialMask.AsSingle(), + specialResult, + n * Vector256.Create(V_LN2) + q + ); + } + +#if NET8_0_OR_GREATER + public static Vector512 Invoke(Vector512 x) + { + Vector512 specialResult = x; + + // x is subnormal or infinity or NaN + Vector512 specialMask = Vector512.GreaterThanOrEqual(x.AsUInt32() - Vector512.Create(V_MIN), Vector512.Create(V_MAX - V_MIN)); + + if (specialMask != Vector512.Zero) + { + // float.IsZero(x) ? float.NegativeInfinity : x + Vector512 zeroMask = Vector512.Equals(x, Vector512.Zero); + + specialResult = Vector512.ConditionalSelect( + zeroMask, + Vector512.Create(float.NegativeInfinity), + specialResult + ); + + // (x < 0) ? float.NaN : x + Vector512 lessThanZeroMask = Vector512.LessThan(x, Vector512.Zero); + + specialResult = Vector512.ConditionalSelect( + lessThanZeroMask, + Vector512.Create(float.NaN), + specialResult + ); + + // float.IsZero(x) | (x < 0) | float.IsNaN(x) | float.IsPositiveInfinity(x) + Vector512 temp = zeroMask + | lessThanZeroMask + | ~Vector512.Equals(x, x) + | Vector512.Equals(x, Vector512.Create(float.PositiveInfinity)); + + // subnormal + Vector512 subnormalMask = Vector512.AndNot(specialMask.AsSingle(), temp); + + x = Vector512.ConditionalSelect( + subnormalMask, + ((x * 8388608.0f).AsUInt32() - Vector512.Create(23u << 23)).AsSingle(), + x + ); + + specialMask = temp.AsUInt32(); + } + + Vector512 vx = x.AsUInt32() - Vector512.Create(V_OFF); + Vector512 n = Vector512.ConvertToSingle(Vector512.ShiftRightArithmetic(vx.AsInt32(), 23)); + + vx = (vx & Vector512.Create(V_MASK)) + Vector512.Create(V_OFF); + + Vector512 r = vx.AsSingle() - Vector512.Create(1.0f); + + Vector512 r2 = r * r; + Vector512 r4 = r2 * r2; + Vector512 r8 = r4 * r4; + + Vector512 q = (Vector512.Create(C10) * r2 + (Vector512.Create(C9) * r + Vector512.Create(C8))) + * r8 + (((Vector512.Create(C7) * r + Vector512.Create(C6)) + * r2 + (Vector512.Create(C5) * r + Vector512.Create(C4))) + * r4 + ((Vector512.Create(C3) * r + Vector512.Create(C2)) + * r2 + (Vector512.Create(C1) * r + Vector512.Create(C0)))); + + return Vector512.ConditionalSelect( + specialMask.AsSingle(), + specialResult, + n * Vector512.Create(V_LN2) + q + ); + } +#endif + } + + /// MathF.Log2(x) + private readonly struct Log2Operator : IUnaryOperator + { + // This code is based on `vrs4_log2f` from amd/aocl-libm-ose + // Copyright (C) 2021-2022 Advanced Micro Devices, Inc. All rights reserved. + // + // Licensed under the BSD 3-Clause "New" or "Revised" License + // See THIRD-PARTY-NOTICES.TXT for the full license text + + // Spec: + // log2f(x) + // = log2f(x) if x ∈ F and x > 0 + // = x if x = qNaN + // = 0 if x = 1 + // = -inf if x = (-0, 0} + // = NaN otherwise + // + // Assumptions/Expectations + // - Maximum ULP is observed to be at 4 + // - Some FPU Exceptions may not be available + // - Performance is at least 3x + // + // Implementation Notes: + // 1. Range Reduction: + // x = 2^n*(1+f) .... (1) + // where n is exponent and is an integer + // (1+f) is mantissa ∈ [1,2). i.e., 1 ≤ 1+f < 2 .... (2) + // + // From (1), taking log on both sides + // log2(x) = log2(2^n * (1+f)) + // = n + log2(1+f) .... (3) + // + // let z = 1 + f + // log2(z) = log2(k) + log2(z) - log2(k) + // log2(z) = log2(kz) - log2(k) + // + // From (2), range of z is [1, 2) + // by simply dividing range by 'k', z is in [1/k, 2/k) .... (4) + // Best choice of k is the one which gives equal and opposite values + // at extrema +- -+ + // 1 | 2 | + // --- - 1 = - |--- - 1 | + // k | k | .... (5) + // +- -+ + // + // Solving for k, k = 3/2, + // From (4), using 'k' value, range is therefore [-0.3333, 0.3333] + // + // 2. Polynomial Approximation: + // More information refer to tools/sollya/vrs4_logf.sollya + // + // 7th Deg - Error abs: 0x1.04c4ac98p-22 rel: 0x1.2216e6f8p-19 + + private const uint V_MIN = 0x00800000; + private const uint V_MAX = 0x7F800000; + private const uint V_MASK = 0x007FFFFF; + private const uint V_OFF = 0x3F2AAAAB; + + private const float C0 = 0.0f; + private const float C1 = 1.4426951f; + private const float C2 = -0.72134554f; + private const float C3 = 0.48089063f; + private const float C4 = -0.36084408f; + private const float C5 = 0.2888971f; + private const float C6 = -0.23594281f; + private const float C7 = 0.19948183f; + private const float C8 = -0.22616665f; + private const float C9 = 0.21228963f; + + public static float Invoke(float x) => MathF.Log2(x); + + public static Vector128 Invoke(Vector128 x) + { + Vector128 specialResult = x; + + // x is subnormal or infinity or NaN + Vector128 specialMask = Vector128.GreaterThanOrEqual(x.AsUInt32() - Vector128.Create(V_MIN), Vector128.Create(V_MAX - V_MIN)); + + if (specialMask != Vector128.Zero) + { + // float.IsZero(x) ? float.NegativeInfinity : x + Vector128 zeroMask = Vector128.Equals(x, Vector128.Zero); + + specialResult = Vector128.ConditionalSelect( + zeroMask, + Vector128.Create(float.NegativeInfinity), + specialResult + ); + + // (x < 0) ? float.NaN : x + Vector128 lessThanZeroMask = Vector128.LessThan(x, Vector128.Zero); + + specialResult = Vector128.ConditionalSelect( + lessThanZeroMask, + Vector128.Create(float.NaN), + specialResult + ); + + // float.IsZero(x) | (x < 0) | float.IsNaN(x) | float.IsPositiveInfinity(x) + Vector128 temp = zeroMask + | lessThanZeroMask + | ~Vector128.Equals(x, x) + | Vector128.Equals(x, Vector128.Create(float.PositiveInfinity)); + + // subnormal + Vector128 subnormalMask = Vector128.AndNot(specialMask.AsSingle(), temp); + + x = Vector128.ConditionalSelect( + subnormalMask, + ((x * 8388608.0f).AsUInt32() - Vector128.Create(23u << 23)).AsSingle(), + x + ); + + specialMask = temp.AsUInt32(); + } + + Vector128 vx = x.AsUInt32() - Vector128.Create(V_OFF); + Vector128 n = Vector128.ConvertToSingle(Vector128.ShiftRightArithmetic(vx.AsInt32(), 23)); + + vx = (vx & Vector128.Create(V_MASK)) + Vector128.Create(V_OFF); + + Vector128 r = vx.AsSingle() - Vector128.Create(1.0f); + + Vector128 r2 = r * r; + Vector128 r4 = r2 * r2; + Vector128 r8 = r4 * r4; + + Vector128 poly = (Vector128.Create(C9) * r + Vector128.Create(C8)) * r8 + + (((Vector128.Create(C7) * r + Vector128.Create(C6)) * r2 + + (Vector128.Create(C5) * r + Vector128.Create(C4))) * r4 + + ((Vector128.Create(C3) * r + Vector128.Create(C2)) * r2 + + (Vector128.Create(C1) * r + Vector128.Create(C0)))); + + return Vector128.ConditionalSelect( + specialMask.AsSingle(), + specialResult, + n + poly + ); + } + + public static Vector256 Invoke(Vector256 x) + { + Vector256 specialResult = x; + + // x is subnormal or infinity or NaN + Vector256 specialMask = Vector256.GreaterThanOrEqual(x.AsUInt32() - Vector256.Create(V_MIN), Vector256.Create(V_MAX - V_MIN)); + + if (specialMask != Vector256.Zero) + { + // float.IsZero(x) ? float.NegativeInfinity : x + Vector256 zeroMask = Vector256.Equals(x, Vector256.Zero); + + specialResult = Vector256.ConditionalSelect( + zeroMask, + Vector256.Create(float.NegativeInfinity), + specialResult + ); + + // (x < 0) ? float.NaN : x + Vector256 lessThanZeroMask = Vector256.LessThan(x, Vector256.Zero); + + specialResult = Vector256.ConditionalSelect( + lessThanZeroMask, + Vector256.Create(float.NaN), + specialResult + ); + + // float.IsZero(x) | (x < 0) | float.IsNaN(x) | float.IsPositiveInfinity(x) + Vector256 temp = zeroMask + | lessThanZeroMask + | ~Vector256.Equals(x, x) + | Vector256.Equals(x, Vector256.Create(float.PositiveInfinity)); + + // subnormal + Vector256 subnormalMask = Vector256.AndNot(specialMask.AsSingle(), temp); + + x = Vector256.ConditionalSelect( + subnormalMask, + ((x * 8388608.0f).AsUInt32() - Vector256.Create(23u << 23)).AsSingle(), + x + ); + + specialMask = temp.AsUInt32(); + } + + Vector256 vx = x.AsUInt32() - Vector256.Create(V_OFF); + Vector256 n = Vector256.ConvertToSingle(Vector256.ShiftRightArithmetic(vx.AsInt32(), 23)); + + vx = (vx & Vector256.Create(V_MASK)) + Vector256.Create(V_OFF); + + Vector256 r = vx.AsSingle() - Vector256.Create(1.0f); + + Vector256 r2 = r * r; + Vector256 r4 = r2 * r2; + Vector256 r8 = r4 * r4; + + Vector256 poly = (Vector256.Create(C9) * r + Vector256.Create(C8)) * r8 + + (((Vector256.Create(C7) * r + Vector256.Create(C6)) * r2 + + (Vector256.Create(C5) * r + Vector256.Create(C4))) * r4 + + ((Vector256.Create(C3) * r + Vector256.Create(C2)) * r2 + + (Vector256.Create(C1) * r + Vector256.Create(C0)))); + + return Vector256.ConditionalSelect( + specialMask.AsSingle(), + specialResult, + n + poly + ); + } + +#if NET8_0_OR_GREATER + public static Vector512 Invoke(Vector512 x) + { + Vector512 specialResult = x; + + // x is subnormal or infinity or NaN + Vector512 specialMask = Vector512.GreaterThanOrEqual(x.AsUInt32() - Vector512.Create(V_MIN), Vector512.Create(V_MAX - V_MIN)); + + if (specialMask != Vector512.Zero) + { + // float.IsZero(x) ? float.NegativeInfinity : x + Vector512 zeroMask = Vector512.Equals(x, Vector512.Zero); + + specialResult = Vector512.ConditionalSelect( + zeroMask, + Vector512.Create(float.NegativeInfinity), + specialResult + ); + + // (x < 0) ? float.NaN : x + Vector512 lessThanZeroMask = Vector512.LessThan(x, Vector512.Zero); + + specialResult = Vector512.ConditionalSelect( + lessThanZeroMask, + Vector512.Create(float.NaN), + specialResult + ); + + // float.IsZero(x) | (x < 0) | float.IsNaN(x) | float.IsPositiveInfinity(x) + Vector512 temp = zeroMask + | lessThanZeroMask + | ~Vector512.Equals(x, x) + | Vector512.Equals(x, Vector512.Create(float.PositiveInfinity)); + + // subnormal + Vector512 subnormalMask = Vector512.AndNot(specialMask.AsSingle(), temp); + + x = Vector512.ConditionalSelect( + subnormalMask, + ((x * 8388608.0f).AsUInt32() - Vector512.Create(23u << 23)).AsSingle(), + x + ); + + specialMask = temp.AsUInt32(); + } + + Vector512 vx = x.AsUInt32() - Vector512.Create(V_OFF); + Vector512 n = Vector512.ConvertToSingle(Vector512.ShiftRightArithmetic(vx.AsInt32(), 23)); + + vx = (vx & Vector512.Create(V_MASK)) + Vector512.Create(V_OFF); + + Vector512 r = vx.AsSingle() - Vector512.Create(1.0f); + + Vector512 r2 = r * r; + Vector512 r4 = r2 * r2; + Vector512 r8 = r4 * r4; + + Vector512 poly = (Vector512.Create(C9) * r + Vector512.Create(C8)) * r8 + + (((Vector512.Create(C7) * r + Vector512.Create(C6)) * r2 + + (Vector512.Create(C5) * r + Vector512.Create(C4))) * r4 + + ((Vector512.Create(C3) * r + Vector512.Create(C2)) * r2 + + (Vector512.Create(C1) * r + Vector512.Create(C0)))); + + return Vector512.ConditionalSelect( + specialMask.AsSingle(), + specialResult, + n + poly + ); + } +#endif + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector128 ElementWiseSelect(Vector128 mask, Vector128 left, Vector128 right) + { + if (Sse41.IsSupported) + return Sse41.BlendVariable(left, right, ~mask); + + return Vector128.ConditionalSelect(mask, left, right); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector128 ElementWiseSelect(Vector128 mask, Vector128 left, Vector128 right) + { + if (Sse41.IsSupported) + return Sse41.BlendVariable(left, right, ~mask); + + return Vector128.ConditionalSelect(mask, left, right); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector256 ElementWiseSelect(Vector256 mask, Vector256 left, Vector256 right) + { + if (Avx2.IsSupported) + return Avx2.BlendVariable(left, right, ~mask); + + return Vector256.ConditionalSelect(mask, left, right); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector256 ElementWiseSelect(Vector256 mask, Vector256 left, Vector256 right) + { + if (Avx2.IsSupported) + return Avx2.BlendVariable(left, right, ~mask); + + return Vector256.ConditionalSelect(mask, left, right); + } + +#if NET8_0_OR_GREATER + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector512 ElementWiseSelect(Vector512 mask, Vector512 left, Vector512 right) + { + if (Avx512F.IsSupported) + return Avx512F.BlendVariable(left, right, ~mask); + + return Vector512.ConditionalSelect(mask, left, right); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector512 ElementWiseSelect(Vector512 mask, Vector512 left, Vector512 right) + { + if (Avx512F.IsSupported) + return Avx512F.BlendVariable(left, right, ~mask); + + return Vector512.ConditionalSelect(mask, left, right); + } +#endif + + /// 1f / (1f + MathF.Exp(-x)) + private readonly struct SigmoidOperator : IUnaryOperator + { + public static float Invoke(float x) => 1.0f / (1.0f + MathF.Exp(-x)); + public static Vector128 Invoke(Vector128 x) => Vector128.Create(1f) / (Vector128.Create(1f) + ExpOperator.Invoke(-x)); + public static Vector256 Invoke(Vector256 x) => Vector256.Create(1f) / (Vector256.Create(1f) + ExpOperator.Invoke(-x)); +#if NET8_0_OR_GREATER + public static Vector512 Invoke(Vector512 x) => Vector512.Create(1f) / (Vector512.Create(1f) + ExpOperator.Invoke(-x)); +#endif + } + + /// Operator that takes one input value and returns a single value. + private interface IUnaryOperator + { + static abstract float Invoke(float x); + static abstract Vector128 Invoke(Vector128 x); + static abstract Vector256 Invoke(Vector256 x); +#if NET8_0_OR_GREATER + static abstract Vector512 Invoke(Vector512 x); +#endif + } + + /// Operator that takes two input values and returns a single value. + private interface IBinaryOperator + { + static abstract float Invoke(float x, float y); + static abstract Vector128 Invoke(Vector128 x, Vector128 y); + static abstract Vector256 Invoke(Vector256 x, Vector256 y); +#if NET8_0_OR_GREATER + static abstract Vector512 Invoke(Vector512 x, Vector512 y); +#endif + } + + /// that specializes horizontal aggregation of all elements in a vector. + private interface IAggregationOperator : IBinaryOperator + { + static abstract float Invoke(Vector128 x); + static abstract float Invoke(Vector256 x); +#if NET8_0_OR_GREATER + static abstract float Invoke(Vector512 x); +#endif + + static virtual float IdentityValue => throw new NotSupportedException(); + } + + /// Operator that takes three input values and returns a single value. + private interface ITernaryOperator + { + static abstract float Invoke(float x, float y, float z); + static abstract Vector128 Invoke(Vector128 x, Vector128 y, Vector128 z); + static abstract Vector256 Invoke(Vector256 x, Vector256 y, Vector256 z); +#if NET8_0_OR_GREATER + static abstract Vector512 Invoke(Vector512 x, Vector512 y, Vector512 z); +#endif + } + } +} diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs new file mode 100644 index 00000000000000..c0039be0a08e2b --- /dev/null +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs @@ -0,0 +1,3538 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +namespace System.Numerics.Tensors +{ + public static partial class TensorPrimitives + { + /// Computes the cosine similarity between the two specified non-empty, equal-length tensors of single-precision floating-point numbers. + /// Assumes arguments have already been validated to be non-empty and equal length. + private static float CosineSimilarityCore(ReadOnlySpan x, ReadOnlySpan y) + { + // Compute the same as: + // TensorPrimitives.Dot(x, y) / (Math.Sqrt(TensorPrimitives.SumOfSquares(x)) * Math.Sqrt(TensorPrimitives.SumOfSquares(y))) + // but only looping over each span once. + + float dotProduct = 0f; + float xSumOfSquares = 0f; + float ySumOfSquares = 0f; + + if (Vector.IsHardwareAccelerated && + Vector.Count <= 16 && // currently never greater than 8, but 16 would occur if/when AVX512 is supported, and logic in remainder handling assumes that maximum + x.Length >= Vector.Count) + { + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float yRef = ref MemoryMarshal.GetReference(y); + + Vector dotProductVector = Vector.Zero; + Vector xSumOfSquaresVector = Vector.Zero; + Vector ySumOfSquaresVector = Vector.Zero; + + // Process vectors, summing their dot products and squares, as long as there's a vector's worth remaining. + int oneVectorFromEnd = x.Length - Vector.Count; + int i = 0; + do + { + Vector xVec = AsVector(ref xRef, i); + Vector yVec = AsVector(ref yRef, i); + + dotProductVector += xVec * yVec; + xSumOfSquaresVector += xVec * xVec; + ySumOfSquaresVector += yVec * yVec; + + i += Vector.Count; + } + while (i <= oneVectorFromEnd); + + // Process the last vector in the span, masking off elements already processed. + if (i != x.Length) + { + Vector xVec = AsVector(ref xRef, x.Length - Vector.Count); + Vector yVec = AsVector(ref yRef, x.Length - Vector.Count); + + Vector remainderMask = CreateRemainderMaskSingleVector(x.Length - i); + xVec &= remainderMask; + yVec &= remainderMask; + + dotProductVector += xVec * yVec; + xSumOfSquaresVector += xVec * xVec; + ySumOfSquaresVector += yVec * yVec; + } + + // Sum the vector lanes into the scalar result. + for (int e = 0; e < Vector.Count; e++) + { + dotProduct += dotProductVector[e]; + xSumOfSquares += xSumOfSquaresVector[e]; + ySumOfSquares += ySumOfSquaresVector[e]; + } + } + else + { + // Vectorization isn't supported or there are too few elements to vectorize. + // Use a scalar implementation. + for (int i = 0; i < x.Length; i++) + { + dotProduct += x[i] * y[i]; + xSumOfSquares += x[i] * x[i]; + ySumOfSquares += y[i] * y[i]; + } + } + + // Sum(X * Y) / (|X| * |Y|) + return dotProduct / (MathF.Sqrt(xSumOfSquares) * MathF.Sqrt(ySumOfSquares)); + } + + /// Performs an aggregation over all elements in to produce a single-precision floating-point value. + /// Specifies the transform operation that should be applied to each element loaded from . + /// + /// Specifies the aggregation binary operation that should be applied to multiple values to aggregate them into a single value. + /// The aggregation is applied after the transform is applied to each element. + /// + private static unsafe float Aggregate( + ReadOnlySpan x, TTransformOperator transformOp = default, TAggregationOperator aggregationOp = default) + where TTransformOperator : struct, IUnaryOperator + where TAggregationOperator : struct, IAggregationOperator + { + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes + + ref float xRef = ref MemoryMarshal.GetReference(x); + + nuint remainder = (uint)(x.Length); + + if (Vector.IsHardwareAccelerated && transformOp.CanVectorize) + { + float result; + + if (remainder >= (uint)(Vector.Count)) + { + result = Vectorized(ref xRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + result = VectorizedSmall(ref xRef, remainder); + } + + return result; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + return SoftwareFallback(ref xRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static float SoftwareFallback(ref float xRef, nuint length, TTransformOperator transformOp = default, TAggregationOperator aggregationOp = default) + { + float result = aggregationOp.IdentityValue; + + for (nuint i = 0; i < length; i++) + { + result = aggregationOp.Invoke(result, transformOp.Invoke(Unsafe.Add(ref xRef, (nint)(i)))); + } + + return result; + } + + static float Vectorized(ref float xRef, nuint remainder, TTransformOperator transformOp = default, TAggregationOperator aggregationOp = default) + { + Vector vresult = new Vector(aggregationOp.IdentityValue); + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector beg = transformOp.Invoke(AsVector(ref xRef)); + Vector end = transformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count))); + + nuint misalignment = 0; + + if (remainder > (uint)(Vector.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + { + float* xPtr = px; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(xPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + misalignment = ((uint)(sizeof(Vector)) - ((nuint)(xPtr) % (uint)(sizeof(Vector)))) / sizeof(float); + + xPtr += misalignment; + + Debug.Assert(((nuint)(xPtr) % (uint)(sizeof(Vector))) == 0); + + remainder -= misalignment; + } + + Vector vector1; + Vector vector2; + Vector vector3; + Vector vector4; + + // We only need to load, so there isn't a lot of benefit to doing non-temporal operations + + while (remainder >= (uint)(Vector.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = transformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 0))); + vector2 = transformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 1))); + vector3 = transformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 2))); + vector4 = transformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 3))); + + vresult = aggregationOp.Invoke(vresult, vector1); + vresult = aggregationOp.Invoke(vresult, vector2); + vresult = aggregationOp.Invoke(vresult, vector3); + vresult = aggregationOp.Invoke(vresult, vector4); + + // We load, process, and store the next four vectors + + vector1 = transformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 4))); + vector2 = transformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 5))); + vector3 = transformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 6))); + vector4 = transformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 7))); + + vresult = aggregationOp.Invoke(vresult, vector1); + vresult = aggregationOp.Invoke(vresult, vector2); + vresult = aggregationOp.Invoke(vresult, vector3); + vresult = aggregationOp.Invoke(vresult, vector4); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector.Count * 8); + + remainder -= (uint)(Vector.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + } + } + + // Store the first block. Handling this separately simplifies the latter code as we know + // they come after and so we can relegate it to full blocks or the trailing elements + + beg = Vector.ConditionalSelect(CreateAlignmentMaskSingleVector((int)(misalignment)), beg, new Vector(aggregationOp.IdentityValue)); + vresult = aggregationOp.Invoke(vresult, beg); + + // Process the remaining [0, Count * 7] elements via a jump table + // + // We end up handling any trailing elements in case 0 and in the + // worst case end up just doing the identity operation here if there + // were no trailing elements. + + nuint blocks = remainder / (nuint)(Vector.Count); + nuint trailing = remainder - (blocks * (nuint)(Vector.Count)); + blocks -= (misalignment == 0) ? 1u : 0u; + remainder -= trailing; + + switch (blocks) + { + case 7: + { + Vector vector = transformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 7))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 6; + } + + case 6: + { + Vector vector = transformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 6))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 5; + } + + case 5: + { + Vector vector = transformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 5))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 4; + } + + case 4: + { + Vector vector = transformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 4))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 3; + } + + case 3: + { + Vector vector = transformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 3))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 2; + } + + case 2: + { + Vector vector = transformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 2))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 1; + } + + case 1: + { + Vector vector = transformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 1))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 0; + } + + case 0: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end = Vector.ConditionalSelect(CreateRemainderMaskSingleVector((int)(trailing)), end, new Vector(aggregationOp.IdentityValue)); + vresult = aggregationOp.Invoke(vresult, end); + break; + } + } + + float result = aggregationOp.IdentityValue; + + for (int i = 0; i < Vector.Count; i++) + { + result = aggregationOp.Invoke(result, vresult[i]); + } + + return result; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static float VectorizedSmall(ref float xRef, nuint remainder, TTransformOperator transformOp = default, TAggregationOperator aggregationOp = default) + { + float result = aggregationOp.IdentityValue; + + switch (remainder) + { + case 7: + { + result = aggregationOp.Invoke(result, transformOp.Invoke(Unsafe.Add(ref xRef, 6))); + goto case 6; + } + + case 6: + { + result = aggregationOp.Invoke(result, transformOp.Invoke(Unsafe.Add(ref xRef, 5))); + goto case 5; + } + + case 5: + { + result = aggregationOp.Invoke(result, transformOp.Invoke(Unsafe.Add(ref xRef, 4))); + goto case 4; + } + + case 4: + { + result = aggregationOp.Invoke(result, transformOp.Invoke(Unsafe.Add(ref xRef, 3))); + goto case 3; + } + + case 3: + { + result = aggregationOp.Invoke(result, transformOp.Invoke(Unsafe.Add(ref xRef, 2))); + goto case 2; + } + + case 2: + { + result = aggregationOp.Invoke(result, transformOp.Invoke(Unsafe.Add(ref xRef, 1))); + goto case 1; + } + + case 1: + { + result = aggregationOp.Invoke(result, transformOp.Invoke(xRef)); + goto case 0; + } + + case 0: + { + break; + } + } + + return result; + } + } + + /// Performs an aggregation over all pair-wise elements in and to produce a single-precision floating-point value. + /// Specifies the binary operation that should be applied to the pair-wise elements loaded from and . + /// + /// Specifies the aggregation binary operation that should be applied to multiple values to aggregate them into a single value. + /// The aggregation is applied to the results of the binary operations on the pair-wise values. + /// + private static unsafe float Aggregate( + ReadOnlySpan x, ReadOnlySpan y, TBinaryOperator binaryOp = default, TAggregationOperator aggregationOp = default) + where TBinaryOperator : struct, IBinaryOperator + where TAggregationOperator : struct, IAggregationOperator + { + if (x.Length != y.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + } + + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float yRef = ref MemoryMarshal.GetReference(y); + + nuint remainder = (uint)(x.Length); + + if (Vector.IsHardwareAccelerated) + { + float result; + + if (remainder >= (uint)(Vector.Count)) + { + result = Vectorized(ref xRef, ref yRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + result = VectorizedSmall(ref xRef, ref yRef, remainder); + } + + return result; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + return SoftwareFallback(ref xRef, ref yRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static float SoftwareFallback(ref float xRef, ref float yRef, nuint length, TBinaryOperator binaryOp = default, TAggregationOperator aggregationOp = default) + { + float result = aggregationOp.IdentityValue; + + for (nuint i = 0; i < length; i++) + { + result = aggregationOp.Invoke(result, binaryOp.Invoke(Unsafe.Add(ref xRef, (nint)(i)), + Unsafe.Add(ref yRef, (nint)(i)))); + } + + return result; + } + + static float Vectorized(ref float xRef, ref float yRef, nuint remainder, TBinaryOperator binaryOp = default, TAggregationOperator aggregationOp = default) + { + Vector vresult = new Vector(aggregationOp.IdentityValue); + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector beg = binaryOp.Invoke(AsVector(ref xRef), + AsVector(ref yRef)); + Vector end = binaryOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count)), + AsVector(ref yRef, remainder - (uint)(Vector.Count))); + + nuint misalignment = 0; + + if (remainder > (uint)(Vector.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + { + float* xPtr = px; + float* yPtr = py; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(xPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + misalignment = ((uint)(sizeof(Vector)) - ((nuint)(xPtr) % (uint)(sizeof(Vector)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + + Debug.Assert(((nuint)(xPtr) % (uint)(sizeof(Vector))) == 0); + + remainder -= misalignment; + } + + Vector vector1; + Vector vector2; + Vector vector3; + Vector vector4; + + // We only need to load, so there isn't a lot of benefit to doing non-temporal operations + + while (remainder >= (uint)(Vector.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = binaryOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 0)), + *(Vector*)(yPtr + (uint)(Vector.Count * 0))); + vector2 = binaryOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 1)), + *(Vector*)(yPtr + (uint)(Vector.Count * 1))); + vector3 = binaryOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 2)), + *(Vector*)(yPtr + (uint)(Vector.Count * 2))); + vector4 = binaryOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 3)), + *(Vector*)(yPtr + (uint)(Vector.Count * 3))); + + vresult = aggregationOp.Invoke(vresult, vector1); + vresult = aggregationOp.Invoke(vresult, vector2); + vresult = aggregationOp.Invoke(vresult, vector3); + vresult = aggregationOp.Invoke(vresult, vector4); + + // We load, process, and store the next four vectors + + vector1 = binaryOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 4)), + *(Vector*)(yPtr + (uint)(Vector.Count * 4))); + vector2 = binaryOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 5)), + *(Vector*)(yPtr + (uint)(Vector.Count * 5))); + vector3 = binaryOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 6)), + *(Vector*)(yPtr + (uint)(Vector.Count * 6))); + vector4 = binaryOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 7)), + *(Vector*)(yPtr + (uint)(Vector.Count * 7))); + + vresult = aggregationOp.Invoke(vresult, vector1); + vresult = aggregationOp.Invoke(vresult, vector2); + vresult = aggregationOp.Invoke(vresult, vector3); + vresult = aggregationOp.Invoke(vresult, vector4); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector.Count * 8); + yPtr += (uint)(Vector.Count * 8); + + remainder -= (uint)(Vector.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + } + } + + // Store the first block. Handling this separately simplifies the latter code as we know + // they come after and so we can relegate it to full blocks or the trailing elements + + beg = Vector.ConditionalSelect(CreateAlignmentMaskSingleVector((int)(misalignment)), beg, new Vector(aggregationOp.IdentityValue)); + vresult = aggregationOp.Invoke(vresult, beg); + + // Process the remaining [0, Count * 7] elements via a jump table + // + // We end up handling any trailing elements in case 0 and in the + // worst case end up just doing the identity operation here if there + // were no trailing elements. + + nuint blocks = remainder / (nuint)(Vector.Count); + nuint trailing = remainder - (blocks * (nuint)(Vector.Count)); + blocks -= (misalignment == 0) ? 1u : 0u; + remainder -= trailing; + + switch (blocks) + { + case 7: + { + Vector vector = binaryOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 7)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 7))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 6; + } + + case 6: + { + Vector vector = binaryOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 6)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 6))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 5; + } + + case 5: + { + Vector vector = binaryOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 5)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 5))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 4; + } + + case 4: + { + Vector vector = binaryOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 4)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 4))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 3; + } + + case 3: + { + Vector vector = binaryOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 3)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 3))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 2; + } + + case 2: + { + Vector vector = binaryOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 2)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 2))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 1; + } + + case 1: + { + Vector vector = binaryOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 1)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 1))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 0; + } + + case 0: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end = Vector.ConditionalSelect(CreateRemainderMaskSingleVector((int)(trailing)), end, new Vector(aggregationOp.IdentityValue)); + vresult = aggregationOp.Invoke(vresult, end); + break; + } + } + + float result = aggregationOp.IdentityValue; + + for (int i = 0; i < Vector.Count; i++) + { + result = aggregationOp.Invoke(result, vresult[i]); + } + + return result; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static float VectorizedSmall(ref float xRef, ref float yRef, nuint remainder, TBinaryOperator binaryOp = default, TAggregationOperator aggregationOp = default) + { + float result = aggregationOp.IdentityValue; + + switch (remainder) + { + case 7: + { + result = aggregationOp.Invoke(result, binaryOp.Invoke(Unsafe.Add(ref xRef, 6), + Unsafe.Add(ref yRef, 6))); + goto case 6; + } + + case 6: + { + result = aggregationOp.Invoke(result, binaryOp.Invoke(Unsafe.Add(ref xRef, 5), + Unsafe.Add(ref yRef, 5))); + goto case 5; + } + + case 5: + { + result = aggregationOp.Invoke(result, binaryOp.Invoke(Unsafe.Add(ref xRef, 4), + Unsafe.Add(ref yRef, 4))); + goto case 4; + } + + case 4: + { + result = aggregationOp.Invoke(result, binaryOp.Invoke(Unsafe.Add(ref xRef, 3), + Unsafe.Add(ref yRef, 3))); + goto case 3; + } + + case 3: + { + result = aggregationOp.Invoke(result, binaryOp.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2))); + goto case 2; + } + + case 2: + { + result = aggregationOp.Invoke(result, binaryOp.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1))); + goto case 1; + } + + case 1: + { + result = aggregationOp.Invoke(result, binaryOp.Invoke(xRef, yRef)); + goto case 0; + } + + case 0: + { + break; + } + } + + return result; + } + } + + /// + /// This is the same as + /// with an identity transform, except it early exits on NaN. + /// + private static float MinMaxCore(ReadOnlySpan x, TMinMaxOperator op = default) + where TMinMaxOperator : struct, IBinaryOperator + { + if (x.IsEmpty) + { + ThrowHelper.ThrowArgument_SpansMustBeNonEmpty(); + } + + // This matches the IEEE 754:2019 `maximum`/`minimum` functions. + // It propagates NaN inputs back to the caller and + // otherwise returns the greater of the inputs. + // It treats +0 as greater than -0 as per the specification. + + float result = x[0]; + int i = 0; + + if (Vector.IsHardwareAccelerated && x.Length >= Vector.Count) + { + ref float xRef = ref MemoryMarshal.GetReference(x); + + // Load the first vector as the initial set of results, and bail immediately + // to scalar handling if it contains any NaNs (which don't compare equally to themselves). + Vector resultVector = AsVector(ref xRef, 0), current; + if (Vector.EqualsAll(resultVector, resultVector)) + { + int oneVectorFromEnd = x.Length - Vector.Count; + i = Vector.Count; + + // Aggregate additional vectors into the result as long as there's at least one full vector left to process. + while (i <= oneVectorFromEnd) + { + // Load the next vector, and early exit on NaN. + current = AsVector(ref xRef, i); + if (!Vector.EqualsAll(current, current)) + { + goto Scalar; + } + + resultVector = op.Invoke(resultVector, current); + i += Vector.Count; + } + + // If any elements remain, handle them in one final vector. + if (i != x.Length) + { + current = AsVector(ref xRef, x.Length - Vector.Count); + if (!Vector.EqualsAll(current, current)) + { + goto Scalar; + } + + resultVector = op.Invoke(resultVector, current); + } + + // Aggregate the lanes in the vector to create the final scalar result. + for (int f = 0; f < Vector.Count; f++) + { + result = op.Invoke(result, resultVector[f]); + } + + return result; + } + } + + // Scalar path used when either vectorization is not supported, the input is too small to vectorize, + // or a NaN is encountered. + Scalar: + for (; (uint)i < (uint)x.Length; i++) + { + float current = x[i]; + + if (float.IsNaN(current)) + { + return current; + } + + result = op.Invoke(result, current); + } + + return result; + } + + private static readonly int[] s_0through7 = [0, 1, 2, 3, 4, 5, 6, 7]; + + private static int IndexOfMinMaxCore(ReadOnlySpan x, TIndexOfMinMaxOperator op = default) + where TIndexOfMinMaxOperator : struct, IIndexOfOperator + { + // This matches the IEEE 754:2019 `maximum`/`minimum` functions. + // It propagates NaN inputs back to the caller and + // otherwise returns the index of the greater of the inputs. + // It treats +0 as greater than -0 as per the specification. + + int result; + int i = 0; + + if (Vector.IsHardwareAccelerated && Vector.Count <= 8 && x.Length >= Vector.Count) + { + ref float xRef = ref MemoryMarshal.GetReference(x); + + Vector resultIndex = new Vector(s_0through7); + Vector curIndex = resultIndex; + Vector increment = new Vector(Vector.Count); + + // Load the first vector as the initial set of results, and bail immediately + // to scalar handling if it contains any NaNs (which don't compare equally to themselves). + Vector resultVector = AsVector(ref xRef, 0), current; + if (Vector.EqualsAll(resultVector, resultVector)) + { + int oneVectorFromEnd = x.Length - Vector.Count; + i = Vector.Count; + + // Aggregate additional vectors into the result as long as there's at least one full vector left to process. + while (i <= oneVectorFromEnd) + { + // Load the next vector, and early exit on NaN. + current = AsVector(ref xRef, i); + curIndex = Vector.Add(curIndex, increment); + + if (!Vector.EqualsAll(current, current)) + { + goto Scalar; + } + + op.Invoke(ref resultVector, current, ref resultIndex, curIndex); + i += Vector.Count; + } + + // If any elements remain, handle them in one final vector. + if (i != x.Length) + { + curIndex = Vector.Add(curIndex, new Vector(x.Length - i)); + + current = AsVector(ref xRef, x.Length - Vector.Count); + if (!Vector.EqualsAll(current, current)) + { + goto Scalar; + } + + op.Invoke(ref resultVector, current, ref resultIndex, curIndex); + } + + result = op.Invoke(resultVector, resultIndex); + + return result; + } + } + + // Scalar path used when either vectorization is not supported, the input is too small to vectorize, + // or a NaN is encountered. + Scalar: + float curResult = x[i]; + int curIn = i; + if (float.IsNaN(curResult)) + { + return curIn; + } + + for (; i < x.Length; i++) + { + float current = x[i]; + if (float.IsNaN(current)) + { + return i; + } + + curIn = op.Invoke(ref curResult, current, curIn, i); + } + + return curIn; + } + + /// Performs an element-wise operation on and writes the results to . + /// Specifies the operation to perform on each element loaded from . + private static unsafe void InvokeSpanIntoSpan( + ReadOnlySpan x, Span destination, TUnaryOperator op = default) + where TUnaryOperator : struct, IUnaryOperator + { + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ValidateInputOutputSpanNonOverlapping(x, destination); + + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float dRef = ref MemoryMarshal.GetReference(destination); + + nuint remainder = (uint)(x.Length); + + if (Vector.IsHardwareAccelerated && op.CanVectorize) + { + if (remainder >= (uint)(Vector.Count)) + { + Vectorized(ref xRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + VectorizedSmall(ref xRef, ref dRef, remainder); + } + + return; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, ref float dRef, nuint length, TUnaryOperator op = default) + { + for (nuint i = 0; i < length; i++) + { + Unsafe.Add(ref dRef, (nint)(i)) = op.Invoke(Unsafe.Add(ref xRef, (nint)(i))); + } + } + + static void Vectorized(ref float xRef, ref float dRef, nuint remainder, TUnaryOperator op = default) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector beg = op.Invoke(AsVector(ref xRef)); + Vector end = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count))); + + if (remainder > (uint)(Vector.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector)) - ((nuint)(dPtr) % (uint)(sizeof(Vector)))) / sizeof(float); + + xPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector))) == 0); + + remainder -= misalignment; + } + + Vector vector1; + Vector vector2; + Vector vector3; + Vector vector4; + + while (remainder >= (uint)(Vector.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 0))); + vector2 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 1))); + vector3 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 2))); + vector4 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 3))); + + *(Vector*)(dPtr + (uint)(Vector.Count * 0)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 1)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 2)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 3)) = vector4; + + // We load, process, and store the next four vectors + + vector1 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 4))); + vector2 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 5))); + vector3 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 6))); + vector4 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 7))); + + *(Vector*)(dPtr + (uint)(Vector.Count * 4)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 5)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 6)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 7)) = vector4; + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector.Count * 8); + dPtr += (uint)(Vector.Count * 8); + + remainder -= (uint)(Vector.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector.Count - 1)) & (nuint)(-Vector.Count); + + switch (remainder / (uint)(Vector.Count)) + { + case 8: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 8))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 8)) = vector; + goto case 7; + } + + case 7: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 7))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 7)) = vector; + goto case 6; + } + + case 6: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 6))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 6)) = vector; + goto case 5; + } + + case 5: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 5))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 5)) = vector; + goto case 4; + } + + case 4: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 4))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 4)) = vector; + goto case 3; + } + + case 3: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 3))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 3)) = vector; + goto case 2; + } + + case 2: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 2))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 2)) = vector; + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + AsVector(ref dRef, endIndex - (uint)Vector.Count) = end; + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + AsVector(ref dRefBeg) = beg; + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void VectorizedSmall(ref float xRef, ref float dRef, nuint remainder, TUnaryOperator op = default) + { + switch (remainder) + { + case 7: + { + Unsafe.Add(ref dRef, 6) = op.Invoke(Unsafe.Add(ref xRef, 6)); + goto case 6; + } + + case 6: + { + Unsafe.Add(ref dRef, 5) = op.Invoke(Unsafe.Add(ref xRef, 5)); + goto case 5; + } + + case 5: + { + Unsafe.Add(ref dRef, 4) = op.Invoke(Unsafe.Add(ref xRef, 4)); + goto case 4; + } + + case 4: + { + Unsafe.Add(ref dRef, 3) = op.Invoke(Unsafe.Add(ref xRef, 3)); + goto case 3; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = op.Invoke(Unsafe.Add(ref xRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = op.Invoke(Unsafe.Add(ref xRef, 1)); + goto case 1; + } + + case 1: + { + dRef = op.Invoke(xRef); + goto case 0; + } + + case 0: + { + break; + } + } + } + } + + /// + /// Performs an element-wise operation on and , + /// and writes the results to . + /// + /// + /// Specifies the operation to perform on the pair-wise elements loaded from and . + /// + private static unsafe void InvokeSpanSpanIntoSpan( + ReadOnlySpan x, ReadOnlySpan y, Span destination, TBinaryOperator op = default) + where TBinaryOperator : struct, IBinaryOperator + { + if (x.Length != y.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + } + + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ValidateInputOutputSpanNonOverlapping(x, destination); + ValidateInputOutputSpanNonOverlapping(y, destination); + + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float yRef = ref MemoryMarshal.GetReference(y); + ref float dRef = ref MemoryMarshal.GetReference(destination); + + nuint remainder = (uint)(x.Length); + + if (Vector.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector.Count)) + { + Vectorized(ref xRef, ref yRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + VectorizedSmall(ref xRef, ref yRef, ref dRef, remainder); + } + + return; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, ref yRef, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, ref float yRef, ref float dRef, nuint length, TBinaryOperator op = default) + { + for (nuint i = 0; i < length; i++) + { + Unsafe.Add(ref dRef, (nint)(i)) = op.Invoke(Unsafe.Add(ref xRef, (nint)(i)), + Unsafe.Add(ref yRef, (nint)(i))); + } + } + + static void Vectorized(ref float xRef, ref float yRef, ref float dRef, nuint remainder, TBinaryOperator op = default) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector beg = op.Invoke(AsVector(ref xRef), + AsVector(ref yRef)); + Vector end = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count)), + AsVector(ref yRef, remainder - (uint)(Vector.Count))); + + if (remainder > (uint)(Vector.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* yPtr = py; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector)) - ((nuint)(dPtr) % (uint)(sizeof(Vector)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector))) == 0); + + remainder -= misalignment; + } + + Vector vector1; + Vector vector2; + Vector vector3; + Vector vector4; + + while (remainder >= (uint)(Vector.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 0)), + *(Vector*)(yPtr + (uint)(Vector.Count * 0))); + vector2 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 1)), + *(Vector*)(yPtr + (uint)(Vector.Count * 1))); + vector3 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 2)), + *(Vector*)(yPtr + (uint)(Vector.Count * 2))); + vector4 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 3)), + *(Vector*)(yPtr + (uint)(Vector.Count * 3))); + + *(Vector*)(dPtr + (uint)(Vector.Count * 0)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 1)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 2)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 3)) = vector4; + + // We load, process, and store the next four vectors + + vector1 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 4)), + *(Vector*)(yPtr + (uint)(Vector.Count * 4))); + vector2 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 5)), + *(Vector*)(yPtr + (uint)(Vector.Count * 5))); + vector3 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 6)), + *(Vector*)(yPtr + (uint)(Vector.Count * 6))); + vector4 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 7)), + *(Vector*)(yPtr + (uint)(Vector.Count * 7))); + + *(Vector*)(dPtr + (uint)(Vector.Count * 4)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 5)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 6)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 7)) = vector4; + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector.Count * 8); + yPtr += (uint)(Vector.Count * 8); + dPtr += (uint)(Vector.Count * 8); + + remainder -= (uint)(Vector.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector.Count - 1)) & (nuint)(-Vector.Count); + + switch (remainder / (uint)(Vector.Count)) + { + case 8: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 8)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 8))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 8)) = vector; + goto case 7; + } + + case 7: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 7)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 7))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 7)) = vector; + goto case 6; + } + + case 6: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 6)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 6))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 6)) = vector; + goto case 5; + } + + case 5: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 5)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 5))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 5)) = vector; + goto case 4; + } + + case 4: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 4)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 4))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 4)) = vector; + goto case 3; + } + + case 3: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 3)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 3))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 3)) = vector; + goto case 2; + } + + case 2: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 2)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 2))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 2)) = vector; + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + AsVector(ref dRef, endIndex - (uint)Vector.Count) = end; + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + AsVector(ref dRefBeg) = beg; + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void VectorizedSmall(ref float xRef, ref float yRef, ref float dRef, nuint remainder, TBinaryOperator op = default) + { + switch (remainder) + { + case 7: + { + Unsafe.Add(ref dRef, 6) = op.Invoke(Unsafe.Add(ref xRef, 6), + Unsafe.Add(ref yRef, 6)); + goto case 6; + } + + case 6: + { + Unsafe.Add(ref dRef, 5) = op.Invoke(Unsafe.Add(ref xRef, 5), + Unsafe.Add(ref yRef, 5)); + goto case 5; + } + + case 5: + { + Unsafe.Add(ref dRef, 4) = op.Invoke(Unsafe.Add(ref xRef, 4), + Unsafe.Add(ref yRef, 4)); + goto case 4; + } + + case 4: + { + Unsafe.Add(ref dRef, 3) = op.Invoke(Unsafe.Add(ref xRef, 3), + Unsafe.Add(ref yRef, 3)); + goto case 3; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = op.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = op.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1)); + goto case 1; + } + + case 1: + { + dRef = op.Invoke(xRef, yRef); + goto case 0; + } + + case 0: + { + break; + } + } + } + } + + /// + /// Performs an element-wise operation on and , + /// and writes the results to . + /// + /// + /// Specifies the operation to perform on each element loaded from with . + /// + private static void InvokeSpanScalarIntoSpan( + ReadOnlySpan x, float y, Span destination, TBinaryOperator op = default) + where TBinaryOperator : struct, IBinaryOperator => + InvokeSpanScalarIntoSpan(x, y, destination, default, op); + + /// + /// Performs an element-wise operation on and , + /// and writes the results to . + /// + /// + /// Specifies the operation to perform on each element loaded from . + /// It is not used with . + /// + /// + /// Specifies the operation to perform on the transformed value from with . + /// + private static unsafe void InvokeSpanScalarIntoSpan( + ReadOnlySpan x, float y, Span destination, TTransformOperator xTransformOp = default, TBinaryOperator binaryOp = default) + where TTransformOperator : struct, IUnaryOperator + where TBinaryOperator : struct, IBinaryOperator + { + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ValidateInputOutputSpanNonOverlapping(x, destination); + + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float dRef = ref MemoryMarshal.GetReference(destination); + + nuint remainder = (uint)(x.Length); + + if (Vector.IsHardwareAccelerated && xTransformOp.CanVectorize) + { + if (remainder >= (uint)(Vector.Count)) + { + Vectorized(ref xRef, y, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + VectorizedSmall(ref xRef, y, ref dRef, remainder); + } + + return; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, y, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, float y, ref float dRef, nuint length, TTransformOperator xTransformOp = default, TBinaryOperator binaryOp = default) + { + for (nuint i = 0; i < length; i++) + { + Unsafe.Add(ref dRef, (nint)(i)) = binaryOp.Invoke(xTransformOp.Invoke(Unsafe.Add(ref xRef, (nint)(i))), + y); + } + } + + static void Vectorized(ref float xRef, float y, ref float dRef, nuint remainder, TTransformOperator xTransformOp = default, TBinaryOperator binaryOp = default) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector yVec = new Vector(y); + + Vector beg = binaryOp.Invoke(xTransformOp.Invoke(AsVector(ref xRef)), + yVec); + Vector end = binaryOp.Invoke(xTransformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count))), + yVec); + + if (remainder > (uint)(Vector.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector)) - ((nuint)(dPtr) % (uint)(sizeof(Vector)))) / sizeof(float); + + xPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector))) == 0); + + remainder -= misalignment; + } + + Vector vector1; + Vector vector2; + Vector vector3; + Vector vector4; + + while (remainder >= (uint)(Vector.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = binaryOp.Invoke(xTransformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 0))), + yVec); + vector2 = binaryOp.Invoke(xTransformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 1))), + yVec); + vector3 = binaryOp.Invoke(xTransformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 2))), + yVec); + vector4 = binaryOp.Invoke(xTransformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 3))), + yVec); + + *(Vector*)(dPtr + (uint)(Vector.Count * 0)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 1)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 2)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 3)) = vector4; + + // We load, process, and store the next four vectors + + vector1 = binaryOp.Invoke(xTransformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 4))), + yVec); + vector2 = binaryOp.Invoke(xTransformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 5))), + yVec); + vector3 = binaryOp.Invoke(xTransformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 6))), + yVec); + vector4 = binaryOp.Invoke(xTransformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 7))), + yVec); + + *(Vector*)(dPtr + (uint)(Vector.Count * 4)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 5)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 6)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 7)) = vector4; + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector.Count * 8); + dPtr += (uint)(Vector.Count * 8); + + remainder -= (uint)(Vector.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector.Count - 1)) & (nuint)(-Vector.Count); + + switch (remainder / (uint)(Vector.Count)) + { + case 8: + { + Vector vector = binaryOp.Invoke(xTransformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 8))), + yVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 8)) = vector; + goto case 7; + } + + case 7: + { + Vector vector = binaryOp.Invoke(xTransformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 7))), + yVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 7)) = vector; + goto case 6; + } + + case 6: + { + Vector vector = binaryOp.Invoke(xTransformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 6))), + yVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 6)) = vector; + goto case 5; + } + + case 5: + { + Vector vector = binaryOp.Invoke(xTransformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 5))), + yVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 5)) = vector; + goto case 4; + } + + case 4: + { + Vector vector = binaryOp.Invoke(xTransformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 4))), + yVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 4)) = vector; + goto case 3; + } + + case 3: + { + Vector vector = binaryOp.Invoke(xTransformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 3))), + yVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 3)) = vector; + goto case 2; + } + + case 2: + { + Vector vector = binaryOp.Invoke(xTransformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 2))), + yVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 2)) = vector; + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + AsVector(ref dRef, endIndex - (uint)Vector.Count) = end; + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + AsVector(ref dRefBeg) = beg; + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void VectorizedSmall(ref float xRef, float y, ref float dRef, nuint remainder, TTransformOperator xTransformOp = default, TBinaryOperator binaryOp = default) + { + switch (remainder) + { + case 7: + { + Unsafe.Add(ref dRef, 6) = binaryOp.Invoke(xTransformOp.Invoke(Unsafe.Add(ref xRef, 6)), + y); + goto case 6; + } + + case 6: + { + Unsafe.Add(ref dRef, 5) = binaryOp.Invoke(xTransformOp.Invoke(Unsafe.Add(ref xRef, 5)), + y); + goto case 5; + } + + case 5: + { + Unsafe.Add(ref dRef, 4) = binaryOp.Invoke(xTransformOp.Invoke(Unsafe.Add(ref xRef, 4)), + y); + goto case 4; + } + + case 4: + { + Unsafe.Add(ref dRef, 3) = binaryOp.Invoke(xTransformOp.Invoke(Unsafe.Add(ref xRef, 3)), + y); + goto case 3; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = binaryOp.Invoke(xTransformOp.Invoke(Unsafe.Add(ref xRef, 2)), + y); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = binaryOp.Invoke(xTransformOp.Invoke(Unsafe.Add(ref xRef, 1)), + y); + goto case 1; + } + + case 1: + { + dRef = binaryOp.Invoke(xTransformOp.Invoke(xRef), y); + goto case 0; + } + + case 0: + { + break; + } + } + } + } + + /// + /// Performs an element-wise operation on , , and , + /// and writes the results to . + /// + /// + /// Specifies the operation to perform on the pair-wise elements loaded from , , + /// and . + /// + private static unsafe void InvokeSpanSpanSpanIntoSpan( + ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan z, Span destination, TTernaryOperator op = default) + where TTernaryOperator : struct, ITernaryOperator + { + if (x.Length != y.Length || x.Length != z.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + } + + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ValidateInputOutputSpanNonOverlapping(x, destination); + ValidateInputOutputSpanNonOverlapping(y, destination); + ValidateInputOutputSpanNonOverlapping(z, destination); + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float yRef = ref MemoryMarshal.GetReference(y); + ref float zRef = ref MemoryMarshal.GetReference(z); + ref float dRef = ref MemoryMarshal.GetReference(destination); + + nuint remainder = (uint)(x.Length); + + if (Vector.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector.Count)) + { + Vectorized(ref xRef, ref yRef, ref zRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + VectorizedSmall(ref xRef, ref yRef, ref zRef, ref dRef, remainder); + } + + return; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, ref yRef, ref zRef, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint length, TTernaryOperator op = default) + { + for (nuint i = 0; i < length; i++) + { + Unsafe.Add(ref dRef, (nint)(i)) = op.Invoke(Unsafe.Add(ref xRef, (nint)(i)), + Unsafe.Add(ref yRef, (nint)(i)), + Unsafe.Add(ref zRef, (nint)(i))); + } + } + + static void Vectorized(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint remainder, TTernaryOperator op = default) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector beg = op.Invoke(AsVector(ref xRef), + AsVector(ref yRef), + AsVector(ref zRef)); + Vector end = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count)), + AsVector(ref yRef, remainder - (uint)(Vector.Count)), + AsVector(ref zRef, remainder - (uint)(Vector.Count))); + + if (remainder > (uint)(Vector.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pz = &zRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* yPtr = py; + float* zPtr = pz; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector)) - ((nuint)(dPtr) % (uint)(sizeof(Vector)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + zPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector))) == 0); + + remainder -= misalignment; + } + + Vector vector1; + Vector vector2; + Vector vector3; + Vector vector4; + + while (remainder >= (uint)(Vector.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 0)), + *(Vector*)(yPtr + (uint)(Vector.Count * 0)), + *(Vector*)(zPtr + (uint)(Vector.Count * 0))); + vector2 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 1)), + *(Vector*)(yPtr + (uint)(Vector.Count * 1)), + *(Vector*)(zPtr + (uint)(Vector.Count * 1))); + vector3 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 2)), + *(Vector*)(yPtr + (uint)(Vector.Count * 2)), + *(Vector*)(zPtr + (uint)(Vector.Count * 2))); + vector4 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 3)), + *(Vector*)(yPtr + (uint)(Vector.Count * 3)), + *(Vector*)(zPtr + (uint)(Vector.Count * 3))); + + *(Vector*)(dPtr + (uint)(Vector.Count * 0)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 1)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 2)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 3)) = vector4; + + // We load, process, and store the next four vectors + + vector1 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 4)), + *(Vector*)(yPtr + (uint)(Vector.Count * 4)), + *(Vector*)(zPtr + (uint)(Vector.Count * 4))); + vector2 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 5)), + *(Vector*)(yPtr + (uint)(Vector.Count * 5)), + *(Vector*)(zPtr + (uint)(Vector.Count * 5))); + vector3 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 6)), + *(Vector*)(yPtr + (uint)(Vector.Count * 6)), + *(Vector*)(zPtr + (uint)(Vector.Count * 6))); + vector4 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 7)), + *(Vector*)(yPtr + (uint)(Vector.Count * 7)), + *(Vector*)(zPtr + (uint)(Vector.Count * 7))); + + *(Vector*)(dPtr + (uint)(Vector.Count * 4)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 5)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 6)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 7)) = vector4; + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector.Count * 8); + yPtr += (uint)(Vector.Count * 8); + zPtr += (uint)(Vector.Count * 8); + dPtr += (uint)(Vector.Count * 8); + + remainder -= (uint)(Vector.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + zRef = ref *zPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector.Count - 1)) & (nuint)(-Vector.Count); + + switch (remainder / (uint)(Vector.Count)) + { + case 8: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 8)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 8)), + AsVector(ref zRef, remainder - (uint)(Vector.Count * 8))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 8)) = vector; + goto case 7; + } + + case 7: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 7)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 7)), + AsVector(ref zRef, remainder - (uint)(Vector.Count * 7))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 7)) = vector; + goto case 6; + } + + case 6: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 6)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 6)), + AsVector(ref zRef, remainder - (uint)(Vector.Count * 6))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 6)) = vector; + goto case 5; + } + + case 5: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 5)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 5)), + AsVector(ref zRef, remainder - (uint)(Vector.Count * 5))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 5)) = vector; + goto case 4; + } + + case 4: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 4)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 4)), + AsVector(ref zRef, remainder - (uint)(Vector.Count * 4))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 4)) = vector; + goto case 3; + } + + case 3: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 3)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 3)), + AsVector(ref zRef, remainder - (uint)(Vector.Count * 3))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 3)) = vector; + goto case 2; + } + + case 2: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 2)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 2)), + AsVector(ref zRef, remainder - (uint)(Vector.Count * 2))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 2)) = vector; + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + AsVector(ref dRef, endIndex - (uint)Vector.Count) = end; + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + AsVector(ref dRefBeg) = beg; + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void VectorizedSmall(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint remainder, TTernaryOperator op = default) + { + switch (remainder) + { + case 7: + { + Unsafe.Add(ref dRef, 6) = op.Invoke(Unsafe.Add(ref xRef, 6), + Unsafe.Add(ref yRef, 6), + Unsafe.Add(ref zRef, 6)); + goto case 6; + } + + case 6: + { + Unsafe.Add(ref dRef, 5) = op.Invoke(Unsafe.Add(ref xRef, 5), + Unsafe.Add(ref yRef, 5), + Unsafe.Add(ref zRef, 5)); + goto case 5; + } + + case 5: + { + Unsafe.Add(ref dRef, 4) = op.Invoke(Unsafe.Add(ref xRef, 4), + Unsafe.Add(ref yRef, 4), + Unsafe.Add(ref zRef, 4)); + goto case 4; + } + + case 4: + { + Unsafe.Add(ref dRef, 3) = op.Invoke(Unsafe.Add(ref xRef, 3), + Unsafe.Add(ref yRef, 3), + Unsafe.Add(ref zRef, 3)); + goto case 3; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = op.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2), + Unsafe.Add(ref zRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = op.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1), + Unsafe.Add(ref zRef, 1)); + goto case 1; + } + + case 1: + { + dRef = op.Invoke(xRef, yRef, zRef); + break; + } + + case 0: + { + break; + } + } + } + } + + /// + /// Performs an element-wise operation on , , and , + /// and writes the results to . + /// + /// + /// Specifies the operation to perform on the pair-wise elements loaded from and + /// with . + /// + private static unsafe void InvokeSpanSpanScalarIntoSpan( + ReadOnlySpan x, ReadOnlySpan y, float z, Span destination, TTernaryOperator op = default) + where TTernaryOperator : struct, ITernaryOperator + { + if (x.Length != y.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + } + + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ValidateInputOutputSpanNonOverlapping(x, destination); + ValidateInputOutputSpanNonOverlapping(y, destination); + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float yRef = ref MemoryMarshal.GetReference(y); + ref float dRef = ref MemoryMarshal.GetReference(destination); + + nuint remainder = (uint)(x.Length); + + if (Vector.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector.Count)) + { + Vectorized(ref xRef, ref yRef, z, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + VectorizedSmall(ref xRef, ref yRef, z, ref dRef, remainder); + } + + return; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, ref yRef, z, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, ref float yRef, float z, ref float dRef, nuint length, TTernaryOperator op = default) + { + for (nuint i = 0; i < length; i++) + { + Unsafe.Add(ref dRef, (nint)(i)) = op.Invoke(Unsafe.Add(ref xRef, (nint)(i)), + Unsafe.Add(ref yRef, (nint)(i)), + z); + } + } + + static void Vectorized(ref float xRef, ref float yRef, float z, ref float dRef, nuint remainder, TTernaryOperator op = default) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector zVec = new Vector(z); + + Vector beg = op.Invoke(AsVector(ref xRef), + AsVector(ref yRef), + zVec); + Vector end = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count)), + AsVector(ref yRef, remainder - (uint)(Vector.Count)), + zVec); + + if (remainder > (uint)(Vector.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* yPtr = py; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector)) - ((nuint)(dPtr) % (uint)(sizeof(Vector)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector))) == 0); + + remainder -= misalignment; + } + + Vector vector1; + Vector vector2; + Vector vector3; + Vector vector4; + + while (remainder >= (uint)(Vector.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 0)), + *(Vector*)(yPtr + (uint)(Vector.Count * 0)), + zVec); + vector2 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 1)), + *(Vector*)(yPtr + (uint)(Vector.Count * 1)), + zVec); + vector3 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 2)), + *(Vector*)(yPtr + (uint)(Vector.Count * 2)), + zVec); + vector4 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 3)), + *(Vector*)(yPtr + (uint)(Vector.Count * 3)), + zVec); + + *(Vector*)(dPtr + (uint)(Vector.Count * 0)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 1)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 2)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 3)) = vector4; + + // We load, process, and store the next four vectors + + vector1 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 4)), + *(Vector*)(yPtr + (uint)(Vector.Count * 4)), + zVec); + vector2 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 5)), + *(Vector*)(yPtr + (uint)(Vector.Count * 5)), + zVec); + vector3 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 6)), + *(Vector*)(yPtr + (uint)(Vector.Count * 6)), + zVec); + vector4 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 7)), + *(Vector*)(yPtr + (uint)(Vector.Count * 7)), + zVec); + + *(Vector*)(dPtr + (uint)(Vector.Count * 4)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 5)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 6)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 7)) = vector4; + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector.Count * 8); + yPtr += (uint)(Vector.Count * 8); + dPtr += (uint)(Vector.Count * 8); + + remainder -= (uint)(Vector.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector.Count - 1)) & (nuint)(-Vector.Count); + + switch (remainder / (uint)(Vector.Count)) + { + case 8: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 8)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 8)), + zVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 8)) = vector; + goto case 7; + } + + case 7: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 7)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 7)), + zVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 7)) = vector; + goto case 6; + } + + case 6: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 6)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 6)), + zVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 6)) = vector; + goto case 5; + } + + case 5: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 5)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 5)), + zVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 5)) = vector; + goto case 4; + } + + case 4: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 4)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 4)), + zVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 4)) = vector; + goto case 3; + } + + case 3: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 3)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 3)), + zVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 3)) = vector; + goto case 2; + } + + case 2: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 2)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 2)), + zVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 2)) = vector; + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + AsVector(ref dRef, endIndex - (uint)Vector.Count) = end; + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + AsVector(ref dRefBeg) = beg; + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void VectorizedSmall(ref float xRef, ref float yRef, float z, ref float dRef, nuint remainder, TTernaryOperator op = default) + { + switch (remainder) + { + case 7: + { + Unsafe.Add(ref dRef, 6) = op.Invoke(Unsafe.Add(ref xRef, 6), + Unsafe.Add(ref yRef, 6), + z); + goto case 6; + } + + case 6: + { + Unsafe.Add(ref dRef, 5) = op.Invoke(Unsafe.Add(ref xRef, 5), + Unsafe.Add(ref yRef, 5), + z); + goto case 5; + } + + case 5: + { + Unsafe.Add(ref dRef, 4) = op.Invoke(Unsafe.Add(ref xRef, 4), + Unsafe.Add(ref yRef, 4), + z); + goto case 4; + } + + case 4: + { + Unsafe.Add(ref dRef, 3) = op.Invoke(Unsafe.Add(ref xRef, 3), + Unsafe.Add(ref yRef, 3), + z); + goto case 3; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = op.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2), + z); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = op.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1), + z); + goto case 1; + } + + case 1: + { + dRef = op.Invoke(xRef, yRef, z); + goto case 0; + } + + case 0: + { + break; + } + } + } + } + + /// + /// Performs an element-wise operation on , , and , + /// and writes the results to . + /// + /// + /// Specifies the operation to perform on the pair-wise element loaded from , with , + /// and the element loaded from . + /// + private static unsafe void InvokeSpanScalarSpanIntoSpan( + ReadOnlySpan x, float y, ReadOnlySpan z, Span destination, TTernaryOperator op = default) + where TTernaryOperator : struct, ITernaryOperator + { + if (x.Length != z.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + } + + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ValidateInputOutputSpanNonOverlapping(x, destination); + ValidateInputOutputSpanNonOverlapping(z, destination); + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float zRef = ref MemoryMarshal.GetReference(z); + ref float dRef = ref MemoryMarshal.GetReference(destination); + + nuint remainder = (uint)(x.Length); + + if (Vector.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector.Count)) + { + Vectorized(ref xRef, y, ref zRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + VectorizedSmall(ref xRef, y, ref zRef, ref dRef, remainder); + } + + return; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, y, ref zRef, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, float y, ref float zRef, ref float dRef, nuint length, TTernaryOperator op = default) + { + for (nuint i = 0; i < length; i++) + { + Unsafe.Add(ref dRef, (nint)(i)) = op.Invoke(Unsafe.Add(ref xRef, (nint)(i)), + y, + Unsafe.Add(ref zRef, (nint)(i))); + } + } + + static void Vectorized(ref float xRef, float y, ref float zRef, ref float dRef, nuint remainder, TTernaryOperator op = default) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector yVec = new Vector(y); + + Vector beg = op.Invoke(AsVector(ref xRef), + yVec, + AsVector(ref zRef)); + Vector end = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count)), + yVec, + AsVector(ref zRef, remainder - (uint)(Vector.Count))); + + if (remainder > (uint)(Vector.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* pz = &zRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* zPtr = pz; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector)) - ((nuint)(dPtr) % (uint)(sizeof(Vector)))) / sizeof(float); + + xPtr += misalignment; + zPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector))) == 0); + + remainder -= misalignment; + } + + Vector vector1; + Vector vector2; + Vector vector3; + Vector vector4; + + while (remainder >= (uint)(Vector.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 0)), + yVec, + *(Vector*)(zPtr + (uint)(Vector.Count * 0))); + vector2 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 1)), + yVec, + *(Vector*)(zPtr + (uint)(Vector.Count * 1))); + vector3 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 2)), + yVec, + *(Vector*)(zPtr + (uint)(Vector.Count * 2))); + vector4 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 3)), + yVec, + *(Vector*)(zPtr + (uint)(Vector.Count * 3))); + + *(Vector*)(dPtr + (uint)(Vector.Count * 0)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 1)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 2)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 3)) = vector4; + + // We load, process, and store the next four vectors + + vector1 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 4)), + yVec, + *(Vector*)(zPtr + (uint)(Vector.Count * 4))); + vector2 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 5)), + yVec, + *(Vector*)(zPtr + (uint)(Vector.Count * 5))); + vector3 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 6)), + yVec, + *(Vector*)(zPtr + (uint)(Vector.Count * 6))); + vector4 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 7)), + yVec, + *(Vector*)(zPtr + (uint)(Vector.Count * 7))); + + *(Vector*)(dPtr + (uint)(Vector.Count * 4)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 5)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 6)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 7)) = vector4; + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector.Count * 8); + zPtr += (uint)(Vector.Count * 8); + dPtr += (uint)(Vector.Count * 8); + + remainder -= (uint)(Vector.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + zRef = ref *zPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector.Count - 1)) & (nuint)(-Vector.Count); + + switch (remainder / (uint)(Vector.Count)) + { + case 8: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 8)), + yVec, + AsVector(ref zRef, remainder - (uint)(Vector.Count * 8))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 8)) = vector; + goto case 7; + } + + case 7: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 7)), + yVec, + AsVector(ref zRef, remainder - (uint)(Vector.Count * 7))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 7)) = vector; + goto case 6; + } + + case 6: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 6)), + yVec, + AsVector(ref zRef, remainder - (uint)(Vector.Count * 6))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 6)) = vector; + goto case 5; + } + + case 5: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 5)), + yVec, + AsVector(ref zRef, remainder - (uint)(Vector.Count * 5))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 5)) = vector; + goto case 4; + } + + case 4: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 4)), + yVec, + AsVector(ref zRef, remainder - (uint)(Vector.Count * 4))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 4)) = vector; + goto case 3; + } + + case 3: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 3)), + yVec, + AsVector(ref zRef, remainder - (uint)(Vector.Count * 3))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 3)) = vector; + goto case 2; + } + + case 2: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 2)), + yVec, + AsVector(ref zRef, remainder - (uint)(Vector.Count * 2))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 2)) = vector; + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + AsVector(ref dRef, endIndex - (uint)Vector.Count) = end; + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + AsVector(ref dRefBeg) = beg; + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void VectorizedSmall(ref float xRef, float y, ref float zRef, ref float dRef, nuint remainder, TTernaryOperator op = default) + { + switch (remainder) + { + case 7: + { + Unsafe.Add(ref dRef, 6) = op.Invoke(Unsafe.Add(ref xRef, 6), + y, + Unsafe.Add(ref zRef, 6)); + goto case 6; + } + + case 6: + { + Unsafe.Add(ref dRef, 5) = op.Invoke(Unsafe.Add(ref xRef, 5), + y, + Unsafe.Add(ref zRef, 5)); + goto case 5; + } + + case 5: + { + Unsafe.Add(ref dRef, 4) = op.Invoke(Unsafe.Add(ref xRef, 4), + y, + Unsafe.Add(ref zRef, 4)); + goto case 4; + } + + case 4: + { + Unsafe.Add(ref dRef, 3) = op.Invoke(Unsafe.Add(ref xRef, 3), + y, + Unsafe.Add(ref zRef, 3)); + goto case 3; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = op.Invoke(Unsafe.Add(ref xRef, 2), + y, + Unsafe.Add(ref zRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = op.Invoke(Unsafe.Add(ref xRef, 1), + y, + Unsafe.Add(ref zRef, 1)); + goto case 1; + } + + case 1: + { + dRef = op.Invoke(xRef, y, zRef); + goto case 0; + } + + case 0: + { + break; + } + } + } + } + + /// Loads a from . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static ref Vector AsVector(ref float start) => + ref Unsafe.As>(ref start); + + /// Loads a that begins at the specified from . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static ref Vector AsVector(ref float start, int offset) => + ref Unsafe.As>( + ref Unsafe.Add(ref start, offset)); + + /// Loads a that begins at the specified from . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static ref Vector AsVector(ref float start, nuint offset) => + ref Unsafe.As>( + ref Unsafe.Add(ref start, (nint)(offset))); + + /// Loads a that begins at the specified from . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static ref Vector AsVector(ref int start, int offset) => + ref Unsafe.As>( + ref Unsafe.Add(ref start, offset)); + + /// Gets whether the specified is positive. + private static bool IsPositive(float f) => !IsNegative(f); + + /// Gets whether each specified is positive. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector IsPositive(Vector vector) => + ((Vector)Vector.GreaterThan(((Vector)vector), Vector.Zero)); + + /// Gets whether the specified is negative. + private static unsafe bool IsNegative(float f) => *(int*)&f < 0; + + /// Gets whether each specified is negative. + private static Vector IsNegative(Vector f) => + (Vector)Vector.LessThan((Vector)f, Vector.Zero); + + /// Gets the base 2 logarithm of . + private static float Log2(float x) => MathF.Log(x, 2); + + /// + /// Gets a vector mask that will be all-ones-set for the first elements + /// and zero for all other elements. + /// + private static Vector CreateAlignmentMaskSingleVector(int count) + { + Debug.Assert(Vector.Count is 4 or 8 or 16); + + return AsVector( + ref Unsafe.As(ref MemoryMarshal.GetReference(AlignmentUInt32Mask_16x16)), + (count * 16)); + } + + /// + /// Gets a vector mask that will be all-ones-set for the last elements + /// and zero for all other elements. + /// + private static Vector CreateRemainderMaskSingleVector(int count) + { + Debug.Assert(Vector.Count is 4 or 8 or 16); + + return AsVector( + ref Unsafe.As(ref MemoryMarshal.GetReference(RemainderUInt32Mask_16x16)), + (count * 16) + (16 - Vector.Count)); + } + + /// x + y + private readonly struct AddOperator : IAggregationOperator + { + public float Invoke(float x, float y) => x + y; + public Vector Invoke(Vector x, Vector y) => x + y; + public float IdentityValue => 0; + } + + /// x - y + private readonly struct SubtractOperator : IBinaryOperator + { + public float Invoke(float x, float y) => x - y; + public Vector Invoke(Vector x, Vector y) => x - y; + } + + /// (x - y) * (x - y) + private readonly struct SubtractSquaredOperator : IBinaryOperator + { + public float Invoke(float x, float y) + { + float tmp = x - y; + return tmp * tmp; + } + + public Vector Invoke(Vector x, Vector y) + { + Vector tmp = x - y; + return tmp * tmp; + } + } + + /// x * y + private readonly struct MultiplyOperator : IAggregationOperator + { + public float Invoke(float x, float y) => x * y; + public Vector Invoke(Vector x, Vector y) => x * y; + public float IdentityValue => 1; + } + + /// x / y + private readonly struct DivideOperator : IBinaryOperator + { + public float Invoke(float x, float y) => x / y; + public Vector Invoke(Vector x, Vector y) => x / y; + } + + private interface IIndexOfOperator + { + int Invoke(ref float result, float current, int resultIndex, int curIndex); + int Invoke(Vector result, Vector resultIndex); + void Invoke(ref Vector result, Vector current, ref Vector resultIndex, Vector curIndex); + } + + /// Returns the index of MathF.Max(x, y) + private readonly struct IndexOfMaxOperator : IIndexOfOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public int Invoke(Vector result, Vector resultIndex) + { + float curMax = result[0]; + int curIn = resultIndex[0]; + for (int i = 1; i < Vector.Count; i++) + { + if (result[i] == curMax && IsNegative(curMax) && !IsNegative(result[i])) + { + curMax = result[i]; + curIn = resultIndex[i]; + } + else if (result[i] > curMax) + { + curMax = result[i]; + curIn = resultIndex[i]; + } + } + + return curIn; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void Invoke(ref Vector result, Vector current, ref Vector resultIndex, Vector curIndex) + { + Vector lessThanMask = Vector.GreaterThan(result, current); + + Vector equalMask = Vector.Equals(result, current); + + if (equalMask != Vector.Zero) + { + Vector negativeMask = IsNegative(current); + Vector lessThanIndexMask = Vector.LessThan(resultIndex, curIndex); + + lessThanMask |= ((Vector)~negativeMask & equalMask) | ((Vector)IsNegative(result) & equalMask & lessThanIndexMask); + } + + result = Vector.ConditionalSelect(lessThanMask, result, current); + + resultIndex = Vector.ConditionalSelect(lessThanMask, resultIndex, curIndex); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public int Invoke(ref float result, float current, int resultIndex, int curIndex) + { + if (result == current) + { + if (IsNegative(result) && !IsNegative(current)) + { + result = current; + return curIndex; + } + } + else if (current > result) + { + result = current; + return curIndex; + } + + return resultIndex; + } + } + + private readonly struct IndexOfMaxMagnitudeOperator : IIndexOfOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public int Invoke(ref float result, float current, int resultIndex, int curIndex) + { + float curMaxAbs = MathF.Abs(result); + float currentAbs = MathF.Abs(current); + + if (curMaxAbs == currentAbs) + { + if (IsNegative(result) && !IsNegative(current)) + { + result = current; + return curIndex; + } + } + else if (currentAbs > curMaxAbs) + { + result = current; + return curIndex; + } + + return resultIndex; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public int Invoke(Vector result, Vector maxIndex) + { + float curMax = result[0]; + int curIn = maxIndex[0]; + for (int i = 1; i < Vector.Count; i++) + { + if (MathF.Abs(result[i]) == MathF.Abs(curMax) && IsNegative(curMax) && !IsNegative(result[i])) + { + curMax = result[i]; + curIn = maxIndex[i]; + } + else if (MathF.Abs(result[i]) > MathF.Abs(curMax)) + { + curMax = result[i]; + curIn = maxIndex[i]; + } + } + + return curIn; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void Invoke(ref Vector result, Vector current, ref Vector resultIndex, Vector curIndex) + { + Vector maxMag = Vector.Abs(result), currentMag = Vector.Abs(current); + + Vector lessThanMask = Vector.GreaterThan(maxMag, currentMag); + + Vector equalMask = Vector.Equals(result, current); + + if (equalMask != Vector.Zero) + { + Vector negativeMask = IsNegative(current); + Vector lessThanIndexMask = Vector.LessThan(resultIndex, curIndex); + + lessThanMask |= ((Vector)~negativeMask & equalMask) | ((Vector)IsNegative(result) & equalMask & lessThanIndexMask); + } + + result = Vector.ConditionalSelect(lessThanMask, result, current); + + resultIndex = Vector.ConditionalSelect(lessThanMask, resultIndex, curIndex); + } + } + + private readonly struct IndexOfMinOperator : IIndexOfOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public int Invoke(ref float result, float current, int resultIndex, int curIndex) + { + if (result == current) + { + if (IsPositive(result) && !IsPositive(current)) + { + result = current; + return curIndex; + } + } + else if (current < result) + { + result = current; + return curIndex; + } + + return resultIndex; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public int Invoke(Vector result, Vector resultIndex) + { + float curMin = result[0]; + int curIn = resultIndex[0]; + for (int i = 1; i < Vector.Count; i++) + { + if (result[i] == curMin && IsPositive(curMin) && !IsPositive(result[i])) + { + curMin = result[i]; + curIn = resultIndex[i]; + } + else if (result[i] < curMin) + { + curMin = result[i]; + curIn = resultIndex[i]; + } + } + + return curIn; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void Invoke(ref Vector result, Vector current, ref Vector resultIndex, Vector curIndex) + { + Vector lessThanMask = Vector.LessThan(result, current); + + Vector equalMask = Vector.Equals(result, current); + + if (equalMask != Vector.Zero) + { + Vector negativeMask = IsNegative(current); + Vector lessThanIndexMask = Vector.LessThan(resultIndex, curIndex); + + lessThanMask |= ((Vector)negativeMask & equalMask) | (~(Vector)IsNegative(result) & equalMask & lessThanIndexMask); + } + + result = Vector.ConditionalSelect(lessThanMask, result, current); + + resultIndex = Vector.ConditionalSelect(lessThanMask, resultIndex, curIndex); + } + } + + private readonly struct IndexOfMinMagnitudeOperator : IIndexOfOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public int Invoke(ref float result, float current, int resultIndex, int curIndex) + { + float curMinAbs = MathF.Abs(result); + float currentAbs = MathF.Abs(current); + if (curMinAbs == currentAbs) + { + if (IsPositive(result) && !IsPositive(current)) + { + result = current; + return curIndex; + } + } + else if (currentAbs < curMinAbs) + { + result = current; + return curIndex; + } + + return resultIndex; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public int Invoke(Vector result, Vector resultIndex) + { + float curMin = result[0]; + int curIn = resultIndex[0]; + for (int i = 1; i < Vector.Count; i++) + { + if (MathF.Abs(result[i]) == MathF.Abs(curMin) && IsPositive(curMin) && !IsPositive(result[i])) + { + curMin = result[i]; + curIn = resultIndex[i]; + } + else if (MathF.Abs(result[i]) < MathF.Abs(curMin)) + { + curMin = result[i]; + curIn = resultIndex[i]; + } + } + + return curIn; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void Invoke(ref Vector result, Vector current, ref Vector resultIndex, Vector curIndex) + { + Vector minMag = Vector.Abs(result), currentMag = Vector.Abs(current); + + Vector lessThanMask = Vector.LessThan(minMag, currentMag); + + Vector equalMask = Vector.Equals(result, current); + + if (equalMask != Vector.Zero) + { + Vector negativeMask = IsNegative(current); + Vector lessThanIndexMask = Vector.LessThan(resultIndex, curIndex); + + lessThanMask |= ((Vector)negativeMask & equalMask) | (~(Vector)IsNegative(result) & equalMask & lessThanIndexMask); + } + + result = Vector.ConditionalSelect(lessThanMask, result, current); + + resultIndex = Vector.ConditionalSelect(lessThanMask, resultIndex, curIndex); + } + } + + /// MathF.Max(x, y) (but without guaranteed NaN propagation) + private readonly struct MaxOperator : IBinaryOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public float Invoke(float x, float y) => + x == y ? + (IsNegative(x) ? y : x) : + (y > x ? y : x); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Vector Invoke(Vector x, Vector y) => + Vector.ConditionalSelect(Vector.Equals(x, y), + Vector.ConditionalSelect(IsNegative(x), y, x), + Vector.Max(x, y)); + } + + /// MathF.Max(x, y) + private readonly struct MaxPropagateNaNOperator : IBinaryOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public float Invoke(float x, float y) => MathF.Max(x, y); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Vector Invoke(Vector x, Vector y) => + Vector.ConditionalSelect(Vector.Equals(x, x), + Vector.ConditionalSelect(Vector.Equals(y, y), + Vector.ConditionalSelect(Vector.Equals(x, y), + Vector.ConditionalSelect(IsNegative(x), y, x), + Vector.Max(x, y)), + y), + x); + } + + /// Operator to get x or y based on which has the larger MathF.Abs (but NaNs may not be propagated) + private readonly struct MaxMagnitudeOperator : IBinaryOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public float Invoke(float x, float y) + { + float xMag = MathF.Abs(x), yMag = MathF.Abs(y); + return + yMag == xMag ? + (IsNegative(x) ? y : x) : + (xMag > yMag ? x : y); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Vector Invoke(Vector x, Vector y) + { + Vector xMag = Vector.Abs(x), yMag = Vector.Abs(y); + return + Vector.ConditionalSelect(Vector.Equals(xMag, yMag), + Vector.ConditionalSelect(IsNegative(x), y, x), + Vector.ConditionalSelect(Vector.GreaterThan(xMag, yMag), x, y)); + } + } + + /// Operator to get x or y based on which has the larger MathF.Abs + private readonly struct MaxMagnitudePropagateNaNOperator : IBinaryOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public float Invoke(float x, float y) + { + float xMag = MathF.Abs(x), yMag = MathF.Abs(y); + return xMag > yMag || float.IsNaN(xMag) || (xMag == yMag && !IsNegative(x)) ? x : y; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Vector Invoke(Vector x, Vector y) + { + Vector xMag = Vector.Abs(x), yMag = Vector.Abs(y); + return + Vector.ConditionalSelect(Vector.Equals(x, x), + Vector.ConditionalSelect(Vector.Equals(y, y), + Vector.ConditionalSelect(Vector.Equals(xMag, yMag), + Vector.ConditionalSelect(IsNegative(x), y, x), + Vector.ConditionalSelect(Vector.GreaterThan(xMag, yMag), x, y)), + y), + x); + } + } + + /// MathF.Min(x, y) (but NaNs may not be propagated) + private readonly struct MinOperator : IBinaryOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public float Invoke(float x, float y) => + x == y ? + (IsNegative(y) ? y : x) : + (y < x ? y : x); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Vector Invoke(Vector x, Vector y) => + Vector.ConditionalSelect(Vector.Equals(x, y), + Vector.ConditionalSelect(IsNegative(y), y, x), + Vector.Min(x, y)); + } + + /// MathF.Min(x, y) + private readonly struct MinPropagateNaNOperator : IBinaryOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public float Invoke(float x, float y) => MathF.Min(x, y); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Vector Invoke(Vector x, Vector y) => + Vector.ConditionalSelect(Vector.Equals(x, x), + Vector.ConditionalSelect(Vector.Equals(y, y), + Vector.ConditionalSelect(Vector.Equals(x, y), + Vector.ConditionalSelect(IsNegative(x), x, y), + Vector.Min(x, y)), + y), + x); + } + + /// Operator to get x or y based on which has the smaller MathF.Abs (but NaNs may not be propagated) + private readonly struct MinMagnitudeOperator : IBinaryOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public float Invoke(float x, float y) + { + float xMag = MathF.Abs(x), yMag = MathF.Abs(y); + return + yMag == xMag ? + (IsNegative(y) ? y : x) : + (yMag < xMag ? y : x); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Vector Invoke(Vector x, Vector y) + { + Vector xMag = Vector.Abs(x), yMag = Vector.Abs(y); + return + Vector.ConditionalSelect(Vector.Equals(yMag, xMag), + Vector.ConditionalSelect(IsNegative(y), y, x), + Vector.ConditionalSelect(Vector.LessThan(yMag, xMag), y, x)); + } + } + + /// Operator to get x or y based on which has the smaller MathF.Abs + private readonly struct MinMagnitudePropagateNaNOperator : IBinaryOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public float Invoke(float x, float y) + { + float xMag = MathF.Abs(x), yMag = MathF.Abs(y); + return xMag < yMag || float.IsNaN(xMag) || (xMag == yMag && IsNegative(x)) ? x : y; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Vector Invoke(Vector x, Vector y) + { + Vector xMag = Vector.Abs(x), yMag = Vector.Abs(y); + + return + Vector.ConditionalSelect(Vector.Equals(x, x), + Vector.ConditionalSelect(Vector.Equals(y, y), + Vector.ConditionalSelect(Vector.Equals(yMag, xMag), + Vector.ConditionalSelect(IsNegative(x), x, y), + Vector.ConditionalSelect(Vector.LessThan(xMag, yMag), x, y)), + y), + x); + } + } + + /// -x + private readonly struct NegateOperator : IUnaryOperator + { + public bool CanVectorize => true; + public float Invoke(float x) => -x; + public Vector Invoke(Vector x) => -x; + } + + /// (x + y) * z + private readonly struct AddMultiplyOperator : ITernaryOperator + { + public float Invoke(float x, float y, float z) => (x + y) * z; + public Vector Invoke(Vector x, Vector y, Vector z) => (x + y) * z; + } + + /// (x * y) + z + private readonly struct MultiplyAddOperator : ITernaryOperator + { + public float Invoke(float x, float y, float z) => (x * y) + z; + public Vector Invoke(Vector x, Vector y, Vector z) => (x * y) + z; + } + + /// x + private readonly struct IdentityOperator : IUnaryOperator + { + public bool CanVectorize => true; + public float Invoke(float x) => x; + public Vector Invoke(Vector x) => x; + } + + /// x * x + private readonly struct SquaredOperator : IUnaryOperator + { + public bool CanVectorize => true; + public float Invoke(float x) => x * x; + public Vector Invoke(Vector x) => x * x; + } + + /// MathF.Abs(x) + private readonly struct AbsoluteOperator : IUnaryOperator + { + public bool CanVectorize => true; + public float Invoke(float x) => MathF.Abs(x); + public Vector Invoke(Vector x) => Vector.Abs(x); + } + + /// MathF.Exp(x) + private readonly struct ExpOperator : IUnaryOperator + { + public bool CanVectorize => false; + public float Invoke(float x) => MathF.Exp(x); + public Vector Invoke(Vector x) => + // requires ShiftLeft (.NET 7+) + throw new NotImplementedException(); + } + + /// MathF.Sinh(x) + private readonly struct SinhOperator : IUnaryOperator + { + public bool CanVectorize => false; + public float Invoke(float x) => MathF.Sinh(x); + public Vector Invoke(Vector x) => + // requires ShiftLeft (.NET 7+) + throw new NotImplementedException(); + } + + /// MathF.Cosh(x) + private readonly struct CoshOperator : IUnaryOperator + { + public bool CanVectorize => false; + public float Invoke(float x) => MathF.Cosh(x); + public Vector Invoke(Vector x) => + // requires ShiftLeft (.NET 7+) + throw new NotImplementedException(); + } + + /// MathF.Tanh(x) + private readonly struct TanhOperator : IUnaryOperator + { + public bool CanVectorize => false; + public float Invoke(float x) => MathF.Tanh(x); + public Vector Invoke(Vector x) => + // requires ShiftLeft (.NET 7+) + throw new NotImplementedException(); + } + + /// MathF.Log(x) + private readonly struct LogOperator : IUnaryOperator + { + public bool CanVectorize => false; + public float Invoke(float x) => MathF.Log(x); + public Vector Invoke(Vector x) => + // requires ShiftRightArithmetic (.NET 7+) + throw new NotImplementedException(); + } + + /// MathF.Log2(x) + private readonly struct Log2Operator : IUnaryOperator + { + public bool CanVectorize => false; + public float Invoke(float x) => Log2(x); + public Vector Invoke(Vector x) => + // requires ShiftRightArithmetic (.NET 7+) + throw new NotImplementedException(); + } + + /// 1f / (1f + MathF.Exp(-x)) + private readonly struct SigmoidOperator : IUnaryOperator + { + public bool CanVectorize => false; + public float Invoke(float x) => 1.0f / (1.0f + MathF.Exp(-x)); + public Vector Invoke(Vector x) => + // requires ShiftRightArithmetic (.NET 7+) + throw new NotImplementedException(); + } + + /// Operator that takes one input value and returns a single value. + private interface IUnaryOperator + { + bool CanVectorize { get; } + float Invoke(float x); + Vector Invoke(Vector x); + } + + /// Operator that takes two input values and returns a single value. + private interface IBinaryOperator + { + float Invoke(float x, float y); + Vector Invoke(Vector x, Vector y); + } + + /// that specializes horizontal aggregation of all elements in a vector. + private interface IAggregationOperator : IBinaryOperator + { + float IdentityValue { get; } + } + + /// Operator that takes three input values and returns a single value. + private interface ITernaryOperator + { + float Invoke(float x, float y, float z); + Vector Invoke(Vector x, Vector y, Vector z); + } + } +} diff --git a/src/libraries/System.Numerics.Tensors/src/System/ThrowHelper.cs b/src/libraries/System.Numerics.Tensors/src/System/ThrowHelper.cs new file mode 100644 index 00000000000000..272991aed44ab8 --- /dev/null +++ b/src/libraries/System.Numerics.Tensors/src/System/ThrowHelper.cs @@ -0,0 +1,26 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics.CodeAnalysis; + +namespace System +{ + internal static class ThrowHelper + { + [DoesNotReturn] + public static void ThrowArgument_DestinationTooShort() => + throw new ArgumentException(SR.Argument_DestinationTooShort, "destination"); + + [DoesNotReturn] + public static void ThrowArgument_SpansMustHaveSameLength() => + throw new ArgumentException(SR.Argument_SpansMustHaveSameLength); + + [DoesNotReturn] + public static void ThrowArgument_SpansMustBeNonEmpty() => + throw new ArgumentException(SR.Argument_SpansMustBeNonEmpty); + + [DoesNotReturn] + public static void ThrowArgument_InputAndDestinationSpanMustNotOverlap() => + throw new ArgumentException(SR.Argument_InputAndDestinationSpanMustNotOverlap, "destination"); + } +} diff --git a/src/libraries/System.Numerics.Tensors/tests/NativeMemory.cs b/src/libraries/System.Numerics.Tensors/tests/NativeMemory.cs deleted file mode 100644 index b5c9ef8c2c2c92..00000000000000 --- a/src/libraries/System.Numerics.Tensors/tests/NativeMemory.cs +++ /dev/null @@ -1,108 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.Buffers; -using System.Runtime.InteropServices; -using System.Runtime.CompilerServices; -using System.Threading; - -namespace System.Numerics.Tensors.Tests -{ - public class NativeMemory : MemoryManager - { - private bool disposed = false; - private int refCount = 0; - private IntPtr memory; - private int length; - - public NativeMemory(IntPtr memory, int length) - { - this.memory = memory; - this.length = length; - } - - public unsafe NativeMemory(void* memory, int length) - { - this.memory = (IntPtr)memory; - this.length = length; - } - - ~NativeMemory() - { - Dispose(false); - } - - public static NativeMemory Allocate(int length) - { - // typically this would call into a native method appropriate for the platform - // or the constructors above would be used to wrap the native pointer - IntPtr memory = Marshal.AllocHGlobal(Marshal.SizeOf() * length); - return new NativeMemory(memory, length); - } - - public bool IsDisposed => disposed; - - public unsafe override Span GetSpan() => new Span((void*)memory, length); - - protected bool IsRetained => refCount > 0; - - public override MemoryHandle Pin(int elementIndex = 0) - { - unsafe - { - Retain(); - if ((uint)elementIndex > length) throw new ArgumentOutOfRangeException(nameof(elementIndex)); - void* pointer = Unsafe.Add((void*)memory, elementIndex); - return new MemoryHandle(pointer, default, this); - } - } - - public bool Release() - { - int newRefCount = Interlocked.Decrement(ref refCount); - - if (newRefCount < 0) - { - throw new InvalidOperationException("Unmatched Release/Retain"); - } - - return newRefCount != 0; - } - - public void Retain() - { - if (disposed) - { - throw new ObjectDisposedException(nameof(NativeMemory)); - } - - Interlocked.Increment(ref refCount); - } - - protected override void Dispose(bool disposing) - { - if (disposed) - { - return; - } - - // typically this would call into a native method appropriate for the platform - Marshal.FreeHGlobal(memory); - memory = IntPtr.Zero; - - disposed = true; - } - - protected override bool TryGetArray(out ArraySegment arraySegment) - { - // cannot expose managed array - arraySegment = default; - return false; - } - - public override void Unpin() - { - Release(); - } - } -} diff --git a/src/libraries/System.Numerics.Tensors/tests/System.Numerics.Tensors.Tests.csproj b/src/libraries/System.Numerics.Tensors/tests/System.Numerics.Tensors.Tests.csproj index 042e4791b6a3d8..be4a103d7256ce 100644 --- a/src/libraries/System.Numerics.Tensors/tests/System.Numerics.Tensors.Tests.csproj +++ b/src/libraries/System.Numerics.Tensors/tests/System.Numerics.Tensors.Tests.csproj @@ -1,41 +1,21 @@ + - true $(NetCoreAppCurrent);$(NetFrameworkMinimum) + true + - - - True - True - TensorArithmetic.tt - - - - True - True - TensorOperations.tt - - - + - - - TextTemplatingFileGenerator - TensorArithmetic.cs - - - TextTemplatingFileGenerator - TensorOperations.cs - - + + + + - - - - + \ No newline at end of file diff --git a/src/libraries/System.Numerics.Tensors/tests/TensorArithmetic.cs b/src/libraries/System.Numerics.Tensors/tests/TensorArithmetic.cs deleted file mode 100644 index 6dae2ec6854252..00000000000000 --- a/src/libraries/System.Numerics.Tensors/tests/TensorArithmetic.cs +++ /dev/null @@ -1,16165 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -namespace System.Numerics.Tensors -{ - internal interface ITensorArithmetic - { - T One { get; } - T Zero { get; } - void Add(Tensor left, Tensor right, Tensor result); - void Add(Tensor tensor, T scalar, Tensor result); - void And(Tensor left, Tensor right, Tensor result); - void And(Tensor tensor, T scalar, Tensor result); - void Contract(Tensor left, Tensor right, int[] leftAxes, int[] rightAxes, Tensor result); - void Decrement(Tensor tensor, Tensor result); - void Divide(Tensor left, Tensor right, Tensor result); - void Divide(Tensor tensor, T scalar, Tensor result); - void Equals(Tensor left, Tensor right, Tensor result); - void GreaterThan(Tensor left, Tensor right, Tensor result); - void GreaterThanOrEqual(Tensor left, Tensor right, Tensor result); - void Increment(Tensor tensor, Tensor result); - void LeftShift(Tensor tensor, int value, Tensor result); - void LessThan(Tensor left, Tensor right, Tensor result); - void LessThanOrEqual(Tensor left, Tensor right, Tensor result); - void Modulo(Tensor left, Tensor right, Tensor result); - void Modulo(Tensor tensor, T scalar, Tensor result); - void Multiply(Tensor left, Tensor right, Tensor result); - void Multiply(Tensor tensor, T scalar, Tensor result); - void NotEquals(Tensor left, Tensor right, Tensor result); - void Or(Tensor left, Tensor right, Tensor result); - void Or(Tensor tensor, T scalar, Tensor result); - void RightShift(Tensor tensor, int value, Tensor result); - void Subtract(Tensor left, Tensor right, Tensor result); - void Subtract(Tensor tensor, T scalar, Tensor result); - void UnaryMinus(Tensor tensor, Tensor result); - void UnaryPlus(Tensor tensor, Tensor result); - void Xor(Tensor left, Tensor right, Tensor result); - void Xor(Tensor tensor, T scalar, Tensor result); - } - - internal static class TensorArithmetic - { - public static ITensorArithmetic Instance => TensorArithmetic.GetArithmetic(); - } - - internal static class TensorArithmetic - { - public static ITensorArithmetic GetArithmetic() - { - if (typeof(T) == typeof(bool)) - { - return (ITensorArithmetic)new BoolArithmetic(); - } - else if (typeof(T) == typeof(byte)) - { - return (ITensorArithmetic)new ByteArithmetic(); - } - else if (typeof(T) == typeof(char)) - { - return (ITensorArithmetic)new CharArithmetic(); - } - else if (typeof(T) == typeof(decimal)) - { - return (ITensorArithmetic)new DecimalArithmetic(); - } - else if (typeof(T) == typeof(double)) - { - return (ITensorArithmetic)new DoubleArithmetic(); - } - else if (typeof(T) == typeof(float)) - { - return (ITensorArithmetic)new FloatArithmetic(); - } - else if (typeof(T) == typeof(int)) - { - return (ITensorArithmetic)new IntArithmetic(); - } - else if (typeof(T) == typeof(long)) - { - return (ITensorArithmetic)new LongArithmetic(); - } - else if (typeof(T) == typeof(sbyte)) - { - return (ITensorArithmetic)new SByteArithmetic(); - } - else if (typeof(T) == typeof(short)) - { - return (ITensorArithmetic)new ShortArithmetic(); - } - else if (typeof(T) == typeof(uint)) - { - return (ITensorArithmetic)new UIntArithmetic(); - } - else if (typeof(T) == typeof(ulong)) - { - return (ITensorArithmetic)new ULongArithmetic(); - } - else if (typeof(T) == typeof(ushort)) - { - return (ITensorArithmetic)new UShortArithmetic(); - } - return null; - } - } - - internal class BoolArithmetic : ITensorArithmetic - { - public bool One => true; - public bool Zero => false; - - public void Add(Tensor left, Tensor right, Tensor result) - { - throw new NotSupportedException(); - } - public void Add(Tensor tensor, bool scalar, Tensor result) - { - throw new NotSupportedException(); - } - public void And(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (bool)(left[indices] & right[indices]); - } - - } - public void And(Tensor tensor, bool scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (bool)(tensor[indices] & scalar); - } - - } - public void Contract(Tensor left, Tensor right, int[] leftAxes, int[] rightAxes, Tensor result) - { - throw new NotSupportedException(); - } - public void Decrement(Tensor tensor, Tensor result) - { - throw new NotSupportedException(); - } - public void Divide(Tensor left, Tensor right, Tensor result) - { - throw new NotSupportedException(); - } - public void Divide(Tensor tensor, bool scalar, Tensor result) - { - throw new NotSupportedException(); - } - public void Equals(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] == right[indices]; - } - - } - public void GreaterThan(Tensor left, Tensor right, Tensor result) - { - throw new NotSupportedException(); - } - public void GreaterThanOrEqual(Tensor left, Tensor right, Tensor result) - { - throw new NotSupportedException(); - } - public void Increment(Tensor tensor, Tensor result) - { - throw new NotSupportedException(); - } - public void LeftShift(Tensor tensor, int value, Tensor result) - { - throw new NotSupportedException(); - } - public void LessThan(Tensor left, Tensor right, Tensor result) - { - throw new NotSupportedException(); - } - public void LessThanOrEqual(Tensor left, Tensor right, Tensor result) - { - throw new NotSupportedException(); - } - public void Modulo(Tensor left, Tensor right, Tensor result) - { - throw new NotSupportedException(); - } - public void Modulo(Tensor tensor, bool scalar, Tensor result) - { - throw new NotSupportedException(); - } - public void Multiply(Tensor left, Tensor right, Tensor result) - { - throw new NotSupportedException(); - } - public void Multiply(Tensor tensor, bool scalar, Tensor result) - { - throw new NotSupportedException(); - } - public void NotEquals(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] != right[indices]; - } - - } - public void Or(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (bool)(left[indices] | right[indices]); - } - - } - public void Or(Tensor tensor, bool scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (bool)(tensor[indices] | scalar); - } - - } - public void RightShift(Tensor tensor, int value, Tensor result) - { - throw new NotSupportedException(); - } - public void Subtract(Tensor left, Tensor right, Tensor result) - { - throw new NotSupportedException(); - } - public void Subtract(Tensor tensor, bool scalar, Tensor result) - { - throw new NotSupportedException(); - } - public void UnaryMinus(Tensor tensor, Tensor result) - { - throw new NotSupportedException(); - } - public void UnaryPlus(Tensor tensor, Tensor result) - { - throw new NotSupportedException(); - } - public void Xor(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (bool)(left[indices] ^ right[indices]); - } - - } - public void Xor(Tensor tensor, bool scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (bool)(tensor[indices] ^ scalar); - } - - } - - public void Add(DenseTensor left, DenseTensor right, DenseTensor result) - { - throw new NotSupportedException(); - } - public void Add(DenseTensor tensor, bool scalar, DenseTensor result) - { - throw new NotSupportedException(); - } - public void And(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (bool)(leftSpan[i] & rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (bool)(leftSpan[op1Index] & rightSpan[op2Index]); - - } - } - } - public void And(DenseTensor tensor, bool scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (bool)(tensorSpan[i] & scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (bool)(tensorSpan[op1Index] & scalar); - - } - } - } - public void Contract(DenseTensor left, DenseTensor right, int[] leftAxes, int[] rightAxes, DenseTensor result) - { - throw new NotSupportedException(); - } - public void Decrement(DenseTensor tensor, DenseTensor result) - { - throw new NotSupportedException(); - } - public void Divide(DenseTensor left, DenseTensor right, DenseTensor result) - { - throw new NotSupportedException(); - } - public void Divide(DenseTensor tensor, bool scalar, DenseTensor result) - { - throw new NotSupportedException(); - } - public void Equals(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] == rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] == rightSpan[op2Index]; - - } - } - } - public void GreaterThan(DenseTensor left, DenseTensor right, DenseTensor result) - { - throw new NotSupportedException(); - } - public void GreaterThanOrEqual(DenseTensor left, DenseTensor right, DenseTensor result) - { - throw new NotSupportedException(); - } - public void Increment(DenseTensor tensor, DenseTensor result) - { - throw new NotSupportedException(); - } - public void LeftShift(DenseTensor tensor, int value, DenseTensor result) - { - throw new NotSupportedException(); - } - public void LessThan(DenseTensor left, DenseTensor right, DenseTensor result) - { - throw new NotSupportedException(); - } - public void LessThanOrEqual(DenseTensor left, DenseTensor right, DenseTensor result) - { - throw new NotSupportedException(); - } - public void Modulo(DenseTensor left, DenseTensor right, DenseTensor result) - { - throw new NotSupportedException(); - } - public void Modulo(DenseTensor tensor, bool scalar, DenseTensor result) - { - throw new NotSupportedException(); - } - public void Multiply(DenseTensor left, DenseTensor right, DenseTensor result) - { - throw new NotSupportedException(); - } - public void Multiply(DenseTensor tensor, bool scalar, DenseTensor result) - { - throw new NotSupportedException(); - } - public void NotEquals(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] != rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] != rightSpan[op2Index]; - - } - } - } - public void Or(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (bool)(leftSpan[i] | rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (bool)(leftSpan[op1Index] | rightSpan[op2Index]); - - } - } - } - public void Or(DenseTensor tensor, bool scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (bool)(tensorSpan[i] | scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (bool)(tensorSpan[op1Index] | scalar); - - } - } - } - public void RightShift(DenseTensor tensor, int value, DenseTensor result) - { - throw new NotSupportedException(); - } - public void Subtract(DenseTensor left, DenseTensor right, DenseTensor result) - { - throw new NotSupportedException(); - } - public void Subtract(DenseTensor tensor, bool scalar, DenseTensor result) - { - throw new NotSupportedException(); - } - public void UnaryMinus(DenseTensor tensor, DenseTensor result) - { - throw new NotSupportedException(); - } - public void UnaryPlus(DenseTensor tensor, DenseTensor result) - { - throw new NotSupportedException(); - } - public void Xor(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (bool)(leftSpan[i] ^ rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (bool)(leftSpan[op1Index] ^ rightSpan[op2Index]); - - } - } - } - public void Xor(DenseTensor tensor, bool scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (bool)(tensorSpan[i] ^ scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (bool)(tensorSpan[op1Index] ^ scalar); - - } - } - } - } - internal class ByteArithmetic : ITensorArithmetic - { - public byte One => 1; - public byte Zero => 0; - - public void Add(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (byte)(left[indices] + right[indices]); - } - - } - public void Add(Tensor tensor, byte scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (byte)(tensor[indices] + scalar); - } - - } - public void And(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (byte)(left[indices] & right[indices]); - } - - } - public void And(Tensor tensor, byte scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (byte)(tensor[indices] & scalar); - } - - } - public void Contract(Tensor left, Tensor right, int[] leftAxes, int[] rightAxes, Tensor result) - { - var leftIndices = new int[left.Rank]; - var rightIndices = new int[right.Rank]; - var resultIndices = new int[result.Rank]; - - var summingDimensions = new int[leftAxes.Length]; - for (int i = 0; i < leftAxes.Length; i++) - { - summingDimensions[i] = left.dimensions[leftAxes[i]]; - } - - var summingStrides = ArrayUtilities.GetStrides(summingDimensions); - int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions); - - var resultStrides = result.strides; - - // translates from result index to left non-summing dimensions' index portion - // since left non-summing dimensions are given precedence in result, the end is zero-padded - int[] leftNonSummingStrides = new int[result.Rank]; - - // translates from summing index to left summing dimensions' index portion - int[] leftSummingStrides = new int[leftAxes.Length]; - ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0); - - // translates from result index to right non-summing dimensions' index portion - int[] rightNonSummingStrides = new int[result.Rank]; - // right non-summing dimensions appear after left non-summing dimensions. - int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length); - - // translates from summing index to right summing dimensions' index portion - int[] rightSummingStrides = new int[rightAxes.Length]; - ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0); - - for (int resultIndex = 0; resultIndex < result.Length; resultIndex++) - { - byte sum = (byte)0; - - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, resultIndex, resultIndices); - - int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides); - int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides); - - for (int summingIndex = 0; summingIndex < summingLength; summingIndex++) - { - int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides); - int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides); - - int leftIndex = leftIndexNonSumming + leftIndexSumming; - int rightIndex = rightIndexNonSumming + rightIndexSumming; - - // todo, make this more efficient - ArrayUtilities.GetIndices(left.strides, left.IsReversedStride, leftIndex, leftIndices); - ArrayUtilities.GetIndices(right.strides, right.IsReversedStride, rightIndex, rightIndices); - - sum += (byte)(left[leftIndices] * right[rightIndices]); - } - - result[resultIndices] = sum; - } - } - public void Decrement(Tensor tensor, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices]--; - } - - } - public void Divide(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (byte)(left[indices] / right[indices]); - } - - } - public void Divide(Tensor tensor, byte scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (byte)(tensor[indices] / scalar); - } - - } - public void Equals(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] == right[indices]; - } - - } - public void GreaterThan(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] > right[indices]; - } - - } - public void GreaterThanOrEqual(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] >= right[indices]; - } - - } - public void Increment(Tensor tensor, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices]++; - } - - } - public void LeftShift(Tensor tensor, int value, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (byte)(tensor[indices] << value); - } - - } - public void LessThan(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] < right[indices]; - } - - } - public void LessThanOrEqual(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] <= right[indices]; - } - - } - public void Modulo(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (byte)(left[indices] % right[indices]); - } - - } - public void Modulo(Tensor tensor, byte scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (byte)(tensor[indices] % scalar); - } - - } - public void Multiply(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (byte)(left[indices] * right[indices]); - } - - } - public void Multiply(Tensor tensor, byte scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (byte)(tensor[indices] * scalar); - } - - } - public void NotEquals(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] != right[indices]; - } - - } - public void Or(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (byte)(left[indices] | right[indices]); - } - - } - public void Or(Tensor tensor, byte scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (byte)(tensor[indices] | scalar); - } - - } - public void RightShift(Tensor tensor, int value, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (byte)(tensor[indices] >> value); - } - - } - public void Subtract(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (byte)(left[indices] - right[indices]); - } - - } - public void Subtract(Tensor tensor, byte scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (byte)(tensor[indices] - scalar); - } - - } - public void UnaryMinus(Tensor tensor, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (byte)-tensor[indices]; - } - - } - public void UnaryPlus(Tensor tensor, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (byte)+tensor[indices]; - } - - } - public void Xor(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (byte)(left[indices] ^ right[indices]); - } - - } - public void Xor(Tensor tensor, byte scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (byte)(tensor[indices] ^ scalar); - } - - } - - public void Add(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (byte)(leftSpan[i] + rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (byte)(leftSpan[op1Index] + rightSpan[op2Index]); - - } - } - } - public void Add(DenseTensor tensor, byte scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (byte)(tensorSpan[i] + scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (byte)(tensorSpan[op1Index] + scalar); - - } - } - } - public void And(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (byte)(leftSpan[i] & rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (byte)(leftSpan[op1Index] & rightSpan[op2Index]); - - } - } - } - public void And(DenseTensor tensor, byte scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (byte)(tensorSpan[i] & scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (byte)(tensorSpan[op1Index] & scalar); - - } - } - } - public void Contract(DenseTensor left, DenseTensor right, int[] leftAxes, int[] rightAxes, DenseTensor result) - { - var summingDimensions = new int[leftAxes.Length]; - for (int i = 0; i < leftAxes.Length; i++) - { - summingDimensions[i] = left.dimensions[leftAxes[i]]; - } - - var summingStrides = ArrayUtilities.GetStrides(summingDimensions); - int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions); - - var resultStrides = result.strides; - - // translates from result index to left non-summing dimensions' index portion - // since left non-summing dimensions are given precedence in result, the end is zero-padded - int[] leftNonSummingStrides = new int[result.Rank]; - - // translates from summing index to left summing dimensions' index portion - int[] leftSummingStrides = new int[leftAxes.Length]; - ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0); - - // translates from result index to right non-summing dimensions' index portion - int[] rightNonSummingStrides = new int[result.Rank]; - // right non-summing dimensions appear after left non-summing dimensions. - int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length); - - // translates from summing index to right summing dimensions' index portion - int[] rightSummingStrides = new int[rightAxes.Length]; - ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0); - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - - for (int resultIndex = 0; resultIndex < resultSpan.Length; resultIndex++) - { - byte sum = (byte)0; - - int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides); - int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides); - - for (int summingIndex = 0; summingIndex < summingLength; summingIndex++) - { - int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides); - int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides); - - int leftIndex = leftIndexNonSumming + leftIndexSumming; - int rightIndex = rightIndexNonSumming + rightIndexSumming; - - sum += (byte)(leftSpan[leftIndex] * rightSpan[rightIndex]); - } - - resultSpan[resultIndex] = sum; - } - } - public void Decrement(DenseTensor tensor, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i]--; - } - } - public void Divide(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (byte)(leftSpan[i] / rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (byte)(leftSpan[op1Index] / rightSpan[op2Index]); - - } - } - } - public void Divide(DenseTensor tensor, byte scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (byte)(tensorSpan[i] / scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (byte)(tensorSpan[op1Index] / scalar); - - } - } - } - public void Equals(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] == rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] == rightSpan[op2Index]; - - } - } - } - public void GreaterThan(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] > rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] > rightSpan[op2Index]; - - } - } - } - public void GreaterThanOrEqual(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] >= rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] >= rightSpan[op2Index]; - - } - } - } - public void Increment(DenseTensor tensor, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i]++; - } - } - public void LeftShift(DenseTensor tensor, int value, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (byte)(tensorSpan[i] << value); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (byte)(tensorSpan[op1Index] << value); - - } - } - } - public void LessThan(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] < rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] < rightSpan[op2Index]; - - } - } - } - public void LessThanOrEqual(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] <= rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] <= rightSpan[op2Index]; - - } - } - } - public void Modulo(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (byte)(leftSpan[i] % rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (byte)(leftSpan[op1Index] % rightSpan[op2Index]); - - } - } - } - public void Modulo(DenseTensor tensor, byte scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (byte)(tensorSpan[i] % scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (byte)(tensorSpan[op1Index] % scalar); - - } - } - } - public void Multiply(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (byte)(leftSpan[i] * rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (byte)(leftSpan[op1Index] * rightSpan[op2Index]); - - } - } - } - public void Multiply(DenseTensor tensor, byte scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (byte)(tensorSpan[i] * scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (byte)(tensorSpan[op1Index] * scalar); - - } - } - } - public void NotEquals(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] != rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] != rightSpan[op2Index]; - - } - } - } - public void Or(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (byte)(leftSpan[i] | rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (byte)(leftSpan[op1Index] | rightSpan[op2Index]); - - } - } - } - public void Or(DenseTensor tensor, byte scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (byte)(tensorSpan[i] | scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (byte)(tensorSpan[op1Index] | scalar); - - } - } - } - public void RightShift(DenseTensor tensor, int value, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (byte)(tensorSpan[i] >> value); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (byte)(tensorSpan[op1Index] >> value); - - } - } - } - public void Subtract(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (byte)(leftSpan[i] - rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (byte)(leftSpan[op1Index] - rightSpan[op2Index]); - - } - } - } - public void Subtract(DenseTensor tensor, byte scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (byte)(tensorSpan[i] - scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (byte)(tensorSpan[op1Index] - scalar); - - } - } - } - public void UnaryMinus(DenseTensor tensor, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (byte)-tensorSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (byte)-tensorSpan[op1Index]; - - } - } - } - public void UnaryPlus(DenseTensor tensor, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (byte)+tensorSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (byte)+tensorSpan[op1Index]; - - } - } - } - public void Xor(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (byte)(leftSpan[i] ^ rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (byte)(leftSpan[op1Index] ^ rightSpan[op2Index]); - - } - } - } - public void Xor(DenseTensor tensor, byte scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (byte)(tensorSpan[i] ^ scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (byte)(tensorSpan[op1Index] ^ scalar); - - } - } - } - } - internal class CharArithmetic : ITensorArithmetic - { - public char One => (char)1; - public char Zero => (char)0; - - public void Add(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (char)(left[indices] + right[indices]); - } - - } - public void Add(Tensor tensor, char scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (char)(tensor[indices] + scalar); - } - - } - public void And(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (char)(left[indices] & right[indices]); - } - - } - public void And(Tensor tensor, char scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (char)(tensor[indices] & scalar); - } - - } - public void Contract(Tensor left, Tensor right, int[] leftAxes, int[] rightAxes, Tensor result) - { - var leftIndices = new int[left.Rank]; - var rightIndices = new int[right.Rank]; - var resultIndices = new int[result.Rank]; - - var summingDimensions = new int[leftAxes.Length]; - for (int i = 0; i < leftAxes.Length; i++) - { - summingDimensions[i] = left.dimensions[leftAxes[i]]; - } - - var summingStrides = ArrayUtilities.GetStrides(summingDimensions); - int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions); - - var resultStrides = result.strides; - - // translates from result index to left non-summing dimensions' index portion - // since left non-summing dimensions are given precedence in result, the end is zero-padded - int[] leftNonSummingStrides = new int[result.Rank]; - - // translates from summing index to left summing dimensions' index portion - int[] leftSummingStrides = new int[leftAxes.Length]; - ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0); - - // translates from result index to right non-summing dimensions' index portion - int[] rightNonSummingStrides = new int[result.Rank]; - // right non-summing dimensions appear after left non-summing dimensions. - int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length); - - // translates from summing index to right summing dimensions' index portion - int[] rightSummingStrides = new int[rightAxes.Length]; - ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0); - - for (int resultIndex = 0; resultIndex < result.Length; resultIndex++) - { - char sum = (char)0; - - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, resultIndex, resultIndices); - - int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides); - int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides); - - for (int summingIndex = 0; summingIndex < summingLength; summingIndex++) - { - int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides); - int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides); - - int leftIndex = leftIndexNonSumming + leftIndexSumming; - int rightIndex = rightIndexNonSumming + rightIndexSumming; - - // todo, make this more efficient - ArrayUtilities.GetIndices(left.strides, left.IsReversedStride, leftIndex, leftIndices); - ArrayUtilities.GetIndices(right.strides, right.IsReversedStride, rightIndex, rightIndices); - - sum += (char)(left[leftIndices] * right[rightIndices]); - } - - result[resultIndices] = sum; - } - } - public void Decrement(Tensor tensor, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices]--; - } - - } - public void Divide(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (char)(left[indices] / right[indices]); - } - - } - public void Divide(Tensor tensor, char scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (char)(tensor[indices] / scalar); - } - - } - public void Equals(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] == right[indices]; - } - - } - public void GreaterThan(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] > right[indices]; - } - - } - public void GreaterThanOrEqual(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] >= right[indices]; - } - - } - public void Increment(Tensor tensor, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices]++; - } - - } - public void LeftShift(Tensor tensor, int value, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (char)(tensor[indices] << value); - } - - } - public void LessThan(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] < right[indices]; - } - - } - public void LessThanOrEqual(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] <= right[indices]; - } - - } - public void Modulo(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (char)(left[indices] % right[indices]); - } - - } - public void Modulo(Tensor tensor, char scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (char)(tensor[indices] % scalar); - } - - } - public void Multiply(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (char)(left[indices] * right[indices]); - } - - } - public void Multiply(Tensor tensor, char scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (char)(tensor[indices] * scalar); - } - - } - public void NotEquals(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] != right[indices]; - } - - } - public void Or(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (char)(left[indices] | right[indices]); - } - - } - public void Or(Tensor tensor, char scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (char)(tensor[indices] | scalar); - } - - } - public void RightShift(Tensor tensor, int value, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (char)(tensor[indices] >> value); - } - - } - public void Subtract(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (char)(left[indices] - right[indices]); - } - - } - public void Subtract(Tensor tensor, char scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (char)(tensor[indices] - scalar); - } - - } - public void UnaryMinus(Tensor tensor, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (char)-tensor[indices]; - } - - } - public void UnaryPlus(Tensor tensor, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (char)+tensor[indices]; - } - - } - public void Xor(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (char)(left[indices] ^ right[indices]); - } - - } - public void Xor(Tensor tensor, char scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (char)(tensor[indices] ^ scalar); - } - - } - - public void Add(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (char)(leftSpan[i] + rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (char)(leftSpan[op1Index] + rightSpan[op2Index]); - - } - } - } - public void Add(DenseTensor tensor, char scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (char)(tensorSpan[i] + scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (char)(tensorSpan[op1Index] + scalar); - - } - } - } - public void And(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (char)(leftSpan[i] & rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (char)(leftSpan[op1Index] & rightSpan[op2Index]); - - } - } - } - public void And(DenseTensor tensor, char scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (char)(tensorSpan[i] & scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (char)(tensorSpan[op1Index] & scalar); - - } - } - } - public void Contract(DenseTensor left, DenseTensor right, int[] leftAxes, int[] rightAxes, DenseTensor result) - { - var summingDimensions = new int[leftAxes.Length]; - for (int i = 0; i < leftAxes.Length; i++) - { - summingDimensions[i] = left.dimensions[leftAxes[i]]; - } - - var summingStrides = ArrayUtilities.GetStrides(summingDimensions); - int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions); - - var resultStrides = result.strides; - - // translates from result index to left non-summing dimensions' index portion - // since left non-summing dimensions are given precedence in result, the end is zero-padded - int[] leftNonSummingStrides = new int[result.Rank]; - - // translates from summing index to left summing dimensions' index portion - int[] leftSummingStrides = new int[leftAxes.Length]; - ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0); - - // translates from result index to right non-summing dimensions' index portion - int[] rightNonSummingStrides = new int[result.Rank]; - // right non-summing dimensions appear after left non-summing dimensions. - int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length); - - // translates from summing index to right summing dimensions' index portion - int[] rightSummingStrides = new int[rightAxes.Length]; - ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0); - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - - for (int resultIndex = 0; resultIndex < resultSpan.Length; resultIndex++) - { - char sum = (char)0; - - int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides); - int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides); - - for (int summingIndex = 0; summingIndex < summingLength; summingIndex++) - { - int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides); - int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides); - - int leftIndex = leftIndexNonSumming + leftIndexSumming; - int rightIndex = rightIndexNonSumming + rightIndexSumming; - - sum += (char)(leftSpan[leftIndex] * rightSpan[rightIndex]); - } - - resultSpan[resultIndex] = sum; - } - } - public void Decrement(DenseTensor tensor, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i]--; - } - } - public void Divide(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (char)(leftSpan[i] / rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (char)(leftSpan[op1Index] / rightSpan[op2Index]); - - } - } - } - public void Divide(DenseTensor tensor, char scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (char)(tensorSpan[i] / scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (char)(tensorSpan[op1Index] / scalar); - - } - } - } - public void Equals(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] == rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] == rightSpan[op2Index]; - - } - } - } - public void GreaterThan(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] > rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] > rightSpan[op2Index]; - - } - } - } - public void GreaterThanOrEqual(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] >= rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] >= rightSpan[op2Index]; - - } - } - } - public void Increment(DenseTensor tensor, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i]++; - } - } - public void LeftShift(DenseTensor tensor, int value, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (char)(tensorSpan[i] << value); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (char)(tensorSpan[op1Index] << value); - - } - } - } - public void LessThan(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] < rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] < rightSpan[op2Index]; - - } - } - } - public void LessThanOrEqual(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] <= rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] <= rightSpan[op2Index]; - - } - } - } - public void Modulo(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (char)(leftSpan[i] % rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (char)(leftSpan[op1Index] % rightSpan[op2Index]); - - } - } - } - public void Modulo(DenseTensor tensor, char scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (char)(tensorSpan[i] % scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (char)(tensorSpan[op1Index] % scalar); - - } - } - } - public void Multiply(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (char)(leftSpan[i] * rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (char)(leftSpan[op1Index] * rightSpan[op2Index]); - - } - } - } - public void Multiply(DenseTensor tensor, char scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (char)(tensorSpan[i] * scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (char)(tensorSpan[op1Index] * scalar); - - } - } - } - public void NotEquals(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] != rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] != rightSpan[op2Index]; - - } - } - } - public void Or(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (char)(leftSpan[i] | rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (char)(leftSpan[op1Index] | rightSpan[op2Index]); - - } - } - } - public void Or(DenseTensor tensor, char scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (char)(tensorSpan[i] | scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (char)(tensorSpan[op1Index] | scalar); - - } - } - } - public void RightShift(DenseTensor tensor, int value, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (char)(tensorSpan[i] >> value); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (char)(tensorSpan[op1Index] >> value); - - } - } - } - public void Subtract(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (char)(leftSpan[i] - rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (char)(leftSpan[op1Index] - rightSpan[op2Index]); - - } - } - } - public void Subtract(DenseTensor tensor, char scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (char)(tensorSpan[i] - scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (char)(tensorSpan[op1Index] - scalar); - - } - } - } - public void UnaryMinus(DenseTensor tensor, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (char)-tensorSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (char)-tensorSpan[op1Index]; - - } - } - } - public void UnaryPlus(DenseTensor tensor, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (char)+tensorSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (char)+tensorSpan[op1Index]; - - } - } - } - public void Xor(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (char)(leftSpan[i] ^ rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (char)(leftSpan[op1Index] ^ rightSpan[op2Index]); - - } - } - } - public void Xor(DenseTensor tensor, char scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (char)(tensorSpan[i] ^ scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (char)(tensorSpan[op1Index] ^ scalar); - - } - } - } - } - internal class DecimalArithmetic : ITensorArithmetic - { - public decimal One => 1; - public decimal Zero => 0; - - public void Add(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (decimal)(left[indices] + right[indices]); - } - - } - public void Add(Tensor tensor, decimal scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (decimal)(tensor[indices] + scalar); - } - - } - public void And(Tensor left, Tensor right, Tensor result) - { - throw new NotSupportedException(); - } - public void And(Tensor tensor, decimal scalar, Tensor result) - { - throw new NotSupportedException(); - } - public void Contract(Tensor left, Tensor right, int[] leftAxes, int[] rightAxes, Tensor result) - { - var leftIndices = new int[left.Rank]; - var rightIndices = new int[right.Rank]; - var resultIndices = new int[result.Rank]; - - var summingDimensions = new int[leftAxes.Length]; - for (int i = 0; i < leftAxes.Length; i++) - { - summingDimensions[i] = left.dimensions[leftAxes[i]]; - } - - var summingStrides = ArrayUtilities.GetStrides(summingDimensions); - int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions); - - var resultStrides = result.strides; - - // translates from result index to left non-summing dimensions' index portion - // since left non-summing dimensions are given precedence in result, the end is zero-padded - int[] leftNonSummingStrides = new int[result.Rank]; - - // translates from summing index to left summing dimensions' index portion - int[] leftSummingStrides = new int[leftAxes.Length]; - ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0); - - // translates from result index to right non-summing dimensions' index portion - int[] rightNonSummingStrides = new int[result.Rank]; - // right non-summing dimensions appear after left non-summing dimensions. - int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length); - - // translates from summing index to right summing dimensions' index portion - int[] rightSummingStrides = new int[rightAxes.Length]; - ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0); - - for (int resultIndex = 0; resultIndex < result.Length; resultIndex++) - { - decimal sum = (decimal)0; - - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, resultIndex, resultIndices); - - int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides); - int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides); - - for (int summingIndex = 0; summingIndex < summingLength; summingIndex++) - { - int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides); - int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides); - - int leftIndex = leftIndexNonSumming + leftIndexSumming; - int rightIndex = rightIndexNonSumming + rightIndexSumming; - - // todo, make this more efficient - ArrayUtilities.GetIndices(left.strides, left.IsReversedStride, leftIndex, leftIndices); - ArrayUtilities.GetIndices(right.strides, right.IsReversedStride, rightIndex, rightIndices); - - sum += (decimal)(left[leftIndices] * right[rightIndices]); - } - - result[resultIndices] = sum; - } - } - public void Decrement(Tensor tensor, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices]--; - } - - } - public void Divide(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (decimal)(left[indices] / right[indices]); - } - - } - public void Divide(Tensor tensor, decimal scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (decimal)(tensor[indices] / scalar); - } - - } - public void Equals(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] == right[indices]; - } - - } - public void GreaterThan(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] > right[indices]; - } - - } - public void GreaterThanOrEqual(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] >= right[indices]; - } - - } - public void Increment(Tensor tensor, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices]++; - } - - } - public void LeftShift(Tensor tensor, int value, Tensor result) - { - throw new NotSupportedException(); - } - public void LessThan(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] < right[indices]; - } - - } - public void LessThanOrEqual(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] <= right[indices]; - } - - } - public void Modulo(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (decimal)(left[indices] % right[indices]); - } - - } - public void Modulo(Tensor tensor, decimal scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (decimal)(tensor[indices] % scalar); - } - - } - public void Multiply(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (decimal)(left[indices] * right[indices]); - } - - } - public void Multiply(Tensor tensor, decimal scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (decimal)(tensor[indices] * scalar); - } - - } - public void NotEquals(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] != right[indices]; - } - - } - public void Or(Tensor left, Tensor right, Tensor result) - { - throw new NotSupportedException(); - } - public void Or(Tensor tensor, decimal scalar, Tensor result) - { - throw new NotSupportedException(); - } - public void RightShift(Tensor tensor, int value, Tensor result) - { - throw new NotSupportedException(); - } - public void Subtract(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (decimal)(left[indices] - right[indices]); - } - - } - public void Subtract(Tensor tensor, decimal scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (decimal)(tensor[indices] - scalar); - } - - } - public void UnaryMinus(Tensor tensor, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (decimal)-tensor[indices]; - } - - } - public void UnaryPlus(Tensor tensor, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (decimal)+tensor[indices]; - } - - } - public void Xor(Tensor left, Tensor right, Tensor result) - { - throw new NotSupportedException(); - } - public void Xor(Tensor tensor, decimal scalar, Tensor result) - { - throw new NotSupportedException(); - } - - public void Add(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (decimal)(leftSpan[i] + rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (decimal)(leftSpan[op1Index] + rightSpan[op2Index]); - - } - } - } - public void Add(DenseTensor tensor, decimal scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (decimal)(tensorSpan[i] + scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (decimal)(tensorSpan[op1Index] + scalar); - - } - } - } - public void And(DenseTensor left, DenseTensor right, DenseTensor result) - { - throw new NotSupportedException(); - } - public void And(DenseTensor tensor, decimal scalar, DenseTensor result) - { - throw new NotSupportedException(); - } - public void Contract(DenseTensor left, DenseTensor right, int[] leftAxes, int[] rightAxes, DenseTensor result) - { - var summingDimensions = new int[leftAxes.Length]; - for (int i = 0; i < leftAxes.Length; i++) - { - summingDimensions[i] = left.dimensions[leftAxes[i]]; - } - - var summingStrides = ArrayUtilities.GetStrides(summingDimensions); - int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions); - - var resultStrides = result.strides; - - // translates from result index to left non-summing dimensions' index portion - // since left non-summing dimensions are given precedence in result, the end is zero-padded - int[] leftNonSummingStrides = new int[result.Rank]; - - // translates from summing index to left summing dimensions' index portion - int[] leftSummingStrides = new int[leftAxes.Length]; - ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0); - - // translates from result index to right non-summing dimensions' index portion - int[] rightNonSummingStrides = new int[result.Rank]; - // right non-summing dimensions appear after left non-summing dimensions. - int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length); - - // translates from summing index to right summing dimensions' index portion - int[] rightSummingStrides = new int[rightAxes.Length]; - ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0); - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - - for (int resultIndex = 0; resultIndex < resultSpan.Length; resultIndex++) - { - decimal sum = (decimal)0; - - int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides); - int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides); - - for (int summingIndex = 0; summingIndex < summingLength; summingIndex++) - { - int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides); - int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides); - - int leftIndex = leftIndexNonSumming + leftIndexSumming; - int rightIndex = rightIndexNonSumming + rightIndexSumming; - - sum += (decimal)(leftSpan[leftIndex] * rightSpan[rightIndex]); - } - - resultSpan[resultIndex] = sum; - } - } - public void Decrement(DenseTensor tensor, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i]--; - } - } - public void Divide(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (decimal)(leftSpan[i] / rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (decimal)(leftSpan[op1Index] / rightSpan[op2Index]); - - } - } - } - public void Divide(DenseTensor tensor, decimal scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (decimal)(tensorSpan[i] / scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (decimal)(tensorSpan[op1Index] / scalar); - - } - } - } - public void Equals(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] == rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] == rightSpan[op2Index]; - - } - } - } - public void GreaterThan(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] > rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] > rightSpan[op2Index]; - - } - } - } - public void GreaterThanOrEqual(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] >= rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] >= rightSpan[op2Index]; - - } - } - } - public void Increment(DenseTensor tensor, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i]++; - } - } - public void LeftShift(DenseTensor tensor, int value, DenseTensor result) - { - throw new NotSupportedException(); - } - public void LessThan(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] < rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] < rightSpan[op2Index]; - - } - } - } - public void LessThanOrEqual(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] <= rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] <= rightSpan[op2Index]; - - } - } - } - public void Modulo(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (decimal)(leftSpan[i] % rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (decimal)(leftSpan[op1Index] % rightSpan[op2Index]); - - } - } - } - public void Modulo(DenseTensor tensor, decimal scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (decimal)(tensorSpan[i] % scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (decimal)(tensorSpan[op1Index] % scalar); - - } - } - } - public void Multiply(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (decimal)(leftSpan[i] * rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (decimal)(leftSpan[op1Index] * rightSpan[op2Index]); - - } - } - } - public void Multiply(DenseTensor tensor, decimal scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (decimal)(tensorSpan[i] * scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (decimal)(tensorSpan[op1Index] * scalar); - - } - } - } - public void NotEquals(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] != rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] != rightSpan[op2Index]; - - } - } - } - public void Or(DenseTensor left, DenseTensor right, DenseTensor result) - { - throw new NotSupportedException(); - } - public void Or(DenseTensor tensor, decimal scalar, DenseTensor result) - { - throw new NotSupportedException(); - } - public void RightShift(DenseTensor tensor, int value, DenseTensor result) - { - throw new NotSupportedException(); - } - public void Subtract(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (decimal)(leftSpan[i] - rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (decimal)(leftSpan[op1Index] - rightSpan[op2Index]); - - } - } - } - public void Subtract(DenseTensor tensor, decimal scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (decimal)(tensorSpan[i] - scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (decimal)(tensorSpan[op1Index] - scalar); - - } - } - } - public void UnaryMinus(DenseTensor tensor, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (decimal)-tensorSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (decimal)-tensorSpan[op1Index]; - - } - } - } - public void UnaryPlus(DenseTensor tensor, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (decimal)+tensorSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (decimal)+tensorSpan[op1Index]; - - } - } - } - public void Xor(DenseTensor left, DenseTensor right, DenseTensor result) - { - throw new NotSupportedException(); - } - public void Xor(DenseTensor tensor, decimal scalar, DenseTensor result) - { - throw new NotSupportedException(); - } - } - internal class DoubleArithmetic : ITensorArithmetic - { - public double One => 1.0; - public double Zero => 0; - - public void Add(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (double)(left[indices] + right[indices]); - } - - } - public void Add(Tensor tensor, double scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (double)(tensor[indices] + scalar); - } - - } - public void And(Tensor left, Tensor right, Tensor result) - { - throw new NotSupportedException(); - } - public void And(Tensor tensor, double scalar, Tensor result) - { - throw new NotSupportedException(); - } - public void Contract(Tensor left, Tensor right, int[] leftAxes, int[] rightAxes, Tensor result) - { - var leftIndices = new int[left.Rank]; - var rightIndices = new int[right.Rank]; - var resultIndices = new int[result.Rank]; - - var summingDimensions = new int[leftAxes.Length]; - for (int i = 0; i < leftAxes.Length; i++) - { - summingDimensions[i] = left.dimensions[leftAxes[i]]; - } - - var summingStrides = ArrayUtilities.GetStrides(summingDimensions); - int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions); - - var resultStrides = result.strides; - - // translates from result index to left non-summing dimensions' index portion - // since left non-summing dimensions are given precedence in result, the end is zero-padded - int[] leftNonSummingStrides = new int[result.Rank]; - - // translates from summing index to left summing dimensions' index portion - int[] leftSummingStrides = new int[leftAxes.Length]; - ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0); - - // translates from result index to right non-summing dimensions' index portion - int[] rightNonSummingStrides = new int[result.Rank]; - // right non-summing dimensions appear after left non-summing dimensions. - int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length); - - // translates from summing index to right summing dimensions' index portion - int[] rightSummingStrides = new int[rightAxes.Length]; - ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0); - - for (int resultIndex = 0; resultIndex < result.Length; resultIndex++) - { - double sum = (double)0; - - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, resultIndex, resultIndices); - - int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides); - int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides); - - for (int summingIndex = 0; summingIndex < summingLength; summingIndex++) - { - int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides); - int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides); - - int leftIndex = leftIndexNonSumming + leftIndexSumming; - int rightIndex = rightIndexNonSumming + rightIndexSumming; - - // todo, make this more efficient - ArrayUtilities.GetIndices(left.strides, left.IsReversedStride, leftIndex, leftIndices); - ArrayUtilities.GetIndices(right.strides, right.IsReversedStride, rightIndex, rightIndices); - - sum += (double)(left[leftIndices] * right[rightIndices]); - } - - result[resultIndices] = sum; - } - } - public void Decrement(Tensor tensor, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices]--; - } - - } - public void Divide(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (double)(left[indices] / right[indices]); - } - - } - public void Divide(Tensor tensor, double scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (double)(tensor[indices] / scalar); - } - - } - public void Equals(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] == right[indices]; - } - - } - public void GreaterThan(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] > right[indices]; - } - - } - public void GreaterThanOrEqual(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] >= right[indices]; - } - - } - public void Increment(Tensor tensor, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices]++; - } - - } - public void LeftShift(Tensor tensor, int value, Tensor result) - { - throw new NotSupportedException(); - } - public void LessThan(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] < right[indices]; - } - - } - public void LessThanOrEqual(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] <= right[indices]; - } - - } - public void Modulo(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (double)(left[indices] % right[indices]); - } - - } - public void Modulo(Tensor tensor, double scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (double)(tensor[indices] % scalar); - } - - } - public void Multiply(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (double)(left[indices] * right[indices]); - } - - } - public void Multiply(Tensor tensor, double scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (double)(tensor[indices] * scalar); - } - - } - public void NotEquals(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] != right[indices]; - } - - } - public void Or(Tensor left, Tensor right, Tensor result) - { - throw new NotSupportedException(); - } - public void Or(Tensor tensor, double scalar, Tensor result) - { - throw new NotSupportedException(); - } - public void RightShift(Tensor tensor, int value, Tensor result) - { - throw new NotSupportedException(); - } - public void Subtract(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (double)(left[indices] - right[indices]); - } - - } - public void Subtract(Tensor tensor, double scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (double)(tensor[indices] - scalar); - } - - } - public void UnaryMinus(Tensor tensor, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (double)-tensor[indices]; - } - - } - public void UnaryPlus(Tensor tensor, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (double)+tensor[indices]; - } - - } - public void Xor(Tensor left, Tensor right, Tensor result) - { - throw new NotSupportedException(); - } - public void Xor(Tensor tensor, double scalar, Tensor result) - { - throw new NotSupportedException(); - } - - public void Add(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (double)(leftSpan[i] + rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (double)(leftSpan[op1Index] + rightSpan[op2Index]); - - } - } - } - public void Add(DenseTensor tensor, double scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (double)(tensorSpan[i] + scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (double)(tensorSpan[op1Index] + scalar); - - } - } - } - public void And(DenseTensor left, DenseTensor right, DenseTensor result) - { - throw new NotSupportedException(); - } - public void And(DenseTensor tensor, double scalar, DenseTensor result) - { - throw new NotSupportedException(); - } - public void Contract(DenseTensor left, DenseTensor right, int[] leftAxes, int[] rightAxes, DenseTensor result) - { - var summingDimensions = new int[leftAxes.Length]; - for (int i = 0; i < leftAxes.Length; i++) - { - summingDimensions[i] = left.dimensions[leftAxes[i]]; - } - - var summingStrides = ArrayUtilities.GetStrides(summingDimensions); - int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions); - - var resultStrides = result.strides; - - // translates from result index to left non-summing dimensions' index portion - // since left non-summing dimensions are given precedence in result, the end is zero-padded - int[] leftNonSummingStrides = new int[result.Rank]; - - // translates from summing index to left summing dimensions' index portion - int[] leftSummingStrides = new int[leftAxes.Length]; - ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0); - - // translates from result index to right non-summing dimensions' index portion - int[] rightNonSummingStrides = new int[result.Rank]; - // right non-summing dimensions appear after left non-summing dimensions. - int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length); - - // translates from summing index to right summing dimensions' index portion - int[] rightSummingStrides = new int[rightAxes.Length]; - ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0); - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - - for (int resultIndex = 0; resultIndex < resultSpan.Length; resultIndex++) - { - double sum = (double)0; - - int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides); - int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides); - - for (int summingIndex = 0; summingIndex < summingLength; summingIndex++) - { - int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides); - int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides); - - int leftIndex = leftIndexNonSumming + leftIndexSumming; - int rightIndex = rightIndexNonSumming + rightIndexSumming; - - sum += (double)(leftSpan[leftIndex] * rightSpan[rightIndex]); - } - - resultSpan[resultIndex] = sum; - } - } - public void Decrement(DenseTensor tensor, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i]--; - } - } - public void Divide(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (double)(leftSpan[i] / rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (double)(leftSpan[op1Index] / rightSpan[op2Index]); - - } - } - } - public void Divide(DenseTensor tensor, double scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (double)(tensorSpan[i] / scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (double)(tensorSpan[op1Index] / scalar); - - } - } - } - public void Equals(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] == rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] == rightSpan[op2Index]; - - } - } - } - public void GreaterThan(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] > rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] > rightSpan[op2Index]; - - } - } - } - public void GreaterThanOrEqual(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] >= rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] >= rightSpan[op2Index]; - - } - } - } - public void Increment(DenseTensor tensor, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i]++; - } - } - public void LeftShift(DenseTensor tensor, int value, DenseTensor result) - { - throw new NotSupportedException(); - } - public void LessThan(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] < rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] < rightSpan[op2Index]; - - } - } - } - public void LessThanOrEqual(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] <= rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] <= rightSpan[op2Index]; - - } - } - } - public void Modulo(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (double)(leftSpan[i] % rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (double)(leftSpan[op1Index] % rightSpan[op2Index]); - - } - } - } - public void Modulo(DenseTensor tensor, double scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (double)(tensorSpan[i] % scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (double)(tensorSpan[op1Index] % scalar); - - } - } - } - public void Multiply(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (double)(leftSpan[i] * rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (double)(leftSpan[op1Index] * rightSpan[op2Index]); - - } - } - } - public void Multiply(DenseTensor tensor, double scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (double)(tensorSpan[i] * scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (double)(tensorSpan[op1Index] * scalar); - - } - } - } - public void NotEquals(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] != rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] != rightSpan[op2Index]; - - } - } - } - public void Or(DenseTensor left, DenseTensor right, DenseTensor result) - { - throw new NotSupportedException(); - } - public void Or(DenseTensor tensor, double scalar, DenseTensor result) - { - throw new NotSupportedException(); - } - public void RightShift(DenseTensor tensor, int value, DenseTensor result) - { - throw new NotSupportedException(); - } - public void Subtract(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (double)(leftSpan[i] - rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (double)(leftSpan[op1Index] - rightSpan[op2Index]); - - } - } - } - public void Subtract(DenseTensor tensor, double scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (double)(tensorSpan[i] - scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (double)(tensorSpan[op1Index] - scalar); - - } - } - } - public void UnaryMinus(DenseTensor tensor, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (double)-tensorSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (double)-tensorSpan[op1Index]; - - } - } - } - public void UnaryPlus(DenseTensor tensor, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (double)+tensorSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (double)+tensorSpan[op1Index]; - - } - } - } - public void Xor(DenseTensor left, DenseTensor right, DenseTensor result) - { - throw new NotSupportedException(); - } - public void Xor(DenseTensor tensor, double scalar, DenseTensor result) - { - throw new NotSupportedException(); - } - } - internal class FloatArithmetic : ITensorArithmetic - { - public float One => 1.0f; - public float Zero => 0; - - public void Add(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (float)(left[indices] + right[indices]); - } - - } - public void Add(Tensor tensor, float scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (float)(tensor[indices] + scalar); - } - - } - public void And(Tensor left, Tensor right, Tensor result) - { - throw new NotSupportedException(); - } - public void And(Tensor tensor, float scalar, Tensor result) - { - throw new NotSupportedException(); - } - public void Contract(Tensor left, Tensor right, int[] leftAxes, int[] rightAxes, Tensor result) - { - var leftIndices = new int[left.Rank]; - var rightIndices = new int[right.Rank]; - var resultIndices = new int[result.Rank]; - - var summingDimensions = new int[leftAxes.Length]; - for (int i = 0; i < leftAxes.Length; i++) - { - summingDimensions[i] = left.dimensions[leftAxes[i]]; - } - - var summingStrides = ArrayUtilities.GetStrides(summingDimensions); - int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions); - - var resultStrides = result.strides; - - // translates from result index to left non-summing dimensions' index portion - // since left non-summing dimensions are given precedence in result, the end is zero-padded - int[] leftNonSummingStrides = new int[result.Rank]; - - // translates from summing index to left summing dimensions' index portion - int[] leftSummingStrides = new int[leftAxes.Length]; - ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0); - - // translates from result index to right non-summing dimensions' index portion - int[] rightNonSummingStrides = new int[result.Rank]; - // right non-summing dimensions appear after left non-summing dimensions. - int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length); - - // translates from summing index to right summing dimensions' index portion - int[] rightSummingStrides = new int[rightAxes.Length]; - ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0); - - for (int resultIndex = 0; resultIndex < result.Length; resultIndex++) - { - float sum = (float)0; - - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, resultIndex, resultIndices); - - int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides); - int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides); - - for (int summingIndex = 0; summingIndex < summingLength; summingIndex++) - { - int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides); - int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides); - - int leftIndex = leftIndexNonSumming + leftIndexSumming; - int rightIndex = rightIndexNonSumming + rightIndexSumming; - - // todo, make this more efficient - ArrayUtilities.GetIndices(left.strides, left.IsReversedStride, leftIndex, leftIndices); - ArrayUtilities.GetIndices(right.strides, right.IsReversedStride, rightIndex, rightIndices); - - sum += (float)(left[leftIndices] * right[rightIndices]); - } - - result[resultIndices] = sum; - } - } - public void Decrement(Tensor tensor, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices]--; - } - - } - public void Divide(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (float)(left[indices] / right[indices]); - } - - } - public void Divide(Tensor tensor, float scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (float)(tensor[indices] / scalar); - } - - } - public void Equals(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] == right[indices]; - } - - } - public void GreaterThan(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] > right[indices]; - } - - } - public void GreaterThanOrEqual(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] >= right[indices]; - } - - } - public void Increment(Tensor tensor, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices]++; - } - - } - public void LeftShift(Tensor tensor, int value, Tensor result) - { - throw new NotSupportedException(); - } - public void LessThan(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] < right[indices]; - } - - } - public void LessThanOrEqual(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] <= right[indices]; - } - - } - public void Modulo(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (float)(left[indices] % right[indices]); - } - - } - public void Modulo(Tensor tensor, float scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (float)(tensor[indices] % scalar); - } - - } - public void Multiply(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (float)(left[indices] * right[indices]); - } - - } - public void Multiply(Tensor tensor, float scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (float)(tensor[indices] * scalar); - } - - } - public void NotEquals(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] != right[indices]; - } - - } - public void Or(Tensor left, Tensor right, Tensor result) - { - throw new NotSupportedException(); - } - public void Or(Tensor tensor, float scalar, Tensor result) - { - throw new NotSupportedException(); - } - public void RightShift(Tensor tensor, int value, Tensor result) - { - throw new NotSupportedException(); - } - public void Subtract(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (float)(left[indices] - right[indices]); - } - - } - public void Subtract(Tensor tensor, float scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (float)(tensor[indices] - scalar); - } - - } - public void UnaryMinus(Tensor tensor, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (float)-tensor[indices]; - } - - } - public void UnaryPlus(Tensor tensor, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (float)+tensor[indices]; - } - - } - public void Xor(Tensor left, Tensor right, Tensor result) - { - throw new NotSupportedException(); - } - public void Xor(Tensor tensor, float scalar, Tensor result) - { - throw new NotSupportedException(); - } - - public void Add(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (float)(leftSpan[i] + rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (float)(leftSpan[op1Index] + rightSpan[op2Index]); - - } - } - } - public void Add(DenseTensor tensor, float scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (float)(tensorSpan[i] + scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (float)(tensorSpan[op1Index] + scalar); - - } - } - } - public void And(DenseTensor left, DenseTensor right, DenseTensor result) - { - throw new NotSupportedException(); - } - public void And(DenseTensor tensor, float scalar, DenseTensor result) - { - throw new NotSupportedException(); - } - public void Contract(DenseTensor left, DenseTensor right, int[] leftAxes, int[] rightAxes, DenseTensor result) - { - var summingDimensions = new int[leftAxes.Length]; - for (int i = 0; i < leftAxes.Length; i++) - { - summingDimensions[i] = left.dimensions[leftAxes[i]]; - } - - var summingStrides = ArrayUtilities.GetStrides(summingDimensions); - int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions); - - var resultStrides = result.strides; - - // translates from result index to left non-summing dimensions' index portion - // since left non-summing dimensions are given precedence in result, the end is zero-padded - int[] leftNonSummingStrides = new int[result.Rank]; - - // translates from summing index to left summing dimensions' index portion - int[] leftSummingStrides = new int[leftAxes.Length]; - ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0); - - // translates from result index to right non-summing dimensions' index portion - int[] rightNonSummingStrides = new int[result.Rank]; - // right non-summing dimensions appear after left non-summing dimensions. - int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length); - - // translates from summing index to right summing dimensions' index portion - int[] rightSummingStrides = new int[rightAxes.Length]; - ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0); - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - - for (int resultIndex = 0; resultIndex < resultSpan.Length; resultIndex++) - { - float sum = (float)0; - - int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides); - int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides); - - for (int summingIndex = 0; summingIndex < summingLength; summingIndex++) - { - int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides); - int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides); - - int leftIndex = leftIndexNonSumming + leftIndexSumming; - int rightIndex = rightIndexNonSumming + rightIndexSumming; - - sum += (float)(leftSpan[leftIndex] * rightSpan[rightIndex]); - } - - resultSpan[resultIndex] = sum; - } - } - public void Decrement(DenseTensor tensor, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i]--; - } - } - public void Divide(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (float)(leftSpan[i] / rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (float)(leftSpan[op1Index] / rightSpan[op2Index]); - - } - } - } - public void Divide(DenseTensor tensor, float scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (float)(tensorSpan[i] / scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (float)(tensorSpan[op1Index] / scalar); - - } - } - } - public void Equals(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] == rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] == rightSpan[op2Index]; - - } - } - } - public void GreaterThan(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] > rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] > rightSpan[op2Index]; - - } - } - } - public void GreaterThanOrEqual(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] >= rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] >= rightSpan[op2Index]; - - } - } - } - public void Increment(DenseTensor tensor, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i]++; - } - } - public void LeftShift(DenseTensor tensor, int value, DenseTensor result) - { - throw new NotSupportedException(); - } - public void LessThan(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] < rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] < rightSpan[op2Index]; - - } - } - } - public void LessThanOrEqual(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] <= rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] <= rightSpan[op2Index]; - - } - } - } - public void Modulo(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (float)(leftSpan[i] % rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (float)(leftSpan[op1Index] % rightSpan[op2Index]); - - } - } - } - public void Modulo(DenseTensor tensor, float scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (float)(tensorSpan[i] % scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (float)(tensorSpan[op1Index] % scalar); - - } - } - } - public void Multiply(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (float)(leftSpan[i] * rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (float)(leftSpan[op1Index] * rightSpan[op2Index]); - - } - } - } - public void Multiply(DenseTensor tensor, float scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (float)(tensorSpan[i] * scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (float)(tensorSpan[op1Index] * scalar); - - } - } - } - public void NotEquals(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] != rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] != rightSpan[op2Index]; - - } - } - } - public void Or(DenseTensor left, DenseTensor right, DenseTensor result) - { - throw new NotSupportedException(); - } - public void Or(DenseTensor tensor, float scalar, DenseTensor result) - { - throw new NotSupportedException(); - } - public void RightShift(DenseTensor tensor, int value, DenseTensor result) - { - throw new NotSupportedException(); - } - public void Subtract(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (float)(leftSpan[i] - rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (float)(leftSpan[op1Index] - rightSpan[op2Index]); - - } - } - } - public void Subtract(DenseTensor tensor, float scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (float)(tensorSpan[i] - scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (float)(tensorSpan[op1Index] - scalar); - - } - } - } - public void UnaryMinus(DenseTensor tensor, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (float)-tensorSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (float)-tensorSpan[op1Index]; - - } - } - } - public void UnaryPlus(DenseTensor tensor, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (float)+tensorSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (float)+tensorSpan[op1Index]; - - } - } - } - public void Xor(DenseTensor left, DenseTensor right, DenseTensor result) - { - throw new NotSupportedException(); - } - public void Xor(DenseTensor tensor, float scalar, DenseTensor result) - { - throw new NotSupportedException(); - } - } - internal class IntArithmetic : ITensorArithmetic - { - public int One => 1; - public int Zero => 0; - - public void Add(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (int)(left[indices] + right[indices]); - } - - } - public void Add(Tensor tensor, int scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (int)(tensor[indices] + scalar); - } - - } - public void And(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (int)(left[indices] & right[indices]); - } - - } - public void And(Tensor tensor, int scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (int)(tensor[indices] & scalar); - } - - } - public void Contract(Tensor left, Tensor right, int[] leftAxes, int[] rightAxes, Tensor result) - { - var leftIndices = new int[left.Rank]; - var rightIndices = new int[right.Rank]; - var resultIndices = new int[result.Rank]; - - var summingDimensions = new int[leftAxes.Length]; - for (int i = 0; i < leftAxes.Length; i++) - { - summingDimensions[i] = left.dimensions[leftAxes[i]]; - } - - var summingStrides = ArrayUtilities.GetStrides(summingDimensions); - int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions); - - var resultStrides = result.strides; - - // translates from result index to left non-summing dimensions' index portion - // since left non-summing dimensions are given precedence in result, the end is zero-padded - int[] leftNonSummingStrides = new int[result.Rank]; - - // translates from summing index to left summing dimensions' index portion - int[] leftSummingStrides = new int[leftAxes.Length]; - ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0); - - // translates from result index to right non-summing dimensions' index portion - int[] rightNonSummingStrides = new int[result.Rank]; - // right non-summing dimensions appear after left non-summing dimensions. - int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length); - - // translates from summing index to right summing dimensions' index portion - int[] rightSummingStrides = new int[rightAxes.Length]; - ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0); - - for (int resultIndex = 0; resultIndex < result.Length; resultIndex++) - { - int sum = (int)0; - - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, resultIndex, resultIndices); - - int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides); - int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides); - - for (int summingIndex = 0; summingIndex < summingLength; summingIndex++) - { - int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides); - int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides); - - int leftIndex = leftIndexNonSumming + leftIndexSumming; - int rightIndex = rightIndexNonSumming + rightIndexSumming; - - // todo, make this more efficient - ArrayUtilities.GetIndices(left.strides, left.IsReversedStride, leftIndex, leftIndices); - ArrayUtilities.GetIndices(right.strides, right.IsReversedStride, rightIndex, rightIndices); - - sum += (int)(left[leftIndices] * right[rightIndices]); - } - - result[resultIndices] = sum; - } - } - public void Decrement(Tensor tensor, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices]--; - } - - } - public void Divide(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (int)(left[indices] / right[indices]); - } - - } - public void Divide(Tensor tensor, int scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (int)(tensor[indices] / scalar); - } - - } - public void Equals(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] == right[indices]; - } - - } - public void GreaterThan(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] > right[indices]; - } - - } - public void GreaterThanOrEqual(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] >= right[indices]; - } - - } - public void Increment(Tensor tensor, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices]++; - } - - } - public void LeftShift(Tensor tensor, int value, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (int)(tensor[indices] << value); - } - - } - public void LessThan(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] < right[indices]; - } - - } - public void LessThanOrEqual(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] <= right[indices]; - } - - } - public void Modulo(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (int)(left[indices] % right[indices]); - } - - } - public void Modulo(Tensor tensor, int scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (int)(tensor[indices] % scalar); - } - - } - public void Multiply(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (int)(left[indices] * right[indices]); - } - - } - public void Multiply(Tensor tensor, int scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (int)(tensor[indices] * scalar); - } - - } - public void NotEquals(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] != right[indices]; - } - - } - public void Or(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (int)(left[indices] | right[indices]); - } - - } - public void Or(Tensor tensor, int scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (int)(tensor[indices] | scalar); - } - - } - public void RightShift(Tensor tensor, int value, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (int)(tensor[indices] >> value); - } - - } - public void Subtract(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (int)(left[indices] - right[indices]); - } - - } - public void Subtract(Tensor tensor, int scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (int)(tensor[indices] - scalar); - } - - } - public void UnaryMinus(Tensor tensor, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (int)-tensor[indices]; - } - - } - public void UnaryPlus(Tensor tensor, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (int)+tensor[indices]; - } - - } - public void Xor(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (int)(left[indices] ^ right[indices]); - } - - } - public void Xor(Tensor tensor, int scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (int)(tensor[indices] ^ scalar); - } - - } - - public void Add(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (int)(leftSpan[i] + rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (int)(leftSpan[op1Index] + rightSpan[op2Index]); - - } - } - } - public void Add(DenseTensor tensor, int scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (int)(tensorSpan[i] + scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (int)(tensorSpan[op1Index] + scalar); - - } - } - } - public void And(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (int)(leftSpan[i] & rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (int)(leftSpan[op1Index] & rightSpan[op2Index]); - - } - } - } - public void And(DenseTensor tensor, int scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (int)(tensorSpan[i] & scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (int)(tensorSpan[op1Index] & scalar); - - } - } - } - public void Contract(DenseTensor left, DenseTensor right, int[] leftAxes, int[] rightAxes, DenseTensor result) - { - var summingDimensions = new int[leftAxes.Length]; - for (int i = 0; i < leftAxes.Length; i++) - { - summingDimensions[i] = left.dimensions[leftAxes[i]]; - } - - var summingStrides = ArrayUtilities.GetStrides(summingDimensions); - int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions); - - var resultStrides = result.strides; - - // translates from result index to left non-summing dimensions' index portion - // since left non-summing dimensions are given precedence in result, the end is zero-padded - int[] leftNonSummingStrides = new int[result.Rank]; - - // translates from summing index to left summing dimensions' index portion - int[] leftSummingStrides = new int[leftAxes.Length]; - ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0); - - // translates from result index to right non-summing dimensions' index portion - int[] rightNonSummingStrides = new int[result.Rank]; - // right non-summing dimensions appear after left non-summing dimensions. - int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length); - - // translates from summing index to right summing dimensions' index portion - int[] rightSummingStrides = new int[rightAxes.Length]; - ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0); - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - - for (int resultIndex = 0; resultIndex < resultSpan.Length; resultIndex++) - { - int sum = (int)0; - - int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides); - int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides); - - for (int summingIndex = 0; summingIndex < summingLength; summingIndex++) - { - int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides); - int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides); - - int leftIndex = leftIndexNonSumming + leftIndexSumming; - int rightIndex = rightIndexNonSumming + rightIndexSumming; - - sum += (int)(leftSpan[leftIndex] * rightSpan[rightIndex]); - } - - resultSpan[resultIndex] = sum; - } - } - public void Decrement(DenseTensor tensor, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i]--; - } - } - public void Divide(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (int)(leftSpan[i] / rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (int)(leftSpan[op1Index] / rightSpan[op2Index]); - - } - } - } - public void Divide(DenseTensor tensor, int scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (int)(tensorSpan[i] / scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (int)(tensorSpan[op1Index] / scalar); - - } - } - } - public void Equals(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] == rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] == rightSpan[op2Index]; - - } - } - } - public void GreaterThan(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] > rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] > rightSpan[op2Index]; - - } - } - } - public void GreaterThanOrEqual(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] >= rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] >= rightSpan[op2Index]; - - } - } - } - public void Increment(DenseTensor tensor, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i]++; - } - } - public void LeftShift(DenseTensor tensor, int value, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (int)(tensorSpan[i] << value); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (int)(tensorSpan[op1Index] << value); - - } - } - } - public void LessThan(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] < rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] < rightSpan[op2Index]; - - } - } - } - public void LessThanOrEqual(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] <= rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] <= rightSpan[op2Index]; - - } - } - } - public void Modulo(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (int)(leftSpan[i] % rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (int)(leftSpan[op1Index] % rightSpan[op2Index]); - - } - } - } - public void Modulo(DenseTensor tensor, int scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (int)(tensorSpan[i] % scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (int)(tensorSpan[op1Index] % scalar); - - } - } - } - public void Multiply(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (int)(leftSpan[i] * rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (int)(leftSpan[op1Index] * rightSpan[op2Index]); - - } - } - } - public void Multiply(DenseTensor tensor, int scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (int)(tensorSpan[i] * scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (int)(tensorSpan[op1Index] * scalar); - - } - } - } - public void NotEquals(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] != rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] != rightSpan[op2Index]; - - } - } - } - public void Or(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (int)(leftSpan[i] | rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (int)(leftSpan[op1Index] | rightSpan[op2Index]); - - } - } - } - public void Or(DenseTensor tensor, int scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (int)(tensorSpan[i] | scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (int)(tensorSpan[op1Index] | scalar); - - } - } - } - public void RightShift(DenseTensor tensor, int value, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (int)(tensorSpan[i] >> value); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (int)(tensorSpan[op1Index] >> value); - - } - } - } - public void Subtract(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (int)(leftSpan[i] - rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (int)(leftSpan[op1Index] - rightSpan[op2Index]); - - } - } - } - public void Subtract(DenseTensor tensor, int scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (int)(tensorSpan[i] - scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (int)(tensorSpan[op1Index] - scalar); - - } - } - } - public void UnaryMinus(DenseTensor tensor, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (int)-tensorSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (int)-tensorSpan[op1Index]; - - } - } - } - public void UnaryPlus(DenseTensor tensor, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (int)+tensorSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (int)+tensorSpan[op1Index]; - - } - } - } - public void Xor(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (int)(leftSpan[i] ^ rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (int)(leftSpan[op1Index] ^ rightSpan[op2Index]); - - } - } - } - public void Xor(DenseTensor tensor, int scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (int)(tensorSpan[i] ^ scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (int)(tensorSpan[op1Index] ^ scalar); - - } - } - } - } - internal class LongArithmetic : ITensorArithmetic - { - public long One => 1; - public long Zero => 0; - - public void Add(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (long)(left[indices] + right[indices]); - } - - } - public void Add(Tensor tensor, long scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (long)(tensor[indices] + scalar); - } - - } - public void And(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (long)(left[indices] & right[indices]); - } - - } - public void And(Tensor tensor, long scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (long)(tensor[indices] & scalar); - } - - } - public void Contract(Tensor left, Tensor right, int[] leftAxes, int[] rightAxes, Tensor result) - { - var leftIndices = new int[left.Rank]; - var rightIndices = new int[right.Rank]; - var resultIndices = new int[result.Rank]; - - var summingDimensions = new int[leftAxes.Length]; - for (int i = 0; i < leftAxes.Length; i++) - { - summingDimensions[i] = left.dimensions[leftAxes[i]]; - } - - var summingStrides = ArrayUtilities.GetStrides(summingDimensions); - int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions); - - var resultStrides = result.strides; - - // translates from result index to left non-summing dimensions' index portion - // since left non-summing dimensions are given precedence in result, the end is zero-padded - int[] leftNonSummingStrides = new int[result.Rank]; - - // translates from summing index to left summing dimensions' index portion - int[] leftSummingStrides = new int[leftAxes.Length]; - ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0); - - // translates from result index to right non-summing dimensions' index portion - int[] rightNonSummingStrides = new int[result.Rank]; - // right non-summing dimensions appear after left non-summing dimensions. - int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length); - - // translates from summing index to right summing dimensions' index portion - int[] rightSummingStrides = new int[rightAxes.Length]; - ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0); - - for (int resultIndex = 0; resultIndex < result.Length; resultIndex++) - { - long sum = (long)0; - - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, resultIndex, resultIndices); - - int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides); - int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides); - - for (int summingIndex = 0; summingIndex < summingLength; summingIndex++) - { - int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides); - int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides); - - int leftIndex = leftIndexNonSumming + leftIndexSumming; - int rightIndex = rightIndexNonSumming + rightIndexSumming; - - // todo, make this more efficient - ArrayUtilities.GetIndices(left.strides, left.IsReversedStride, leftIndex, leftIndices); - ArrayUtilities.GetIndices(right.strides, right.IsReversedStride, rightIndex, rightIndices); - - sum += (long)(left[leftIndices] * right[rightIndices]); - } - - result[resultIndices] = sum; - } - } - public void Decrement(Tensor tensor, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices]--; - } - - } - public void Divide(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (long)(left[indices] / right[indices]); - } - - } - public void Divide(Tensor tensor, long scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (long)(tensor[indices] / scalar); - } - - } - public void Equals(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] == right[indices]; - } - - } - public void GreaterThan(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] > right[indices]; - } - - } - public void GreaterThanOrEqual(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] >= right[indices]; - } - - } - public void Increment(Tensor tensor, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices]++; - } - - } - public void LeftShift(Tensor tensor, int value, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (long)(tensor[indices] << value); - } - - } - public void LessThan(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] < right[indices]; - } - - } - public void LessThanOrEqual(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] <= right[indices]; - } - - } - public void Modulo(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (long)(left[indices] % right[indices]); - } - - } - public void Modulo(Tensor tensor, long scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (long)(tensor[indices] % scalar); - } - - } - public void Multiply(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (long)(left[indices] * right[indices]); - } - - } - public void Multiply(Tensor tensor, long scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (long)(tensor[indices] * scalar); - } - - } - public void NotEquals(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] != right[indices]; - } - - } - public void Or(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (long)(left[indices] | right[indices]); - } - - } - public void Or(Tensor tensor, long scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (long)(tensor[indices] | scalar); - } - - } - public void RightShift(Tensor tensor, int value, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (long)(tensor[indices] >> value); - } - - } - public void Subtract(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (long)(left[indices] - right[indices]); - } - - } - public void Subtract(Tensor tensor, long scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (long)(tensor[indices] - scalar); - } - - } - public void UnaryMinus(Tensor tensor, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (long)-tensor[indices]; - } - - } - public void UnaryPlus(Tensor tensor, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (long)+tensor[indices]; - } - - } - public void Xor(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (long)(left[indices] ^ right[indices]); - } - - } - public void Xor(Tensor tensor, long scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (long)(tensor[indices] ^ scalar); - } - - } - - public void Add(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (long)(leftSpan[i] + rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (long)(leftSpan[op1Index] + rightSpan[op2Index]); - - } - } - } - public void Add(DenseTensor tensor, long scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (long)(tensorSpan[i] + scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (long)(tensorSpan[op1Index] + scalar); - - } - } - } - public void And(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (long)(leftSpan[i] & rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (long)(leftSpan[op1Index] & rightSpan[op2Index]); - - } - } - } - public void And(DenseTensor tensor, long scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (long)(tensorSpan[i] & scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (long)(tensorSpan[op1Index] & scalar); - - } - } - } - public void Contract(DenseTensor left, DenseTensor right, int[] leftAxes, int[] rightAxes, DenseTensor result) - { - var summingDimensions = new int[leftAxes.Length]; - for (int i = 0; i < leftAxes.Length; i++) - { - summingDimensions[i] = left.dimensions[leftAxes[i]]; - } - - var summingStrides = ArrayUtilities.GetStrides(summingDimensions); - int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions); - - var resultStrides = result.strides; - - // translates from result index to left non-summing dimensions' index portion - // since left non-summing dimensions are given precedence in result, the end is zero-padded - int[] leftNonSummingStrides = new int[result.Rank]; - - // translates from summing index to left summing dimensions' index portion - int[] leftSummingStrides = new int[leftAxes.Length]; - ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0); - - // translates from result index to right non-summing dimensions' index portion - int[] rightNonSummingStrides = new int[result.Rank]; - // right non-summing dimensions appear after left non-summing dimensions. - int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length); - - // translates from summing index to right summing dimensions' index portion - int[] rightSummingStrides = new int[rightAxes.Length]; - ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0); - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - - for (int resultIndex = 0; resultIndex < resultSpan.Length; resultIndex++) - { - long sum = (long)0; - - int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides); - int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides); - - for (int summingIndex = 0; summingIndex < summingLength; summingIndex++) - { - int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides); - int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides); - - int leftIndex = leftIndexNonSumming + leftIndexSumming; - int rightIndex = rightIndexNonSumming + rightIndexSumming; - - sum += (long)(leftSpan[leftIndex] * rightSpan[rightIndex]); - } - - resultSpan[resultIndex] = sum; - } - } - public void Decrement(DenseTensor tensor, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i]--; - } - } - public void Divide(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (long)(leftSpan[i] / rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (long)(leftSpan[op1Index] / rightSpan[op2Index]); - - } - } - } - public void Divide(DenseTensor tensor, long scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (long)(tensorSpan[i] / scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (long)(tensorSpan[op1Index] / scalar); - - } - } - } - public void Equals(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] == rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] == rightSpan[op2Index]; - - } - } - } - public void GreaterThan(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] > rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] > rightSpan[op2Index]; - - } - } - } - public void GreaterThanOrEqual(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] >= rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] >= rightSpan[op2Index]; - - } - } - } - public void Increment(DenseTensor tensor, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i]++; - } - } - public void LeftShift(DenseTensor tensor, int value, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (long)(tensorSpan[i] << value); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (long)(tensorSpan[op1Index] << value); - - } - } - } - public void LessThan(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] < rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] < rightSpan[op2Index]; - - } - } - } - public void LessThanOrEqual(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] <= rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] <= rightSpan[op2Index]; - - } - } - } - public void Modulo(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (long)(leftSpan[i] % rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (long)(leftSpan[op1Index] % rightSpan[op2Index]); - - } - } - } - public void Modulo(DenseTensor tensor, long scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (long)(tensorSpan[i] % scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (long)(tensorSpan[op1Index] % scalar); - - } - } - } - public void Multiply(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (long)(leftSpan[i] * rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (long)(leftSpan[op1Index] * rightSpan[op2Index]); - - } - } - } - public void Multiply(DenseTensor tensor, long scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (long)(tensorSpan[i] * scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (long)(tensorSpan[op1Index] * scalar); - - } - } - } - public void NotEquals(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] != rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] != rightSpan[op2Index]; - - } - } - } - public void Or(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (long)(leftSpan[i] | rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (long)(leftSpan[op1Index] | rightSpan[op2Index]); - - } - } - } - public void Or(DenseTensor tensor, long scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (long)(tensorSpan[i] | scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (long)(tensorSpan[op1Index] | scalar); - - } - } - } - public void RightShift(DenseTensor tensor, int value, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (long)(tensorSpan[i] >> value); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (long)(tensorSpan[op1Index] >> value); - - } - } - } - public void Subtract(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (long)(leftSpan[i] - rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (long)(leftSpan[op1Index] - rightSpan[op2Index]); - - } - } - } - public void Subtract(DenseTensor tensor, long scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (long)(tensorSpan[i] - scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (long)(tensorSpan[op1Index] - scalar); - - } - } - } - public void UnaryMinus(DenseTensor tensor, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (long)-tensorSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (long)-tensorSpan[op1Index]; - - } - } - } - public void UnaryPlus(DenseTensor tensor, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (long)+tensorSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (long)+tensorSpan[op1Index]; - - } - } - } - public void Xor(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (long)(leftSpan[i] ^ rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (long)(leftSpan[op1Index] ^ rightSpan[op2Index]); - - } - } - } - public void Xor(DenseTensor tensor, long scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (long)(tensorSpan[i] ^ scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (long)(tensorSpan[op1Index] ^ scalar); - - } - } - } - } - internal class SByteArithmetic : ITensorArithmetic - { - public sbyte One => 1; - public sbyte Zero => 0; - - public void Add(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (sbyte)(left[indices] + right[indices]); - } - - } - public void Add(Tensor tensor, sbyte scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (sbyte)(tensor[indices] + scalar); - } - - } - public void And(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (sbyte)(left[indices] & right[indices]); - } - - } - public void And(Tensor tensor, sbyte scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (sbyte)(tensor[indices] & scalar); - } - - } - public void Contract(Tensor left, Tensor right, int[] leftAxes, int[] rightAxes, Tensor result) - { - var leftIndices = new int[left.Rank]; - var rightIndices = new int[right.Rank]; - var resultIndices = new int[result.Rank]; - - var summingDimensions = new int[leftAxes.Length]; - for (int i = 0; i < leftAxes.Length; i++) - { - summingDimensions[i] = left.dimensions[leftAxes[i]]; - } - - var summingStrides = ArrayUtilities.GetStrides(summingDimensions); - int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions); - - var resultStrides = result.strides; - - // translates from result index to left non-summing dimensions' index portion - // since left non-summing dimensions are given precedence in result, the end is zero-padded - int[] leftNonSummingStrides = new int[result.Rank]; - - // translates from summing index to left summing dimensions' index portion - int[] leftSummingStrides = new int[leftAxes.Length]; - ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0); - - // translates from result index to right non-summing dimensions' index portion - int[] rightNonSummingStrides = new int[result.Rank]; - // right non-summing dimensions appear after left non-summing dimensions. - int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length); - - // translates from summing index to right summing dimensions' index portion - int[] rightSummingStrides = new int[rightAxes.Length]; - ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0); - - for (int resultIndex = 0; resultIndex < result.Length; resultIndex++) - { - sbyte sum = (sbyte)0; - - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, resultIndex, resultIndices); - - int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides); - int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides); - - for (int summingIndex = 0; summingIndex < summingLength; summingIndex++) - { - int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides); - int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides); - - int leftIndex = leftIndexNonSumming + leftIndexSumming; - int rightIndex = rightIndexNonSumming + rightIndexSumming; - - // todo, make this more efficient - ArrayUtilities.GetIndices(left.strides, left.IsReversedStride, leftIndex, leftIndices); - ArrayUtilities.GetIndices(right.strides, right.IsReversedStride, rightIndex, rightIndices); - - sum += (sbyte)(left[leftIndices] * right[rightIndices]); - } - - result[resultIndices] = sum; - } - } - public void Decrement(Tensor tensor, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices]--; - } - - } - public void Divide(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (sbyte)(left[indices] / right[indices]); - } - - } - public void Divide(Tensor tensor, sbyte scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (sbyte)(tensor[indices] / scalar); - } - - } - public void Equals(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] == right[indices]; - } - - } - public void GreaterThan(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] > right[indices]; - } - - } - public void GreaterThanOrEqual(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] >= right[indices]; - } - - } - public void Increment(Tensor tensor, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices]++; - } - - } - public void LeftShift(Tensor tensor, int value, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (sbyte)(tensor[indices] << value); - } - - } - public void LessThan(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] < right[indices]; - } - - } - public void LessThanOrEqual(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] <= right[indices]; - } - - } - public void Modulo(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (sbyte)(left[indices] % right[indices]); - } - - } - public void Modulo(Tensor tensor, sbyte scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (sbyte)(tensor[indices] % scalar); - } - - } - public void Multiply(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (sbyte)(left[indices] * right[indices]); - } - - } - public void Multiply(Tensor tensor, sbyte scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (sbyte)(tensor[indices] * scalar); - } - - } - public void NotEquals(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] != right[indices]; - } - - } - public void Or(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (sbyte)(left[indices] | right[indices]); - } - - } - public void Or(Tensor tensor, sbyte scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (sbyte)(tensor[indices] | scalar); - } - - } - public void RightShift(Tensor tensor, int value, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (sbyte)(tensor[indices] >> value); - } - - } - public void Subtract(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (sbyte)(left[indices] - right[indices]); - } - - } - public void Subtract(Tensor tensor, sbyte scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (sbyte)(tensor[indices] - scalar); - } - - } - public void UnaryMinus(Tensor tensor, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (sbyte)-tensor[indices]; - } - - } - public void UnaryPlus(Tensor tensor, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (sbyte)+tensor[indices]; - } - - } - public void Xor(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (sbyte)(left[indices] ^ right[indices]); - } - - } - public void Xor(Tensor tensor, sbyte scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (sbyte)(tensor[indices] ^ scalar); - } - - } - - public void Add(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (sbyte)(leftSpan[i] + rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (sbyte)(leftSpan[op1Index] + rightSpan[op2Index]); - - } - } - } - public void Add(DenseTensor tensor, sbyte scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (sbyte)(tensorSpan[i] + scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (sbyte)(tensorSpan[op1Index] + scalar); - - } - } - } - public void And(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (sbyte)(leftSpan[i] & rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (sbyte)(leftSpan[op1Index] & rightSpan[op2Index]); - - } - } - } - public void And(DenseTensor tensor, sbyte scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (sbyte)(tensorSpan[i] & scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (sbyte)(tensorSpan[op1Index] & scalar); - - } - } - } - public void Contract(DenseTensor left, DenseTensor right, int[] leftAxes, int[] rightAxes, DenseTensor result) - { - var summingDimensions = new int[leftAxes.Length]; - for (int i = 0; i < leftAxes.Length; i++) - { - summingDimensions[i] = left.dimensions[leftAxes[i]]; - } - - var summingStrides = ArrayUtilities.GetStrides(summingDimensions); - int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions); - - var resultStrides = result.strides; - - // translates from result index to left non-summing dimensions' index portion - // since left non-summing dimensions are given precedence in result, the end is zero-padded - int[] leftNonSummingStrides = new int[result.Rank]; - - // translates from summing index to left summing dimensions' index portion - int[] leftSummingStrides = new int[leftAxes.Length]; - ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0); - - // translates from result index to right non-summing dimensions' index portion - int[] rightNonSummingStrides = new int[result.Rank]; - // right non-summing dimensions appear after left non-summing dimensions. - int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length); - - // translates from summing index to right summing dimensions' index portion - int[] rightSummingStrides = new int[rightAxes.Length]; - ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0); - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - - for (int resultIndex = 0; resultIndex < resultSpan.Length; resultIndex++) - { - sbyte sum = (sbyte)0; - - int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides); - int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides); - - for (int summingIndex = 0; summingIndex < summingLength; summingIndex++) - { - int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides); - int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides); - - int leftIndex = leftIndexNonSumming + leftIndexSumming; - int rightIndex = rightIndexNonSumming + rightIndexSumming; - - sum += (sbyte)(leftSpan[leftIndex] * rightSpan[rightIndex]); - } - - resultSpan[resultIndex] = sum; - } - } - public void Decrement(DenseTensor tensor, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i]--; - } - } - public void Divide(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (sbyte)(leftSpan[i] / rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (sbyte)(leftSpan[op1Index] / rightSpan[op2Index]); - - } - } - } - public void Divide(DenseTensor tensor, sbyte scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (sbyte)(tensorSpan[i] / scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (sbyte)(tensorSpan[op1Index] / scalar); - - } - } - } - public void Equals(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] == rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] == rightSpan[op2Index]; - - } - } - } - public void GreaterThan(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] > rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] > rightSpan[op2Index]; - - } - } - } - public void GreaterThanOrEqual(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] >= rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] >= rightSpan[op2Index]; - - } - } - } - public void Increment(DenseTensor tensor, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i]++; - } - } - public void LeftShift(DenseTensor tensor, int value, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (sbyte)(tensorSpan[i] << value); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (sbyte)(tensorSpan[op1Index] << value); - - } - } - } - public void LessThan(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] < rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] < rightSpan[op2Index]; - - } - } - } - public void LessThanOrEqual(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] <= rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] <= rightSpan[op2Index]; - - } - } - } - public void Modulo(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (sbyte)(leftSpan[i] % rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (sbyte)(leftSpan[op1Index] % rightSpan[op2Index]); - - } - } - } - public void Modulo(DenseTensor tensor, sbyte scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (sbyte)(tensorSpan[i] % scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (sbyte)(tensorSpan[op1Index] % scalar); - - } - } - } - public void Multiply(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (sbyte)(leftSpan[i] * rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (sbyte)(leftSpan[op1Index] * rightSpan[op2Index]); - - } - } - } - public void Multiply(DenseTensor tensor, sbyte scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (sbyte)(tensorSpan[i] * scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (sbyte)(tensorSpan[op1Index] * scalar); - - } - } - } - public void NotEquals(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] != rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] != rightSpan[op2Index]; - - } - } - } - public void Or(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (sbyte)(leftSpan[i] | rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (sbyte)(leftSpan[op1Index] | rightSpan[op2Index]); - - } - } - } - public void Or(DenseTensor tensor, sbyte scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (sbyte)(tensorSpan[i] | scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (sbyte)(tensorSpan[op1Index] | scalar); - - } - } - } - public void RightShift(DenseTensor tensor, int value, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (sbyte)(tensorSpan[i] >> value); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (sbyte)(tensorSpan[op1Index] >> value); - - } - } - } - public void Subtract(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (sbyte)(leftSpan[i] - rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (sbyte)(leftSpan[op1Index] - rightSpan[op2Index]); - - } - } - } - public void Subtract(DenseTensor tensor, sbyte scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (sbyte)(tensorSpan[i] - scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (sbyte)(tensorSpan[op1Index] - scalar); - - } - } - } - public void UnaryMinus(DenseTensor tensor, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (sbyte)-tensorSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (sbyte)-tensorSpan[op1Index]; - - } - } - } - public void UnaryPlus(DenseTensor tensor, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (sbyte)+tensorSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (sbyte)+tensorSpan[op1Index]; - - } - } - } - public void Xor(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (sbyte)(leftSpan[i] ^ rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (sbyte)(leftSpan[op1Index] ^ rightSpan[op2Index]); - - } - } - } - public void Xor(DenseTensor tensor, sbyte scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (sbyte)(tensorSpan[i] ^ scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (sbyte)(tensorSpan[op1Index] ^ scalar); - - } - } - } - } - internal class ShortArithmetic : ITensorArithmetic - { - public short One => 1; - public short Zero => 0; - - public void Add(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (short)(left[indices] + right[indices]); - } - - } - public void Add(Tensor tensor, short scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (short)(tensor[indices] + scalar); - } - - } - public void And(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (short)(left[indices] & right[indices]); - } - - } - public void And(Tensor tensor, short scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (short)(tensor[indices] & scalar); - } - - } - public void Contract(Tensor left, Tensor right, int[] leftAxes, int[] rightAxes, Tensor result) - { - var leftIndices = new int[left.Rank]; - var rightIndices = new int[right.Rank]; - var resultIndices = new int[result.Rank]; - - var summingDimensions = new int[leftAxes.Length]; - for (int i = 0; i < leftAxes.Length; i++) - { - summingDimensions[i] = left.dimensions[leftAxes[i]]; - } - - var summingStrides = ArrayUtilities.GetStrides(summingDimensions); - int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions); - - var resultStrides = result.strides; - - // translates from result index to left non-summing dimensions' index portion - // since left non-summing dimensions are given precedence in result, the end is zero-padded - int[] leftNonSummingStrides = new int[result.Rank]; - - // translates from summing index to left summing dimensions' index portion - int[] leftSummingStrides = new int[leftAxes.Length]; - ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0); - - // translates from result index to right non-summing dimensions' index portion - int[] rightNonSummingStrides = new int[result.Rank]; - // right non-summing dimensions appear after left non-summing dimensions. - int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length); - - // translates from summing index to right summing dimensions' index portion - int[] rightSummingStrides = new int[rightAxes.Length]; - ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0); - - for (int resultIndex = 0; resultIndex < result.Length; resultIndex++) - { - short sum = (short)0; - - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, resultIndex, resultIndices); - - int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides); - int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides); - - for (int summingIndex = 0; summingIndex < summingLength; summingIndex++) - { - int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides); - int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides); - - int leftIndex = leftIndexNonSumming + leftIndexSumming; - int rightIndex = rightIndexNonSumming + rightIndexSumming; - - // todo, make this more efficient - ArrayUtilities.GetIndices(left.strides, left.IsReversedStride, leftIndex, leftIndices); - ArrayUtilities.GetIndices(right.strides, right.IsReversedStride, rightIndex, rightIndices); - - sum += (short)(left[leftIndices] * right[rightIndices]); - } - - result[resultIndices] = sum; - } - } - public void Decrement(Tensor tensor, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices]--; - } - - } - public void Divide(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (short)(left[indices] / right[indices]); - } - - } - public void Divide(Tensor tensor, short scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (short)(tensor[indices] / scalar); - } - - } - public void Equals(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] == right[indices]; - } - - } - public void GreaterThan(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] > right[indices]; - } - - } - public void GreaterThanOrEqual(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] >= right[indices]; - } - - } - public void Increment(Tensor tensor, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices]++; - } - - } - public void LeftShift(Tensor tensor, int value, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (short)(tensor[indices] << value); - } - - } - public void LessThan(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] < right[indices]; - } - - } - public void LessThanOrEqual(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] <= right[indices]; - } - - } - public void Modulo(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (short)(left[indices] % right[indices]); - } - - } - public void Modulo(Tensor tensor, short scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (short)(tensor[indices] % scalar); - } - - } - public void Multiply(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (short)(left[indices] * right[indices]); - } - - } - public void Multiply(Tensor tensor, short scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (short)(tensor[indices] * scalar); - } - - } - public void NotEquals(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] != right[indices]; - } - - } - public void Or(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (short)(left[indices] | right[indices]); - } - - } - public void Or(Tensor tensor, short scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (short)(tensor[indices] | scalar); - } - - } - public void RightShift(Tensor tensor, int value, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (short)(tensor[indices] >> value); - } - - } - public void Subtract(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (short)(left[indices] - right[indices]); - } - - } - public void Subtract(Tensor tensor, short scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (short)(tensor[indices] - scalar); - } - - } - public void UnaryMinus(Tensor tensor, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (short)-tensor[indices]; - } - - } - public void UnaryPlus(Tensor tensor, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (short)+tensor[indices]; - } - - } - public void Xor(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (short)(left[indices] ^ right[indices]); - } - - } - public void Xor(Tensor tensor, short scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (short)(tensor[indices] ^ scalar); - } - - } - - public void Add(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (short)(leftSpan[i] + rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (short)(leftSpan[op1Index] + rightSpan[op2Index]); - - } - } - } - public void Add(DenseTensor tensor, short scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (short)(tensorSpan[i] + scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (short)(tensorSpan[op1Index] + scalar); - - } - } - } - public void And(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (short)(leftSpan[i] & rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (short)(leftSpan[op1Index] & rightSpan[op2Index]); - - } - } - } - public void And(DenseTensor tensor, short scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (short)(tensorSpan[i] & scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (short)(tensorSpan[op1Index] & scalar); - - } - } - } - public void Contract(DenseTensor left, DenseTensor right, int[] leftAxes, int[] rightAxes, DenseTensor result) - { - var summingDimensions = new int[leftAxes.Length]; - for (int i = 0; i < leftAxes.Length; i++) - { - summingDimensions[i] = left.dimensions[leftAxes[i]]; - } - - var summingStrides = ArrayUtilities.GetStrides(summingDimensions); - int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions); - - var resultStrides = result.strides; - - // translates from result index to left non-summing dimensions' index portion - // since left non-summing dimensions are given precedence in result, the end is zero-padded - int[] leftNonSummingStrides = new int[result.Rank]; - - // translates from summing index to left summing dimensions' index portion - int[] leftSummingStrides = new int[leftAxes.Length]; - ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0); - - // translates from result index to right non-summing dimensions' index portion - int[] rightNonSummingStrides = new int[result.Rank]; - // right non-summing dimensions appear after left non-summing dimensions. - int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length); - - // translates from summing index to right summing dimensions' index portion - int[] rightSummingStrides = new int[rightAxes.Length]; - ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0); - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - - for (int resultIndex = 0; resultIndex < resultSpan.Length; resultIndex++) - { - short sum = (short)0; - - int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides); - int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides); - - for (int summingIndex = 0; summingIndex < summingLength; summingIndex++) - { - int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides); - int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides); - - int leftIndex = leftIndexNonSumming + leftIndexSumming; - int rightIndex = rightIndexNonSumming + rightIndexSumming; - - sum += (short)(leftSpan[leftIndex] * rightSpan[rightIndex]); - } - - resultSpan[resultIndex] = sum; - } - } - public void Decrement(DenseTensor tensor, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i]--; - } - } - public void Divide(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (short)(leftSpan[i] / rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (short)(leftSpan[op1Index] / rightSpan[op2Index]); - - } - } - } - public void Divide(DenseTensor tensor, short scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (short)(tensorSpan[i] / scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (short)(tensorSpan[op1Index] / scalar); - - } - } - } - public void Equals(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] == rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] == rightSpan[op2Index]; - - } - } - } - public void GreaterThan(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] > rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] > rightSpan[op2Index]; - - } - } - } - public void GreaterThanOrEqual(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] >= rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] >= rightSpan[op2Index]; - - } - } - } - public void Increment(DenseTensor tensor, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i]++; - } - } - public void LeftShift(DenseTensor tensor, int value, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (short)(tensorSpan[i] << value); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (short)(tensorSpan[op1Index] << value); - - } - } - } - public void LessThan(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] < rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] < rightSpan[op2Index]; - - } - } - } - public void LessThanOrEqual(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] <= rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] <= rightSpan[op2Index]; - - } - } - } - public void Modulo(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (short)(leftSpan[i] % rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (short)(leftSpan[op1Index] % rightSpan[op2Index]); - - } - } - } - public void Modulo(DenseTensor tensor, short scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (short)(tensorSpan[i] % scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (short)(tensorSpan[op1Index] % scalar); - - } - } - } - public void Multiply(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (short)(leftSpan[i] * rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (short)(leftSpan[op1Index] * rightSpan[op2Index]); - - } - } - } - public void Multiply(DenseTensor tensor, short scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (short)(tensorSpan[i] * scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (short)(tensorSpan[op1Index] * scalar); - - } - } - } - public void NotEquals(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] != rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] != rightSpan[op2Index]; - - } - } - } - public void Or(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (short)(leftSpan[i] | rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (short)(leftSpan[op1Index] | rightSpan[op2Index]); - - } - } - } - public void Or(DenseTensor tensor, short scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (short)(tensorSpan[i] | scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (short)(tensorSpan[op1Index] | scalar); - - } - } - } - public void RightShift(DenseTensor tensor, int value, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (short)(tensorSpan[i] >> value); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (short)(tensorSpan[op1Index] >> value); - - } - } - } - public void Subtract(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (short)(leftSpan[i] - rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (short)(leftSpan[op1Index] - rightSpan[op2Index]); - - } - } - } - public void Subtract(DenseTensor tensor, short scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (short)(tensorSpan[i] - scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (short)(tensorSpan[op1Index] - scalar); - - } - } - } - public void UnaryMinus(DenseTensor tensor, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (short)-tensorSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (short)-tensorSpan[op1Index]; - - } - } - } - public void UnaryPlus(DenseTensor tensor, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (short)+tensorSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (short)+tensorSpan[op1Index]; - - } - } - } - public void Xor(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (short)(leftSpan[i] ^ rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (short)(leftSpan[op1Index] ^ rightSpan[op2Index]); - - } - } - } - public void Xor(DenseTensor tensor, short scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (short)(tensorSpan[i] ^ scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (short)(tensorSpan[op1Index] ^ scalar); - - } - } - } - } - internal class UIntArithmetic : ITensorArithmetic - { - public uint One => 1; - public uint Zero => 0; - - public void Add(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (uint)(left[indices] + right[indices]); - } - - } - public void Add(Tensor tensor, uint scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (uint)(tensor[indices] + scalar); - } - - } - public void And(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (uint)(left[indices] & right[indices]); - } - - } - public void And(Tensor tensor, uint scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (uint)(tensor[indices] & scalar); - } - - } - public void Contract(Tensor left, Tensor right, int[] leftAxes, int[] rightAxes, Tensor result) - { - var leftIndices = new int[left.Rank]; - var rightIndices = new int[right.Rank]; - var resultIndices = new int[result.Rank]; - - var summingDimensions = new int[leftAxes.Length]; - for (int i = 0; i < leftAxes.Length; i++) - { - summingDimensions[i] = left.dimensions[leftAxes[i]]; - } - - var summingStrides = ArrayUtilities.GetStrides(summingDimensions); - int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions); - - var resultStrides = result.strides; - - // translates from result index to left non-summing dimensions' index portion - // since left non-summing dimensions are given precedence in result, the end is zero-padded - int[] leftNonSummingStrides = new int[result.Rank]; - - // translates from summing index to left summing dimensions' index portion - int[] leftSummingStrides = new int[leftAxes.Length]; - ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0); - - // translates from result index to right non-summing dimensions' index portion - int[] rightNonSummingStrides = new int[result.Rank]; - // right non-summing dimensions appear after left non-summing dimensions. - int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length); - - // translates from summing index to right summing dimensions' index portion - int[] rightSummingStrides = new int[rightAxes.Length]; - ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0); - - for (int resultIndex = 0; resultIndex < result.Length; resultIndex++) - { - uint sum = (uint)0; - - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, resultIndex, resultIndices); - - int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides); - int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides); - - for (int summingIndex = 0; summingIndex < summingLength; summingIndex++) - { - int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides); - int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides); - - int leftIndex = leftIndexNonSumming + leftIndexSumming; - int rightIndex = rightIndexNonSumming + rightIndexSumming; - - // todo, make this more efficient - ArrayUtilities.GetIndices(left.strides, left.IsReversedStride, leftIndex, leftIndices); - ArrayUtilities.GetIndices(right.strides, right.IsReversedStride, rightIndex, rightIndices); - - sum += (uint)(left[leftIndices] * right[rightIndices]); - } - - result[resultIndices] = sum; - } - } - public void Decrement(Tensor tensor, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices]--; - } - - } - public void Divide(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (uint)(left[indices] / right[indices]); - } - - } - public void Divide(Tensor tensor, uint scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (uint)(tensor[indices] / scalar); - } - - } - public void Equals(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] == right[indices]; - } - - } - public void GreaterThan(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] > right[indices]; - } - - } - public void GreaterThanOrEqual(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] >= right[indices]; - } - - } - public void Increment(Tensor tensor, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices]++; - } - - } - public void LeftShift(Tensor tensor, int value, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (uint)(tensor[indices] << value); - } - - } - public void LessThan(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] < right[indices]; - } - - } - public void LessThanOrEqual(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] <= right[indices]; - } - - } - public void Modulo(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (uint)(left[indices] % right[indices]); - } - - } - public void Modulo(Tensor tensor, uint scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (uint)(tensor[indices] % scalar); - } - - } - public void Multiply(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (uint)(left[indices] * right[indices]); - } - - } - public void Multiply(Tensor tensor, uint scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (uint)(tensor[indices] * scalar); - } - - } - public void NotEquals(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] != right[indices]; - } - - } - public void Or(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (uint)(left[indices] | right[indices]); - } - - } - public void Or(Tensor tensor, uint scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (uint)(tensor[indices] | scalar); - } - - } - public void RightShift(Tensor tensor, int value, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (uint)(tensor[indices] >> value); - } - - } - public void Subtract(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (uint)(left[indices] - right[indices]); - } - - } - public void Subtract(Tensor tensor, uint scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (uint)(tensor[indices] - scalar); - } - - } - public void UnaryMinus(Tensor tensor, Tensor result) - { - throw new NotSupportedException(); - } - public void UnaryPlus(Tensor tensor, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (uint)+tensor[indices]; - } - - } - public void Xor(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (uint)(left[indices] ^ right[indices]); - } - - } - public void Xor(Tensor tensor, uint scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (uint)(tensor[indices] ^ scalar); - } - - } - - public void Add(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (uint)(leftSpan[i] + rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (uint)(leftSpan[op1Index] + rightSpan[op2Index]); - - } - } - } - public void Add(DenseTensor tensor, uint scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (uint)(tensorSpan[i] + scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (uint)(tensorSpan[op1Index] + scalar); - - } - } - } - public void And(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (uint)(leftSpan[i] & rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (uint)(leftSpan[op1Index] & rightSpan[op2Index]); - - } - } - } - public void And(DenseTensor tensor, uint scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (uint)(tensorSpan[i] & scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (uint)(tensorSpan[op1Index] & scalar); - - } - } - } - public void Contract(DenseTensor left, DenseTensor right, int[] leftAxes, int[] rightAxes, DenseTensor result) - { - var summingDimensions = new int[leftAxes.Length]; - for (int i = 0; i < leftAxes.Length; i++) - { - summingDimensions[i] = left.dimensions[leftAxes[i]]; - } - - var summingStrides = ArrayUtilities.GetStrides(summingDimensions); - int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions); - - var resultStrides = result.strides; - - // translates from result index to left non-summing dimensions' index portion - // since left non-summing dimensions are given precedence in result, the end is zero-padded - int[] leftNonSummingStrides = new int[result.Rank]; - - // translates from summing index to left summing dimensions' index portion - int[] leftSummingStrides = new int[leftAxes.Length]; - ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0); - - // translates from result index to right non-summing dimensions' index portion - int[] rightNonSummingStrides = new int[result.Rank]; - // right non-summing dimensions appear after left non-summing dimensions. - int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length); - - // translates from summing index to right summing dimensions' index portion - int[] rightSummingStrides = new int[rightAxes.Length]; - ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0); - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - - for (int resultIndex = 0; resultIndex < resultSpan.Length; resultIndex++) - { - uint sum = (uint)0; - - int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides); - int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides); - - for (int summingIndex = 0; summingIndex < summingLength; summingIndex++) - { - int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides); - int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides); - - int leftIndex = leftIndexNonSumming + leftIndexSumming; - int rightIndex = rightIndexNonSumming + rightIndexSumming; - - sum += (uint)(leftSpan[leftIndex] * rightSpan[rightIndex]); - } - - resultSpan[resultIndex] = sum; - } - } - public void Decrement(DenseTensor tensor, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i]--; - } - } - public void Divide(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (uint)(leftSpan[i] / rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (uint)(leftSpan[op1Index] / rightSpan[op2Index]); - - } - } - } - public void Divide(DenseTensor tensor, uint scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (uint)(tensorSpan[i] / scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (uint)(tensorSpan[op1Index] / scalar); - - } - } - } - public void Equals(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] == rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] == rightSpan[op2Index]; - - } - } - } - public void GreaterThan(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] > rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] > rightSpan[op2Index]; - - } - } - } - public void GreaterThanOrEqual(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] >= rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] >= rightSpan[op2Index]; - - } - } - } - public void Increment(DenseTensor tensor, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i]++; - } - } - public void LeftShift(DenseTensor tensor, int value, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (uint)(tensorSpan[i] << value); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (uint)(tensorSpan[op1Index] << value); - - } - } - } - public void LessThan(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] < rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] < rightSpan[op2Index]; - - } - } - } - public void LessThanOrEqual(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] <= rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] <= rightSpan[op2Index]; - - } - } - } - public void Modulo(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (uint)(leftSpan[i] % rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (uint)(leftSpan[op1Index] % rightSpan[op2Index]); - - } - } - } - public void Modulo(DenseTensor tensor, uint scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (uint)(tensorSpan[i] % scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (uint)(tensorSpan[op1Index] % scalar); - - } - } - } - public void Multiply(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (uint)(leftSpan[i] * rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (uint)(leftSpan[op1Index] * rightSpan[op2Index]); - - } - } - } - public void Multiply(DenseTensor tensor, uint scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (uint)(tensorSpan[i] * scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (uint)(tensorSpan[op1Index] * scalar); - - } - } - } - public void NotEquals(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] != rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] != rightSpan[op2Index]; - - } - } - } - public void Or(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (uint)(leftSpan[i] | rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (uint)(leftSpan[op1Index] | rightSpan[op2Index]); - - } - } - } - public void Or(DenseTensor tensor, uint scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (uint)(tensorSpan[i] | scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (uint)(tensorSpan[op1Index] | scalar); - - } - } - } - public void RightShift(DenseTensor tensor, int value, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (uint)(tensorSpan[i] >> value); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (uint)(tensorSpan[op1Index] >> value); - - } - } - } - public void Subtract(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (uint)(leftSpan[i] - rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (uint)(leftSpan[op1Index] - rightSpan[op2Index]); - - } - } - } - public void Subtract(DenseTensor tensor, uint scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (uint)(tensorSpan[i] - scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (uint)(tensorSpan[op1Index] - scalar); - - } - } - } - public void UnaryMinus(DenseTensor tensor, DenseTensor result) - { - throw new NotSupportedException(); - } - public void UnaryPlus(DenseTensor tensor, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (uint)+tensorSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (uint)+tensorSpan[op1Index]; - - } - } - } - public void Xor(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (uint)(leftSpan[i] ^ rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (uint)(leftSpan[op1Index] ^ rightSpan[op2Index]); - - } - } - } - public void Xor(DenseTensor tensor, uint scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (uint)(tensorSpan[i] ^ scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (uint)(tensorSpan[op1Index] ^ scalar); - - } - } - } - } - internal class ULongArithmetic : ITensorArithmetic - { - public ulong One => 1; - public ulong Zero => 0; - - public void Add(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (ulong)(left[indices] + right[indices]); - } - - } - public void Add(Tensor tensor, ulong scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (ulong)(tensor[indices] + scalar); - } - - } - public void And(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (ulong)(left[indices] & right[indices]); - } - - } - public void And(Tensor tensor, ulong scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (ulong)(tensor[indices] & scalar); - } - - } - public void Contract(Tensor left, Tensor right, int[] leftAxes, int[] rightAxes, Tensor result) - { - var leftIndices = new int[left.Rank]; - var rightIndices = new int[right.Rank]; - var resultIndices = new int[result.Rank]; - - var summingDimensions = new int[leftAxes.Length]; - for (int i = 0; i < leftAxes.Length; i++) - { - summingDimensions[i] = left.dimensions[leftAxes[i]]; - } - - var summingStrides = ArrayUtilities.GetStrides(summingDimensions); - int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions); - - var resultStrides = result.strides; - - // translates from result index to left non-summing dimensions' index portion - // since left non-summing dimensions are given precedence in result, the end is zero-padded - int[] leftNonSummingStrides = new int[result.Rank]; - - // translates from summing index to left summing dimensions' index portion - int[] leftSummingStrides = new int[leftAxes.Length]; - ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0); - - // translates from result index to right non-summing dimensions' index portion - int[] rightNonSummingStrides = new int[result.Rank]; - // right non-summing dimensions appear after left non-summing dimensions. - int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length); - - // translates from summing index to right summing dimensions' index portion - int[] rightSummingStrides = new int[rightAxes.Length]; - ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0); - - for (int resultIndex = 0; resultIndex < result.Length; resultIndex++) - { - ulong sum = (ulong)0; - - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, resultIndex, resultIndices); - - int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides); - int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides); - - for (int summingIndex = 0; summingIndex < summingLength; summingIndex++) - { - int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides); - int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides); - - int leftIndex = leftIndexNonSumming + leftIndexSumming; - int rightIndex = rightIndexNonSumming + rightIndexSumming; - - // todo, make this more efficient - ArrayUtilities.GetIndices(left.strides, left.IsReversedStride, leftIndex, leftIndices); - ArrayUtilities.GetIndices(right.strides, right.IsReversedStride, rightIndex, rightIndices); - - sum += (ulong)(left[leftIndices] * right[rightIndices]); - } - - result[resultIndices] = sum; - } - } - public void Decrement(Tensor tensor, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices]--; - } - - } - public void Divide(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (ulong)(left[indices] / right[indices]); - } - - } - public void Divide(Tensor tensor, ulong scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (ulong)(tensor[indices] / scalar); - } - - } - public void Equals(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] == right[indices]; - } - - } - public void GreaterThan(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] > right[indices]; - } - - } - public void GreaterThanOrEqual(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] >= right[indices]; - } - - } - public void Increment(Tensor tensor, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices]++; - } - - } - public void LeftShift(Tensor tensor, int value, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (ulong)(tensor[indices] << value); - } - - } - public void LessThan(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] < right[indices]; - } - - } - public void LessThanOrEqual(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] <= right[indices]; - } - - } - public void Modulo(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (ulong)(left[indices] % right[indices]); - } - - } - public void Modulo(Tensor tensor, ulong scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (ulong)(tensor[indices] % scalar); - } - - } - public void Multiply(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (ulong)(left[indices] * right[indices]); - } - - } - public void Multiply(Tensor tensor, ulong scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (ulong)(tensor[indices] * scalar); - } - - } - public void NotEquals(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] != right[indices]; - } - - } - public void Or(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (ulong)(left[indices] | right[indices]); - } - - } - public void Or(Tensor tensor, ulong scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (ulong)(tensor[indices] | scalar); - } - - } - public void RightShift(Tensor tensor, int value, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (ulong)(tensor[indices] >> value); - } - - } - public void Subtract(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (ulong)(left[indices] - right[indices]); - } - - } - public void Subtract(Tensor tensor, ulong scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (ulong)(tensor[indices] - scalar); - } - - } - public void UnaryMinus(Tensor tensor, Tensor result) - { - throw new NotSupportedException(); - } - public void UnaryPlus(Tensor tensor, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (ulong)+tensor[indices]; - } - - } - public void Xor(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (ulong)(left[indices] ^ right[indices]); - } - - } - public void Xor(Tensor tensor, ulong scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (ulong)(tensor[indices] ^ scalar); - } - - } - - public void Add(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (ulong)(leftSpan[i] + rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (ulong)(leftSpan[op1Index] + rightSpan[op2Index]); - - } - } - } - public void Add(DenseTensor tensor, ulong scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (ulong)(tensorSpan[i] + scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (ulong)(tensorSpan[op1Index] + scalar); - - } - } - } - public void And(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (ulong)(leftSpan[i] & rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (ulong)(leftSpan[op1Index] & rightSpan[op2Index]); - - } - } - } - public void And(DenseTensor tensor, ulong scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (ulong)(tensorSpan[i] & scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (ulong)(tensorSpan[op1Index] & scalar); - - } - } - } - public void Contract(DenseTensor left, DenseTensor right, int[] leftAxes, int[] rightAxes, DenseTensor result) - { - var summingDimensions = new int[leftAxes.Length]; - for (int i = 0; i < leftAxes.Length; i++) - { - summingDimensions[i] = left.dimensions[leftAxes[i]]; - } - - var summingStrides = ArrayUtilities.GetStrides(summingDimensions); - int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions); - - var resultStrides = result.strides; - - // translates from result index to left non-summing dimensions' index portion - // since left non-summing dimensions are given precedence in result, the end is zero-padded - int[] leftNonSummingStrides = new int[result.Rank]; - - // translates from summing index to left summing dimensions' index portion - int[] leftSummingStrides = new int[leftAxes.Length]; - ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0); - - // translates from result index to right non-summing dimensions' index portion - int[] rightNonSummingStrides = new int[result.Rank]; - // right non-summing dimensions appear after left non-summing dimensions. - int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length); - - // translates from summing index to right summing dimensions' index portion - int[] rightSummingStrides = new int[rightAxes.Length]; - ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0); - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - - for (int resultIndex = 0; resultIndex < resultSpan.Length; resultIndex++) - { - ulong sum = (ulong)0; - - int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides); - int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides); - - for (int summingIndex = 0; summingIndex < summingLength; summingIndex++) - { - int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides); - int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides); - - int leftIndex = leftIndexNonSumming + leftIndexSumming; - int rightIndex = rightIndexNonSumming + rightIndexSumming; - - sum += (ulong)(leftSpan[leftIndex] * rightSpan[rightIndex]); - } - - resultSpan[resultIndex] = sum; - } - } - public void Decrement(DenseTensor tensor, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i]--; - } - } - public void Divide(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (ulong)(leftSpan[i] / rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (ulong)(leftSpan[op1Index] / rightSpan[op2Index]); - - } - } - } - public void Divide(DenseTensor tensor, ulong scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (ulong)(tensorSpan[i] / scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (ulong)(tensorSpan[op1Index] / scalar); - - } - } - } - public void Equals(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] == rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] == rightSpan[op2Index]; - - } - } - } - public void GreaterThan(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] > rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] > rightSpan[op2Index]; - - } - } - } - public void GreaterThanOrEqual(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] >= rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] >= rightSpan[op2Index]; - - } - } - } - public void Increment(DenseTensor tensor, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i]++; - } - } - public void LeftShift(DenseTensor tensor, int value, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (ulong)(tensorSpan[i] << value); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (ulong)(tensorSpan[op1Index] << value); - - } - } - } - public void LessThan(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] < rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] < rightSpan[op2Index]; - - } - } - } - public void LessThanOrEqual(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] <= rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] <= rightSpan[op2Index]; - - } - } - } - public void Modulo(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (ulong)(leftSpan[i] % rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (ulong)(leftSpan[op1Index] % rightSpan[op2Index]); - - } - } - } - public void Modulo(DenseTensor tensor, ulong scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (ulong)(tensorSpan[i] % scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (ulong)(tensorSpan[op1Index] % scalar); - - } - } - } - public void Multiply(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (ulong)(leftSpan[i] * rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (ulong)(leftSpan[op1Index] * rightSpan[op2Index]); - - } - } - } - public void Multiply(DenseTensor tensor, ulong scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (ulong)(tensorSpan[i] * scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (ulong)(tensorSpan[op1Index] * scalar); - - } - } - } - public void NotEquals(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] != rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] != rightSpan[op2Index]; - - } - } - } - public void Or(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (ulong)(leftSpan[i] | rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (ulong)(leftSpan[op1Index] | rightSpan[op2Index]); - - } - } - } - public void Or(DenseTensor tensor, ulong scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (ulong)(tensorSpan[i] | scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (ulong)(tensorSpan[op1Index] | scalar); - - } - } - } - public void RightShift(DenseTensor tensor, int value, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (ulong)(tensorSpan[i] >> value); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (ulong)(tensorSpan[op1Index] >> value); - - } - } - } - public void Subtract(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (ulong)(leftSpan[i] - rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (ulong)(leftSpan[op1Index] - rightSpan[op2Index]); - - } - } - } - public void Subtract(DenseTensor tensor, ulong scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (ulong)(tensorSpan[i] - scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (ulong)(tensorSpan[op1Index] - scalar); - - } - } - } - public void UnaryMinus(DenseTensor tensor, DenseTensor result) - { - throw new NotSupportedException(); - } - public void UnaryPlus(DenseTensor tensor, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (ulong)+tensorSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (ulong)+tensorSpan[op1Index]; - - } - } - } - public void Xor(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (ulong)(leftSpan[i] ^ rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (ulong)(leftSpan[op1Index] ^ rightSpan[op2Index]); - - } - } - } - public void Xor(DenseTensor tensor, ulong scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (ulong)(tensorSpan[i] ^ scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (ulong)(tensorSpan[op1Index] ^ scalar); - - } - } - } - } - internal class UShortArithmetic : ITensorArithmetic - { - public ushort One => 1; - public ushort Zero => 0; - - public void Add(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (ushort)(left[indices] + right[indices]); - } - - } - public void Add(Tensor tensor, ushort scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (ushort)(tensor[indices] + scalar); - } - - } - public void And(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (ushort)(left[indices] & right[indices]); - } - - } - public void And(Tensor tensor, ushort scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (ushort)(tensor[indices] & scalar); - } - - } - public void Contract(Tensor left, Tensor right, int[] leftAxes, int[] rightAxes, Tensor result) - { - var leftIndices = new int[left.Rank]; - var rightIndices = new int[right.Rank]; - var resultIndices = new int[result.Rank]; - - var summingDimensions = new int[leftAxes.Length]; - for (int i = 0; i < leftAxes.Length; i++) - { - summingDimensions[i] = left.dimensions[leftAxes[i]]; - } - - var summingStrides = ArrayUtilities.GetStrides(summingDimensions); - int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions); - - var resultStrides = result.strides; - - // translates from result index to left non-summing dimensions' index portion - // since left non-summing dimensions are given precedence in result, the end is zero-padded - int[] leftNonSummingStrides = new int[result.Rank]; - - // translates from summing index to left summing dimensions' index portion - int[] leftSummingStrides = new int[leftAxes.Length]; - ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0); - - // translates from result index to right non-summing dimensions' index portion - int[] rightNonSummingStrides = new int[result.Rank]; - // right non-summing dimensions appear after left non-summing dimensions. - int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length); - - // translates from summing index to right summing dimensions' index portion - int[] rightSummingStrides = new int[rightAxes.Length]; - ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0); - - for (int resultIndex = 0; resultIndex < result.Length; resultIndex++) - { - ushort sum = (ushort)0; - - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, resultIndex, resultIndices); - - int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides); - int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides); - - for (int summingIndex = 0; summingIndex < summingLength; summingIndex++) - { - int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides); - int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides); - - int leftIndex = leftIndexNonSumming + leftIndexSumming; - int rightIndex = rightIndexNonSumming + rightIndexSumming; - - // todo, make this more efficient - ArrayUtilities.GetIndices(left.strides, left.IsReversedStride, leftIndex, leftIndices); - ArrayUtilities.GetIndices(right.strides, right.IsReversedStride, rightIndex, rightIndices); - - sum += (ushort)(left[leftIndices] * right[rightIndices]); - } - - result[resultIndices] = sum; - } - } - public void Decrement(Tensor tensor, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices]--; - } - - } - public void Divide(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (ushort)(left[indices] / right[indices]); - } - - } - public void Divide(Tensor tensor, ushort scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (ushort)(tensor[indices] / scalar); - } - - } - public void Equals(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] == right[indices]; - } - - } - public void GreaterThan(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] > right[indices]; - } - - } - public void GreaterThanOrEqual(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] >= right[indices]; - } - - } - public void Increment(Tensor tensor, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices]++; - } - - } - public void LeftShift(Tensor tensor, int value, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (ushort)(tensor[indices] << value); - } - - } - public void LessThan(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] < right[indices]; - } - - } - public void LessThanOrEqual(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] <= right[indices]; - } - - } - public void Modulo(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (ushort)(left[indices] % right[indices]); - } - - } - public void Modulo(Tensor tensor, ushort scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (ushort)(tensor[indices] % scalar); - } - - } - public void Multiply(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (ushort)(left[indices] * right[indices]); - } - - } - public void Multiply(Tensor tensor, ushort scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (ushort)(tensor[indices] * scalar); - } - - } - public void NotEquals(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = left[indices] != right[indices]; - } - - } - public void Or(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (ushort)(left[indices] | right[indices]); - } - - } - public void Or(Tensor tensor, ushort scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (ushort)(tensor[indices] | scalar); - } - - } - public void RightShift(Tensor tensor, int value, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (ushort)(tensor[indices] >> value); - } - - } - public void Subtract(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (ushort)(left[indices] - right[indices]); - } - - } - public void Subtract(Tensor tensor, ushort scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (ushort)(tensor[indices] - scalar); - } - - } - public void UnaryMinus(Tensor tensor, Tensor result) - { - throw new NotSupportedException(); - } - public void UnaryPlus(Tensor tensor, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (ushort)+tensor[indices]; - } - - } - public void Xor(Tensor left, Tensor right, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (ushort)(left[indices] ^ right[indices]); - } - - } - public void Xor(Tensor tensor, ushort scalar, Tensor result) - { - - Span indices = new Span(new int[result.Rank]); - for (int i = 0; i < result.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - result[indices] = (ushort)(tensor[indices] ^ scalar); - } - - } - - public void Add(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (ushort)(leftSpan[i] + rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (ushort)(leftSpan[op1Index] + rightSpan[op2Index]); - - } - } - } - public void Add(DenseTensor tensor, ushort scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (ushort)(tensorSpan[i] + scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (ushort)(tensorSpan[op1Index] + scalar); - - } - } - } - public void And(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (ushort)(leftSpan[i] & rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (ushort)(leftSpan[op1Index] & rightSpan[op2Index]); - - } - } - } - public void And(DenseTensor tensor, ushort scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (ushort)(tensorSpan[i] & scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (ushort)(tensorSpan[op1Index] & scalar); - - } - } - } - public void Contract(DenseTensor left, DenseTensor right, int[] leftAxes, int[] rightAxes, DenseTensor result) - { - var summingDimensions = new int[leftAxes.Length]; - for (int i = 0; i < leftAxes.Length; i++) - { - summingDimensions[i] = left.dimensions[leftAxes[i]]; - } - - var summingStrides = ArrayUtilities.GetStrides(summingDimensions); - int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions); - - var resultStrides = result.strides; - - // translates from result index to left non-summing dimensions' index portion - // since left non-summing dimensions are given precedence in result, the end is zero-padded - int[] leftNonSummingStrides = new int[result.Rank]; - - // translates from summing index to left summing dimensions' index portion - int[] leftSummingStrides = new int[leftAxes.Length]; - ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0); - - // translates from result index to right non-summing dimensions' index portion - int[] rightNonSummingStrides = new int[result.Rank]; - // right non-summing dimensions appear after left non-summing dimensions. - int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length); - - // translates from summing index to right summing dimensions' index portion - int[] rightSummingStrides = new int[rightAxes.Length]; - ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0); - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - - for (int resultIndex = 0; resultIndex < resultSpan.Length; resultIndex++) - { - ushort sum = (ushort)0; - - int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides); - int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides); - - for (int summingIndex = 0; summingIndex < summingLength; summingIndex++) - { - int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides); - int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides); - - int leftIndex = leftIndexNonSumming + leftIndexSumming; - int rightIndex = rightIndexNonSumming + rightIndexSumming; - - sum += (ushort)(leftSpan[leftIndex] * rightSpan[rightIndex]); - } - - resultSpan[resultIndex] = sum; - } - } - public void Decrement(DenseTensor tensor, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i]--; - } - } - public void Divide(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (ushort)(leftSpan[i] / rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (ushort)(leftSpan[op1Index] / rightSpan[op2Index]); - - } - } - } - public void Divide(DenseTensor tensor, ushort scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (ushort)(tensorSpan[i] / scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (ushort)(tensorSpan[op1Index] / scalar); - - } - } - } - public void Equals(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] == rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] == rightSpan[op2Index]; - - } - } - } - public void GreaterThan(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] > rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] > rightSpan[op2Index]; - - } - } - } - public void GreaterThanOrEqual(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] >= rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] >= rightSpan[op2Index]; - - } - } - } - public void Increment(DenseTensor tensor, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i]++; - } - } - public void LeftShift(DenseTensor tensor, int value, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (ushort)(tensorSpan[i] << value); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (ushort)(tensorSpan[op1Index] << value); - - } - } - } - public void LessThan(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] < rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] < rightSpan[op2Index]; - - } - } - } - public void LessThanOrEqual(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] <= rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] <= rightSpan[op2Index]; - - } - } - } - public void Modulo(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (ushort)(leftSpan[i] % rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (ushort)(leftSpan[op1Index] % rightSpan[op2Index]); - - } - } - } - public void Modulo(DenseTensor tensor, ushort scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (ushort)(tensorSpan[i] % scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (ushort)(tensorSpan[op1Index] % scalar); - - } - } - } - public void Multiply(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (ushort)(leftSpan[i] * rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (ushort)(leftSpan[op1Index] * rightSpan[op2Index]); - - } - } - } - public void Multiply(DenseTensor tensor, ushort scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (ushort)(tensorSpan[i] * scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (ushort)(tensorSpan[op1Index] * scalar); - - } - } - } - public void NotEquals(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = leftSpan[i] != rightSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = leftSpan[op1Index] != rightSpan[op2Index]; - - } - } - } - public void Or(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (ushort)(leftSpan[i] | rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (ushort)(leftSpan[op1Index] | rightSpan[op2Index]); - - } - } - } - public void Or(DenseTensor tensor, ushort scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (ushort)(tensorSpan[i] | scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (ushort)(tensorSpan[op1Index] | scalar); - - } - } - } - public void RightShift(DenseTensor tensor, int value, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (ushort)(tensorSpan[i] >> value); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (ushort)(tensorSpan[op1Index] >> value); - - } - } - } - public void Subtract(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (ushort)(leftSpan[i] - rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (ushort)(leftSpan[op1Index] - rightSpan[op2Index]); - - } - } - } - public void Subtract(DenseTensor tensor, ushort scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (ushort)(tensorSpan[i] - scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (ushort)(tensorSpan[op1Index] - scalar); - - } - } - } - public void UnaryMinus(DenseTensor tensor, DenseTensor result) - { - throw new NotSupportedException(); - } - public void UnaryPlus(DenseTensor tensor, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (ushort)+tensorSpan[i]; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (ushort)+tensorSpan[op1Index]; - - } - } - } - public void Xor(DenseTensor left, DenseTensor right, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - if ((result.IsReversedStride == left.IsReversedStride) && (result.IsReversedStride == right.IsReversedStride)) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (ushort)(leftSpan[i] ^ rightSpan[i]); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref left.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - ref int op2Index = ref right.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - !left.IsReversedStride ? left.strides : - right.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - left.IsReversedStride ? left.strides : - right.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (ushort)(leftSpan[op1Index] ^ rightSpan[op2Index]); - - } - } - } - public void Xor(DenseTensor tensor, ushort scalar, DenseTensor result) - { - - var resultSpan = result.Buffer.Span; - var tensorSpan = tensor.Buffer.Span; - if (result.IsReversedStride == tensor.IsReversedStride) - { - for (int i = 0; i < resultSpan.Length; i++) - { - resultSpan[i] = (ushort)(tensorSpan[i] ^ scalar); - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref result.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref tensor.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !result.IsReversedStride ? result.strides : - tensor.strides; - var columnMajorStrides = result.IsReversedStride ? result.strides : - tensor.strides; - for (;rowMajorIndex < resultSpan.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - resultSpan[resultIndex] = (ushort)(tensorSpan[op1Index] ^ scalar); - - } - } - } - } -} diff --git a/src/libraries/System.Numerics.Tensors/tests/TensorArithmetic.tt b/src/libraries/System.Numerics.Tensors/tests/TensorArithmetic.tt deleted file mode 100644 index 91efa47c1ab768..00000000000000 --- a/src/libraries/System.Numerics.Tensors/tests/TensorArithmetic.tt +++ /dev/null @@ -1,237 +0,0 @@ -<#@ template debug="false" hostspecific="false" language="C#" #> -<#@ assembly name="System.Core" #> -<#@ output extension=".cs" #> -<#@ include file="TensorTemplate.ttinclude" #>// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -namespace System.Numerics.Tensors -{ - internal interface ITensorArithmetic - { - T One { get; } - T Zero { get; } -<# foreach (MethodConfiguration method in methodConfiguration) { #> - <#= method.GetResultMethodSignature("Tensor", "T")#>; -<# } #> - } - - internal static class TensorArithmetic - { - public static ITensorArithmetic Instance => TensorArithmetic.GetArithmetic(); - } - - internal static class TensorArithmetic - { - public static ITensorArithmetic GetArithmetic() - { -<# foreach (TypeConfiguration type in typeConfiguration) { #> - <#=GenerateIfStatementHeader(type)#> - { - return (ITensorArithmetic)new <#=type.ClassPrefix#>Arithmetic(); - } -<# } #> - return null; - } - } - -<# foreach (TypeConfiguration type in typeConfiguration) { #> - internal class <#=type.ClassPrefix#>Arithmetic : ITensorArithmetic<<#=type.TypeName#>> - { - public <#=type.TypeName#> One => <#=type.OneLiteral#>; - public <#=type.TypeName#> Zero => <#=type.ZeroLiteral#>; - -<# foreach (MethodConfiguration method in methodConfiguration) { #> - public <#= method.GetResultMethodSignature("Tensor", type.TypeName)#> - { -<# if ((method.IsNumeric && !type.SupportsNumeric) || (method.IsBitwise && !type.SupportsBitwise) || (type.UnsupportedMethods.Contains(method.MethodName))) { #> - throw new NotSupportedException(); -<# } else if (method.Operator != null) { #> - - Span indices = new Span(new int[result.Rank]); - for(int i = 0; i < <#= method.ResultName #>.Length; i++) - { - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); - <#=method.GetElementOperation(type.TypeName, "[indices]")#>; - } - -<# } else if (method.MethodName == "Contract") {#> - var leftIndices = new int[left.Rank]; - var rightIndices = new int[right.Rank]; - var resultIndices = new int[result.Rank]; - - var summingDimensions = new int[leftAxes.Length]; - for(int i = 0; i < leftAxes.Length; i++) - { - summingDimensions[i] = left.dimensions[leftAxes[i]]; - } - - var summingStrides = ArrayUtilities.GetStrides(summingDimensions); - int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions); - - var resultStrides = result.strides; - - // translates from result index to left non-summing dimensions' index portion - // since left non-summing dimensions are given precedence in result, the end is zero-padded - int[] leftNonSummingStrides = new int[result.Rank]; - - // translates from summing index to left summing dimensions' index portion - int[] leftSummingStrides = new int[leftAxes.Length]; - ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0); - - // translates from result index to right non-summing dimensions' index portion - int[] rightNonSummingStrides = new int[result.Rank]; - // right non-summing dimensions appear after left non-summing dimensions. - int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length); - - // translates from summing index to right summing dimensions' index portion - int[] rightSummingStrides = new int[rightAxes.Length]; - ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0); - - for (int resultIndex = 0; resultIndex < result.Length; resultIndex++) - { - <#=type.TypeName#> sum = (<#=type.TypeName#>)0; - - ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, resultIndex, resultIndices); - - int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides); - int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides); - - for (int summingIndex = 0; summingIndex < summingLength; summingIndex++) - { - int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides); - int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides); - - int leftIndex = leftIndexNonSumming + leftIndexSumming; - int rightIndex = rightIndexNonSumming + rightIndexSumming; - - // todo, make this more efficient - ArrayUtilities.GetIndices(left.strides, left.IsReversedStride, leftIndex, leftIndices); - ArrayUtilities.GetIndices(right.strides, right.IsReversedStride, rightIndex, rightIndices); - - sum += (<#=type.TypeName#>)(left[leftIndices] * right[rightIndices]); - } - - result[resultIndices] = sum; - } -<# } #> - } -<# } #> - -<# foreach (MethodConfiguration method in methodConfiguration) { #> - public <#= method.GetResultMethodSignature("DenseTensor", type.TypeName)#> - { -<# if ((method.IsNumeric && !type.SupportsNumeric) || (method.IsBitwise && !type.SupportsBitwise) || (type.UnsupportedMethods.Contains(method.MethodName))) { #> - throw new NotSupportedException(); -<# } else if (method.Operator != null) { #> - -<# if (method.MethodType == MethodType.UnaryInPlace) { #> - var <#=method.ResultName #>Span = <#=method.ResultName #>.Buffer.Span; - var <#=method.Op1Name #>Span = <#=method.Op1Name #>.Buffer.Span; - for(int i = 0; i < <#=method.ResultName #>Span.Length; i++) - { - <#=method.GetElementOperation(type.TypeName, "Span[i]")#>; - } -<# } else {#> - var <#=method.ResultName #>Span = <#=method.ResultName #>.Buffer.Span; - var <#=method.Op1Name #>Span = <#=method.Op1Name #>.Buffer.Span; -<# if ((method.MethodType == MethodType.Binary) || (method.MethodType == MethodType.Comparison)) {#> - var <#=method.Op2Name #>Span = <#=method.Op2Name #>.Buffer.Span; -<# } #> - if <#= method.GetLinearOperationCheck() #> - { - for(int i = 0; i < <#= method.ResultName #>Span.Length; i++) - { - <#=method.GetElementOperation(type.TypeName, "Span[i]")#>; - } - } - else - { - int rowMajorIndex = 0; - int colMajorIndex = 0; - - ref int resultIndex = ref <#= method.ResultName #>.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - ref int op1Index = ref <#= method.Op1Name #>.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - -<# if ((method.MethodType == MethodType.Binary) || (method.MethodType == MethodType.Comparison)) {#> - ref int op2Index = ref <#= method.Op2Name #>.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; - - var rowMajorStrides = !<#= method.ResultName #>.IsReversedStride ? <#= method.ResultName #>.strides : - !<#= method.Op1Name #>.IsReversedStride ? <#= method.Op1Name #>.strides : - <#= method.Op2Name #>.strides; - var columnMajorStrides = <#= method.ResultName #>.IsReversedStride ? <#= method.ResultName #>.strides : - <#= method.Op1Name #>.IsReversedStride ? <#= method.Op1Name #>.strides : - <#= method.Op2Name #>.strides; -<# } else {#> - var rowMajorStrides = !<#= method.ResultName #>.IsReversedStride ? <#= method.ResultName #>.strides : - <#= method.Op1Name #>.strides; - var columnMajorStrides = <#= method.ResultName #>.IsReversedStride ? <#= method.ResultName #>.strides : - <#= method.Op1Name #>.strides; -<# } #> - for(;rowMajorIndex < <#= method.ResultName #>Span.Length; rowMajorIndex++) - { - colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); - - <#=method.GetElementOperation(type.TypeName, "Span[resultIndex]", "Span[op1Index]", "Span[op2Index]")#>; - - } - } -<# } #> -<# } else if (method.MethodName == "Contract") {#> - var summingDimensions = new int[leftAxes.Length]; - for(int i = 0; i < leftAxes.Length; i++) - { - summingDimensions[i] = left.dimensions[leftAxes[i]]; - } - - var summingStrides = ArrayUtilities.GetStrides(summingDimensions); - int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions); - - var resultStrides = result.strides; - - // translates from result index to left non-summing dimensions' index portion - // since left non-summing dimensions are given precedence in result, the end is zero-padded - int[] leftNonSummingStrides = new int[result.Rank]; - - // translates from summing index to left summing dimensions' index portion - int[] leftSummingStrides = new int[leftAxes.Length]; - ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0); - - // translates from result index to right non-summing dimensions' index portion - int[] rightNonSummingStrides = new int[result.Rank]; - // right non-summing dimensions appear after left non-summing dimensions. - int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length); - - // translates from summing index to right summing dimensions' index portion - int[] rightSummingStrides = new int[rightAxes.Length]; - ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0); - - var resultSpan = result.Buffer.Span; - var leftSpan = left.Buffer.Span; - var rightSpan = right.Buffer.Span; - - for (int resultIndex = 0; resultIndex < resultSpan.Length; resultIndex++) - { - <#=type.TypeName#> sum = (<#=type.TypeName#>)0; - - int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides); - int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides); - - for (int summingIndex = 0; summingIndex < summingLength; summingIndex++) - { - int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides); - int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides); - - int leftIndex = leftIndexNonSumming + leftIndexSumming; - int rightIndex = rightIndexNonSumming + rightIndexSumming; - - sum += (<#=type.TypeName#>)(leftSpan[leftIndex] * rightSpan[rightIndex]); - } - - resultSpan[resultIndex] = sum; - } -<# } #> - } -<# } #> - } -<# } #> -} diff --git a/src/libraries/System.Numerics.Tensors/tests/TensorExtensions.cs b/src/libraries/System.Numerics.Tensors/tests/TensorExtensions.cs deleted file mode 100644 index 2aa79f0e120589..00000000000000 --- a/src/libraries/System.Numerics.Tensors/tests/TensorExtensions.cs +++ /dev/null @@ -1,31 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -namespace System.Numerics.Tensors -{ - public static partial class TensorExtensions - { - private static int[] s_zeroArray = new[] { 0 }; - private static int[] s_oneArray = new[] { 1 }; - - internal static Tensor MatrixMultiply(this Tensor left, Tensor right) - { - if (left.Rank != 2) - { - throw new InvalidOperationException($"{nameof(MatrixMultiply)} is only valid for a {nameof(Tensor)} of {nameof(left.Rank)} 2."); - } - - if (right.Rank != 2) - { - throw new ArgumentException($"{nameof(Tensor)} {nameof(right)} must have {nameof(left.Rank)} 2.", nameof(right)); - } - - if (left.dimensions[1] != right.dimensions[0]) - { - throw new ArgumentException($"{nameof(Tensor)} {nameof(right)} must have first dimension of {left.dimensions[1]}.", nameof(right)); - } - - return TensorOperations.Contract(left, right, s_oneArray, s_zeroArray); - } - } -} diff --git a/src/libraries/System.Numerics.Tensors/tests/TensorOperations.cs b/src/libraries/System.Numerics.Tensors/tests/TensorOperations.cs deleted file mode 100644 index 009ad006b88c53..00000000000000 --- a/src/libraries/System.Numerics.Tensors/tests/TensorOperations.cs +++ /dev/null @@ -1,738 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -namespace System.Numerics.Tensors -{ - public static partial class TensorOperations - { - internal static void ValidateBinaryArgs(Tensor left, Tensor right) - { - if (left.Rank != right.Rank || left.Length != right.Length) - { - throw new ArgumentException("Operands must have matching dimensions", nameof(right)); - } - - if (left.Rank == 0) - { - throw new ArgumentException($"Cannot operate on Tensor with {nameof(Tensor.Rank)} of 0.", nameof(left)); - } - - for (int i = 0; i < left.Rank; i++) - { - if (left.dimensions[i] != right.dimensions[i]) - { - throw new ArgumentException("Operands must have matching dimensions", nameof(right)); - } - } - } - - internal static void ValidateBinaryArgs(Tensor left, Tensor right, Tensor result) - { - if (left.Rank != right.Rank || left.Length != right.Length) - { - throw new ArgumentException("Operands must have matching dimensions", nameof(right)); - } - - if (left.Rank != result.Rank || left.Length != result.Length) - { - throw new ArgumentException("Operands must have matching dimensions", nameof(result)); - } - - if (left.Rank == 0) - { - throw new ArgumentException($"Cannot operate on Tensor with {nameof(Tensor.Rank)} of 0.", nameof(left)); - } - - for (int i = 0; i < result.Rank; i++) - { - if (left.dimensions[i] != right.dimensions[i]) - { - throw new ArgumentException("Operands must have matching dimensions", nameof(right)); - } - - if (left.dimensions[i] != result.dimensions[i]) - { - throw new ArgumentException("Operands and result must have matching dimensions", nameof(result)); - } - } - } - - internal static void ValidateBinaryArgs(Tensor left, Tensor right, Tensor result) - { - if (left.Rank != right.Rank || left.Length != right.Length) - { - throw new ArgumentException("Operands must have matching dimensions", nameof(right)); - } - - if (left.Rank != result.Rank || left.Length != result.Length) - { - throw new ArgumentException("Operands must have matching dimensions", nameof(result)); - } - - if (left.Rank == 0) - { - throw new ArgumentException($"Cannot operate on Tensor with {nameof(Tensor.Rank)} of 0.", nameof(left)); - } - - for (int i = 0; i < result.Rank; i++) - { - if (left.dimensions[i] != right.dimensions[i]) - { - throw new ArgumentException("Operands must have matching dimensions", nameof(right)); - } - - if (left.dimensions[i] != result.dimensions[i]) - { - throw new ArgumentException("Operands and result must have matching dimensions", nameof(result)); - } - } - } - - internal static void ValidateArgs(Tensor tensor) - { - if (tensor.Rank == 0) - { - throw new ArgumentException($"Cannot operate on Tensor with {nameof(Tensor.Rank)} of 0.", nameof(tensor)); - } - } - - internal static void ValidateArgs(Tensor tensor, Tensor result) - { - if (tensor.Rank != result.Rank || tensor.Length != result.Length) - { - throw new ArgumentException("Operands and result must have matching dimensions", nameof(result)); - } - - if (tensor.Rank == 0) - { - throw new ArgumentException($"Cannot operate on Tensor with {nameof(Tensor.Rank)} of 0.", nameof(tensor)); - } - - for (int i = 0; i < result.Rank; i++) - { - if (tensor.dimensions[i] != result.dimensions[i]) - { - throw new ArgumentException("Operands and result must have matching dimensions", nameof(result)); - } - } - } - - internal static int[] ValidateContractArgs(Tensor left, Tensor right, int[] leftAxes, int[] rightAxes) - { - if (leftAxes == null) - { - throw new ArgumentNullException(nameof(left)); - } - - if (rightAxes == null) - { - throw new ArgumentNullException(nameof(left)); - } - - if (leftAxes.Length != rightAxes.Length) - { - throw new ArgumentException($"{nameof(leftAxes)} and {nameof(rightAxes)} must have the same length, but were {leftAxes.Length} and {rightAxes.Length}, respectively."); - } - - for (int i = 0; i < leftAxes.Length; i++) - { - var leftAxis = leftAxes[i]; - - if (leftAxis >= left.Rank) - { - throw new ArgumentOutOfRangeException($"{nameof(leftAxes)}[{i}] was set to axis index {leftAxis} which exceeds the Rank of {left}."); - } - - var leftDimension = left.dimensions[leftAxis]; - - var rightAxis = rightAxes[i]; - - if (rightAxis >= right.Rank) - { - throw new ArgumentOutOfRangeException($"{nameof(rightAxes)}[{i}] was set to axis index {rightAxis} which exceeds the Rank of {right}."); - } - - var rightDimension = right.dimensions[rightAxis]; - - if (leftDimension != rightDimension) - { - throw new ArgumentOutOfRangeException($"Tensors may only be contracted on axes of the same length, but {nameof(leftAxes)} index {i} was length {leftDimension} and {nameof(rightAxes)} index {i} was length {rightDimension}."); - } - } - - var leftNonSummingDimensions = left.Rank - leftAxes.Length; - var rightNonSummingDimensions = right.Rank - rightAxes.Length; - var resultDimensions = new int[leftNonSummingDimensions + rightNonSummingDimensions]; - int dimensionsIndex = 0; - - Action, int[]> fillDimensions = (tensor, axes) => - { - for (int i = 0; i < tensor.Rank; i++) - { - var skip = false; - foreach (var contractionIndex in axes) - { - if (contractionIndex == i) - { - skip = true; - break; - } - } - - if (!skip) - { - resultDimensions[dimensionsIndex++] = tensor.dimensions[i]; - } - } - }; - - fillDimensions(left, leftAxes); - fillDimensions(right, rightAxes); - - return resultDimensions; - } - - internal static int[] ValidateContractArgs(Tensor left, Tensor right, int[] leftAxes, int[] rightAxes, Tensor result) - { - var expectedDimensions = ValidateContractArgs(left, right, leftAxes, rightAxes); - - if (result.Rank != expectedDimensions.Length) - { - throw new ArgumentException($"{nameof(result)} should have {expectedDimensions.Length} dimensions but had {result.Rank}."); - } - - for (int i = 0; i < expectedDimensions.Length; i++) - { - if (result.dimensions[i] != expectedDimensions[i]) - { - throw new ArgumentException($"{nameof(result)} dimension {i} should be {expectedDimensions[i]} but was {result.dimensions[i]}."); - } - } - - return expectedDimensions; - } - - internal static void Add(Tensor left, Tensor right, Tensor result) - { - ValidateBinaryArgs(left, right, result); - - TensorArithmetic.Instance.Add(left, right, result); - } - - internal static Tensor Add(Tensor left, Tensor right) - { - ValidateBinaryArgs(left, right); - - var result = left.CloneEmpty(); - - TensorArithmetic.Instance.Add(left, right, result); - - return result; - } - - internal static void Add(Tensor tensor, T scalar, Tensor result) - { - ValidateArgs(tensor, result); - - TensorArithmetic.Instance.Add(tensor, scalar, result); - } - - internal static Tensor Add(Tensor tensor, T scalar) - { - ValidateArgs(tensor); - - var result = tensor.CloneEmpty(); - - TensorArithmetic.Instance.Add(tensor, scalar, result); - - return result; - } - - internal static void And(Tensor left, Tensor right, Tensor result) - { - ValidateBinaryArgs(left, right, result); - - TensorArithmetic.Instance.And(left, right, result); - } - - internal static Tensor And(Tensor left, Tensor right) - { - ValidateBinaryArgs(left, right); - - var result = left.CloneEmpty(); - - TensorArithmetic.Instance.And(left, right, result); - - return result; - } - - internal static void And(Tensor tensor, T scalar, Tensor result) - { - ValidateArgs(tensor, result); - - TensorArithmetic.Instance.And(tensor, scalar, result); - } - - internal static Tensor And(Tensor tensor, T scalar) - { - ValidateArgs(tensor); - - var result = tensor.CloneEmpty(); - - TensorArithmetic.Instance.And(tensor, scalar, result); - - return result; - } - - internal static void Contract(Tensor left, Tensor right, int[] leftAxes, int[] rightAxes, Tensor result) - { - ValidateContractArgs(left, right, leftAxes, rightAxes, result); - - TensorArithmetic.Instance.Contract(left, right, leftAxes, rightAxes, result); - } - - internal static Tensor Contract(Tensor left, Tensor right, int[] leftAxes, int[] rightAxes) - { - var resultDimensions = ValidateContractArgs(left, right, leftAxes, rightAxes); - - var result = left.CloneEmpty(resultDimensions); - - TensorArithmetic.Instance.Contract(left, right, leftAxes, rightAxes, result); - - return result; - } - - internal static void Decrement(Tensor tensor, Tensor result) - { - ValidateArgs(tensor, result); - - TensorArithmetic.Instance.Decrement(tensor, result); - } - - internal static Tensor Decrement(Tensor tensor) - { - ValidateArgs(tensor); - - var result = tensor.Clone(); - - TensorArithmetic.Instance.Decrement(tensor, result); - - return result; - } - - internal static void Divide(Tensor left, Tensor right, Tensor result) - { - ValidateBinaryArgs(left, right, result); - - TensorArithmetic.Instance.Divide(left, right, result); - } - - internal static Tensor Divide(Tensor left, Tensor right) - { - ValidateBinaryArgs(left, right); - - var result = left.CloneEmpty(); - - TensorArithmetic.Instance.Divide(left, right, result); - - return result; - } - - internal static void Divide(Tensor tensor, T scalar, Tensor result) - { - ValidateArgs(tensor, result); - - TensorArithmetic.Instance.Divide(tensor, scalar, result); - } - - internal static Tensor Divide(Tensor tensor, T scalar) - { - ValidateArgs(tensor); - - var result = tensor.CloneEmpty(); - - TensorArithmetic.Instance.Divide(tensor, scalar, result); - - return result; - } - - internal static void Equals(Tensor left, Tensor right, Tensor result) - { - ValidateBinaryArgs(left, right, result); - - TensorArithmetic.Instance.Equals(left, right, result); - } - - internal static Tensor Equals(Tensor left, Tensor right) - { - ValidateBinaryArgs(left, right); - - var result = left.CloneEmpty(); - - TensorArithmetic.Instance.Equals(left, right, result); - - return result; - } - - internal static void GreaterThan(Tensor left, Tensor right, Tensor result) - { - ValidateBinaryArgs(left, right, result); - - TensorArithmetic.Instance.GreaterThan(left, right, result); - } - - internal static Tensor GreaterThan(Tensor left, Tensor right) - { - ValidateBinaryArgs(left, right); - - var result = left.CloneEmpty(); - - TensorArithmetic.Instance.GreaterThan(left, right, result); - - return result; - } - - internal static void GreaterThanOrEqual(Tensor left, Tensor right, Tensor result) - { - ValidateBinaryArgs(left, right, result); - - TensorArithmetic.Instance.GreaterThanOrEqual(left, right, result); - } - - internal static Tensor GreaterThanOrEqual(Tensor left, Tensor right) - { - ValidateBinaryArgs(left, right); - - var result = left.CloneEmpty(); - - TensorArithmetic.Instance.GreaterThanOrEqual(left, right, result); - - return result; - } - - internal static void Increment(Tensor tensor, Tensor result) - { - ValidateArgs(tensor, result); - - TensorArithmetic.Instance.Increment(tensor, result); - } - - internal static Tensor Increment(Tensor tensor) - { - ValidateArgs(tensor); - - var result = tensor.Clone(); - - TensorArithmetic.Instance.Increment(tensor, result); - - return result; - } - - internal static void LeftShift(Tensor tensor, int value, Tensor result) - { - ValidateArgs(tensor, result); - - TensorArithmetic.Instance.LeftShift(tensor, value, result); - } - - internal static Tensor LeftShift(Tensor tensor, int value) - { - ValidateArgs(tensor); - - var result = tensor.CloneEmpty(); - - TensorArithmetic.Instance.LeftShift(tensor, value, result); - - return result; - } - - internal static void LessThan(Tensor left, Tensor right, Tensor result) - { - ValidateBinaryArgs(left, right, result); - - TensorArithmetic.Instance.LessThan(left, right, result); - } - - internal static Tensor LessThan(Tensor left, Tensor right) - { - ValidateBinaryArgs(left, right); - - var result = left.CloneEmpty(); - - TensorArithmetic.Instance.LessThan(left, right, result); - - return result; - } - - internal static void LessThanOrEqual(Tensor left, Tensor right, Tensor result) - { - ValidateBinaryArgs(left, right, result); - - TensorArithmetic.Instance.LessThanOrEqual(left, right, result); - } - - internal static Tensor LessThanOrEqual(Tensor left, Tensor right) - { - ValidateBinaryArgs(left, right); - - var result = left.CloneEmpty(); - - TensorArithmetic.Instance.LessThanOrEqual(left, right, result); - - return result; - } - - internal static void Modulo(Tensor left, Tensor right, Tensor result) - { - ValidateBinaryArgs(left, right, result); - - TensorArithmetic.Instance.Modulo(left, right, result); - } - - internal static Tensor Modulo(Tensor left, Tensor right) - { - ValidateBinaryArgs(left, right); - - var result = left.CloneEmpty(); - - TensorArithmetic.Instance.Modulo(left, right, result); - - return result; - } - - internal static void Modulo(Tensor tensor, T scalar, Tensor result) - { - ValidateArgs(tensor, result); - - TensorArithmetic.Instance.Modulo(tensor, scalar, result); - } - - internal static Tensor Modulo(Tensor tensor, T scalar) - { - ValidateArgs(tensor); - - var result = tensor.CloneEmpty(); - - TensorArithmetic.Instance.Modulo(tensor, scalar, result); - - return result; - } - - internal static void Multiply(Tensor left, Tensor right, Tensor result) - { - ValidateBinaryArgs(left, right, result); - - TensorArithmetic.Instance.Multiply(left, right, result); - } - - internal static Tensor Multiply(Tensor left, Tensor right) - { - ValidateBinaryArgs(left, right); - - var result = left.CloneEmpty(); - - TensorArithmetic.Instance.Multiply(left, right, result); - - return result; - } - - internal static void Multiply(Tensor tensor, T scalar, Tensor result) - { - ValidateArgs(tensor, result); - - TensorArithmetic.Instance.Multiply(tensor, scalar, result); - } - - internal static Tensor Multiply(Tensor tensor, T scalar) - { - ValidateArgs(tensor); - - var result = tensor.CloneEmpty(); - - TensorArithmetic.Instance.Multiply(tensor, scalar, result); - - return result; - } - - internal static void NotEquals(Tensor left, Tensor right, Tensor result) - { - ValidateBinaryArgs(left, right, result); - - TensorArithmetic.Instance.NotEquals(left, right, result); - } - - internal static Tensor NotEquals(Tensor left, Tensor right) - { - ValidateBinaryArgs(left, right); - - var result = left.CloneEmpty(); - - TensorArithmetic.Instance.NotEquals(left, right, result); - - return result; - } - - internal static void Or(Tensor left, Tensor right, Tensor result) - { - ValidateBinaryArgs(left, right, result); - - TensorArithmetic.Instance.Or(left, right, result); - } - - internal static Tensor Or(Tensor left, Tensor right) - { - ValidateBinaryArgs(left, right); - - var result = left.CloneEmpty(); - - TensorArithmetic.Instance.Or(left, right, result); - - return result; - } - - internal static void Or(Tensor tensor, T scalar, Tensor result) - { - ValidateArgs(tensor, result); - - TensorArithmetic.Instance.Or(tensor, scalar, result); - } - - internal static Tensor Or(Tensor tensor, T scalar) - { - ValidateArgs(tensor); - - var result = tensor.CloneEmpty(); - - TensorArithmetic.Instance.Or(tensor, scalar, result); - - return result; - } - - internal static void RightShift(Tensor tensor, int value, Tensor result) - { - ValidateArgs(tensor, result); - - TensorArithmetic.Instance.RightShift(tensor, value, result); - } - - internal static Tensor RightShift(Tensor tensor, int value) - { - ValidateArgs(tensor); - - var result = tensor.CloneEmpty(); - - TensorArithmetic.Instance.RightShift(tensor, value, result); - - return result; - } - - internal static void Subtract(Tensor left, Tensor right, Tensor result) - { - ValidateBinaryArgs(left, right, result); - - TensorArithmetic.Instance.Subtract(left, right, result); - } - - internal static Tensor Subtract(Tensor left, Tensor right) - { - ValidateBinaryArgs(left, right); - - var result = left.CloneEmpty(); - - TensorArithmetic.Instance.Subtract(left, right, result); - - return result; - } - - internal static void Subtract(Tensor tensor, T scalar, Tensor result) - { - ValidateArgs(tensor, result); - - TensorArithmetic.Instance.Subtract(tensor, scalar, result); - } - - internal static Tensor Subtract(Tensor tensor, T scalar) - { - ValidateArgs(tensor); - - var result = tensor.CloneEmpty(); - - TensorArithmetic.Instance.Subtract(tensor, scalar, result); - - return result; - } - - internal static void UnaryMinus(Tensor tensor, Tensor result) - { - ValidateArgs(tensor, result); - - TensorArithmetic.Instance.UnaryMinus(tensor, result); - } - - internal static Tensor UnaryMinus(Tensor tensor) - { - ValidateArgs(tensor); - - var result = tensor.CloneEmpty(); - - TensorArithmetic.Instance.UnaryMinus(tensor, result); - - return result; - } - - internal static void UnaryPlus(Tensor tensor, Tensor result) - { - ValidateArgs(tensor, result); - - TensorArithmetic.Instance.UnaryPlus(tensor, result); - } - - internal static Tensor UnaryPlus(Tensor tensor) - { - ValidateArgs(tensor); - - var result = tensor.CloneEmpty(); - - TensorArithmetic.Instance.UnaryPlus(tensor, result); - - return result; - } - - internal static void Xor(Tensor left, Tensor right, Tensor result) - { - ValidateBinaryArgs(left, right, result); - - TensorArithmetic.Instance.Xor(left, right, result); - } - - internal static Tensor Xor(Tensor left, Tensor right) - { - ValidateBinaryArgs(left, right); - - var result = left.CloneEmpty(); - - TensorArithmetic.Instance.Xor(left, right, result); - - return result; - } - - internal static void Xor(Tensor tensor, T scalar, Tensor result) - { - ValidateArgs(tensor, result); - - TensorArithmetic.Instance.Xor(tensor, scalar, result); - } - - internal static Tensor Xor(Tensor tensor, T scalar) - { - ValidateArgs(tensor); - - var result = tensor.CloneEmpty(); - - TensorArithmetic.Instance.Xor(tensor, scalar, result); - - return result; - } - - } -} diff --git a/src/libraries/System.Numerics.Tensors/tests/TensorOperations.tt b/src/libraries/System.Numerics.Tensors/tests/TensorOperations.tt deleted file mode 100644 index 6b96c0bb554a5a..00000000000000 --- a/src/libraries/System.Numerics.Tensors/tests/TensorOperations.tt +++ /dev/null @@ -1,239 +0,0 @@ -<#@ template debug="false" hostspecific="false" language="C#" #> -<#@ assembly name="System.Core" #> -<#@ output extension=".cs" #> -<#@ include file="TensorTemplate.ttinclude" #>// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -namespace System.Numerics.Tensors -{ - public static partial class TensorOperations - { - internal static void ValidateBinaryArgs(Tensor left, Tensor right) - { - if (left.Rank != right.Rank || left.Length != right.Length) - { - throw new ArgumentException("Operands must have matching dimensions", nameof(right)); - } - - if (left.Rank == 0) - { - throw new ArgumentException($"Cannot operate on Tensor with {nameof(Tensor.Rank)} of 0.", nameof(left)); - } - - for (int i = 0; i < left.Rank; i++) - { - if (left.dimensions[i] != right.dimensions[i]) - { - throw new ArgumentException("Operands must have matching dimensions", nameof(right)); - } - } - } - - internal static void ValidateBinaryArgs(Tensor left, Tensor right, Tensor result) - { - if (left.Rank != right.Rank || left.Length != right.Length) - { - throw new ArgumentException("Operands must have matching dimensions", nameof(right)); - } - - if (left.Rank != result.Rank || left.Length != result.Length) - { - throw new ArgumentException("Operands must have matching dimensions", nameof(result)); - } - - if (left.Rank == 0) - { - throw new ArgumentException($"Cannot operate on Tensor with {nameof(Tensor.Rank)} of 0.", nameof(left)); - } - - for (int i = 0; i < result.Rank; i++) - { - if (left.dimensions[i] != right.dimensions[i]) - { - throw new ArgumentException("Operands must have matching dimensions", nameof(right)); - } - - if (left.dimensions[i] != result.dimensions[i]) - { - throw new ArgumentException("Operands and result must have matching dimensions", nameof(result)); - } - } - } - - internal static void ValidateBinaryArgs(Tensor left, Tensor right, Tensor result) - { - if (left.Rank != right.Rank || left.Length != right.Length) - { - throw new ArgumentException("Operands must have matching dimensions", nameof(right)); - } - - if (left.Rank != result.Rank || left.Length != result.Length) - { - throw new ArgumentException("Operands must have matching dimensions", nameof(result)); - } - - if (left.Rank == 0) - { - throw new ArgumentException($"Cannot operate on Tensor with {nameof(Tensor.Rank)} of 0.", nameof(left)); - } - - for (int i = 0; i < result.Rank; i++) - { - if (left.dimensions[i] != right.dimensions[i]) - { - throw new ArgumentException("Operands must have matching dimensions", nameof(right)); - } - - if (left.dimensions[i] != result.dimensions[i]) - { - throw new ArgumentException("Operands and result must have matching dimensions", nameof(result)); - } - } - } - - internal static void ValidateArgs(Tensor tensor) - { - if (tensor.Rank == 0) - { - throw new ArgumentException($"Cannot operate on Tensor with {nameof(Tensor.Rank)} of 0.", nameof(tensor)); - } - } - - internal static void ValidateArgs(Tensor tensor, Tensor result) - { - if (tensor.Rank != result.Rank || tensor.Length != result.Length) - { - throw new ArgumentException("Operands and result must have matching dimensions", nameof(result)); - } - - if (tensor.Rank == 0) - { - throw new ArgumentException($"Cannot operate on Tensor with {nameof(Tensor.Rank)} of 0.", nameof(tensor)); - } - - for (int i = 0; i < result.Rank; i++) - { - if (tensor.dimensions[i] != result.dimensions[i]) - { - throw new ArgumentException("Operands and result must have matching dimensions", nameof(result)); - } - } - } - - internal static int[] ValidateContractArgs(Tensor left, Tensor right, int[] leftAxes, int[] rightAxes) - { - if (leftAxes == null) - { - throw new ArgumentNullException(nameof(left)); - } - - if (rightAxes == null) - { - throw new ArgumentNullException(nameof(left)); - } - - if (leftAxes.Length != rightAxes.Length) - { - throw new ArgumentException($"{nameof(leftAxes)} and {nameof(rightAxes)} must have the same length, but were {leftAxes.Length} and {rightAxes.Length}, respectively."); - } - - for (int i = 0; i < leftAxes.Length; i++) - { - var leftAxis = leftAxes[i]; - - if (leftAxis >= left.Rank) - { - throw new ArgumentOutOfRangeException($"{nameof(leftAxes)}[{i}] was set to axis index {leftAxis} which exceeds the Rank of {left}."); - } - - var leftDimension = left.dimensions[leftAxis]; - - var rightAxis = rightAxes[i]; - - if (rightAxis >= right.Rank) - { - throw new ArgumentOutOfRangeException($"{nameof(rightAxes)}[{i}] was set to axis index {rightAxis} which exceeds the Rank of {right}."); - } - - var rightDimension = right.dimensions[rightAxis]; - - if (leftDimension != rightDimension) - { - throw new ArgumentOutOfRangeException($"Tensors may only be contracted on axes of the same length, but {nameof(leftAxes)} index {i} was length {leftDimension} and {nameof(rightAxes)} index {i} was length {rightDimension}."); - } - } - - var leftNonSummingDimensions = left.Rank - leftAxes.Length; - var rightNonSummingDimensions = right.Rank - rightAxes.Length; - var resultDimensions = new int[leftNonSummingDimensions + rightNonSummingDimensions]; - int dimensionsIndex = 0; - - Action, int[]> fillDimensions = (tensor, axes) => - { - for (int i = 0; i < tensor.Rank; i++) - { - var skip = false; - foreach (var contractionIndex in axes) - { - if (contractionIndex == i) - { - skip = true; - break; - } - } - - if (!skip) - { - resultDimensions[dimensionsIndex++] = tensor.dimensions[i]; - } - } - }; - - fillDimensions(left, leftAxes); - fillDimensions(right, rightAxes); - - return resultDimensions; - } - - internal static int[] ValidateContractArgs(Tensor left, Tensor right, int[] leftAxes, int[] rightAxes, Tensor result) - { - var expectedDimensions = ValidateContractArgs(left, right, leftAxes, rightAxes); - - if (result.Rank != expectedDimensions.Length) - { - throw new ArgumentException($"{nameof(result)} should have {expectedDimensions.Length} dimensions but had {result.Rank}."); - } - - for (int i = 0; i < expectedDimensions.Length; i++) - { - if (result.dimensions[i] != expectedDimensions[i]) - { - throw new ArgumentException($"{nameof(result)} dimension {i} should be {expectedDimensions[i]} but was {result.dimensions[i]}."); - } - } - - return expectedDimensions; - } - -<# foreach (MethodConfiguration method in methodConfiguration) { #> - internal static <#= method.GetGenericResultMethodSignature("Tensor", "T")#> - { - <#= method.GetValidationMethod(true) #> - - TensorArithmetic.Instance.<#=method.MethodName#>(<#=method.GetCallArguments()#>, <#= method.ResultName #>); - } - - internal static <#= method.GetGenericMethodSignature("Tensor", "T")#> - { - <#= method.GetValidationMethod(false) #> - - var <#= method.ResultName #> = <#=method.InitializeResult("T")#>; - - TensorArithmetic.Instance.<#=method.MethodName#>(<#=method.GetCallArguments()#>, <#= method.ResultName #>); - - return <#= method.ResultName #>; - } - -<# } #> - } -} diff --git a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs new file mode 100644 index 00000000000000..09aa13ae35800f --- /dev/null +++ b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs @@ -0,0 +1,2992 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.InteropServices; +using Xunit; +using Xunit.Sdk; + +#pragma warning disable xUnit1025 // reporting duplicate test cases due to not distinguishing 0.0 from -0.0 + +namespace System.Numerics.Tensors.Tests +{ + public static partial class TensorPrimitivesTests + { + #region Test Utilities + public static IEnumerable TensorLengthsIncluding0 => + TensorLengths.Concat(new object[][] { [0] }); + + public static IEnumerable TensorLengths => + from length in Enumerable.Range(1, 256) + select new object[] { length }; + + public static IEnumerable VectorLengthAndIteratedRange(float min, float max, float increment) + { + foreach (int length in new[] { 4, 8, 16 }) + { + for (float f = min; f <= max; f += increment) + { + yield return new object[] { length, f }; + } + } + } + + private static readonly Random s_random = new Random(20230828); + + private static BoundedMemory CreateTensor(int size) => BoundedMemory.Allocate(size); + + private static BoundedMemory CreateAndFillTensor(int size) + { + BoundedMemory tensor = CreateTensor(size); + FillTensor(tensor.Span); + return tensor; + } + + private static void FillTensor(Span tensor) + { + for (int i = 0; i < tensor.Length; i++) + { + tensor[i] = NextSingle(); + } + } + + private static float NextSingle() => + // For testing purposes, get a mix of negative and positive values. + (float)((s_random.NextDouble() * 2) - 1); + + private static void AssertEqualTolerance(double expected, double actual, double tolerance = 0.00001f) + { + double diff = Math.Abs(expected - actual); + if (diff > tolerance && + diff > Math.Max(Math.Abs(expected), Math.Abs(actual)) * tolerance) + { + throw new EqualException(expected, actual); + } + } + + private static unsafe float MathFMaxMagnitude(float x, float y) + { + float ax = MathF.Abs(x), ay = MathF.Abs(y); + return (ax > ay) || float.IsNaN(ax) || (ax == ay && *(int*)&x >= 0) ? x : y; + } + + private static unsafe float MathFMinMagnitude(float x, float y) + { + float ax = MathF.Abs(x), ay = MathF.Abs(y); + return (ax < ay) || float.IsNaN(ax) || (ax == ay && *(int*)&x < 0) ? x : y; + } + + private static unsafe float UInt32ToSingle(uint i) => *(float*)&i; + + private static unsafe float SingleToUInt32(float f) => *(uint*)&f; + + /// Gets a variety of special values (e.g. NaN). + private static IEnumerable GetSpecialValues() + { + // NaN + yield return UInt32ToSingle(0xFFC0_0000); // -qNaN / float.NaN + yield return UInt32ToSingle(0xFFFF_FFFF); // -qNaN / all-bits-set + yield return UInt32ToSingle(0x7FC0_0000); // +qNaN + yield return UInt32ToSingle(0xFFA0_0000); // -sNaN + yield return UInt32ToSingle(0x7FA0_0000); // +sNaN + + // +Infinity, -Infinity + yield return float.PositiveInfinity; + yield return float.NegativeInfinity; + + // +Zero, -Zero + yield return +0.0f; + yield return -0.0f; + + // Subnormals + yield return +float.Epsilon; + yield return -float.Epsilon; + yield return UInt32ToSingle(0x007F_FFFF); + yield return UInt32ToSingle(0x807F_FFFF); + + // Normals + yield return UInt32ToSingle(0x0080_0000); + yield return UInt32ToSingle(0x8080_0000); + yield return UInt32ToSingle(0x7F7F_FFFF); // MaxValue + yield return UInt32ToSingle(0xFF7F_FFFF); // MinValue + } + + /// + /// Runs the specified action for each special value. Before the action is invoked, + /// the value is stored into a random position in , and the original + /// value is subsequently restored. + /// + private static void RunForEachSpecialValue(Action action, BoundedMemory x) + { + foreach (float value in GetSpecialValues()) + { + int pos = s_random.Next(x.Length); + float orig = x[pos]; + x[pos] = value; + + action(); + + x[pos] = orig; + } + } + + /// + /// Loads a variety of special values (e.g. NaN) into random positions in + /// and related values into the corresponding positions in . + /// + private static void SetSpecialValues(Span x, Span y) + { + int pos; + + // NaNs + pos = s_random.Next(x.Length); + x[pos] = float.NaN; + y[pos] = UInt32ToSingle(0x7FC0_0000); + + // +Infinity, -Infinity + pos = s_random.Next(x.Length); + x[pos] = float.PositiveInfinity; + y[pos] = float.NegativeInfinity; + + // +Zero, -Zero + pos = s_random.Next(x.Length); + x[pos] = +0.0f; + y[pos] = -0.0f; + + // +Epsilon, -Epsilon + pos = s_random.Next(x.Length); + x[pos] = +float.Epsilon; + y[pos] = -float.Epsilon; + + // Same magnitude, opposite sign + pos = s_random.Next(x.Length); + x[pos] = +5.0f; + y[pos] = -5.0f; + } + #endregion + + #region Abs + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Abs(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.Abs(x, destination); + + for (int i = 0; i < x.Length; i++) + { + AssertEqualTolerance(MathF.Abs(x[i]), destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Abs_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + + TensorPrimitives.Abs(x, x); + + for (int i = 0; i < x.Length; i++) + { + AssertEqualTolerance(MathF.Abs(xOrig[i]), x[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Abs_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.Abs(x, destination)); + } + + [Fact] + public static void Abs_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Abs(array.AsSpan(1, 5), array.AsSpan(0, 5))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Abs(array.AsSpan(1, 5), array.AsSpan(2, 5))); + } + #endregion + + #region Add + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Add_TwoTensors(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.Add(x, y, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(x[i] + y[i], destination[i]); + } + + float[] xOrig = x.Span.ToArray(); + + // Validate that the destination can be the same as an input. + TensorPrimitives.Add(x, x, x); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(xOrig[i] + xOrig[i], x[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Add_TwoTensors_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + + TensorPrimitives.Add(x, x, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(xOrig[i] + xOrig[i], x[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Add_TwoTensors_ThrowsForMismatchedLengths(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); + using BoundedMemory destination = CreateTensor(tensorLength); + + Assert.Throws(() => TensorPrimitives.Add(x, y, destination)); + Assert.Throws(() => TensorPrimitives.Add(y, x, destination)); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Add_TwoTensors_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.Add(x, y, destination)); + } + + [Fact] + public static void Add_TwoTensors_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Add(array.AsSpan(1, 2), array.AsSpan(5, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Add(array.AsSpan(1, 2), array.AsSpan(5, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Add(array.AsSpan(1, 2), array.AsSpan(5, 2), array.AsSpan(4, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Add(array.AsSpan(1, 2), array.AsSpan(5, 2), array.AsSpan(6, 2))); + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Add_TensorScalar(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float y = NextSingle(); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.Add(x, y, destination); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(x[i] + y, destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Add_TensorScalar_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + float y = NextSingle(); + + TensorPrimitives.Add(x, y, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(xOrig[i] + y, x[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Add_TensorScalar_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float y = NextSingle(); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.Add(x, y, destination)); + } + + [Fact] + public static void Add_TensorScalar_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Add(array.AsSpan(1, 2), 42, array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Add(array.AsSpan(1, 2), 42, array.AsSpan(2, 2))); + } + #endregion + + #region AddMultiply + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void AddMultiply_ThreeTensors(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory multiplier = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.AddMultiply(x, y, multiplier, destination); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance((x[i] + y[i]) * multiplier[i], destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void AddMultiply_ThreeTensors_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + + TensorPrimitives.AddMultiply(x, x, x, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance((xOrig[i] + xOrig[i]) * xOrig[i], x[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void AddMultiply_ThreeTensors_ThrowsForMismatchedLengths(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory z = CreateAndFillTensor(tensorLength - 1); + using BoundedMemory destination = CreateTensor(tensorLength); + + Assert.Throws(() => TensorPrimitives.AddMultiply(x, y, z, destination)); + Assert.Throws(() => TensorPrimitives.AddMultiply(x, z, y, destination)); + Assert.Throws(() => TensorPrimitives.AddMultiply(z, x, y, destination)); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void AddMultiply_ThreeTensors_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory multiplier = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(x, y, multiplier, destination)); + } + + [Fact] + public static void AddMultiply_ThreeTensors_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(5, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(6, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(8, 2))); + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void AddMultiply_TensorTensorScalar(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + float multiplier = NextSingle(); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.AddMultiply(x, y, multiplier, destination); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance((x[i] + y[i]) * multiplier, destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void AddMultiply_TensorTensorScalar_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + float multiplier = NextSingle(); + + TensorPrimitives.AddMultiply(x, x, multiplier, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance((xOrig[i] + xOrig[i]) * multiplier, x[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void AddMultiply_TensorTensorScalar_ThrowsForMismatchedLengths_x_y(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); + float multiplier = NextSingle(); + using BoundedMemory destination = CreateTensor(tensorLength); + + Assert.Throws(() => TensorPrimitives.AddMultiply(x, y, multiplier, destination)); + Assert.Throws(() => TensorPrimitives.AddMultiply(y, x, multiplier, destination)); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void AddMultiply_TensorTensorScalar_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + float multiplier = NextSingle(); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(x, y, multiplier, destination)); + } + + [Fact] + public static void AddMultiply_TensorTensorScalar_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), 42, array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), 42, array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), 42, array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), 42, array.AsSpan(5, 2))); + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void AddMultiply_TensorScalarTensor(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float y = NextSingle(); + using BoundedMemory multiplier = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.AddMultiply(x, y, multiplier, destination); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance((x[i] + y) * multiplier[i], destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void AddMultiply_TensorScalarTensor_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + float y = NextSingle(); + + TensorPrimitives.AddMultiply(x, y, x, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance((xOrig[i] + y) * xOrig[i], x[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void AddMultiply_TensorScalarTensor_ThrowsForMismatchedLengths_x_z(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float y = NextSingle(); + using BoundedMemory z = CreateAndFillTensor(tensorLength - 1); + using BoundedMemory destination = CreateTensor(tensorLength); + + Assert.Throws(() => TensorPrimitives.AddMultiply(x, y, z, destination)); + Assert.Throws(() => TensorPrimitives.AddMultiply(z, y, x, destination)); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void AddMultiply_TensorScalarTensor_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float y = NextSingle(); + using BoundedMemory multiplier = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(x, y, multiplier, destination)); + } + + [Fact] + public static void AddMultiply_TensorScalarTensor_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), 42, array.AsSpan(4, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), 42, array.AsSpan(4, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), 42, array.AsSpan(4, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), 42, array.AsSpan(4, 2), array.AsSpan(5, 2))); + } + #endregion + + #region Cosh + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Cosh(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.Cosh(x, destination); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Cosh(x[i]), destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Cosh_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + + TensorPrimitives.Cosh(x, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Cosh(xOrig[i]), x[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Cosh_SpecialValues(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + RunForEachSpecialValue(() => + { + TensorPrimitives.Cosh(x, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Cosh(x[i]), destination[i]); + } + }, x); + } + + [Theory] + [MemberData(nameof(VectorLengthAndIteratedRange), new object[] { -100f, 100f, 3f })] + public static void Cosh_ValueRange(int vectorLength, float element) + { + float[] x = new float[vectorLength]; + float[] dest = new float[vectorLength]; + + x.AsSpan().Fill(element); + TensorPrimitives.Cosh(x, dest); + + float expected = MathF.Cosh(element); + foreach (float actual in dest) + { + AssertEqualTolerance(expected, actual); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Cosh_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.Cosh(x, destination)); + } + + [Fact] + public static void Cosh_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Cosh(array.AsSpan(1, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Cosh(array.AsSpan(1, 2), array.AsSpan(2, 2))); + } + #endregion + + #region CosineSimilarity + [Theory] + [MemberData(nameof(TensorLengths))] + public static void CosineSimilarity_ThrowsForMismatchedLengths(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); + + Assert.Throws(() => TensorPrimitives.CosineSimilarity(x, y)); + Assert.Throws(() => TensorPrimitives.CosineSimilarity(y, x)); + } + + [Fact] + public static void CosineSimilarity_ThrowsForEmpty() + { + Assert.Throws(() => TensorPrimitives.CosineSimilarity(ReadOnlySpan.Empty, ReadOnlySpan.Empty)); + Assert.Throws(() => TensorPrimitives.CosineSimilarity(ReadOnlySpan.Empty, CreateTensor(1))); + Assert.Throws(() => TensorPrimitives.CosineSimilarity(CreateTensor(1), ReadOnlySpan.Empty)); + } + + [Theory] + [InlineData(new float[] { 3, 2, 0, 5 }, new float[] { 1, 0, 0, 0 }, 0.48666f)] + [InlineData(new float[] { 1, 1, 1, 1, 1, 0 }, new float[] { 1, 1, 1, 1, 0, 1 }, 0.80f)] + public static void CosineSimilarity_KnownValues(float[] x, float[] y, float expectedResult) + { + AssertEqualTolerance(expectedResult, TensorPrimitives.CosineSimilarity(x, y)); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void CosineSimilarity(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + + float dot = 0f, squareX = 0f, squareY = 0f; + for (int i = 0; i < x.Length; i++) + { + dot += x[i] * y[i]; + squareX += x[i] * x[i]; + squareY += y[i] * y[i]; + } + + AssertEqualTolerance(dot / (MathF.Sqrt(squareX) * MathF.Sqrt(squareY)), TensorPrimitives.CosineSimilarity(x, y)); + } + #endregion + + #region Distance + [Fact] + public static void Distance_ThrowsForEmpty() + { + Assert.Throws(() => TensorPrimitives.Distance(ReadOnlySpan.Empty, ReadOnlySpan.Empty)); + Assert.Throws(() => TensorPrimitives.Distance(ReadOnlySpan.Empty, CreateTensor(1))); + Assert.Throws(() => TensorPrimitives.Distance(CreateTensor(1), ReadOnlySpan.Empty)); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Distance_ThrowsForMismatchedLengths(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); + + Assert.Throws(() => TensorPrimitives.Distance(x, y)); + Assert.Throws(() => TensorPrimitives.Distance(y, x)); + } + + [Theory] + [InlineData(new float[] { 3, 2 }, new float[] { 4, 1 }, 1.4142f)] + [InlineData(new float[] { 0, 4 }, new float[] { 6, 2 }, 6.3245f)] + [InlineData(new float[] { 1, 2, 3 }, new float[] { 4, 5, 6 }, 5.19615f)] + [InlineData(new float[] { 5, 1, 6, 10 }, new float[] { 7, 2, 8, 4 }, 6.7082f)] + public static void Distance_KnownValues(float[] x, float[] y, float expectedResult) + { + AssertEqualTolerance(expectedResult, TensorPrimitives.Distance(x, y)); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Distance(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + + float distance = 0f; + for (int i = 0; i < x.Length; i++) + { + distance += (x[i] - y[i]) * (x[i] - y[i]); + } + + AssertEqualTolerance(MathF.Sqrt(distance), TensorPrimitives.Distance(x, y)); + } + #endregion + + #region Divide + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Divide_TwoTensors(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.Divide(x, y, destination); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(x[i] / y[i], destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Divide_TwoTensors_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + + TensorPrimitives.Divide(x, x, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(xOrig[i] / xOrig[i], x[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Divide_TwoTensors_ThrowsForMismatchedLengths(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); + using BoundedMemory destination = CreateTensor(tensorLength); + + Assert.Throws(() => TensorPrimitives.Divide(x, y, destination)); + Assert.Throws(() => TensorPrimitives.Divide(y, x, destination)); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Divide_TwoTensors_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(x, y, destination)); + } + + [Fact] + public static void Divide_TwoTensors_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(5, 2))); + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Divide_TensorScalar(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float y = NextSingle(); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.Divide(x, y, destination); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(x[i] / y, destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Divide_TensorScalar_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + float y = NextSingle(); + + TensorPrimitives.Divide(x, y, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(xOrig[i] / y, x[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Divide_TensorScalar_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float y = NextSingle(); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(x, y, destination)); + } + + [Fact] + public static void Divide_TensorScalar_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(5, 2))); + } + #endregion + + #region Dot + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Dot_ThrowsForMismatchedLengths_x_y(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); + + Assert.Throws(() => TensorPrimitives.Dot(x, y)); + Assert.Throws(() => TensorPrimitives.Dot(y, x)); + } + + [Theory] + [InlineData(new float[] { 1, 3, -5 }, new float[] { 4, -2, -1 }, 3)] + [InlineData(new float[] { 1, 2, 3 }, new float[] { 4, 5, 6 }, 32)] + [InlineData(new float[] { 1, 2, 3, 10, 8 }, new float[] { 4, 5, 6, -2, 7 }, 68)] + [InlineData(new float[] { }, new float[] { }, 0)] + public static void Dot_KnownValues(float[] x, float[] y, float expectedResult) + { + AssertEqualTolerance(expectedResult, TensorPrimitives.Dot(x, y)); + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Dot(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + + float dot = 0f; + for (int i = 0; i < x.Length; i++) + { + dot += x[i] * y[i]; + } + + AssertEqualTolerance(dot, TensorPrimitives.Dot(x, y)); + } + #endregion + + #region Exp + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Exp(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.Exp(x, destination); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Exp(x[i]), destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Exp_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + + TensorPrimitives.Exp(x, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Exp(xOrig[i]), x[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Exp_SpecialValues(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + RunForEachSpecialValue(() => + { + TensorPrimitives.Exp(x, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Exp(x[i]), destination[i]); + } + }, x); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Exp_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.Exp(x, destination)); + } + + [Fact] + public static void Exp_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Exp(array.AsSpan(1, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Exp(array.AsSpan(1, 2), array.AsSpan(2, 2))); + } + #endregion + + #region IndexOfMax + [Fact] + public static void IndexOfMax_ReturnsNegative1OnEmpty() + { + Assert.Equal(-1, TensorPrimitives.IndexOfMax(ReadOnlySpan.Empty)); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void IndexOfMax(int tensorLength) + { + foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + x[expected] = Enumerable.Max(MemoryMarshal.ToEnumerable(x.Memory)) + 1; + Assert.Equal(expected, TensorPrimitives.IndexOfMax(x)); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void IndexOfMax_FirstNaNReturned(int tensorLength) + { + foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + x[expected] = float.NaN; + x[tensorLength - 1] = float.NaN; + Assert.Equal(expected, TensorPrimitives.IndexOfMax(x)); + } + } + + [Fact] + public static void IndexOfMax_Negative0LesserThanPositive0() + { + Assert.Equal(1, TensorPrimitives.IndexOfMax([-0f, +0f])); + Assert.Equal(0, TensorPrimitives.IndexOfMax([-0f, -0f, -0f, -0f])); + Assert.Equal(4, TensorPrimitives.IndexOfMax([-0f, -0f, -0f, -0f, +0f, +0f, +0f])); + Assert.Equal(0, TensorPrimitives.IndexOfMax([+0f, -0f])); + Assert.Equal(1, TensorPrimitives.IndexOfMax([-1, -0f])); + Assert.Equal(2, TensorPrimitives.IndexOfMax([-1, -0f, 1])); + } + #endregion + + #region IndexOfMaxMagnitude + [Fact] + public static void IndexOfMaxMagnitude_ReturnsNegative1OnEmpty() + { + Assert.Equal(-1, TensorPrimitives.IndexOfMaxMagnitude(ReadOnlySpan.Empty)); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void IndexOfMaxMagnitude(int tensorLength) + { + foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + x[expected] = Enumerable.Max(MemoryMarshal.ToEnumerable(x.Memory), Math.Abs) + 1; + Assert.Equal(expected, TensorPrimitives.IndexOfMaxMagnitude(x)); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void IndexOfMaxMagnitude_FirstNaNReturned(int tensorLength) + { + foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + x[expected] = float.NaN; + x[tensorLength - 1] = float.NaN; + Assert.Equal(expected, TensorPrimitives.IndexOfMaxMagnitude(x)); + } + } + + [Fact] + public static void IndexOfMaxMagnitude_Negative0LesserThanPositive0() + { + Assert.Equal(0, TensorPrimitives.IndexOfMaxMagnitude([-0f, -0f, -0f, -0f])); + Assert.Equal(1, TensorPrimitives.IndexOfMaxMagnitude([-0f, +0f])); + Assert.Equal(1, TensorPrimitives.IndexOfMaxMagnitude([-0f, +0f, +0f, +0f])); + Assert.Equal(0, TensorPrimitives.IndexOfMaxMagnitude([+0f, -0f])); + Assert.Equal(0, TensorPrimitives.IndexOfMaxMagnitude([-1, -0f])); + Assert.Equal(2, TensorPrimitives.IndexOfMaxMagnitude([-1, -0f, 1])); + } + #endregion + + #region IndexOfMin + [Fact] + public static void IndexOfMin_ReturnsNegative1OnEmpty() + { + Assert.Equal(-1, TensorPrimitives.IndexOfMin(ReadOnlySpan.Empty)); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void IndexOfMin(int tensorLength) + { + foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + x[expected] = Enumerable.Min(MemoryMarshal.ToEnumerable(x.Memory)) - 1; + Assert.Equal(expected, TensorPrimitives.IndexOfMin(x)); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void IndexOfMin_FirstNaNReturned(int tensorLength) + { + foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + x[expected] = float.NaN; + x[tensorLength - 1] = float.NaN; + Assert.Equal(expected, TensorPrimitives.IndexOfMin(x)); + } + } + + [Fact] + public static void IndexOfMin_Negative0LesserThanPositive0() + { + Assert.Equal(0, TensorPrimitives.IndexOfMin([-0f, +0f])); + Assert.Equal(1, TensorPrimitives.IndexOfMin([+0f, -0f])); + Assert.Equal(1, TensorPrimitives.IndexOfMin([+0f, -0f, -0f, -0f, -0f])); + Assert.Equal(0, TensorPrimitives.IndexOfMin([-1, -0f])); + Assert.Equal(0, TensorPrimitives.IndexOfMin([-1, -0f, 1])); + } + #endregion + + #region IndexOfMinMagnitude + [Fact] + public static void IndexOfMinMagnitude_ReturnsNegative1OnEmpty() + { + Assert.Equal(-1, TensorPrimitives.IndexOfMinMagnitude(ReadOnlySpan.Empty)); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void IndexOfMinMagnitude(int tensorLength) + { + foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) + { + using BoundedMemory x = CreateTensor(tensorLength); + for (int i = 0; i < x.Length; i++) + { + x[i] = i % 2 == 0 ? 42 : -42; + } + + x[expected] = -41; + + Assert.Equal(expected, TensorPrimitives.IndexOfMinMagnitude(x)); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void IndexOfMinMagnitude_FirstNaNReturned(int tensorLength) + { + foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + x[expected] = float.NaN; + x[tensorLength - 1] = float.NaN; + Assert.Equal(expected, TensorPrimitives.IndexOfMinMagnitude(x)); + } + } + + [Fact] + public static void IndexOfMinMagnitude_Negative0LesserThanPositive0() + { + Assert.Equal(0, TensorPrimitives.IndexOfMinMagnitude([-0f, -0f, -0f, -0f])); + Assert.Equal(0, TensorPrimitives.IndexOfMinMagnitude([-0f, +0f])); + Assert.Equal(1, TensorPrimitives.IndexOfMinMagnitude([+0f, -0f])); + Assert.Equal(1, TensorPrimitives.IndexOfMinMagnitude([+0f, -0f, -0f, -0f])); + Assert.Equal(1, TensorPrimitives.IndexOfMinMagnitude([-1, -0f])); + Assert.Equal(1, TensorPrimitives.IndexOfMinMagnitude([-1, -0f, 1])); + } + #endregion + + #region Log + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Log(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.Log(x, destination); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Log(x[i]), destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Log_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + + TensorPrimitives.Log(x, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Log(xOrig[i]), x[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Log_SpecialValues(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + RunForEachSpecialValue(() => + { + TensorPrimitives.Log(x, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Log(x[i]), destination[i]); + } + }, x); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Log_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.Log(x, destination)); + } + + [Fact] + public static void Log_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Log(array.AsSpan(1, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Log(array.AsSpan(1, 2), array.AsSpan(2, 2))); + } + #endregion + + #region Log2 + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Log2(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.Log2(x, destination); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Log(x[i], 2), destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Log2_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + + TensorPrimitives.Log2(x, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Log(xOrig[i], 2), x[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Log2_SpecialValues(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + RunForEachSpecialValue(() => + { + TensorPrimitives.Log2(x, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Log(x[i], 2), destination[i]); + } + }, x); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Log2_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.Log2(x, destination)); + } + + [Fact] + public static void Log2_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Log2(array.AsSpan(1, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Log2(array.AsSpan(1, 2), array.AsSpan(2, 2))); + } + #endregion + + #region Max + [Fact] + public static void Max_Tensor_ThrowsForEmpty() + { + Assert.Throws(() => TensorPrimitives.Max(ReadOnlySpan.Empty)); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Max_Tensor(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + + Assert.Equal(Enumerable.Max(MemoryMarshal.ToEnumerable(x.Memory)), TensorPrimitives.Max(x)); + + float max = float.NegativeInfinity; + foreach (float f in x.Span) + { + max = Math.Max(max, f); + } + + Assert.Equal(max, TensorPrimitives.Max(x)); + Assert.Equal(SingleToUInt32(x[TensorPrimitives.IndexOfMax(x)]), SingleToUInt32(TensorPrimitives.Max(x))); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Max_Tensor_SpecialValues(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + + RunForEachSpecialValue(() => + { + float max = float.NegativeInfinity; + foreach (float f in x.Span) + { + max = Math.Max(max, f); + } + + Assert.Equal(max, TensorPrimitives.Max(x)); + Assert.Equal(SingleToUInt32(x[TensorPrimitives.IndexOfMax(x)]), SingleToUInt32(TensorPrimitives.Max(x))); + }, x); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Max_Tensor_NanReturned(int tensorLength) + { + using BoundedMemory x = CreateTensor(tensorLength); + foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) + { + FillTensor(x); + x[expected] = float.NaN; + Assert.Equal(float.NaN, TensorPrimitives.Max(x)); + } + } + + [Fact] + public static void Max_Tensor_Negative0LesserThanPositive0() + { + Assert.Equal(+0f, TensorPrimitives.Max([-0f, +0f])); + Assert.Equal(+0f, TensorPrimitives.Max([+0f, -0f])); + Assert.Equal(-0f, TensorPrimitives.Max([-1, -0f])); + Assert.Equal(1, TensorPrimitives.Max([-1, -0f, 1])); + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Max_TwoTensors(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.Max(x, y, destination); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Max(x[i], y[i]), destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Max_TwoTensors_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(), yOrig = y.Span.ToArray(); + + TensorPrimitives.Max(x, y, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Max(xOrig[i], y[i]), x[i]); + } + + xOrig.AsSpan().CopyTo(x.Span); + yOrig.AsSpan().CopyTo(y.Span); + + TensorPrimitives.Max(x, y, y); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Max(x[i], yOrig[i]), y[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Max_TwoTensors_SpecialValues(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + SetSpecialValues(x, y); + + TensorPrimitives.Max(x, y, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Max(x[i], y[i]), destination[i]); + } + + TensorPrimitives.Max(y, x, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Max(y[i], x[i]), destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Max_TwoTensors_ThrowsForMismatchedLengths(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); + using BoundedMemory destination = CreateTensor(tensorLength); + + Assert.Throws(() => TensorPrimitives.Max(x, y, destination)); + Assert.Throws(() => TensorPrimitives.Max(y, x, destination)); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Max_TwoTensors_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.Max(x, y, destination)); + } + + [Fact] + public static void Max_TwoTensors_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Max(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Max(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Max(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Max(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(5, 2))); + } + #endregion + + #region MaxMagnitude + [Fact] + public static void MaxMagnitude_Tensor_ThrowsForEmpty() + { + Assert.Throws(() => TensorPrimitives.MaxMagnitude(ReadOnlySpan.Empty)); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MaxMagnitude_Tensor(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + + float maxMagnitude = x[0]; + foreach (float f in x.Span) + { + maxMagnitude = MathFMaxMagnitude(maxMagnitude, f); + } + + Assert.Equal(maxMagnitude, TensorPrimitives.MaxMagnitude(x)); + Assert.Equal(SingleToUInt32(x[TensorPrimitives.IndexOfMaxMagnitude(x)]), SingleToUInt32(TensorPrimitives.MaxMagnitude(x))); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MaxMagnitude_Tensor_SpecialValues(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + + RunForEachSpecialValue(() => + { + float maxMagnitude = x[0]; + foreach (float f in x.Span) + { + maxMagnitude = MathFMaxMagnitude(maxMagnitude, f); + } + + Assert.Equal(maxMagnitude, TensorPrimitives.MaxMagnitude(x)); + Assert.Equal(SingleToUInt32(x[TensorPrimitives.IndexOfMaxMagnitude(x)]), SingleToUInt32(TensorPrimitives.MaxMagnitude(x))); + }, x); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MaxMagnitude_Tensor_NanReturned(int tensorLength) + { + using BoundedMemory x = CreateTensor(tensorLength); + foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) + { + FillTensor(x); + x[expected] = float.NaN; + Assert.Equal(float.NaN, TensorPrimitives.MaxMagnitude(x)); + } + } + + [Fact] + public static void MaxMagnitude_Tensor_Negative0LesserThanPositive0() + { + Assert.Equal(+0f, TensorPrimitives.MaxMagnitude([-0f, +0f])); + Assert.Equal(+0f, TensorPrimitives.MaxMagnitude([+0f, -0f])); + Assert.Equal(-1, TensorPrimitives.MaxMagnitude([-1, -0f])); + Assert.Equal(1, TensorPrimitives.MaxMagnitude([-1, -0f, 1])); + Assert.Equal(0f, TensorPrimitives.MaxMagnitude([-0f, -0f, -0f, -0f, -0f, 0f])); + Assert.Equal(1, TensorPrimitives.MaxMagnitude([-0f, -0f, -0f, -0f, -1, -0f, 0f, 1])); + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void MaxMagnitude_TwoTensors(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.MaxMagnitude(x, y, destination); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathFMaxMagnitude(x[i], y[i]), destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void MaxMagnitude_TwoTensors_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(), yOrig = y.Span.ToArray(); + + TensorPrimitives.MaxMagnitude(x, y, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathFMaxMagnitude(xOrig[i], y[i]), x[i]); + } + + xOrig.AsSpan().CopyTo(x.Span); + yOrig.AsSpan().CopyTo(y.Span); + + TensorPrimitives.MaxMagnitude(x, y, y); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathFMaxMagnitude(x[i], yOrig[i]), y[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MaxMagnitude_TwoTensors_SpecialValues(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + SetSpecialValues(x, y); + + TensorPrimitives.MaxMagnitude(x, y, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathFMaxMagnitude(x[i], y[i]), destination[i]); + } + + TensorPrimitives.MaxMagnitude(y, x, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathFMaxMagnitude(y[i], x[i]), destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MaxMagnitude_TwoTensors_ThrowsForMismatchedLengths(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); + using BoundedMemory destination = CreateTensor(tensorLength); + + Assert.Throws(() => TensorPrimitives.MaxMagnitude(x, y, destination)); + Assert.Throws(() => TensorPrimitives.MaxMagnitude(y, x, destination)); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MaxMagnitude_TwoTensors_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.MaxMagnitude(x, y, destination)); + } + + [Fact] + public static void MaxMagnitude_TwoTensors_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.MaxMagnitude(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MaxMagnitude(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MaxMagnitude(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MaxMagnitude(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(5, 2))); + } + #endregion + + #region Min + [Fact] + public static void Min_Tensor_ThrowsForEmpty() + { + Assert.Throws(() => TensorPrimitives.Min(ReadOnlySpan.Empty)); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Min_Tensor(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + + Assert.Equal(Enumerable.Min(MemoryMarshal.ToEnumerable(x.Memory)), TensorPrimitives.Min(x)); + + float min = float.PositiveInfinity; + foreach (float f in x.Span) + { + min = Math.Min(min, f); + } + + Assert.Equal(min, TensorPrimitives.Min(x)); + Assert.Equal(SingleToUInt32(x[TensorPrimitives.IndexOfMin(x)]), SingleToUInt32(TensorPrimitives.Min(x))); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Min_Tensor_SpecialValues(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + + RunForEachSpecialValue(() => + { + float min = float.PositiveInfinity; + foreach (float f in x.Span) + { + min = Math.Min(min, f); + } + + Assert.Equal(min, TensorPrimitives.Min(x)); + Assert.Equal(SingleToUInt32(x[TensorPrimitives.IndexOfMin(x)]), SingleToUInt32(TensorPrimitives.Min(x))); + }, x); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Min_Tensor_NanReturned(int tensorLength) + { + using BoundedMemory x = CreateTensor(tensorLength); + foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) + { + FillTensor(x); + x[expected] = float.NaN; + Assert.Equal(float.NaN, TensorPrimitives.Min(x)); + } + } + + [Fact] + public static void Min_Tensor_Negative0LesserThanPositive0() + { + Assert.Equal(-0f, TensorPrimitives.Min([-0f, +0f])); + Assert.Equal(-0f, TensorPrimitives.Min([+0f, -0f])); + Assert.Equal(-1, TensorPrimitives.Min([-1, -0f])); + Assert.Equal(-1, TensorPrimitives.Min([-1, -0f, 1])); + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Min_TwoTensors(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.Min(x, y, destination); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Min(x[i], y[i]), destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Min_TwoTensors_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(), yOrig = y.Span.ToArray(); + + TensorPrimitives.Min(x, y, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Min(xOrig[i], y[i]), x[i]); + } + + xOrig.AsSpan().CopyTo(x.Span); + yOrig.AsSpan().CopyTo(y.Span); + + TensorPrimitives.Min(x, y, y); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Min(x[i], yOrig[i]), y[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Min_TwoTensors_SpecialValues(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + SetSpecialValues(x, y); + + TensorPrimitives.Min(x, y, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Min(x[i], y[i]), destination[i]); + } + + TensorPrimitives.Min(y, x, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Min(y[i], x[i]), destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Min_TwoTensors_ThrowsForMismatchedLengths(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); + using BoundedMemory destination = CreateTensor(tensorLength); + + Assert.Throws(() => TensorPrimitives.Min(x, y, destination)); + Assert.Throws(() => TensorPrimitives.Min(y, x, destination)); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Min_TwoTensors_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.Min(x, y, destination)); + } + + [Fact] + public static void Min_TwoTensors_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Min(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Min(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Min(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Min(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(5, 2))); + } + #endregion + + #region MinMagnitude + [Fact] + public static void MinMagnitude_Tensor_ThrowsForEmpty() + { + Assert.Throws(() => TensorPrimitives.MinMagnitude(ReadOnlySpan.Empty)); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MinMagnitude_Tensor(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + + float minMagnitude = x[0]; + foreach (float f in x.Span) + { + minMagnitude = MathFMinMagnitude(minMagnitude, f); + } + + Assert.Equal(minMagnitude, TensorPrimitives.MinMagnitude(x)); + Assert.Equal(SingleToUInt32(x[TensorPrimitives.IndexOfMinMagnitude(x)]), SingleToUInt32(TensorPrimitives.MinMagnitude(x))); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MinMagnitude_Tensor_SpecialValues(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + + RunForEachSpecialValue(() => + { + float minMagnitude = x[0]; + foreach (float f in x.Span) + { + minMagnitude = MathFMinMagnitude(minMagnitude, f); + } + + Assert.Equal(minMagnitude, TensorPrimitives.MinMagnitude(x)); + Assert.Equal(SingleToUInt32(x[TensorPrimitives.IndexOfMinMagnitude(x)]), SingleToUInt32(TensorPrimitives.MinMagnitude(x))); + }, x); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MinMagnitude_Tensor_NanReturned(int tensorLength) + { + using BoundedMemory x = CreateTensor(tensorLength); + foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) + { + FillTensor(x); + x[expected] = float.NaN; + Assert.Equal(float.NaN, TensorPrimitives.MinMagnitude(x)); + } + } + + [Fact] + public static void MinMagnitude_Tensor_Negative0LesserThanPositive0() + { + Assert.Equal(0, TensorPrimitives.MinMagnitude([-0f, +0f])); + Assert.Equal(0, TensorPrimitives.MinMagnitude([+0f, -0f])); + Assert.Equal(0, TensorPrimitives.MinMagnitude([-1, -0f])); + Assert.Equal(0, TensorPrimitives.MinMagnitude([-1, -0f, 1])); + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void MinMagnitude_TwoTensors(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.MinMagnitude(x, y, destination); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathFMinMagnitude(x[i], y[i]), destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void MinMagnitude_TwoTensors_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(), yOrig = y.Span.ToArray(); + + TensorPrimitives.MinMagnitude(x, y, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathFMinMagnitude(xOrig[i], y[i]), x[i]); + } + + xOrig.AsSpan().CopyTo(x.Span); + yOrig.AsSpan().CopyTo(y.Span); + + TensorPrimitives.MinMagnitude(x, y, y); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathFMinMagnitude(x[i], yOrig[i]), y[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MinMagnitude_TwoTensors_SpecialValues(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + SetSpecialValues(x, y); + + TensorPrimitives.MinMagnitude(x, y, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathFMinMagnitude(x[i], y[i]), destination[i]); + } + + TensorPrimitives.MinMagnitude(y, x, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathFMinMagnitude(y[i], x[i]), destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MinMagnitude_TwoTensors_ThrowsForMismatchedLengths(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); + using BoundedMemory destination = CreateTensor(tensorLength); + + Assert.Throws(() => TensorPrimitives.MinMagnitude(x, y, destination)); + Assert.Throws(() => TensorPrimitives.MinMagnitude(y, x, destination)); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MinMagnitude_TwoTensors_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.MinMagnitude(x, y, destination)); + } + + [Fact] + public static void MinMagnitude_TwoTensors_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.MinMagnitude(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MinMagnitude(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MinMagnitude(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MinMagnitude(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(5, 2))); + } + #endregion + + #region Multiply + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Multiply_TwoTensors(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.Multiply(x, y, destination); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(x[i] * y[i], destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Multiply_TwoTensors_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + + TensorPrimitives.Multiply(x, x, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(xOrig[i] * xOrig[i], x[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Multiply_TwoTensors_ThrowsForMismatchedLengths(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); + using BoundedMemory destination = CreateTensor(tensorLength); + + Assert.Throws(() => TensorPrimitives.Multiply(x, y, destination)); + Assert.Throws(() => TensorPrimitives.Multiply(y, x, destination)); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Multiply_TwoTensors_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.Multiply(x, y, destination)); + } + + [Fact] + public static void Multiply_TwoTensors_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Multiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Multiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Multiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Multiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(5, 2))); + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Multiply_TensorScalar(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float y = NextSingle(); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.Multiply(x, y, destination); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(x[i] * y, destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Multiply_TensorScalar_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + float y = NextSingle(); + + TensorPrimitives.Multiply(x, y, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(xOrig[i] * y, x[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Multiply_TensorScalar_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float y = NextSingle(); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.Multiply(x, y, destination)); + } + + [Fact] + public static void Multiply_TensorScalar_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Multiply(array.AsSpan(1, 2), 42, array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Multiply(array.AsSpan(1, 2), 42, array.AsSpan(2, 2))); + } + #endregion + + #region MultiplyAdd + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void MultiplyAdd_ThreeTensors(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory addend = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.MultiplyAdd(x, y, addend, destination); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance((x[i] * y[i]) + addend[i], destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void MultiplyAdd_ThreeTensors_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + + TensorPrimitives.MultiplyAdd(x, x, x, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance((xOrig[i] * xOrig[i]) + xOrig[i], x[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MultiplyAdd_ThreeTensors_ThrowsForMismatchedLengths_x_y(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory z = CreateAndFillTensor(tensorLength - 1); + using BoundedMemory destination = CreateTensor(tensorLength); + + Assert.Throws(() => TensorPrimitives.MultiplyAdd(x, y, z, destination)); + Assert.Throws(() => TensorPrimitives.MultiplyAdd(x, z, y, destination)); + Assert.Throws(() => TensorPrimitives.MultiplyAdd(z, x, y, destination)); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MultiplyAdd_ThreeTensors_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory addend = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(x, y, addend, destination)); + } + + [Fact] + public static void MultiplyAdd_ThreeTensors_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(5, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(6, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(8, 2))); + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void MultiplyAdd_TensorTensorScalar(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + float addend = NextSingle(); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.MultiplyAdd(x, y, addend, destination); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance((x[i] * y[i]) + addend, destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void MultiplyAdd_TensorTensorScalar_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + float addend = NextSingle(); + + TensorPrimitives.MultiplyAdd(x, x, addend, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance((xOrig[i] * xOrig[i]) + addend, x[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MultiplyAdd_TensorTensorScalar_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + float addend = NextSingle(); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(x, y, addend, destination)); + } + + [Fact] + public static void MultiplyAdd_TensorTensorScalar_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), 42, array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), 42, array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), 42, array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), 42, array.AsSpan(5, 2))); + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void MultiplyAdd_TensorScalarTensor(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float y = NextSingle(); + using BoundedMemory addend = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.MultiplyAdd(x, y, addend, destination); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance((x[i] * y) + addend[i], destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void MultiplyAdd_TensorScalarTensor_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + float y = NextSingle(); + + TensorPrimitives.MultiplyAdd(x, y, x, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance((xOrig[i] * y) + xOrig[i], x[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MultiplyAdd_TensorScalarTensor_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float y = NextSingle(); + using BoundedMemory addend = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(x, y, addend, destination)); + } + + [Fact] + public static void MultiplyAdd_TensorScalarTensor_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), 42, array.AsSpan(4, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), 42, array.AsSpan(4, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), 42, array.AsSpan(4, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), 42, array.AsSpan(4, 2), array.AsSpan(5, 2))); + } + #endregion + + #region Negate + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Negate(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.Negate(x, destination); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(-x[i], destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Negate_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + + TensorPrimitives.Negate(x, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(-xOrig[i], x[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Negate_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.Negate(x, destination)); + } + + [Fact] + public static void Negate_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Negate(array.AsSpan(1, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Negate(array.AsSpan(1, 2), array.AsSpan(2, 2))); + } + #endregion + + #region Norm + [Theory] + [InlineData(new float[] { 1, 2, 3 }, 3.7416575f)] + [InlineData(new float[] { 3, 4 }, 5)] + [InlineData(new float[] { 3 }, 3)] + [InlineData(new float[] { 3, 4, 1, 2 }, 5.477226)] + [InlineData(new float[] { }, 0f)] + public static void Norm_KnownValues(float[] x, float expectedResult) + { + AssertEqualTolerance(expectedResult, TensorPrimitives.Norm(x)); + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Norm(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + + float sumOfSquares = 0f; + for (int i = 0; i < x.Length; i++) + { + sumOfSquares += x[i] * x[i]; + } + + AssertEqualTolerance(MathF.Sqrt(sumOfSquares), TensorPrimitives.Norm(x)); + } + #endregion + + #region Product + [Fact] + public static void Product_ThrowsForEmpty() + { + Assert.Throws(() => TensorPrimitives.Product(ReadOnlySpan.Empty)); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Product(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + + float f = x[0]; + for (int i = 1; i < x.Length; i++) + { + f *= x[i]; + } + + AssertEqualTolerance(f, TensorPrimitives.Product(x)); + } + + [Theory] + [InlineData(1, new float[] { 1 })] + [InlineData(-2, new float[] { 1, -2 })] + [InlineData(-6, new float[] { 1, -2, 3 })] + [InlineData(24, new float[] { 1, -2, 3, -4 })] + [InlineData(120, new float[] { 1, -2, 3, -4, 5 })] + [InlineData(-720, new float[] { 1, -2, 3, -4, 5, -6 })] + [InlineData(0, new float[] { 1, -2, 3, -4, 5, -6, 0 })] + [InlineData(0, new float[] { 0, 1, -2, 3, -4, 5, -6 })] + [InlineData(0, new float[] { 1, -2, 3, 0, -4, 5, -6 })] + [InlineData(float.NaN, new float[] { 1, -2, 3, float.NaN, -4, 5, -6 })] + public static void Product_KnownValues(float expected, float[] input) + { + Assert.Equal(expected, TensorPrimitives.Product(input)); + } + #endregion + + #region ProductOfDifferences + [Fact] + public static void ProductOfDifferences_ThrowsForEmptyAndMismatchedLengths() + { + Assert.Throws(() => TensorPrimitives.ProductOfDifferences(ReadOnlySpan.Empty, ReadOnlySpan.Empty)); + Assert.Throws(() => TensorPrimitives.ProductOfDifferences(ReadOnlySpan.Empty, CreateTensor(1))); + Assert.Throws(() => TensorPrimitives.ProductOfDifferences(CreateTensor(1), ReadOnlySpan.Empty)); + Assert.Throws(() => TensorPrimitives.ProductOfDifferences(CreateTensor(44), CreateTensor(43))); + Assert.Throws(() => TensorPrimitives.ProductOfDifferences(CreateTensor(43), CreateTensor(44))); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void ProductOfDifferences(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + + float f = x[0] - y[0]; + for (int i = 1; i < x.Length; i++) + { + f *= x[i] - y[i]; + } + AssertEqualTolerance(f, TensorPrimitives.ProductOfDifferences(x, y)); + } + + [Theory] + [InlineData(0, new float[] {0 }, new float[] {0})] + [InlineData(0, new float[] {1 }, new float[] {1})] + [InlineData(1, new float[] {1 }, new float[] {0})] + [InlineData(-1, new float[] {0 }, new float[] {1})] + [InlineData(-1, new float[] {1, 2, 3, 4, 5 }, new float[] {2, 3, 4, 5, 6})] + [InlineData(120, new float[] {1, 2, 3, 4, 5 }, new float[] {0, 0, 0, 0, 0})] + [InlineData(-120, new float[] {0, 0, 0, 0, 0 }, new float[] {1, 2, 3, 4, 5})] + [InlineData(float.NaN, new float[] {1, 2, float.NaN, 4, 5 }, new float[] {0, 0, 0, 0, 0})] + public static void ProductOfDifferences_KnownValues(float expected, float[] x, float[] y) + { + Assert.Equal(expected, TensorPrimitives.ProductOfDifferences(x, y)); + + } + #endregion + + #region ProductOfSums + [Fact] + public static void ProductOfSums_ThrowsForEmptyAndMismatchedLengths() + { + Assert.Throws(() => TensorPrimitives.ProductOfSums(ReadOnlySpan.Empty, ReadOnlySpan.Empty)); + Assert.Throws(() => TensorPrimitives.ProductOfSums(ReadOnlySpan.Empty, CreateTensor(1))); + Assert.Throws(() => TensorPrimitives.ProductOfSums(CreateTensor(1), ReadOnlySpan.Empty)); + Assert.Throws(() => TensorPrimitives.ProductOfSums(CreateTensor(44), CreateTensor(43))); + Assert.Throws(() => TensorPrimitives.ProductOfSums(CreateTensor(43), CreateTensor(44))); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void ProductOfSums(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + + float f = x[0] + y[0]; + for (int i = 1; i < x.Length; i++) + { + f *= x[i] + y[i]; + } + AssertEqualTolerance(f, TensorPrimitives.ProductOfSums(x, y)); + } + + [Theory] + [InlineData(0, new float[] {0 }, new float[] { 0 })] + [InlineData(1, new float[] {0 }, new float[] { 1 })] + [InlineData(1, new float[] {1 }, new float[] { 0 })] + [InlineData(2, new float[] {1 }, new float[] { 1 })] + [InlineData(10395, new float[] {1, 2, 3, 4, 5 }, new float[] { 2, 3, 4, 5, 6 })] + [InlineData(120, new float[] {1, 2, 3, 4, 5 }, new float[] { 0, 0, 0, 0, 0 })] + [InlineData(120, new float[] {0, 0, 0, 0, 0 }, new float[] { 1, 2, 3, 4, 5 })] + [InlineData(float.NaN, new float[] {1, 2, float.NaN, 4, 5 }, new float[] { 0, 0, 0, 0, 0 })] + public static void ProductOfSums_KnownValues(float expected, float[] x, float[] y) + { + Assert.Equal(expected, TensorPrimitives.ProductOfSums(x, y)); + } + #endregion + + #region Sigmoid + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Sigmoid(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.Sigmoid(x, destination); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(1f / (1f + MathF.Exp(-x[i])), destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Sigmoid_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + + TensorPrimitives.Sigmoid(x, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(1f / (1f + MathF.Exp(-xOrig[i])), x[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Sigmoid_SpecialValues(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + RunForEachSpecialValue(() => + { + TensorPrimitives.Sigmoid(x, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(1f / (1f + MathF.Exp(-x[i])), destination[i]); + } + }, x); + } + + [Theory] + [InlineData(new float[] { -5, -4.5f, -4 }, new float[] { 0.0066f, 0.0109f, 0.0179f })] + [InlineData(new float[] { 4.5f, 5 }, new float[] { 0.9890f, 0.9933f })] + [InlineData(new float[] { 0, -3, 3, .5f }, new float[] { 0.5f, 0.0474f, 0.9525f, 0.6224f })] + public static void Sigmoid_KnownValues(float[] x, float[] expectedResult) + { + using BoundedMemory dest = CreateTensor(x.Length); + TensorPrimitives.Sigmoid(x, dest); + + for (int i = 0; i < x.Length; i++) + { + AssertEqualTolerance(expectedResult[i], dest[i], 0.0001f); + } + } + + [Theory] + [InlineData(new float[] { -5, -4.5f, -4 }, new float[] { 0.0066f, 0.0109f, 0.0179f })] + public static void Sigmoid_DestinationLongerThanSource(float[] x, float[] expectedResult) + { + using BoundedMemory dest = CreateTensor(x.Length + 1); + + TensorPrimitives.Sigmoid(x, dest); + + float originalLast = dest[dest.Length - 1]; + for (int i = 0; i < x.Length; i++) + { + AssertEqualTolerance(expectedResult[i], dest[i], 0.0001f); + } + Assert.Equal(originalLast, dest[dest.Length - 1]); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Sigmoid_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.Sigmoid(x, destination)); + } + + [Fact] + public static void Sigmoid_ThrowsForEmptyInput() + { + AssertExtensions.Throws(() => TensorPrimitives.Sigmoid(ReadOnlySpan.Empty, CreateTensor(1))); + } + + [Fact] + public static void Sigmoid_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Sigmoid(array.AsSpan(1, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Sigmoid(array.AsSpan(1, 2), array.AsSpan(2, 2))); + } + #endregion + + #region Sinh + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Sinh(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.Sinh(x, destination); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Sinh(x[i]), destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Sinh_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + + TensorPrimitives.Sinh(x, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Sinh(xOrig[i]), x[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Sinh_SpecialValues(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + RunForEachSpecialValue(() => + { + TensorPrimitives.Sinh(x, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Sinh(x[i]), destination[i]); + } + }, x); + } + + [Theory] + [MemberData(nameof(VectorLengthAndIteratedRange), new object[] { -100f, 100f, 3f })] + public static void Sinh_ValueRange(int vectorLengths, float element) + { + float[] x = new float[vectorLengths]; + float[] dest = new float[vectorLengths]; + + x.AsSpan().Fill(element); + TensorPrimitives.Sinh(x, dest); + + float expected = MathF.Sinh(element); + foreach (float actual in dest) + { + AssertEqualTolerance(expected, actual); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Sinh_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.Sinh(x, destination)); + } + + [Fact] + public static void Sinh_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Sinh(array.AsSpan(1, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Sinh(array.AsSpan(1, 2), array.AsSpan(2, 2))); + } + #endregion + + #region SoftMax + [Theory] + [MemberData(nameof(TensorLengths))] + public static void SoftMax(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.SoftMax(x, destination); + + float expSum = MemoryMarshal.ToEnumerable(x.Memory).Sum(MathF.Exp); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Exp(x[i]) / expSum, destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void SoftMax_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + + TensorPrimitives.SoftMax(x, x); + + float expSum = xOrig.Sum(MathF.Exp); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Exp(xOrig[i]) / expSum, x[i]); + } + } + + [Theory] + [InlineData(new float[] { 3, 1, .2f }, new float[] { 0.8360188f, 0.11314284f, 0.05083836f })] + [InlineData(new float[] { 3, 4, 1 }, new float[] { 0.2594f, 0.705384f, 0.0351f })] + [InlineData(new float[] { 5, 3 }, new float[] { 0.8807f, 0.1192f })] + [InlineData(new float[] { 4, 2, 1, 9 }, new float[] { 0.0066f, 9.04658e-4f, 3.32805e-4f, 0.9920f })] + public static void SoftMax_KnownValues(float[] x, float[] expectedResult) + { + using BoundedMemory dest = CreateTensor(x.Length); + TensorPrimitives.SoftMax(x, dest); + + for (int i = 0; i < x.Length; i++) + { + AssertEqualTolerance(expectedResult[i], dest[i], 0.0001f); + } + } + + [Fact] + public static void SoftMax_DestinationLongerThanSource() + { + float[] x = [3, 1, .2f]; + float[] expectedResult = [0.8360188f, 0.11314284f, 0.05083836f]; + using BoundedMemory dest = CreateTensor(x.Length + 1); + TensorPrimitives.SoftMax(x, dest); + + for (int i = 0; i < x.Length; i++) + { + AssertEqualTolerance(expectedResult[i], dest[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void SoftMax_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.SoftMax(x, destination)); + } + + [Fact] + public static void SoftMax_ThrowsForEmptyInput() + { + AssertExtensions.Throws(() => TensorPrimitives.SoftMax(ReadOnlySpan.Empty, CreateTensor(1))); + } + + [Fact] + public static void SoftMax_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.SoftMax(array.AsSpan(1, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.SoftMax(array.AsSpan(1, 2), array.AsSpan(2, 2))); + } + #endregion + + #region Subtract + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Subtract_TwoTensors(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.Subtract(x, y, destination); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(x[i] - y[i], destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Subtract_TwoTensors_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + + TensorPrimitives.Subtract(x, x, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(xOrig[i] - xOrig[i], x[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Subtract_TwoTensors_ThrowsForMismatchedLengths(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); + using BoundedMemory destination = CreateTensor(tensorLength); + + Assert.Throws(() => TensorPrimitives.Subtract(x, y, destination)); + Assert.Throws(() => TensorPrimitives.Subtract(y, x, destination)); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Subtract_TwoTensors_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.Subtract(x, y, destination)); + } + + [Fact] + public static void Subtract_TwoTensors_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Subtract(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Subtract(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Subtract(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Subtract(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(5, 2))); + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Subtract_TensorScalar(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float y = NextSingle(); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.Subtract(x, y, destination); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(x[i] - y, destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Subtract_TensorScalar_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + float y = NextSingle(); + + TensorPrimitives.Subtract(x, y, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(xOrig[i] - y, x[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Subtract_TensorScalar_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float y = NextSingle(); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.Subtract(x, y, destination)); + } + + [Fact] + public static void Subtract_TensorScalar_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Subtract(array.AsSpan(1, 2), 42, array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Subtract(array.AsSpan(1, 2), 42, array.AsSpan(2, 2))); + } + #endregion + + #region Sum + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Sum(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + + AssertEqualTolerance(MemoryMarshal.ToEnumerable(x.Memory).Sum(), TensorPrimitives.Sum(x)); + + float sum = 0; + foreach (float f in x.Span) + { + sum += f; + } + AssertEqualTolerance(sum, TensorPrimitives.Sum(x)); + } + + [Theory] + [InlineData(0, new float[] { 0 })] + [InlineData(1, new float[] { 0, 1 })] + [InlineData(6, new float[] { 1, 2, 3 })] + [InlineData(0, new float[] { -3, 0, 3 })] + [InlineData(float.NaN, new float[] { -3, float.NaN, 3 })] + public static void Sum_KnownValues(float expected, float[] x) + { + Assert.Equal(expected, TensorPrimitives.Sum(x)); + } + #endregion + + #region SumOfMagnitudes + [Theory] + [MemberData(nameof(TensorLengths))] + public static void SumOfMagnitudes(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + + AssertEqualTolerance(Enumerable.Sum(MemoryMarshal.ToEnumerable(x.Memory), MathF.Abs), TensorPrimitives.SumOfMagnitudes(x)); + + float sum = 0; + foreach (float f in x.Span) + { + sum += MathF.Abs(f); + } + AssertEqualTolerance(sum, TensorPrimitives.SumOfMagnitudes(x)); + } + + [Theory] + [InlineData(0, new float[] { 0 })] + [InlineData(1, new float[] { 0, 1 })] + [InlineData(6, new float[] { 1, 2, 3 })] + [InlineData(6, new float[] { -3, 0, 3 })] + [InlineData(float.NaN, new float[] { -3, float.NaN, 3 })] + public static void SumOfMagnitudes_KnownValues(float expected, float[] x) + { + Assert.Equal(expected, TensorPrimitives.SumOfMagnitudes(x)); + } + #endregion + + #region SumOfSquares + [Theory] + [MemberData(nameof(TensorLengths))] + public static void SumOfSquares(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + + AssertEqualTolerance(Enumerable.Sum(MemoryMarshal.ToEnumerable(x.Memory), v => v * v), TensorPrimitives.SumOfSquares(x)); + + float sum = 0; + foreach (float f in x.Span) + { + sum += f * f; + } + AssertEqualTolerance(sum, TensorPrimitives.SumOfSquares(x)); + } + + [Theory] + [InlineData(0, new float[] { 0 })] + [InlineData(1, new float[] { 0, 1 })] + [InlineData(14, new float[] { 1, 2, 3 })] + [InlineData(18, new float[] { -3, 0, 3 })] + [InlineData(float.NaN, new float[] { -3, float.NaN, 3 })] + public static void SumOfSquares_KnownValues(float expected, float[] x) + { + Assert.Equal(expected, TensorPrimitives.SumOfSquares(x)); + } + #endregion + + #region Tanh + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Tanh(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.Tanh(x, destination); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Tanh(x[i]), destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Tanh_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + + TensorPrimitives.Tanh(x, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Tanh(xOrig[i]), x[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Tanh_SpecialValues(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + RunForEachSpecialValue(() => + { + TensorPrimitives.Tanh(x, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Tanh(x[i]), destination[i]); + } + }, x); + } + + [Theory] + [MemberData(nameof(VectorLengthAndIteratedRange), new object[] { -11f, 11f, 0.2f })] + public static void Tanh_ValueRange(int vectorLengths, float element) + { + float[] x = new float[vectorLengths]; + float[] dest = new float[vectorLengths]; + + x.AsSpan().Fill(element); + TensorPrimitives.Tanh(x, dest); + + float expected = MathF.Tanh(element); + foreach (float actual in dest) + { + AssertEqualTolerance(expected, actual); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Tanh_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.Tanh(x, destination)); + } + + [Fact] + public static void Tanh_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Tanh(array.AsSpan(1, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Tanh(array.AsSpan(1, 2), array.AsSpan(2, 2))); + } + #endregion + } +} diff --git a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.netcore.cs b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.netcore.cs new file mode 100644 index 00000000000000..06ab341db16242 --- /dev/null +++ b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.netcore.cs @@ -0,0 +1,142 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; +using Xunit; + +namespace System.Numerics.Tensors.Tests +{ + public static partial class TensorPrimitivesTests + { + #region ConvertToHalf + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void ConvertToHalf(int tensorLength) + { + using BoundedMemory source = CreateAndFillTensor(tensorLength); + foreach (int destLength in new[] { source.Length, source.Length + 1 }) + { + using BoundedMemory destination = BoundedMemory.Allocate(destLength); + destination.Span.Fill(Half.Zero); + + TensorPrimitives.ConvertToHalf(source, destination); + + for (int i = 0; i < source.Length; i++) + { + Assert.Equal((Half)source[i], destination[i]); + } + + if (destination.Length > source.Length) + { + for (int i = source.Length; i < destination.Length; i++) + { + Assert.Equal(Half.Zero, destination[i]); + } + } + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void ConvertToHalf_SpecialValues(int tensorLength) + { + using BoundedMemory source = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = BoundedMemory.Allocate(tensorLength); + + // NaN, infinities, and 0s + source[s_random.Next(source.Length)] = float.NaN; + source[s_random.Next(source.Length)] = float.PositiveInfinity; + source[s_random.Next(source.Length)] = float.NegativeInfinity; + source[s_random.Next(source.Length)] = 0; + source[s_random.Next(source.Length)] = float.NegativeZero; + + TensorPrimitives.ConvertToHalf(source, destination); + + for (int i = 0; i < source.Length; i++) + { + Assert.Equal((Half)source[i], destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void ConvertToHalf_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory source = CreateAndFillTensor(tensorLength); + Half[] destination = new Half[source.Length - 1]; + + AssertExtensions.Throws("destination", () => TensorPrimitives.ConvertToHalf(source, destination)); + } + #endregion + + #region ConvertToSingle + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void ConvertToSingle(int tensorLength) + { + using BoundedMemory source = BoundedMemory.Allocate(tensorLength); + for (int i = 0; i < source.Length; i++) + { + source[i] = (Half)s_random.NextSingle(); + } + + foreach (int destLength in new[] { source.Length, source.Length + 1 }) + { + using BoundedMemory destination = CreateTensor(destLength); + destination.Span.Fill(0f); + + TensorPrimitives.ConvertToSingle(source, destination); + + for (int i = 0; i < source.Length; i++) + { + Assert.Equal((float)source[i], destination[i]); + } + + if (destination.Length > source.Length) + { + for (int i = source.Length; i < destination.Length; i++) + { + Assert.Equal(0f, destination[i]); + } + } + } + } + [Theory] + [MemberData(nameof(TensorLengths))] + public static void ConvertToSingle_SpecialValues(int tensorLength) + { + using BoundedMemory source = BoundedMemory.Allocate(tensorLength); + for (int i = 0; i < source.Length; i++) + { + source[i] = (Half)s_random.NextSingle(); + } + + using BoundedMemory destination = CreateTensor(tensorLength); + + // NaN, infinities, and 0s + source[s_random.Next(source.Length)] = Half.NaN; + source[s_random.Next(source.Length)] = Half.PositiveInfinity; + source[s_random.Next(source.Length)] = Half.NegativeInfinity; + source[s_random.Next(source.Length)] = Half.Zero; + source[s_random.Next(source.Length)] = Half.NegativeZero; + + TensorPrimitives.ConvertToSingle(source, destination); + + for (int i = 0; i < source.Length; i++) + { + Assert.Equal((float)source[i], destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void ConvertToSingle_ThrowsForTooShortDestination(int tensorLength) + { + Half[] source = new Half[tensorLength]; + using BoundedMemory destination = CreateTensor(source.Length - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.ConvertToSingle(source, destination)); + } + #endregion + } +} diff --git a/src/libraries/System.Numerics.Tensors/tests/TensorTemplate.ttinclude b/src/libraries/System.Numerics.Tensors/tests/TensorTemplate.ttinclude deleted file mode 100644 index 9448791a5db6c4..00000000000000 --- a/src/libraries/System.Numerics.Tensors/tests/TensorTemplate.ttinclude +++ /dev/null @@ -1,328 +0,0 @@ -<#@ import namespace="System.Linq" #> -<#@ import namespace="System.Text" #> -<#@ import namespace="System.Collections.Generic" #> -<#+ - public class TypeConfiguration - { - public TypeConfiguration(string typeName, string classPrefix = null, string oneLiteral = "1", string zeroLiteral = "0", bool supportsNumeric = true, bool supportsBitwise = true, IEnumerable unsupportedMethods = null) - { - TypeName = typeName; - ClassPrefix = classPrefix ?? char.ToUpper(typeName[0]) + typeName.Substring(1); - OneLiteral = oneLiteral; - ZeroLiteral = zeroLiteral; - SupportsNumeric = supportsNumeric; - SupportsBitwise = supportsBitwise; - UnsupportedMethods = new HashSet(unsupportedMethods ?? Enumerable.Empty()); - } - - public string TypeName { get; } - public string ClassPrefix { get; } - public string OneLiteral { get; } - public string ZeroLiteral { get; } - - public bool SupportsNumeric { get; } - public bool SupportsBitwise { get; } - public ISet UnsupportedMethods { get; } - } - - public string GenerateIfStatementHeader(TypeConfiguration type) - { - string keyword = (type == typeConfiguration[0]) ? "if" : "else if"; - return $"{keyword} (typeof(T) == typeof({type.TypeName}))"; - } - - public TypeConfiguration[] typeConfiguration = new [] - { - new TypeConfiguration("bool", oneLiteral:"true", zeroLiteral:"false", supportsNumeric: false, unsupportedMethods: new[] {"LeftShift", "RightShift"}), - new TypeConfiguration("byte"), - new TypeConfiguration("char", oneLiteral:"(char)1", zeroLiteral:"(char)0"), - new TypeConfiguration("decimal", supportsBitwise: false), - new TypeConfiguration("double", oneLiteral:"1.0", supportsBitwise: false), - new TypeConfiguration("float", oneLiteral:"1.0f", supportsBitwise: false), - new TypeConfiguration("int"), - new TypeConfiguration("long"), - new TypeConfiguration("sbyte", classPrefix:"SByte"), - new TypeConfiguration("short"), - new TypeConfiguration("uint", classPrefix:"UInt", unsupportedMethods: new[] {"UnaryMinus"}), - new TypeConfiguration("ulong", classPrefix:"ULong", unsupportedMethods: new[] {"UnaryMinus"}), - new TypeConfiguration("ushort", classPrefix:"UShort", unsupportedMethods: new[] {"UnaryMinus"}) - }; - - public enum MethodType - { - Unary, - UnaryInPlace, - BinaryScalar, - BinaryInt, - Binary, - Comparison, - Contraction - } - - public class MethodConfiguration - { - public MethodConfiguration(string methodName, MethodType methodType, string op = null, bool isNumeric = false, bool isBitwise = false) - { - MethodName = methodName; - MethodType = methodType; - Operator = op; - IsNumeric = isNumeric; - IsBitwise = isBitwise; - } - - public string ResultName => "result"; - - public string Op1Name - { - get - { - switch (MethodType) - { - case MethodType.Unary: - case MethodType.UnaryInPlace: - case MethodType.BinaryScalar: - case MethodType.BinaryInt: - return "tensor"; - case MethodType.Binary: - case MethodType.Comparison: - case MethodType.Contraction: - return "left"; - default: - throw new ArgumentException(); - }; - } - } - - public string Op2Name - { - get - { - switch (MethodType) - { - case MethodType.BinaryScalar: - return "scalar"; - case MethodType.BinaryInt: - return "value"; - case MethodType.Binary: - case MethodType.Comparison: - case MethodType.Contraction: - return "right"; - case MethodType.Unary: - case MethodType.UnaryInPlace: - default: - throw new ArgumentException(); - }; - } - } - - public string MethodName { get; } - public MethodType MethodType { get; } - public string Operator { get; } - - public string GetGenericMethodSignature(string tensorType, string genericType) - { - var resultType = GetResultType(tensorType, genericType); - var arguments = GetMethodArguments(tensorType, genericType); - - return $"{resultType} {MethodName}<{genericType}>({arguments})"; - } - - public string GetGenericResultMethodSignature(string tensorType, string genericType) - { - var resultType = GetResultType(tensorType, genericType); - var arguments = GetMethodArguments(tensorType, genericType); - - return $"void {MethodName}<{genericType}>({arguments}, {resultType} {ResultName})"; - } - - public string GetResultMethodSignature(string tensorType, string genericType) - { - var resultType = GetResultType(tensorType, genericType); - var arguments = GetMethodArguments(tensorType, genericType); - - return $"void {MethodName}({arguments}, {resultType} {ResultName})"; - } - - public string GetMethodArguments(string tensorType, string genericType) - { - switch (MethodType) - { - case MethodType.Unary: - case MethodType.UnaryInPlace: - return $"{tensorType}<{genericType}> {Op1Name}"; - case MethodType.BinaryScalar: - return $"{tensorType}<{genericType}> {Op1Name}, {genericType} {Op2Name}"; - case MethodType.BinaryInt: - return $"{tensorType}<{genericType}> {Op1Name}, int {Op2Name}"; - case MethodType.Binary: - case MethodType.Comparison: - return $"{tensorType}<{genericType}> {Op1Name}, {tensorType}<{genericType}> {Op2Name}"; - case MethodType.Contraction: - return $"{tensorType}<{genericType}> {Op1Name}, {tensorType}<{genericType}> {Op2Name}, int[] leftAxes, int[] rightAxes"; - default: - throw new ArgumentException(); - } - } - - public string GetCallArguments() - { - switch (MethodType) - { - case MethodType.Unary: - case MethodType.UnaryInPlace: - return $"{Op1Name}"; - case MethodType.BinaryScalar: - case MethodType.BinaryInt: - case MethodType.Binary: - case MethodType.Comparison: - return $"{Op1Name}, {Op2Name}"; - case MethodType.Contraction: - return "left, right, leftAxes, rightAxes"; - default: - throw new ArgumentException(); - } - } - - public string GetValidationMethod(bool includeResult) - { - var suffix = includeResult ? ", result" : ""; - switch (MethodType) - { - case MethodType.Unary: - case MethodType.UnaryInPlace: - case MethodType.BinaryScalar: - case MethodType.BinaryInt: - return $"ValidateArgs({Op1Name}{suffix});"; - case MethodType.Binary: - case MethodType.Comparison: - return $"ValidateBinaryArgs({Op1Name}, {Op2Name}{suffix});"; - case MethodType.Contraction: - return $"var resultDimensions = ValidateContractArgs({Op1Name}, {Op2Name}, leftAxes, rightAxes{suffix});"; - default: - throw new ArgumentException(); - } - } - - public string GetResultType(string tensorType, string typeName) - { - switch (MethodType) - { - case MethodType.Unary: - case MethodType.UnaryInPlace: - case MethodType.BinaryScalar: - case MethodType.BinaryInt: - case MethodType.Binary: - case MethodType.Contraction: - return $"{tensorType}<{typeName}>"; - case MethodType.Comparison: - return $"{tensorType}"; - default: - throw new ArgumentException(); - } - } - - public string GetLinearOperationCheck() - { - switch (MethodType) - { - case MethodType.Unary: - case MethodType.BinaryScalar: - case MethodType.BinaryInt: - return $"({ResultName}.IsReversedStride == {Op1Name}.IsReversedStride)"; - case MethodType.Binary: - case MethodType.Comparison: - return $"(({ResultName}.IsReversedStride == {Op1Name}.IsReversedStride) && ({ResultName}.IsReversedStride == {Op2Name}.IsReversedStride))"; - case MethodType.UnaryInPlace: - default: - throw new ArgumentException(); - } - } - - - public string GetElementOperation(string typeName, string access) - { - return GetElementOperation(typeName, access, access, access); - } - - public string GetElementOperation(string typeName, string resultAccess, string leftAccess, string rightAccess) - { - switch (MethodType) - { - case MethodType.Unary: - return $"{ResultName}{resultAccess} = ({typeName}){Operator}{Op1Name}{leftAccess}"; - case MethodType.UnaryInPlace: - return $"{ResultName}{resultAccess}{Operator}"; - case MethodType.BinaryScalar: - case MethodType.BinaryInt: - return $"{ResultName}{resultAccess} = ({typeName})({Op1Name}{leftAccess} {Operator} {Op2Name})"; - case MethodType.Binary: - return $"{ResultName}{resultAccess} = ({typeName})({Op1Name}{leftAccess} {Operator} {Op2Name}{rightAccess})"; - case MethodType.Comparison: - return $"{ResultName}{resultAccess} = {Op1Name}{leftAccess} {Operator} {Op2Name}{rightAccess}"; - default: - throw new ArgumentException(); - - } - } - - public string InitializeResult(string typeName) - { - switch (MethodType) - { - case MethodType.UnaryInPlace: - return $"{Op1Name}.Clone()"; - case MethodType.Unary: - case MethodType.BinaryScalar: - case MethodType.BinaryInt: - case MethodType.Binary: - return $"{Op1Name}.CloneEmpty()"; - case MethodType.Comparison: - return $"{Op1Name}.CloneEmpty()"; - case MethodType.Contraction: - return $"{Op1Name}.CloneEmpty(resultDimensions)"; - default: - throw new ArgumentException(); - } - } - - public bool IsNumeric { get; } - public bool IsBitwise { get; } - } - - - public MethodConfiguration[] methodConfiguration = new [] - { - new MethodConfiguration("Add", MethodType.Binary, "+", isNumeric:true), - new MethodConfiguration("Add", MethodType.BinaryScalar, "+", isNumeric:true), - new MethodConfiguration("UnaryPlus", MethodType.Unary, "+", isNumeric:true), - new MethodConfiguration("Subtract", MethodType.Binary, "-", isNumeric:true), - new MethodConfiguration("Subtract", MethodType.BinaryScalar, "-", isNumeric:true), - new MethodConfiguration("UnaryMinus", MethodType.Unary, "-", isNumeric:true), - new MethodConfiguration("Increment", MethodType.UnaryInPlace, "++", isNumeric:true), - new MethodConfiguration("Decrement", MethodType.UnaryInPlace, "--", isNumeric:true), - new MethodConfiguration("Multiply", MethodType.Binary, "*", isNumeric:true), // element-wise product, not matrix product - new MethodConfiguration("Multiply", MethodType.BinaryScalar, "*", isNumeric:true), - new MethodConfiguration("Divide", MethodType.Binary, "/", isNumeric:true), - new MethodConfiguration("Divide", MethodType.BinaryScalar, "/", isNumeric:true), - new MethodConfiguration("Modulo", MethodType.Binary, "%", isNumeric:true), - new MethodConfiguration("Modulo", MethodType.BinaryScalar, "%", isNumeric:true), - new MethodConfiguration("And", MethodType.Binary, "&", isBitwise: true), - new MethodConfiguration("And", MethodType.BinaryScalar, "&", isBitwise: true), - new MethodConfiguration("Or", MethodType.Binary, "|", isBitwise: true), - new MethodConfiguration("Or", MethodType.BinaryScalar, "|", isBitwise: true), - new MethodConfiguration("Xor", MethodType.Binary, "^", isBitwise: true), - new MethodConfiguration("Xor", MethodType.BinaryScalar, "^", isBitwise: true), - new MethodConfiguration("LeftShift", MethodType.BinaryInt, "<<", isBitwise: true), - new MethodConfiguration("RightShift", MethodType.BinaryInt, ">>", isBitwise: true), - - // Note all of these are element-wise operations not testing the operation on the entire Tensor - new MethodConfiguration("Equals", MethodType.Comparison, "=="), - new MethodConfiguration("NotEquals", MethodType.Comparison, "!="), - new MethodConfiguration("GreaterThanOrEqual", MethodType.Comparison, ">=", isNumeric:true), - new MethodConfiguration("LessThanOrEqual", MethodType.Comparison, "<=", isNumeric:true), - new MethodConfiguration("GreaterThan", MethodType.Comparison, ">", isNumeric:true), - new MethodConfiguration("LessThan", MethodType.Comparison, "<", isNumeric:true), - - new MethodConfiguration("Contract", MethodType.Contraction, isNumeric:true), - }.OrderBy(m => m.MethodName).ToArray(); -#> diff --git a/src/libraries/System.Numerics.Tensors/tests/TensorTests.cs b/src/libraries/System.Numerics.Tensors/tests/TensorTests.cs deleted file mode 100644 index 27c7ba75e4e259..00000000000000 --- a/src/libraries/System.Numerics.Tensors/tests/TensorTests.cs +++ /dev/null @@ -1,2486 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.Collections; -using System.Collections.Generic; -using System.Linq; -using Xunit; - -namespace System.Numerics.Tensors.Tests -{ - public class TensorTests : TensorTestsBase - { - [Theory()] - [MemberData(nameof(GetSingleTensorConstructors))] - public void ConstructTensorFromArrayRank1(TensorConstructor tensorConstructor) - { - var tensor = tensorConstructor.CreateFromArray(new[] { 0, 1, 2 }); - - Assert.Equal(tensorConstructor.IsReversedStride, tensor.IsReversedStride); - Assert.Equal(0, tensor[0]); - Assert.Equal(1, tensor[1]); - Assert.Equal(2, tensor[2]); - } - - [Theory()] - [MemberData(nameof(GetSingleTensorConstructors))] - public void ConstructTensorFromArrayRank2(TensorConstructor tensorConstructor) - { - var tensor = tensorConstructor.CreateFromArray(new[,] - { - {0, 1, 2}, - {3, 4, 5} - }); - - Assert.Equal(tensorConstructor.IsReversedStride, tensor.IsReversedStride); - Assert.Equal(0, tensor[0, 0]); - Assert.Equal(1, tensor[0, 1]); - Assert.Equal(2, tensor[0, 2]); - Assert.Equal(3, tensor[1, 0]); - Assert.Equal(4, tensor[1, 1]); - Assert.Equal(5, tensor[1, 2]); - } - - [Theory()] - [MemberData(nameof(GetSingleTensorConstructors))] - public void ConstructTensorFromArrayRank3(TensorConstructor tensorConstructor) - { - var tensor = tensorConstructor.CreateFromArray(new[, ,] - { - { - {0, 1, 2}, - {3, 4, 5} - }, - { - {6, 7 ,8 }, - {9, 10 ,11 }, - }, - { - {12, 13 ,14 }, - {15, 16 ,17 }, - }, - { - {18, 19 ,20 }, - {21, 22 ,23 }, - } - }); - - Assert.Equal(tensorConstructor.IsReversedStride, tensor.IsReversedStride); - - Assert.Equal(0, tensor[0, 0, 0]); - Assert.Equal(1, tensor[0, 0, 1]); - Assert.Equal(2, tensor[0, 0, 2]); - Assert.Equal(3, tensor[0, 1, 0]); - Assert.Equal(4, tensor[0, 1, 1]); - Assert.Equal(5, tensor[0, 1, 2]); - - Assert.Equal(6, tensor[1, 0, 0]); - Assert.Equal(7, tensor[1, 0, 1]); - Assert.Equal(8, tensor[1, 0, 2]); - Assert.Equal(9, tensor[1, 1, 0]); - Assert.Equal(10, tensor[1, 1, 1]); - Assert.Equal(11, tensor[1, 1, 2]); - - Assert.Equal(12, tensor[2, 0, 0]); - Assert.Equal(13, tensor[2, 0, 1]); - Assert.Equal(14, tensor[2, 0, 2]); - Assert.Equal(15, tensor[2, 1, 0]); - Assert.Equal(16, tensor[2, 1, 1]); - Assert.Equal(17, tensor[2, 1, 2]); - - Assert.Equal(18, tensor[3, 0, 0]); - Assert.Equal(19, tensor[3, 0, 1]); - Assert.Equal(20, tensor[3, 0, 2]); - Assert.Equal(21, tensor[3, 1, 0]); - Assert.Equal(22, tensor[3, 1, 1]); - Assert.Equal(23, tensor[3, 1, 2]); - } - - [Fact] - public void ConstructDenseTensorFromPointer() - { - using (var nativeMemory = NativeMemoryFromArray(Enumerable.Range(0, 24).ToArray())) - { - var dimensions = new[] { 4, 2, 3 }; - var tensor = new DenseTensor(nativeMemory.Memory, dimensions, false); - - Assert.Equal(0, tensor[0, 0, 0]); - Assert.Equal(1, tensor[0, 0, 1]); - Assert.Equal(2, tensor[0, 0, 2]); - Assert.Equal(3, tensor[0, 1, 0]); - Assert.Equal(4, tensor[0, 1, 1]); - Assert.Equal(5, tensor[0, 1, 2]); - - Assert.Equal(6, tensor[1, 0, 0]); - Assert.Equal(7, tensor[1, 0, 1]); - Assert.Equal(8, tensor[1, 0, 2]); - Assert.Equal(9, tensor[1, 1, 0]); - Assert.Equal(10, tensor[1, 1, 1]); - Assert.Equal(11, tensor[1, 1, 2]); - - Assert.Equal(12, tensor[2, 0, 0]); - Assert.Equal(13, tensor[2, 0, 1]); - Assert.Equal(14, tensor[2, 0, 2]); - Assert.Equal(15, tensor[2, 1, 0]); - Assert.Equal(16, tensor[2, 1, 1]); - Assert.Equal(17, tensor[2, 1, 2]); - - Assert.Equal(18, tensor[3, 0, 0]); - Assert.Equal(19, tensor[3, 0, 1]); - Assert.Equal(20, tensor[3, 0, 2]); - Assert.Equal(21, tensor[3, 1, 0]); - Assert.Equal(22, tensor[3, 1, 1]); - Assert.Equal(23, tensor[3, 1, 2]); - } - } - - [Theory()] - [MemberData(nameof(GetSingleTensorConstructors))] - public void ConstructSparseTensor(TensorConstructor tensorConstructor) - { - var tensor = tensorConstructor.CreateFromArray(new[,] - { - {0, 0, 0, 0}, - {5, 8, 0, 0}, - {0, 0, 3, 0}, - {0, 6, 0, 0} - }); - - Assert.Equal(tensorConstructor.IsReversedStride, tensor.IsReversedStride); - - - Assert.Equal(0, tensor[0, 0]); - Assert.Equal(0, tensor[0, 1]); - Assert.Equal(0, tensor[0, 2]); - Assert.Equal(0, tensor[0, 3]); - - - Assert.Equal(5, tensor[1, 0]); - Assert.Equal(8, tensor[1, 1]); - Assert.Equal(0, tensor[1, 2]); - Assert.Equal(0, tensor[1, 3]); - - - Assert.Equal(0, tensor[2, 0]); - Assert.Equal(0, tensor[2, 1]); - Assert.Equal(3, tensor[2, 2]); - Assert.Equal(0, tensor[2, 3]); - - - Assert.Equal(0, tensor[3, 0]); - Assert.Equal(6, tensor[3, 1]); - Assert.Equal(0, tensor[3, 2]); - Assert.Equal(0, tensor[3, 3]); - - if (tensorConstructor.TensorType == TensorType.CompressedSparse) - { - var compressedSparseTensor = (CompressedSparseTensor)tensor; - - Assert.Equal(4, compressedSparseTensor.NonZeroCount); - - int[] expectedValues, expectedCompressedCounts, expectedIndices; - - if (compressedSparseTensor.IsReversedStride) - { - // csc - expectedValues = new[] { 5, 8, 6, 3 }; - expectedCompressedCounts = new[] { 0, 1, 3, 4, 4 }; - expectedIndices = new[] { 1, 1, 3, 2 }; - } - else - { - // csr - expectedValues = new[] { 5, 8, 3, 6 }; - expectedCompressedCounts = new[] { 0, 0, 2, 3, 4 }; - expectedIndices = new[] { 0, 1, 2, 1 }; - } - Assert.Equal(expectedValues, compressedSparseTensor.Values.Slice(0, compressedSparseTensor.NonZeroCount).ToArray()); - Assert.Equal(expectedCompressedCounts, compressedSparseTensor.CompressedCounts.ToArray()); - Assert.Equal(expectedIndices, compressedSparseTensor.Indices.Slice(0, compressedSparseTensor.NonZeroCount).ToArray()); - } - } - - [Theory()] - [InlineData(false)] - [InlineData(true)] - public void ConstructCompressedSparseTensorFromPointers(bool isReversedStride) - { - int[] values, compressedCounts, indices; - if (isReversedStride) - { - // csc - values = new[] { 5, 8, 6, 3 }; - compressedCounts = new[] { 0, 1, 3, 4, 4 }; - indices = new[] { 1, 1, 3, 2 }; - } - else - { - // csr - values = new[] { 5, 8, 3, 6 }; - compressedCounts = new[] { 0, 0, 2, 3, 4 }; - indices = new[] { 0, 1, 2, 1 }; - } - int[] dimensions = new[] { 4, 4 }; - - using (var valuesMemory = NativeMemoryFromArray(values)) - using (var compressedCountsMemory = NativeMemoryFromArray(compressedCounts)) - using (var indicesMemory = NativeMemoryFromArray(indices)) - { - var tensor = new CompressedSparseTensor(valuesMemory.Memory, - compressedCountsMemory.Memory, - indicesMemory.Memory, - values.Length, - dimensions, - isReversedStride); - - Assert.Equal(0, tensor[0, 0]); - Assert.Equal(0, tensor[0, 1]); - Assert.Equal(0, tensor[0, 2]); - Assert.Equal(0, tensor[0, 3]); - - - Assert.Equal(5, tensor[1, 0]); - Assert.Equal(8, tensor[1, 1]); - Assert.Equal(0, tensor[1, 2]); - Assert.Equal(0, tensor[1, 3]); - - - Assert.Equal(0, tensor[2, 0]); - Assert.Equal(0, tensor[2, 1]); - Assert.Equal(3, tensor[2, 2]); - Assert.Equal(0, tensor[2, 3]); - - - Assert.Equal(0, tensor[3, 0]); - Assert.Equal(6, tensor[3, 1]); - Assert.Equal(0, tensor[3, 2]); - Assert.Equal(0, tensor[3, 3]); - } - } - - [Theory()] - [MemberData(nameof(GetSingleTensorConstructors))] - public void ConstructFromDimensions(TensorConstructor tensorConstructor) - { - var tensor = tensorConstructor.CreateFromDimensions(new[] { 2, 3, 4 }); - Assert.Equal(3, tensor.Rank); - Assert.Equal(3, tensor.Dimensions.Length); - Assert.Equal(2, tensor.Dimensions[0]); - Assert.Equal(3, tensor.Dimensions[1]); - Assert.Equal(4, tensor.Dimensions[2]); - Assert.Equal(24, tensor.Length); - Assert.Equal(tensorConstructor.IsReversedStride, tensor.IsReversedStride); - - //Assert.Throws("dimensions", () => tensorConstructor.CreateFromDimensions(dimensions: null)); - Assert.Throws("dimensions", () => tensorConstructor.CreateFromDimensions(dimensions: new int[0])); - - Assert.Throws("dimensions", () => tensorConstructor.CreateFromDimensions(dimensions: new[] { 1, 0 })); - Assert.Throws("dimensions", () => tensorConstructor.CreateFromDimensions(dimensions: new[] { 1, -1 })); - - // ensure dimensions are immutable - var dimensions = new[] { 1, 2, 3 }; - tensor = tensorConstructor.CreateFromDimensions(dimensions: dimensions); - dimensions[0] = dimensions[1] = dimensions[2] = 0; - Assert.Equal(1, tensor.Dimensions[0]); - Assert.Equal(2, tensor.Dimensions[1]); - Assert.Equal(3, tensor.Dimensions[2]); - } - - [ConditionalTheory(typeof(PlatformDetection), nameof(PlatformDetection.IsNonZeroLowerBoundArraySupported))] - [MemberData(nameof(GetSingleTensorConstructors))] - public void ConstructTensorFromArrayRank3WithLowerBounds(TensorConstructor tensorConstructor) - { - var dimensions = new[] { 2, 3, 4 }; - var lowerBounds = new[] { 0, 5, 200 }; - var arrayWithLowerBounds = Array.CreateInstance(typeof(int), dimensions, lowerBounds); - - int value = 0; - for (int x = lowerBounds[0]; x < lowerBounds[0] + dimensions[0]; x++) - { - for (int y = lowerBounds[1]; y < lowerBounds[1] + dimensions[1]; y++) - { - for (int z = lowerBounds[2]; z < lowerBounds[2] + dimensions[2]; z++) - { - arrayWithLowerBounds.SetValue(value++, x, y, z); - } - } - } - - var tensor = tensorConstructor.CreateFromArray(arrayWithLowerBounds); - - var expected = tensorConstructor.CreateFromArray(new[, ,] - { - { - { 0, 1, 2, 3 }, - { 4, 5, 6, 7 }, - { 8, 9, 10, 11 } - }, - { - { 12, 13, 14, 15 }, - { 16, 17, 18, 19 }, - { 20, 21, 22, 23 } - } - } - ); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(expected, tensor)); - Assert.Equal(tensorConstructor.IsReversedStride, tensor.IsReversedStride); - } - - [Theory()] - [MemberData(nameof(GetDualTensorConstructors))] - public void StructurallyEqualTensor(TensorConstructor leftConstructor, TensorConstructor rightConstructor) - { - var arr = new[, ,] - { - { - {0, 1, 2}, - {3, 4, 5} - }, - { - {6, 7 ,8 }, - {9, 10 ,11 }, - }, - { - {12, 13 ,14 }, - {15, 16 ,17 }, - }, - { - {18, 19 ,20 }, - {21, 22 ,23 }, - } - }; - var tensor = leftConstructor.CreateFromArray(arr); - var tensor2 = rightConstructor.CreateFromArray(arr); - - Assert.Equal(0, StructuralComparisons.StructuralComparer.Compare(tensor, tensor2)); - Assert.Equal(0, StructuralComparisons.StructuralComparer.Compare(tensor2, tensor)); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tensor, tensor2)); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tensor2, tensor)); - // Issue: should Tensors with different layout be structurally equal? - if (leftConstructor.IsReversedStride == leftConstructor.IsReversedStride) - { - Assert.Equal(StructuralComparisons.StructuralEqualityComparer.GetHashCode(tensor), StructuralComparisons.StructuralEqualityComparer.GetHashCode(tensor2)); - } - } - - [Theory()] - [MemberData(nameof(GetSingleTensorConstructors))] - public void StructurallyEqualArray(TensorConstructor tensorConstructor) - { - var arr = new[, ,] - { - { - {0, 1, 2}, - {3, 4, 5} - }, - { - {6, 7 ,8 }, - {9, 10 ,11 }, - }, - { - {12, 13 ,14 }, - {15, 16 ,17 }, - }, - { - {18, 19 ,20 }, - {21, 22 ,23 }, - } - }; - var tensor = tensorConstructor.CreateFromArray(arr); - - Assert.Equal(0, StructuralComparisons.StructuralComparer.Compare(tensor, arr)); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tensor, arr)); - - } - - [Theory()] - [MemberData(nameof(GetSingleTensorConstructors))] - public void GetDiagonalSquare(TensorConstructor tensorConstructor) - { - var arr = new[,] - { - { 1, 2, 4 }, - { 8, 3, 9 }, - { 1, 7, 5 }, - }; - - var tensor = tensorConstructor.CreateFromArray(arr); - var diag = tensor.GetDiagonal(); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(diag, new[] { 1, 3, 5 })); - diag = tensor.GetDiagonal(1); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(diag, new[] { 2, 9 })); - diag = tensor.GetDiagonal(2); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(diag, new[] { 4 })); - Assert.Throws("offset", () => tensor.GetDiagonal(3)); - - diag = tensor.GetDiagonal(-1); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(diag, new[] { 8, 7 })); - diag = tensor.GetDiagonal(-2); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(diag, new[] { 1 })); - Assert.Throws("offset", () => tensor.GetDiagonal(-3)); - } - - [Theory()] - [MemberData(nameof(GetSingleTensorConstructors))] - public void GetDiagonalRectangle(TensorConstructor tensorConstructor) - { - var arr = new[,] - { - { 1, 2, 4, 3, 7 }, - { 8, 3, 9, 2, 6 }, - { 1, 7, 5, 2, 9 } - }; - - var tensor = tensorConstructor.CreateFromArray(arr); - var diag = tensor.GetDiagonal(); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(diag, new[] { 1, 3, 5 })); - diag = tensor.GetDiagonal(1); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(diag, new[] { 2, 9, 2 })); - diag = tensor.GetDiagonal(2); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(diag, new[] { 4, 2, 9 })); - diag = tensor.GetDiagonal(3); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(diag, new[] { 3, 6 })); - diag = tensor.GetDiagonal(4); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(diag, new[] { 7 })); - Assert.Throws("offset", () => tensor.GetDiagonal(5)); - - diag = tensor.GetDiagonal(-1); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(diag, new[] { 8, 7 })); - diag = tensor.GetDiagonal(-2); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(diag, new[] { 1 })); - Assert.Throws("offset", () => tensor.GetDiagonal(-3)); - Assert.Throws("offset", () => tensor.GetDiagonal(-4)); - Assert.Throws("offset", () => tensor.GetDiagonal(-5)); - } - - - [Theory()] - [MemberData(nameof(GetSingleTensorConstructors))] - public void GetDiagonalCube(TensorConstructor tensorConstructor) - { - var arr = new[, ,] - { - { - { 1, 2, 4 }, - { 8, 3, 9 }, - { 1, 7, 5 }, - }, - { - { 4, 5, 7 }, - { 1, 6, 2 }, - { 3, 0, 8 }, - }, - { - { 5, 6, 1 }, - { 2, 2, 3 }, - { 4, 9, 4 }, - }, - - }; - - var tensor = tensorConstructor.CreateFromArray(arr); - var diag = tensor.GetDiagonal(); - var expected = new[,] - { - { 1, 2, 4 }, - { 1, 6, 2 }, - { 4, 9, 4 } - }; - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(diag, expected)); - Assert.Equal(tensorConstructor.IsReversedStride, diag.IsReversedStride); - } - - [Theory()] - [MemberData(nameof(GetSingleTensorConstructors))] - public void GetTriangleSquare(TensorConstructor tensorConstructor) - { - var arr = new[,] - { - { 1, 2, 4 }, - { 8, 3, 9 }, - { 1, 7, 5 }, - }; - - var tensor = tensorConstructor.CreateFromArray(arr); - var tri = tensor.GetTriangle(0); - Assert.Equal(tensorConstructor.IsReversedStride, tri.IsReversedStride); - - var expected = tensorConstructor.CreateFromArray(new[,] - { - { 1, 0, 0 }, - { 8, 3, 0 }, - { 1, 7, 5 }, - }); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); - tri = tensor.GetTriangle(1); - expected = tensorConstructor.CreateFromArray(new[,] - { - { 1, 2, 0 }, - { 8, 3, 9 }, - { 1, 7, 5 }, - }); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); - tri = tensor.GetTriangle(2); - expected = tensorConstructor.CreateFromArray(new[,] - { - { 1, 2, 4 }, - { 8, 3, 9 }, - { 1, 7, 5 }, - }); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); - - tri = tensor.GetTriangle(3); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); - - tri = tensor.GetTriangle(200); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); - - tri = tensor.GetTriangle(-1); - expected = tensorConstructor.CreateFromArray(new[,] - { - { 0, 0, 0 }, - { 8, 0, 0 }, - { 1, 7, 0 }, - }); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); - tri = tensor.GetTriangle(-2); - expected = tensorConstructor.CreateFromArray(new[,] - { - { 0, 0, 0 }, - { 0, 0, 0 }, - { 1, 0, 0 }, - }); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); - - - expected = tensorConstructor.CreateFromArray(new[,] - { - { 0, 0, 0 }, - { 0, 0, 0 }, - { 0, 0, 0 }, - }); - tri = tensor.GetTriangle(-3); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); - - // same as -3, should it be an exception? - tri = tensor.GetTriangle(-4); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); - tri = tensor.GetTriangle(-300); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); - } - - [Theory()] - [MemberData(nameof(GetSingleTensorConstructors))] - public void GetTriangleRectangle(TensorConstructor tensorConstructor) - { - var arr = new[,] - { - { 1, 2, 4, 3, 7 }, - { 8, 3, 9, 2, 6 }, - { 1, 7, 5, 2, 9 } - }; - - var tensor = tensorConstructor.CreateFromArray(arr); - var tri = tensor.GetTriangle(0); - var expected = tensorConstructor.CreateFromArray(new[,] - { - { 1, 0, 0, 0, 0 }, - { 8, 3, 0, 0, 0 }, - { 1, 7, 5, 0, 0 } - }); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); - Assert.Equal(tensorConstructor.IsReversedStride, tri.IsReversedStride); - - tri = tensor.GetTriangle(1); - expected = tensorConstructor.CreateFromArray(new[,] - { - { 1, 2, 0, 0, 0 }, - { 8, 3, 9, 0, 0 }, - { 1, 7, 5, 2, 0 } - }); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); - tri = tensor.GetTriangle(2); - expected = tensorConstructor.CreateFromArray(new[,] - { - { 1, 2, 4, 0, 0 }, - { 8, 3, 9, 2, 0 }, - { 1, 7, 5, 2, 9 } - }); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); - tri = tensor.GetTriangle(3); - expected = tensorConstructor.CreateFromArray(new[,] - { - { 1, 2, 4, 3, 0 }, - { 8, 3, 9, 2, 6 }, - { 1, 7, 5, 2, 9 } - }); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); - - tri = tensor.GetTriangle(4); - expected = tensorConstructor.CreateFromArray(new[,] - { - { 1, 2, 4, 3, 7 }, - { 8, 3, 9, 2, 6 }, - { 1, 7, 5, 2, 9 } - }); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); - - // same as 4, should it be an exception? - tri = tensor.GetTriangle(5); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); - tri = tensor.GetTriangle(1000); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); - - tri = tensor.GetTriangle(-1); - expected = tensorConstructor.CreateFromArray(new[,] - { - { 0, 0, 0, 0, 0 }, - { 8, 0, 0, 0, 0 }, - { 1, 7, 0, 0, 0 } - }); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); - - expected = tensorConstructor.CreateFromArray(new[,] - { - { 0, 0, 0, 0, 0 }, - { 0, 0, 0, 0, 0 }, - { 1, 0, 0, 0, 0 } - }); - tri = tensor.GetTriangle(-2); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); - - expected = tensorConstructor.CreateFromArray(new[,] - { - { 0, 0, 0, 0, 0 }, - { 0, 0, 0, 0, 0 }, - { 0, 0, 0, 0, 0 } - }); - tri = tensor.GetTriangle(-3); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); - - tri = tensor.GetTriangle(-4); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); - tri = tensor.GetTriangle(-5); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); - tri = tensor.GetTriangle(-100); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); - } - - [Theory()] - [MemberData(nameof(GetSingleTensorConstructors))] - public void GetTriangleCube(TensorConstructor tensorConstructor) - { - var arr = new[, ,] - { - { - { 1, 2, 4 }, - { 8, 3, 9 }, - { 1, 7, 5 }, - }, - { - { 4, 5, 7 }, - { 1, 6, 2 }, - { 3, 0, 8 }, - }, - { - { 5, 6, 1 }, - { 2, 2, 3 }, - { 4, 9, 4 }, - }, - - }; - - var tensor = tensorConstructor.CreateFromArray(arr); - var tri = tensor.GetTriangle(0); - var expected = tensorConstructor.CreateFromArray(new[, ,] - { - { - { 1, 2, 4 }, - { 0, 0, 0 }, - { 0, 0, 0 }, - }, - { - { 4, 5, 7 }, - { 1, 6, 2 }, - { 0, 0, 0 }, - }, - { - { 5, 6, 1 }, - { 2, 2, 3 }, - { 4, 9, 4 }, - }, - - }); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); - Assert.Equal(tensorConstructor.IsReversedStride, tri.IsReversedStride); - } - - [Theory()] - [MemberData(nameof(GetSingleTensorConstructors))] - public void GetUpperTriangleSquare(TensorConstructor tensorConstructor) - { - var arr = new[,] - { - { 1, 2, 4 }, - { 8, 3, 9 }, - { 1, 7, 5 }, - }; - - var tensor = tensorConstructor.CreateFromArray(arr); - var tri = tensor.GetUpperTriangle(0); - - var expected = tensorConstructor.CreateFromArray(new[,] - { - { 1, 2, 4 }, - { 0, 3, 9 }, - { 0, 0, 5 }, - }); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); - Assert.Equal(tensorConstructor.IsReversedStride, tri.IsReversedStride); - - tri = tensor.GetUpperTriangle(1); - expected = tensorConstructor.CreateFromArray(new[,] - { - { 0, 2, 4 }, - { 0, 0, 9 }, - { 0, 0, 0 }, - }); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); - tri = tensor.GetUpperTriangle(2); - expected = tensorConstructor.CreateFromArray(new[,] - { - { 0, 0, 4 }, - { 0, 0, 0 }, - { 0, 0, 0 }, - }); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); - - tri = tensor.GetUpperTriangle(3); - expected = tensorConstructor.CreateFromArray(new[,] - { - { 0, 0, 0 }, - { 0, 0, 0 }, - { 0, 0, 0 }, - }); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); - - tri = tensor.GetUpperTriangle(4); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); - tri = tensor.GetUpperTriangle(42); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); - - tri = tensor.GetUpperTriangle(-1); - expected = tensorConstructor.CreateFromArray(new[,] - { - { 1, 2, 4 }, - { 8, 3, 9 }, - { 0, 7, 5 }, - }); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); - tri = tensor.GetUpperTriangle(-2); - expected = tensorConstructor.CreateFromArray(new[,] - { - { 1, 2, 4 }, - { 8, 3, 9 }, - { 1, 7, 5 }, - }); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); - - tri = tensor.GetUpperTriangle(-3); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); - tri = tensor.GetUpperTriangle(-300); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); - } - - [Theory()] - [MemberData(nameof(GetSingleTensorConstructors))] - public void GetUpperTriangleRectangle(TensorConstructor tensorConstructor) - { - var arr = new[,] - { - { 1, 2, 4, 3, 7 }, - { 8, 3, 9, 2, 6 }, - { 1, 7, 5, 2, 9 } - }; - - var tensor = tensorConstructor.CreateFromArray(arr); - var tri = tensor.GetUpperTriangle(0); - var expected = tensorConstructor.CreateFromArray(new[,] - { - { 1, 2, 4, 3, 7 }, - { 0, 3, 9, 2, 6 }, - { 0, 0, 5, 2, 9 } - }); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); - Assert.Equal(tensorConstructor.IsReversedStride, tri.IsReversedStride); - tri = tensor.GetUpperTriangle(1); - expected = tensorConstructor.CreateFromArray(new[,] - { - { 0, 2, 4, 3, 7 }, - { 0, 0, 9, 2, 6 }, - { 0, 0, 0, 2, 9 } - }); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); - tri = tensor.GetUpperTriangle(2); - expected = tensorConstructor.CreateFromArray(new[,] - { - { 0, 0, 4, 3, 7 }, - { 0, 0, 0, 2, 6 }, - { 0, 0, 0, 0, 9 } - }); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); - tri = tensor.GetUpperTriangle(3); - expected = tensorConstructor.CreateFromArray(new[,] - { - { 0, 0, 0, 3, 7 }, - { 0, 0, 0, 0, 6 }, - { 0, 0, 0, 0, 0 } - }); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); - - tri = tensor.GetUpperTriangle(4); - expected = tensorConstructor.CreateFromArray(new[,] - { - { 0, 0, 0, 0, 7 }, - { 0, 0, 0, 0, 0 }, - { 0, 0, 0, 0, 0 } - }); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); - - expected = tensorConstructor.CreateFromArray(new[,] - { - { 0, 0, 0, 0, 0 }, - { 0, 0, 0, 0, 0 }, - { 0, 0, 0, 0, 0 } - }); - tri = tensor.GetUpperTriangle(5); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); - tri = tensor.GetUpperTriangle(6); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); - tri = tensor.GetUpperTriangle(1000); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); - - tri = tensor.GetUpperTriangle(-1); - expected = tensorConstructor.CreateFromArray(new[,] - { - { 1, 2, 4, 3, 7 }, - { 8, 3, 9, 2, 6 }, - { 0, 7, 5, 2, 9 } - }); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); - - expected = tensorConstructor.CreateFromArray(new[,] - { - { 1, 2, 4, 3, 7 }, - { 8, 3, 9, 2, 6 }, - { 1, 7, 5, 2, 9 } - }); - tri = tensor.GetUpperTriangle(-2); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); - - tri = tensor.GetUpperTriangle(-3); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); - tri = tensor.GetUpperTriangle(-4); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); - tri = tensor.GetUpperTriangle(-100); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); - } - - [Theory()] - [MemberData(nameof(GetSingleTensorConstructors))] - public void GetUpperTriangleCube(TensorConstructor tensorConstructor) - { - var arr = new[, ,] - { - { - { 1, 2, 4 }, - { 8, 3, 9 }, - { 1, 7, 5 }, - }, - { - { 4, 5, 7 }, - { 1, 6, 2 }, - { 3, 0, 8 }, - }, - { - { 5, 6, 1 }, - { 2, 2, 3 }, - { 4, 9, 4 }, - }, - - }; - - var tensor = tensorConstructor.CreateFromArray(arr); - var tri = tensor.GetUpperTriangle(0); - var expected = tensorConstructor.CreateFromArray(new[, ,] - { - { - { 1, 2, 4 }, - { 8, 3, 9 }, - { 1, 7, 5 }, - }, - { - { 0, 0, 0 }, - { 1, 6, 2 }, - { 3, 0, 8 }, - }, - { - { 0, 0, 0 }, - { 0, 0, 0 }, - { 4, 9, 4 }, - }, - - }); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tri, expected)); - Assert.Equal(tensorConstructor.IsReversedStride, tri.IsReversedStride); - } - - [Theory()] - [MemberData(nameof(GetSingleTensorConstructors))] - public void Reshape(TensorConstructor tensorConstructor) - { - var arr = new[,] - { - { 1, 2, 3 }, - { 4, 5, 6 } - }; - - var tensor = tensorConstructor.CreateFromArray(arr); - var actual = tensor.Reshape(new[] { 3, 2 }); - - var expected = tensorConstructor.IsReversedStride ? - new[,] - { - { 1, 5 }, - { 4, 3 }, - { 2, 6 } - } : - new[,] - { - { 1, 2 }, - { 3, 4 }, - { 5, 6 } - }; - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); - Assert.Equal(tensorConstructor.IsReversedStride, actual.IsReversedStride); - } - - [Fact] - public void Identity() - { - var actual = Tensor.CreateIdentity(3); - - var expected = new[,] - { - {1.0, 0, 0 }, - {0, 1.0, 0 }, - {0, 0, 1.0 } - }; - - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); - } - - [Theory] - [MemberData(nameof(GetSingleTensorConstructors))] - public void CreateWithDiagonal(TensorConstructor tensorConstructor) - { - var diagonal = tensorConstructor.CreateFromArray(new[] { 1, 2, 3, 4, 5 }); - var actual = Tensor.CreateFromDiagonal(diagonal); - - var expected = new[,] - { - {1, 0, 0, 0, 0 }, - {0, 2, 0, 0, 0 }, - {0, 0, 3, 0, 0 }, - {0, 0, 0, 4, 0 }, - {0, 0, 0, 0, 5 } - }; - - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); - } - - [Theory] - [MemberData(nameof(GetSingleTensorConstructors))] - public void CreateWithDiagonal3D(TensorConstructor tensorConstructor) - { - var diagonal = tensorConstructor.CreateFromArray(new[,] - { - { 1, 2, 3, 4, 5 }, - { 1, 2, 3, 4, 5 }, - { 1, 2, 3, 4, 5 } - }); - var actual = Tensor.CreateFromDiagonal(diagonal); - var expected = new[, ,] - { - { - {1, 2, 3, 4, 5 }, - {0, 0, 0, 0, 0 }, - {0, 0, 0, 0, 0 } - }, - { - {0, 0, 0, 0, 0 }, - {1, 2, 3, 4, 5 }, - {0, 0, 0, 0, 0 } - }, - { - {0, 0, 0, 0, 0 }, - {0, 0, 0, 0, 0 }, - {1, 2, 3, 4, 5 } - } - }; - - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); - } - - [Theory] - [MemberData(nameof(GetSingleTensorConstructors))] - public void CreateWithDiagonalAndOffset(TensorConstructor tensorConstructor) - { - var diagonal = tensorConstructor.CreateFromArray(new[] { 1, 2, 3, 4 }); - var actual = Tensor.CreateFromDiagonal(diagonal, 1); - - var expected = new[,] - { - {0, 1, 0, 0, 0 }, - {0, 0, 2, 0, 0 }, - {0, 0, 0, 3, 0 }, - {0, 0, 0, 0, 4 }, - {0, 0, 0, 0, 0 } - }; - - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); - - diagonal = tensorConstructor.CreateFromArray(new[] { 1, 2, 3, 4 }); - actual = Tensor.CreateFromDiagonal(diagonal, -1); - - expected = new[,] - { - {0, 0, 0, 0, 0 }, - {1, 0, 0, 0, 0 }, - {0, 2, 0, 0, 0 }, - {0, 0, 3, 0, 0 }, - {0, 0, 0, 4, 0 } - }; - - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); - - diagonal = tensorConstructor.CreateFromArray(new[] { 1 }); - actual = Tensor.CreateFromDiagonal(diagonal, -4); - expected = new[,] - { - {0, 0, 0, 0, 0 }, - {0, 0, 0, 0, 0 }, - {0, 0, 0, 0, 0 }, - {0, 0, 0, 0, 0 }, - {1, 0, 0, 0, 0 } - }; - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); - - diagonal = tensorConstructor.CreateFromArray(new[] { 1 }); - actual = Tensor.CreateFromDiagonal(diagonal, 4); - expected = new[,] - { - {0, 0, 0, 0, 1 }, - {0, 0, 0, 0, 0 }, - {0, 0, 0, 0, 0 }, - {0, 0, 0, 0, 0 }, - {0, 0, 0, 0, 0 } - }; - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); - } - - [Theory] - [MemberData(nameof(GetSingleTensorConstructors))] - public void CreateWithDiagonalAndOffset3D(TensorConstructor tensorConstructor) - { - var diagonal = tensorConstructor.CreateFromArray(new[,] - { - { 1, 2, 3 }, - { 1, 2, 3 }, - { 1, 2, 3 } - }); - var actual = Tensor.CreateFromDiagonal(diagonal, 1); - - var expected = new[, ,] - { - { - { 0, 0, 0 }, - { 1, 2, 3 }, - { 0, 0, 0 }, - { 0, 0, 0 } - }, - { - { 0, 0, 0 }, - { 0, 0, 0 }, - { 1, 2, 3 }, - { 0, 0, 0 } - }, - { - { 0, 0, 0 }, - { 0, 0, 0 }, - { 0, 0, 0 }, - { 1, 2, 3 } - }, - { - { 0, 0, 0 }, - { 0, 0, 0 }, - { 0, 0, 0 }, - { 0, 0, 0 } - } - }; - - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); - - diagonal = tensorConstructor.CreateFromArray(new[,] - { - { 1, 2, 3 }, - { 1, 2, 3 }, - { 1, 2, 3 } - }); - actual = Tensor.CreateFromDiagonal(diagonal, -1); - - expected = new[, ,] - { - { - { 0, 0, 0 }, - { 0, 0, 0 }, - { 0, 0, 0 }, - { 0, 0, 0 } - }, - { - { 1, 2, 3 }, - { 0, 0, 0 }, - { 0, 0, 0 }, - { 0, 0, 0 } - }, - { - { 0, 0, 0 }, - { 1, 2, 3 }, - { 0, 0, 0 }, - { 0, 0, 0 } - }, - { - { 0, 0, 0 }, - { 0, 0, 0 }, - { 1, 2, 3 }, - { 0, 0, 0 } - } - }; - - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); - - diagonal = tensorConstructor.CreateFromArray(new[,] - { - { 1, 2, 3 } - }); - actual = Tensor.CreateFromDiagonal(diagonal, 3); - - expected = new[, ,] - { - { - { 0, 0, 0 }, - { 0, 0, 0 }, - { 0, 0, 0 }, - { 1, 2, 3 }, - }, - { - { 0, 0, 0 }, - { 0, 0, 0 }, - { 0, 0, 0 }, - { 0, 0, 0 } - }, - { - { 0, 0, 0 }, - { 0, 0, 0 }, - { 0, 0, 0 }, - { 0, 0, 0 } - }, - { - { 0, 0, 0 }, - { 0, 0, 0 }, - { 0, 0, 0 }, - { 0, 0, 0 } - } - }; - - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); - - diagonal = tensorConstructor.CreateFromArray(new[,] - { - { 1, 2, 3 } - }); - actual = Tensor.CreateFromDiagonal(diagonal, -3); - - expected = new[, ,] - { - { - { 0, 0, 0 }, - { 0, 0, 0 }, - { 0, 0, 0 }, - { 0, 0, 0 }, - }, - { - { 0, 0, 0 }, - { 0, 0, 0 }, - { 0, 0, 0 }, - { 0, 0, 0 } - }, - { - { 0, 0, 0 }, - { 0, 0, 0 }, - { 0, 0, 0 }, - { 0, 0, 0 } - }, - { - { 1, 2, 3 }, - { 0, 0, 0 }, - { 0, 0, 0 }, - { 0, 0, 0 } - } - }; - - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); - } - - [Theory()] - [MemberData(nameof(GetDualTensorConstructors))] - public void Add(TensorConstructor leftConstructor, TensorConstructor rightConstructor) - { - var left = leftConstructor.CreateFromArray( - new[,] - { - {0, 1, 2}, - {3, 4, 5} - }); - var right = rightConstructor.CreateFromArray( - new[,] - { - { 6, 7 ,8 }, - { 9, 10 ,11 }, - }); - - var expected = leftConstructor.CreateFromArray( - new[,] - { - { 6, 8, 10 }, - { 12, 14, 16 }, - }); - - var actual = TensorOperations.Add(left, right); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); - Assert.Equal(leftConstructor.IsReversedStride, actual.IsReversedStride); - - } - - [Theory()] - [MemberData(nameof(GetSingleTensorConstructors))] - public void AddScalar(TensorConstructor tensorConstructor) - { - var tensor = tensorConstructor.CreateFromArray( - new[,] - { - {0, 1, 2}, - {3, 4, 5} - }); - - var expected = tensorConstructor.CreateFromArray( - new[,] - { - { 1, 2, 3 }, - { 4, 5, 6 }, - }); - - var actual = TensorOperations.Add(tensor, 1); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); - Assert.Equal(tensorConstructor.IsReversedStride, actual.IsReversedStride); - - } - - [Theory()] - [MemberData(nameof(GetSingleTensorConstructors))] - public void UnaryPlus(TensorConstructor tensorConstructor) - { - var tensor = tensorConstructor.CreateFromArray( - new[,] - { - {0, 1, 2}, - {3, 4, 5} - }); - - var expected = tensor; - - var actual = TensorOperations.UnaryPlus(tensor); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); - Assert.False(ReferenceEquals(actual, expected)); - Assert.Equal(tensorConstructor.IsReversedStride, actual.IsReversedStride); - } - - - [Theory()] - [MemberData(nameof(GetDualTensorConstructors))] - public void Subtract(TensorConstructor leftConstructor, TensorConstructor rightConstructor) - { - var left = leftConstructor.CreateFromArray( - new[,] - { - {0, 1, 2}, - {3, 4, 5} - }); - var right = rightConstructor.CreateFromArray( - new[,] - { - { 6, 7 ,8 }, - { 9, 10 ,11 }, - }); - - var expected = leftConstructor.CreateFromArray( - new[,] - { - { -6, -6, -6 }, - { -6, -6, -6}, - }); - - var actual = TensorOperations.Subtract(left, right); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); - Assert.Equal(leftConstructor.IsReversedStride, actual.IsReversedStride); - } - - [Theory()] - [MemberData(nameof(GetSingleTensorConstructors))] - public void SubtractScalar(TensorConstructor tensorConstructor) - { - var tensor = tensorConstructor.CreateFromArray( - new[,] - { - {0, 1, 2}, - {3, 4, 5} - }); - var expected = tensorConstructor.CreateFromArray( - new[,] - { - { -1, 0, 1 }, - { 2, 3, 4 }, - }); - - var actual = TensorOperations.Subtract(tensor, 1); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); - Assert.Equal(tensorConstructor.IsReversedStride, actual.IsReversedStride); - } - - [Theory()] - [MemberData(nameof(GetSingleTensorConstructors))] - public void UnaryMinus(TensorConstructor tensorConstructor) - { - var tensor = tensorConstructor.CreateFromArray( - new[,] - { - {0, 1, 2}, - {3, 4, 5} - }); - - var expected = tensorConstructor.CreateFromArray( - new[,] - { - {0, -1, -2}, - {-3, -4, -5} - }); - - var actual = TensorOperations.UnaryMinus(tensor); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); - Assert.False(ReferenceEquals(actual, expected)); - Assert.Equal(tensorConstructor.IsReversedStride, actual.IsReversedStride); - } - - [Theory()] - [MemberData(nameof(GetSingleTensorConstructors))] - public void PrefixIncrement(TensorConstructor tensorConstructor) - { - Tensor tensor = tensorConstructor.CreateFromArray( - new[,] - { - {0, 1, 2}, - {3, 4, 5} - }); - - var expectedResult = tensorConstructor.CreateFromArray( - new[,] - { - {1, 2, 3}, - {4, 5, 6} - }); - - var expectedTensor = expectedResult; - - tensor = TensorOperations.Increment(tensor); - var actual = tensor; - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expectedResult)); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tensor, expectedTensor)); - Assert.True(ReferenceEquals(tensor, actual)); - Assert.Equal(tensorConstructor.IsReversedStride, actual.IsReversedStride); - } - - - [Theory()] - [MemberData(nameof(GetSingleTensorConstructors))] - public void PostfixIncrement(TensorConstructor tensorConstructor) - { - Tensor tensor = tensorConstructor.CreateFromArray( - new[,] - { - {0, 1, 2}, - {3, 4, 5} - }); - - // returns original value - var expectedResult = tensorConstructor.CreateFromArray( - new[,] - { - {0, 1, 2}, - {3, 4, 5} - }); - - // increments operand - var expectedTensor = tensorConstructor.CreateFromArray( - new[,] - { - {1, 2, 3}, - {4, 5, 6} - }); - - var actual = tensor; - tensor = TensorOperations.Increment(tensor); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expectedResult)); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tensor, expectedTensor)); - Assert.False(ReferenceEquals(tensor, actual)); - Assert.Equal(tensorConstructor.IsReversedStride, actual.IsReversedStride); - } - - - [Theory()] - [MemberData(nameof(GetSingleTensorConstructors))] - public void PrefixDecrement(TensorConstructor tensorConstructor) - { - Tensor tensor = tensorConstructor.CreateFromArray( - new[,] - { - {0, 1, 2}, - {3, 4, 5} - }); - - var expectedResult = tensorConstructor.CreateFromArray( - new[,] - { - {-1, 0, 1}, - {2, 3, 4} - }); - - var expectedTensor = expectedResult; - - tensor = TensorOperations.Decrement(tensor); - var actual = tensor; - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expectedResult)); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tensor, expectedTensor)); - Assert.True(ReferenceEquals(tensor, actual)); - Assert.Equal(tensorConstructor.IsReversedStride, actual.IsReversedStride); - } - - [Theory()] - [MemberData(nameof(GetSingleTensorConstructors))] - public void PostfixDecrement(TensorConstructor tensorConstructor) - { - Tensor tensor = tensorConstructor.CreateFromArray( - new[,] - { - {0, 1, 2}, - {3, 4, 5} - }); - - // returns original value - var expectedResult = tensorConstructor.CreateFromArray( - new[,] - { - {0, 1, 2}, - {3, 4, 5} - }); - - // decrements operand - var expectedTensor = tensorConstructor.CreateFromArray( - new[,] - { - {-1, 0, 1}, - {2, 3, 4} - }); - - var actual = tensor; - tensor = TensorOperations.Decrement(tensor); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expectedResult)); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(tensor, expectedTensor)); - Assert.False(ReferenceEquals(tensor, actual)); - Assert.Equal(tensorConstructor.IsReversedStride, actual.IsReversedStride); - } - - [Theory()] - [MemberData(nameof(GetDualTensorConstructors))] - public void Multiply(TensorConstructor leftConstructor, TensorConstructor rightConstructor) - { - var left = leftConstructor.CreateFromArray( - new[,] - { - {0, 1, 2}, - {3, 4, 5} - }); - var right = rightConstructor.CreateFromArray( - new[,] - { - {0, 1, 2}, - {3, 4, 5} - }); - - var expected = leftConstructor.CreateFromArray( - new[,] - { - {0, 1, 4}, - {9, 16, 25} - }); - - var actual = TensorOperations.Multiply(left, right); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); - Assert.Equal(leftConstructor.IsReversedStride, actual.IsReversedStride); - } - - [Theory()] - [MemberData(nameof(GetSingleTensorConstructors))] - public void MultiplyScalar(TensorConstructor tensorConstructor) - { - var tensor = tensorConstructor.CreateFromArray( - new[,] - { - {0, 1, 2}, - {3, 4, 5} - }); - - var expected = tensorConstructor.CreateFromArray( - new[,] - { - {0, 2, 4}, - {6, 8, 10} - }); - - var actual = TensorOperations.Multiply(tensor, 2); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); - Assert.Equal(tensorConstructor.IsReversedStride, actual.IsReversedStride); - } - - [Theory()] - [MemberData(nameof(GetDualTensorConstructors))] - public void Divide(TensorConstructor dividendConstructor, TensorConstructor divisorConstructor) - { - var dividend = dividendConstructor.CreateFromArray( - new[,] - { - {0, 1, 4}, - {9, 16, 25} - }); - - var divisor = divisorConstructor.CreateFromArray( - new[,] - { - {1, 1, 2}, - {3, 4, 5} - }); - - var expected = divisorConstructor.CreateFromArray( - new[,] - { - {0, 1, 2}, - {3, 4, 5} - }); - - var actual = TensorOperations.Divide(dividend, divisor); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); - Assert.Equal(dividendConstructor.IsReversedStride, actual.IsReversedStride); - } - - [Theory()] - [MemberData(nameof(GetSingleTensorConstructors))] - public void DivideScalar(TensorConstructor tensorConstructor) - { - var tensor = tensorConstructor.CreateFromArray( - new[,] - { - {0, 2, 4}, - {6, 8, 10} - }); - - var expected = tensorConstructor.CreateFromArray( - new[,] - { - {0, 1, 2}, - {3, 4, 5} - }); - - var actual = TensorOperations.Divide(tensor, 2); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); - Assert.Equal(tensorConstructor.IsReversedStride, actual.IsReversedStride); - } - - [Theory()] - [MemberData(nameof(GetDualTensorConstructors))] - public void Modulo(TensorConstructor dividendConstructor, TensorConstructor divisorConstructor) - { - var dividend = dividendConstructor.CreateFromArray( - new[,] - { - {0, 3, 8}, - {11, 14, 17} - }); - - var divisor = divisorConstructor.CreateFromArray( - new[,] - { - {1, 2, 3}, - {4, 5, 6} - }); - - var expected = dividendConstructor.CreateFromArray( - new[,] - { - {0, 1, 2}, - {3, 4, 5} - }); - - var actual = TensorOperations.Modulo(dividend, divisor); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); - Assert.Equal(dividendConstructor.IsReversedStride, actual.IsReversedStride); - } - - [Theory()] - [MemberData(nameof(GetSingleTensorConstructors))] - public void ModuloScalar(TensorConstructor tensorConstructor) - { - var tensor = tensorConstructor.CreateFromArray( - new[,] - { - {0, 3, 4}, - {7, 8, 9} - }); - - var expected = tensorConstructor.CreateFromArray( - new[,] - { - {0, 1, 0}, - {1, 0, 1} - }); - - var actual = TensorOperations.Modulo(tensor, 2); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); - Assert.Equal(tensorConstructor.IsReversedStride, actual.IsReversedStride); - } - - [Theory()] - [MemberData(nameof(GetDualTensorConstructors))] - public void And(TensorConstructor leftConstructor, TensorConstructor rightConstructor) - { - var left = leftConstructor.CreateFromArray( - new[,] - { - {0, 1, 3}, - {7, 15, 31} - }); - - var right = rightConstructor.CreateFromArray( - new[,] - { - {1, 1, 3}, - {2, 4, 8} - }); - - var expected = leftConstructor.CreateFromArray( - new[,] - { - {0, 1, 3}, - {2, 4, 8} - }); - - var actual = TensorOperations.And(left, right); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); - Assert.Equal(leftConstructor.IsReversedStride, actual.IsReversedStride); - } - - [Theory()] - [MemberData(nameof(GetSingleTensorConstructors))] - public void AndScalar(TensorConstructor tensorConstructor) - { - var left = tensorConstructor.CreateFromArray( - new[,] - { - {0, 1, 3}, - {5, 15, 31} - }); - - var expected = tensorConstructor.CreateFromArray( - new[,] - { - {0, 0, 0}, - {4, 4, 20} - }); - - var actual = TensorOperations.And(left, 20); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); - Assert.Equal(tensorConstructor.IsReversedStride, actual.IsReversedStride); - } - - [Theory()] - [MemberData(nameof(GetDualTensorConstructors))] - public void Or(TensorConstructor leftConstructor, TensorConstructor rightConstructor) - { - var left = leftConstructor.CreateFromArray( - new[,] - { - {0, 1, 3}, - {7, 14, 31} - }); - - var right = rightConstructor.CreateFromArray( - new[,] - { - {1, 2, 4}, - {2, 4, 8} - }); - - var expected = leftConstructor.CreateFromArray( - new[,] - { - {1, 3, 7}, - {7, 14, 31} - }); - - var actual = TensorOperations.Or(left, right); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); - Assert.Equal(leftConstructor.IsReversedStride, actual.IsReversedStride); - } - - [Theory()] - [MemberData(nameof(GetSingleTensorConstructors))] - public void OrScalar(TensorConstructor tensorConstructor) - { - var left = tensorConstructor.CreateFromArray( - new[,] - { - {0, 1, 2}, - {3, 4, 5} - }); - - var expected = tensorConstructor.CreateFromArray( - new[,] - { - {1, 1, 3}, - {3, 5, 5} - }); - - var actual = TensorOperations.Or(left, 1); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); - Assert.Equal(tensorConstructor.IsReversedStride, actual.IsReversedStride); - } - - [Theory()] - [MemberData(nameof(GetDualTensorConstructors))] - public void Xor(TensorConstructor leftConstructor, TensorConstructor rightConstructor) - { - var left = leftConstructor.CreateFromArray( - new[,] - { - {0, 1, 3}, - {7, 14, 31} - }); - - var right = rightConstructor.CreateFromArray( - new[,] - { - {1, 2, 4}, - {2, 4, 8} - }); - - var expected = leftConstructor.CreateFromArray( - new[,] - { - {1, 3, 7}, - {5, 10, 23} - }); - - var actual = TensorOperations.Xor(left, right); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); - Assert.Equal(leftConstructor.IsReversedStride, actual.IsReversedStride); - } - - [Theory()] - [MemberData(nameof(GetSingleTensorConstructors))] - public void XorScalar(TensorConstructor tensorConstructor) - { - var left = tensorConstructor.CreateFromArray( - new[,] - { - {0, 1, 2}, - {3, 4, 5} - }); - - var expected = tensorConstructor.CreateFromArray( - new[,] - { - {1, 0, 3}, - {2, 5, 4} - }); - - var actual = TensorOperations.Xor(left, 1); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); - Assert.Equal(tensorConstructor.IsReversedStride, actual.IsReversedStride); - } - - [Theory()] - [MemberData(nameof(GetSingleTensorConstructors))] - public void LeftShift(TensorConstructor tensorConstructor) - { - var left = tensorConstructor.CreateFromArray( - new[,] - { - {0, 1, 2}, - {3, 4, 5} - }); - - var expected = tensorConstructor.CreateFromArray( - new[,] - { - {0, 2, 4}, - {6, 8, 10} - }); - - var actual = TensorOperations.LeftShift(left, 1); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); - Assert.Equal(tensorConstructor.IsReversedStride, actual.IsReversedStride); - } - - [Theory()] - [MemberData(nameof(GetSingleTensorConstructors))] - public void RightShift(TensorConstructor tensorConstructor) - { - var left = tensorConstructor.CreateFromArray( - new[,] - { - {0, 1, 2}, - {3, 4, 5} - }); - - var expected = tensorConstructor.CreateFromArray( - new[,] - { - {0, 0, 1}, - {1, 2, 2} - }); - - var actual = TensorOperations.RightShift(left, 1); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); - Assert.Equal(tensorConstructor.IsReversedStride, actual.IsReversedStride); - } - - [Theory()] - [MemberData(nameof(GetDualTensorConstructors))] - public void ElementWiseEquals(TensorConstructor leftConstructor, TensorConstructor rightConstructor) - { - var left = leftConstructor.CreateFromArray( - new[,] - { - {0, 1, 2}, - {3, 4, 5} - }); - var right = rightConstructor.CreateFromArray( - new[,] - { - {0, 1, -2}, - {2, 3, 5} - }); - - var expected = new[,] - { - {true, true, false }, - {false, false, true} - }.ToTensor(); - - var actual = TensorOperations.Equals(left, right); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); - Assert.Equal(leftConstructor.IsReversedStride, actual.IsReversedStride); - } - - [Theory()] - [MemberData(nameof(GetDualTensorConstructors))] - public void ElementWiseNotEquals(TensorConstructor leftConstructor, TensorConstructor rightConstructor) - { - var left = leftConstructor.CreateFromArray( - new[,] - { - {0, 1, 2}, - {3, 4, 5} - }); - var right = rightConstructor.CreateFromArray( - new[,] - { - {0, 1, -2}, - {2, 3, 5} - }); - - var expected = new[,] - { - {false, false, true}, - {true, true, false} - }.ToTensor(); - - var actual = TensorOperations.NotEquals(left, right); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); - Assert.Equal(leftConstructor.IsReversedStride, actual.IsReversedStride); - } - - [Theory] - [MemberData(nameof(GetDualTensorConstructors))] - public void MatrixMultiply(TensorConstructor leftConstructor, TensorConstructor rightConstructor) - { - var left = leftConstructor.CreateFromArray( - new[,] - { - {0, 1, 2}, - {3, 4, 5} - }); - - var right = rightConstructor.CreateFromArray( - new[,] - { - {0, 1, 2, 3, 4}, - {5, 6, 7, 8, 9}, - {10, 11, 12, 13, 14} - }); - - var expected = leftConstructor.CreateFromArray( - new[,] - { - {0*0 + 1*5 + 2*10, 0*1 + 1*6 + 2*11, 0*2 + 1*7 + 2*12, 0*3 + 1*8 + 2*13, 0*4 + 1*9 + 2*14}, - {3*0 + 4*5 + 5*10, 3*1 + 4*6 + 5*11, 3*2 + 4*7 + 5*12, 3*3 + 4*8 + 5*13, 3*4 + 4*9 + 5*14} - }); - - var actual = left.MatrixMultiply(right); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); - } - - - [Theory] - [MemberData(nameof(GetDualTensorConstructors))] - public void Contract(TensorConstructor leftConstructor, TensorConstructor rightConstructor) - { - var left = leftConstructor.CreateFromArray( - new[, ,] - { - { - {0, 1}, - {2, 3} - }, - { - {4, 5}, - {6, 7} - }, - { - {8, 9}, - {10, 11} - } - }); - - var right = rightConstructor.CreateFromArray( - new[, ,] - { - { - {0, 1}, - {2, 3}, - {4, 5} - }, - { - {6, 7}, - {8, 9}, - {10, 11} - }, - { - {12, 13}, - {14, 15}, - {16, 17} - }, - { - {18, 19}, - {20, 21}, - {22, 23} - } - }); - - // contract a 3*2*2 with a 4*3*2 tensor, summing on (3*2)*2 and 4*(3*2) to produce a 2*4 tensor - var expected = leftConstructor.CreateFromArray( - new[,] - { - {110, 290, 470, 650}, - {125, 341, 557, 773}, - }); - var actual = TensorOperations.Contract(left, right, new[] { 0, 1 }, new[] { 1, 2 }); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); - - // contract a 3*2*2 with a 4*3*2 tensor, summing on (3)*2*(2) and 4*(3*2) to produce a 2*4 tensor - expected = leftConstructor.CreateFromArray( - new[,] - { - {101, 263, 425, 587}, - {131, 365, 599, 833}, - }); - actual = TensorOperations.Contract(left, right, new[] { 0, 2 }, new[] { 1, 2 }); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); - } - - - [Theory] - [MemberData(nameof(GetDualTensorConstructors))] - public void ContractWithSingleLengthDimension(TensorConstructor leftConstructor, TensorConstructor rightConstructor) - { - var left = leftConstructor.CreateFromArray( - new[,] - { - {1, 2, 3}, - {4, 5, 6}, - }); - - var right = rightConstructor.CreateFromArray( - new[,] - { - { 1, 2 }, - { 3, 4 }, - { 5, 6 } - }); - - var expected = leftConstructor.CreateFromArray( - new[,] - { - { 22, 28 }, - { 49, 64 } - }); - - // contract a 2*3 with a 3*2 tensor, summing on 2*(3) and (3)*2 to produce a 2*2 tensor - var actual = TensorOperations.Contract(left, right, new[] { 1 }, new[] { 0 }); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); - - // contract a 1*2*3*1 with a 3*2 tensor, summing on 1*2*(3)*1 and (3)*2 to produce a 1*2*1*2 tensor - var reshapedLeft = left.Reshape(new int[] { 1, 2, 3, 1 }); - var reshapedExpected = expected.Reshape(new int[] { 1, 2, 1, 2 }); - actual = TensorOperations.Contract(reshapedLeft, right, new[] { 2 }, new[] { 0 }); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, reshapedExpected)); - - } - - [Theory] - [MemberData(nameof(GetDualTensorConstructors))] - public void ContractMismatchedDimensions(TensorConstructor leftConstructor, TensorConstructor rightConstructor) - { - var left = leftConstructor.CreateFromArray( - new[] { 0, 1, 2, 3 }); - - var right = rightConstructor.CreateFromArray( - new[,] - { - { 0 }, - { 1 }, - { 2 } - }); - - var expected = leftConstructor.CreateFromArray( - new[,] - { - {0,0,0}, - {0,1,2}, - {0,2,4}, - {0,3,6}, - }); - - Assert.Throws(() => TensorOperations.Contract(left, right, new int[] { }, new[] { 1 })); - - // reshape to include dimension of length 1. - var leftReshaped = left.Reshape(new[] { 1, (int)left.Length }); - - var actual = TensorOperations.Contract(leftReshaped, right, new[] { 0 }, new[] { 1 }); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); - } - - [Theory] - [MemberData(nameof(GetSingleTensorConstructors))] - public void GetArrayString(TensorConstructor constructor) - { - var tensor = constructor.CreateFromArray( - new[, ,] - { - { - {0, 1}, - {2, 3}, - {4, 5} - }, - { - {6, 7}, - {8, 9}, - {10, 11} - }, - { - {12, 13}, - {14, 15}, - {16, 17} - }, - { - {18, 19}, - {20, 21}, - {22, 23} - } - }); - - var expected = -@"{ - { - {0,1}, - {2,3}, - {4,5} - }, - { - {6,7}, - {8,9}, - {10,11} - }, - { - {12,13}, - {14,15}, - {16,17} - }, - { - {18,19}, - {20,21}, - {22,23} - } -}"; - - Assert.Equal(expected, tensor.GetArrayString(), ignoreLineEndingDifferences: !LineEndingsHelper.IsNewLineConsistent); - - var expectedNoSpace = expected.Replace(LineEndingsHelper.CompiledNewline, "").Replace(" ", ""); - Assert.Equal(expectedNoSpace, tensor.GetArrayString(false)); - } - - [Theory] - [MemberData(nameof(GetTensorAndResultConstructor))] - public void ToOtherTensor(TensorConstructor sourceConstructor, TensorConstructor resultConstructor) - { - var array = new[, ,] - { - { - {0, 1, 0, 0 }, - {0, 0, 0, 9 }, - {2, 0, 5, 0 } - }, - { - {3, 0, 0, 6 }, - {0, 0, 0, 0 }, - {0, 0, 4, 0 } - }, - { - {0, 2, 0, 0 }, - {8, 0, 0, 0 }, - {0, 0, 12, 0 } - }, - { - {5, 5, 5, 0 }, - {0, 0, 0, 15 }, - {0, 0, 42, 0 } - }, - { - {1, 0, 0, 4 }, - {0, 2, 0, 0 }, - {0, 0, 3, 0 } - } - }; - - var source = sourceConstructor.CreateFromArray(array); - - Tensor expected = resultConstructor.CreateFromArray(array); - - Tensor actual; - - switch (resultConstructor.TensorType) - { - case TensorType.Dense: - actual = source.ToDenseTensor(); - break; - case TensorType.Sparse: - var actualSparse = source.ToSparseTensor(); - actual = actualSparse; - var expectedSparse = expected as SparseTensor; - Assert.Equal(expectedSparse.NonZeroCount, actualSparse.NonZeroCount); - break; - case TensorType.CompressedSparse: - var actualCompressedSparse = source.ToCompressedSparseTensor(); - actual = actualCompressedSparse; - var expectedCompressedSparse = expected as CompressedSparseTensor; - Assert.Equal(expectedCompressedSparse.NonZeroCount, actualCompressedSparse.NonZeroCount); - if (sourceConstructor.TensorType != TensorType.Dense) - { - // expect packed values when going from sparse -> sparse - Assert.Equal(actualCompressedSparse.NonZeroCount, actualCompressedSparse.Values.Length); - } - break; - default: - throw new ArgumentException(nameof(resultConstructor.TensorType)); - } - - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, expected)); - Assert.True(StructuralComparisons.StructuralEqualityComparer.Equals(actual, source)); - } - - [Theory] - [MemberData(nameof(GetSingleTensorConstructors))] - public void TestICollectionMembers(TensorConstructor constructor) - { - var arr = new[,] - { - { 1, 2, 3 }, - { 4, 5, 6 } - }; - - var tensor = constructor.CreateFromArray(arr); - ICollection tensorCollection = tensor; - - Assert.Equal(6, tensorCollection.Count); - - Assert.False(tensorCollection.IsSynchronized); - - Assert.True(ReferenceEquals(tensorCollection, tensorCollection.SyncRoot)); - - var actual = Array.CreateInstance(typeof(int), tensor.Length); - tensorCollection.CopyTo(actual, 0); - var expected = constructor.IsReversedStride ? - new[] { 1, 4, 2, 5, 3, 6 } : - new[] { 1, 2, 3, 4, 5, 6 }; - Assert.Equal(expected, actual); - - actual = Array.CreateInstance(typeof(int), tensor.Length + 2); - tensorCollection.CopyTo(actual, 2); - expected = constructor.IsReversedStride ? - new[] { 0, 0, 1, 4, 2, 5, 3, 6 } : - new[] { 0, 0, 1, 2, 3, 4, 5, 6 }; - Assert.Equal(expected, actual); - - Assert.Throws(() => tensorCollection.CopyTo(null, 0)); - Assert.Throws(() => tensorCollection.CopyTo(new int[3, 4], 0)); - Assert.Throws(() => tensorCollection.CopyTo(new int[5], 0)); - Assert.Throws(() => tensorCollection.CopyTo(new int[6], 1)); - } - - [Theory] - [MemberData(nameof(GetSingleTensorConstructors))] - public void TestIListMembers(TensorConstructor constructor) - { - var arr = new[,] - { - { 1, 2, 3 }, - { 4, 5, 6 } - }; - - var tensor = constructor.CreateFromArray(arr); - IList tensorList = tensor; - - int expectedIndexValue = constructor.IsReversedStride ? 4 : 2; - Assert.Equal(expectedIndexValue, tensorList[1]); - - tensorList[1] = 7; - Assert.Equal(7, tensorList[1]); - var expected = constructor.IsReversedStride ? - new[] { 1, 7, 2, 5, 3, 6 } : - new[] { 1, 7, 3, 4, 5, 6 }; - Assert.Equal(expected, tensor); - - Assert.True(tensorList.IsFixedSize); - Assert.False(tensorList.IsReadOnly); - - Assert.Throws(() => (tensorList).Add(8)); - - Assert.True(tensorList.Contains(5)); - Assert.True(tensorList.Contains(6)); - Assert.False(tensorList.Contains(0)); - Assert.False(tensorList.Contains(42)); - Assert.False(tensorList.Contains("foo")); - - Assert.Equal(constructor.IsReversedStride ? 3 : 4, tensorList.IndexOf(5)); - Assert.Equal(5, tensorList.IndexOf(6)); - Assert.Equal(-1, tensorList.IndexOf(0)); - Assert.Equal(-1, tensorList.IndexOf(42)); - - Assert.Throws(() => (tensorList).Insert(2, 5)); - Assert.Throws(() => (tensorList).Remove(1)); - Assert.Throws(() => (tensorList).RemoveAt(0)); - - tensorList.Clear(); - Assert.Equal(new[] { 0, 0, 0, 0, 0, 0 }, tensor); - } - - [Theory] - [MemberData(nameof(GetSingleTensorConstructors))] - public void TestICollectionTMembers(TensorConstructor constructor) - { - var arr = new[,] - { - { 1, 2, 3 }, - { 4, 5, 6 } - }; - - var tensor = constructor.CreateFromArray(arr); - ICollection tensorCollection = tensor; - - Assert.Equal(6, tensorCollection.Count); - Assert.False(tensorCollection.IsReadOnly); - - Assert.Throws(() => tensorCollection.Add(8)); - Assert.Throws(() => tensorCollection.Remove(1)); - - Assert.True(tensorCollection.Contains(5)); - Assert.True(tensorCollection.Contains(6)); - Assert.False(tensorCollection.Contains(0)); - Assert.False(tensorCollection.Contains(42)); - - var actual = new int[tensor.Length]; - tensorCollection.CopyTo(actual, 0); - var expected = constructor.IsReversedStride ? - new[] { 1, 4, 2, 5, 3, 6 } : - new[] { 1, 2, 3, 4, 5, 6 }; - Assert.Equal(expected, actual); - - actual = new int[tensor.Length + 2]; - tensorCollection.CopyTo(actual, 2); - expected = constructor.IsReversedStride ? - new[] { 0, 0, 1, 4, 2, 5, 3, 6 } : - new[] { 0, 0, 1, 2, 3, 4, 5, 6 }; - Assert.Equal(expected, actual); - - Assert.Throws(() => tensorCollection.CopyTo(null, 0)); - Assert.Throws(() => tensorCollection.CopyTo(new int[5], 0)); - Assert.Throws(() => tensorCollection.CopyTo(new int[6], 1)); - - tensorCollection.Clear(); - Assert.Equal(new[] { 0, 0, 0, 0, 0, 0 }, tensor); - } - - [Theory] - [MemberData(nameof(GetSingleTensorConstructors))] - public void TestIListTMembers(TensorConstructor constructor) - { - var arr = new[,] - { - { 1, 2, 3 }, - { 4, 5, 6 } - }; - - var tensor = constructor.CreateFromArray(arr); - IList tensorList = tensor; - - int expectedIndexValue = constructor.IsReversedStride ? 4 : 2; - Assert.Equal(expectedIndexValue, tensorList[1]); - - tensorList[1] = 7; - Assert.Equal(7, tensorList[1]); - var expected = constructor.IsReversedStride ? - new[] { 1, 7, 2, 5, 3, 6 } : - new[] { 1, 7, 3, 4, 5, 6 }; - Assert.Equal(expected, tensor); - - Assert.Equal(constructor.IsReversedStride ? 3 : 4, tensorList.IndexOf(5)); - Assert.Equal(5, tensorList.IndexOf(6)); - Assert.Equal(-1, tensorList.IndexOf(0)); - Assert.Equal(-1, tensorList.IndexOf(42)); - - Assert.Throws(() => (tensorList).Insert(2, 5)); - Assert.Throws(() => (tensorList).RemoveAt(0)); - } - - [Theory] - [MemberData(nameof(GetSingleTensorConstructors))] - public void TestIReadOnlyTMembers(TensorConstructor constructor) - { - var arr = new[,] - { - { 1, 2, 3 }, - { 4, 5, 6 } - }; - - var tensor = constructor.CreateFromArray(arr); - - IReadOnlyCollection tensorCollection = tensor; - Assert.Equal(6, tensorCollection.Count); - - IReadOnlyList tensorList = tensor; - int expectedIndexValue = constructor.IsReversedStride ? 4 : 2; - Assert.Equal(expectedIndexValue, tensorList[1]); - } - - [Theory] - [MemberData(nameof(GetConstructedTensors))] - public void TestGetEnumerator(Tensor tensor) - { - static IEnumerable GetExpected(Tensor tensor) - { - for (int index = 0; index < tensor.Length; ++index) - yield return tensor.GetValue(index); - } - - Assert.Equal(GetExpected(tensor), tensor); - } - - [Theory] - [MemberData(nameof(GetConstructedTensors))] - public void TestEnumeratorReset(Tensor tensor) - { - static long AdvanceEnumerator(ref Tensor.Enumerator enumerator, long maxCount) - { - long count = 0; - while (count < maxCount && enumerator.MoveNext()) - count++; - - return count; - } - - static void TestStepCountIfInRange(Tensor tensor, long stepCount) - { - if (stepCount < 0 || stepCount > tensor.Length) - return; - - var enumerator = tensor.GetEnumerator(); - long actualStepCount = AdvanceEnumerator(ref enumerator, stepCount); - - Assert.Equal(stepCount, actualStepCount); - - enumerator.Reset(); - - var itemsPostReset = new List(); - while (enumerator.MoveNext()) - itemsPostReset.Add(enumerator.Current); - - Assert.Equal(tensor, itemsPostReset); - } - - TestStepCountIfInRange(tensor, 1); - TestStepCountIfInRange(tensor, tensor.Length - 1); - TestStepCountIfInRange(tensor, tensor.Length / 4); - TestStepCountIfInRange(tensor, tensor.Length - tensor.Length / 4); - TestStepCountIfInRange(tensor, tensor.Length / 2); - TestStepCountIfInRange(tensor, tensor.Length); - } - - [Theory] - [MemberData(nameof(GetConstructedTensors))] - public void TestEnumeratorDispose_DoesNotThrow(Tensor tensor) - { - var enumerator = tensor.GetEnumerator(); - - enumerator.Dispose(); - enumerator.Dispose(); - } - } -} diff --git a/src/libraries/System.Numerics.Tensors/tests/TensorTestsBase.cs b/src/libraries/System.Numerics.Tensors/tests/TensorTestsBase.cs deleted file mode 100644 index 9774dd22662e6a..00000000000000 --- a/src/libraries/System.Numerics.Tensors/tests/TensorTestsBase.cs +++ /dev/null @@ -1,187 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.Collections.Generic; -using System.Linq; - -namespace System.Numerics.Tensors.Tests -{ - public class TensorTestsBase - { - public enum TensorType - { - Dense, - Sparse, - CompressedSparse - }; - - public class TensorConstructor - { - public TensorType TensorType { get; set; } - - public bool IsReversedStride { get; set; } - - public Tensor CreateFromArray(Array array) - { - switch (TensorType) - { - case TensorType.Dense: - return array.ToTensor(IsReversedStride); - case TensorType.Sparse: - return array.ToSparseTensor(IsReversedStride); - case TensorType.CompressedSparse: - return array.ToCompressedSparseTensor(IsReversedStride); - } - - throw new ArgumentException(nameof(TensorType)); - } - public Tensor CreateFromDimensions(ReadOnlySpan dimensions) - { - switch (TensorType) - { - case TensorType.Dense: - return new DenseTensor(dimensions, IsReversedStride); - case TensorType.Sparse: - return new SparseTensor(dimensions, IsReversedStride); - case TensorType.CompressedSparse: - return new CompressedSparseTensor(dimensions, IsReversedStride); - } - - throw new ArgumentException(nameof(TensorType)); - } - - public override string ToString() - { - return $"{TensorType}, {nameof(IsReversedStride)} = {IsReversedStride}"; - } - } - - private static TensorType[] s_tensorTypes = new[] - { - TensorType.Dense, - TensorType.Sparse, - TensorType.CompressedSparse - }; - - private static bool[] s_reverseStrideValues = new[] - { - false, - true - }; - - public static IEnumerable GetSingleTensorConstructors() - { - foreach (TensorType tensorType in s_tensorTypes) - { - foreach (bool isReversedStride in s_reverseStrideValues) - { - yield return new[] - { - new TensorConstructor() - { - TensorType = tensorType, - IsReversedStride = isReversedStride - } - }; - } - } - } - - public static IEnumerable GetDualTensorConstructors() - { - foreach (TensorType leftTensorType in s_tensorTypes) - { - foreach (TensorType rightTensorType in s_tensorTypes) - { - foreach (bool isLeftReversedStride in s_reverseStrideValues) - { - foreach (bool isRightReversedStride in s_reverseStrideValues) - { - yield return new[] - { - new TensorConstructor() - { - TensorType = leftTensorType, - IsReversedStride = isLeftReversedStride - }, - new TensorConstructor() - { - TensorType = rightTensorType, - IsReversedStride = isRightReversedStride - } - }; - } - } - } - } - } - - public static IEnumerable GetTensorAndResultConstructor() - { - foreach (TensorType leftTensorType in s_tensorTypes) - { - foreach (TensorType rightTensorType in s_tensorTypes) - { - foreach (bool isReversedStride in s_reverseStrideValues) - { - yield return new[] - { - new TensorConstructor() - { - TensorType = leftTensorType, - IsReversedStride = isReversedStride - }, - new TensorConstructor() - { - TensorType = rightTensorType, - IsReversedStride = isReversedStride - } - }; - } - } - } - } - - public static IEnumerable GetConstructedTensors() - { - foreach (var ctor in GetSingleTensorConstructors().Select(x => (TensorConstructor)x[0])) - { - yield return new object[] { ctor.CreateFromArray(Array.Empty()) }; - yield return new object[] { ctor.CreateFromArray(new[] { 7 }) }; - yield return new object[] { ctor.CreateFromArray(new[] { 7, 14 }) }; - yield return new object[] { ctor.CreateFromArray(new[] { 7, 14, 21 }) }; - yield return new object[] - { - ctor.CreateFromArray(new[,] - { - { 3, 6, 9 }, - { 5, 10, 15 }, - { 7, 14, 21 }, - { 11, 22, 33 } - }) - }; - } - } - - public static NativeMemory NativeMemoryFromArray(T[] array) - { - return NativeMemoryFromArray((Array)array); - } - - public static NativeMemory NativeMemoryFromArray(Array array) - { - // this silly method takes a managed array and copies it over to unmanaged memory, - // **only for test purposes** - - var memory = NativeMemory.Allocate(array.Length); - var span = memory.GetSpan(); - int index = 0; - foreach (T item in array) - { - span[index++] = item; - } - - return memory; - } - } -} diff --git a/src/libraries/System.Private.CoreLib/src/System/AppContextConfigHelper.cs b/src/libraries/System.Private.CoreLib/src/System/AppContextConfigHelper.cs index 3a8c3e8e7ea72d..efbe9e691625d7 100644 --- a/src/libraries/System.Private.CoreLib/src/System/AppContextConfigHelper.cs +++ b/src/libraries/System.Private.CoreLib/src/System/AppContextConfigHelper.cs @@ -65,6 +65,45 @@ internal static int GetInt32Config(string configName, int defaultValue, bool all } } + internal static int GetInt32Config(string configName, string envVariable, int defaultValue, bool allowNegative = true) + { + string? str = Environment.GetEnvironmentVariable(envVariable); + if (str != null) + { + try + { + int result; + if (str.StartsWith('0')) + { + if (str.Length >= 2 && str[1] == 'x') + { + result = Convert.ToInt32(str, 16); + } + else + { + result = Convert.ToInt32(str, 8); + } + } + else + { + result = int.Parse(str, NumberStyles.AllowLeadingSign, NumberFormatInfo.InvariantInfo); + } + + if (allowNegative || result >= 0) + { + return result; + } + } + catch (FormatException) + { + } + catch (OverflowException) + { + } + } + + return GetInt32Config(configName, defaultValue, allowNegative); + } internal static short GetInt16Config(string configName, short defaultValue, bool allowNegative = true) { @@ -112,5 +151,45 @@ internal static short GetInt16Config(string configName, short defaultValue, bool return defaultValue; } } + + internal static short GetInt16Config(string configName, string envVariable, short defaultValue, bool allowNegative = true) + { + string? str = Environment.GetEnvironmentVariable(envVariable); + if (str != null) + { + try + { + short result; + if (str.StartsWith('0')) + { + if (str.Length >= 2 && str[1] == 'x') + { + result = Convert.ToInt16(str, 16); + } + else + { + result = Convert.ToInt16(str, 8); + } + } + else + { + result = short.Parse(str, NumberStyles.AllowLeadingSign, NumberFormatInfo.InvariantInfo); + } + + if (allowNegative || result >= 0) + { + return result; + } + } + catch (FormatException) + { + } + catch (OverflowException) + { + } + } + + return GetInt16Config(configName, defaultValue, allowNegative); + } } } diff --git a/src/libraries/System.Private.CoreLib/src/System/Diagnostics/Tracing/EventPipeEventDispatcher.cs b/src/libraries/System.Private.CoreLib/src/System/Diagnostics/Tracing/EventPipeEventDispatcher.cs index efd4b8cfb656dd..030560b2002145 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Diagnostics/Tracing/EventPipeEventDispatcher.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Diagnostics/Tracing/EventPipeEventDispatcher.cs @@ -103,8 +103,8 @@ private void CommitDispatchConfiguration() new EventPipeProviderConfiguration(NativeRuntimeEventSource.EventSourceName, (ulong)aggregatedKeywords, (uint)enableLevel, null) }; - m_sessionID = EventPipeInternal.Enable(null, EventPipeSerializationFormat.NetTrace, DefaultEventListenerCircularMBSize, providerConfiguration); - if (m_sessionID == 0) + ulong sessionID = EventPipeInternal.Enable(null, EventPipeSerializationFormat.NetTrace, DefaultEventListenerCircularMBSize, providerConfiguration); + if (sessionID == 0) { throw new EventSourceException(SR.EventSource_CouldNotEnableEventPipe); } @@ -113,7 +113,7 @@ private void CommitDispatchConfiguration() EventPipeSessionInfo sessionInfo; unsafe { - if (!EventPipeInternal.GetSessionInfo(m_sessionID, &sessionInfo)) + if (!EventPipeInternal.GetSessionInfo(sessionID, &sessionInfo)) { Debug.Fail("GetSessionInfo returned false."); } @@ -124,8 +124,11 @@ private void CommitDispatchConfiguration() long syncTimeQPC = sessionInfo.StartTimeStamp; long timeQPCFrequency = sessionInfo.TimeStampFrequency; + Debug.Assert(Volatile.Read(ref m_sessionID) == 0); + Volatile.Write(ref m_sessionID, sessionID); + // Start the dispatch task. - StartDispatchTask(m_sessionID, syncTimeUtc, syncTimeQPC, timeQPCFrequency); + StartDispatchTask(sessionID, syncTimeUtc, syncTimeQPC, timeQPCFrequency); } private void StartDispatchTask(ulong sessionID, DateTime syncTimeUtc, long syncTimeQPC, long timeQPCFrequency) @@ -142,12 +145,16 @@ private void SetStopDispatchTask() { Debug.Assert(Monitor.IsEntered(m_dispatchControlLock)); - if (m_dispatchTask != null) + if (m_dispatchTaskCancellationSource?.IsCancellationRequested ?? true) { - Debug.Assert(m_dispatchTaskCancellationSource != null); - m_dispatchTaskCancellationSource?.Cancel(); - EventPipeInternal.SignalSession(m_sessionID); + return; } + + ulong sessionID = Volatile.Read(ref m_sessionID); + Debug.Assert(sessionID != 0); + m_dispatchTaskCancellationSource.Cancel(); + EventPipeInternal.SignalSession(sessionID); + Volatile.Write(ref m_sessionID, 0); } private unsafe void DispatchEventsToEventListeners(ulong sessionID, DateTime syncTimeUtc, long syncTimeQPC, long timeQPCFrequency, Task? previousDispatchTask, CancellationToken token) @@ -187,7 +194,16 @@ private unsafe void DispatchEventsToEventListeners(ulong sessionID, DateTime syn } } - // Disable the old session. This can happen asynchronously since we aren't using the old session anymore + // Wait for SignalSession() to be called before we call disable, otherwise + // the SignalSession() call could be on a disabled session. + SpinWait sw = default; + while (Volatile.Read(ref m_sessionID) == sessionID) + { + sw.SpinOnce(); + } + + // Disable the old session. This can happen asynchronously since we aren't using the old session + // anymore. EventPipeInternal.Disable(sessionID); } diff --git a/src/libraries/System.Private.CoreLib/src/System/Globalization/CultureData.Icu.cs b/src/libraries/System.Private.CoreLib/src/System/Globalization/CultureData.Icu.cs index 0d4dad112249ca..12af791d392e24 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Globalization/CultureData.Icu.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Globalization/CultureData.Icu.cs @@ -420,7 +420,19 @@ private static CultureInfo[] IcuEnumCultures(CultureTypes types) return Array.Empty(); } - int bufferLength = Interop.Globalization.GetLocales(null, 0); + int bufferLength; +#if TARGET_MACCATALYST || TARGET_IOS || TARGET_TVOS + if (GlobalizationMode.Hybrid) + { + bufferLength = Interop.Globalization.GetLocalesNative(null, 0); + } + else + { + bufferLength = Interop.Globalization.GetLocales(null, 0); + } +#else + bufferLength = Interop.Globalization.GetLocales(null, 0); +#endif if (bufferLength <= 0) { return Array.Empty(); @@ -428,7 +440,18 @@ private static CultureInfo[] IcuEnumCultures(CultureTypes types) char [] chars = new char[bufferLength]; +#if TARGET_MACCATALYST || TARGET_IOS || TARGET_TVOS + if (GlobalizationMode.Hybrid) + { + bufferLength = Interop.Globalization.GetLocalesNative(chars, bufferLength); + } + else + { + bufferLength = Interop.Globalization.GetLocales(chars, bufferLength); + } +#else bufferLength = Interop.Globalization.GetLocales(chars, bufferLength); +#endif if (bufferLength <= 0) { return Array.Empty(); diff --git a/src/libraries/System.Private.CoreLib/src/System/Half.cs b/src/libraries/System.Private.CoreLib/src/System/Half.cs index 8daa37bbab576b..cd3e6ab3ed73c3 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Half.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Half.cs @@ -1044,7 +1044,7 @@ public static explicit operator float(Half value) // BitConverter.SingleToUInt32Bits(1.0f) - ((uint)BitConverter.HalfToUInt16Bits((Half)1.0f) << 13) const uint ExponentOffset = 0x3800_0000u; // Mask for sign bit in Single - const uint FloatSignMask = float.SignMask; + const uint SingleSignMask = float.SignMask; // Mask for exponent bits in Half const uint HalfExponentMask = BiasedExponentMask; // Mask for bits in Single converted from Half @@ -1052,7 +1052,7 @@ public static explicit operator float(Half value) // Extract the internal representation of value short valueInInt16Bits = BitConverter.HalfToInt16Bits(value); // Extract sign bit of value - uint sign = (uint)(int)valueInInt16Bits & FloatSignMask; + uint sign = (uint)(int)valueInInt16Bits & SingleSignMask; // Copy sign bit to upper bits uint bitValueInProcess = (uint)valueInInt16Bits; // Extract exponent bits of value (BiasedExponent is not for here as it performs unnecessary shift) diff --git a/src/libraries/System.Private.CoreLib/src/System/IO/UnmanagedMemoryStream.cs b/src/libraries/System.Private.CoreLib/src/System/IO/UnmanagedMemoryStream.cs index 5f049e69445381..b2a3134ae7501a 100644 --- a/src/libraries/System.Private.CoreLib/src/System/IO/UnmanagedMemoryStream.cs +++ b/src/libraries/System.Private.CoreLib/src/System/IO/UnmanagedMemoryStream.cs @@ -23,14 +23,9 @@ namespace System.IO * of the UnmanagedMemoryStream. * 3) You clean up the memory when appropriate. The UnmanagedMemoryStream * currently will do NOTHING to free this memory. - * 4) All calls to Write and WriteByte may not be threadsafe currently. - * - * It may become necessary to add in some sort of - * DeallocationMode enum, specifying whether we unmap a section of memory, - * call free, run a user-provided delegate to free the memory, etc. - * We'll suggest user write a subclass of UnmanagedMemoryStream that uses - * a SafeHandle subclass to hold onto the memory. - * + * 4) This type is not thread safe. However, the implementation should prevent buffer + * overruns or returning uninitialized memory when Reads and Writes are called + * concurrently in thread unsafe manner. */ /// @@ -40,10 +35,10 @@ public class UnmanagedMemoryStream : Stream { private SafeBuffer? _buffer; private unsafe byte* _mem; - private long _length; - private long _capacity; - private long _position; - private long _offset; + private nuint _capacity; + private nuint _offset; + private nuint _length; // nuint to guarantee atomic access on 32-bit platforms + private long _position; // long to allow seeking to any location beyond the length of the stream. private FileAccess _access; private bool _isOpen; private CachedCompletedInt32Task _lastReadTask; // The last successful task returned from ReadAsync @@ -123,10 +118,10 @@ protected void Initialize(SafeBuffer buffer, long offset, long length, FileAcces } } - _offset = offset; + _offset = (nuint)offset; _buffer = buffer; - _length = length; - _capacity = length; + _length = (nuint)length; + _capacity = (nuint)length; _access = access; _isOpen = true; } @@ -171,8 +166,8 @@ protected unsafe void Initialize(byte* pointer, long length, long capacity, File _mem = pointer; _offset = 0; - _length = length; - _capacity = capacity; + _length = (nuint)length; + _capacity = (nuint)capacity; _access = access; _isOpen = true; } @@ -259,7 +254,7 @@ public override long Length get { EnsureNotClosed(); - return Interlocked.Read(ref _length); + return (long)_length; } } @@ -271,7 +266,7 @@ public long Capacity get { EnsureNotClosed(); - return _capacity; + return (long)_capacity; } } @@ -283,14 +278,14 @@ public override long Position get { if (!CanSeek) ThrowHelper.ThrowObjectDisposedException_StreamClosed(null); - return Interlocked.Read(ref _position); + return _position; } set { ArgumentOutOfRangeException.ThrowIfNegative(value); if (!CanSeek) ThrowHelper.ThrowObjectDisposedException_StreamClosed(null); - Interlocked.Exchange(ref _position, value); + _position = value; } } @@ -308,11 +303,10 @@ public unsafe byte* PositionPointer EnsureNotClosed(); // Use a temp to avoid a race - long pos = Interlocked.Read(ref _position); - if (pos > _capacity) + long pos = _position; + if (pos > (long)_capacity) throw new IndexOutOfRangeException(SR.IndexOutOfRange_UMSPosition); - byte* ptr = _mem + pos; - return ptr; + return _mem + pos; } set { @@ -327,7 +321,7 @@ public unsafe byte* PositionPointer if (newPosition < 0) throw new ArgumentOutOfRangeException(nameof(value), SR.ArgumentOutOfRange_UnmanagedMemStreamLength); - Interlocked.Exchange(ref _position, newPosition); + _position = newPosition; } } @@ -367,8 +361,13 @@ internal int ReadCore(Span buffer) // Use a local variable to avoid a race where another thread // changes our position after we decide we can read some bytes. - long pos = Interlocked.Read(ref _position); - long len = Interlocked.Read(ref _length); + long pos = _position; + + // Use a volatile read to prevent reading of the uninitialized memory. This volatile read + // and matching volatile write that set _length avoids reordering of NativeMemory.Clear + // operations with reading of the buffer below. + long len = (long)Volatile.Read(ref _length); + long n = Math.Min(len - pos, buffer.Length); if (n <= 0) { @@ -407,7 +406,7 @@ internal int ReadCore(Span buffer) } } - Interlocked.Exchange(ref _position, pos + n); + _position = pos + n; return nInt; } @@ -484,11 +483,16 @@ public override int ReadByte() EnsureNotClosed(); EnsureReadable(); - long pos = Interlocked.Read(ref _position); // Use a local to avoid a race condition - long len = Interlocked.Read(ref _length); + long pos = _position; // Use a local to avoid a race condition + + // Use a volatile read to prevent reading of the uninitialized memory. This volatile read + // and matching volatile write that set _length avoids reordering of NativeMemory.Clear + // operations with reading of the buffer below. + long len = (long)Volatile.Read(ref _length); + if (pos >= len) return -1; - Interlocked.Exchange(ref _position, pos + 1); + _position = pos + 1; int result; if (_buffer != null) { @@ -529,35 +533,33 @@ public override long Seek(long offset, SeekOrigin loc) { EnsureNotClosed(); + long newPosition; switch (loc) { case SeekOrigin.Begin: - if (offset < 0) + newPosition = offset; + if (newPosition < 0) throw new IOException(SR.IO_SeekBeforeBegin); - Interlocked.Exchange(ref _position, offset); break; case SeekOrigin.Current: - long pos = Interlocked.Read(ref _position); - if (offset + pos < 0) + newPosition = _position + offset; + if (newPosition < 0) throw new IOException(SR.IO_SeekBeforeBegin); - Interlocked.Exchange(ref _position, offset + pos); break; case SeekOrigin.End: - long len = Interlocked.Read(ref _length); - if (len + offset < 0) + newPosition = (long)_length + offset; + if (newPosition < 0) throw new IOException(SR.IO_SeekBeforeBegin); - Interlocked.Exchange(ref _position, len + offset); break; default: throw new ArgumentException(SR.Argument_InvalidSeekOrigin); } - long finalPos = Interlocked.Read(ref _position); - Debug.Assert(finalPos >= 0, "_position >= 0"); - return finalPos; + _position = newPosition; + return newPosition; } /// @@ -573,11 +575,10 @@ public override void SetLength(long value) EnsureNotClosed(); EnsureWriteable(); - if (value > _capacity) + if (value > (long)_capacity) throw new IOException(SR.IO_FixedCapacity); - long pos = Interlocked.Read(ref _position); - long len = Interlocked.Read(ref _length); + long len = (long)_length; if (value > len) { unsafe @@ -585,10 +586,11 @@ public override void SetLength(long value) NativeMemory.Clear(_mem + len, (nuint)(value - len)); } } - Interlocked.Exchange(ref _length, value); - if (pos > value) + Volatile.Write(ref _length, (nuint)value); // volatile to prevent reading of uninitialized memory + + if (_position > value) { - Interlocked.Exchange(ref _position, value); + _position = value; } } @@ -625,8 +627,8 @@ internal unsafe void WriteCore(ReadOnlySpan buffer) EnsureNotClosed(); EnsureWriteable(); - long pos = Interlocked.Read(ref _position); // Use a local to avoid a race condition - long len = Interlocked.Read(ref _length); + long pos = _position; // Use a local to avoid a race condition + long len = (long)_length; long n = pos + buffer.Length; // Check for overflow if (n < 0) @@ -634,7 +636,7 @@ internal unsafe void WriteCore(ReadOnlySpan buffer) throw new IOException(SR.IO_StreamTooLong); } - if (n > _capacity) + if (n > (long)_capacity) { throw new NotSupportedException(SR.IO_FixedCapacity); } @@ -648,16 +650,16 @@ internal unsafe void WriteCore(ReadOnlySpan buffer) NativeMemory.Clear(_mem + len, (nuint)(pos - len)); } - // set length after zeroing memory to avoid race condition of accessing unzeroed memory + // set length after zeroing memory to avoid race condition of accessing uninitialized memory if (n > len) { - Interlocked.Exchange(ref _length, n); + Volatile.Write(ref _length, (nuint)n); // volatile to prevent reading of uninitialized memory } } if (_buffer != null) { - long bytesLeft = _capacity - pos; + long bytesLeft = (long)_capacity - pos; if (bytesLeft < buffer.Length) { throw new ArgumentException(SR.Arg_BufferTooSmall); @@ -682,8 +684,7 @@ internal unsafe void WriteCore(ReadOnlySpan buffer) Buffer.Memmove(ref *(_mem + pos), ref MemoryMarshal.GetReference(buffer), (nuint)buffer.Length); } - Interlocked.Exchange(ref _position, n); - return; + _position = n; } /// @@ -754,8 +755,8 @@ public override void WriteByte(byte value) EnsureNotClosed(); EnsureWriteable(); - long pos = Interlocked.Read(ref _position); // Use a local to avoid a race condition - long len = Interlocked.Read(ref _length); + long pos = _position; // Use a local to avoid a race condition + long len = (long)_length; long n = pos + 1; if (pos >= len) { @@ -763,7 +764,7 @@ public override void WriteByte(byte value) if (n < 0) throw new IOException(SR.IO_StreamTooLong); - if (n > _capacity) + if (n > (long)_capacity) throw new NotSupportedException(SR.IO_FixedCapacity); // Check to see whether we are now expanding the stream and must @@ -779,8 +780,7 @@ public override void WriteByte(byte value) } } - // set length after zeroing memory to avoid race condition of accessing unzeroed memory - Interlocked.Exchange(ref _length, n); + Volatile.Write(ref _length, (nuint)n); // volatile to prevent reading of uninitialized memory } } @@ -810,7 +810,7 @@ public override void WriteByte(byte value) _mem[pos] = value; } } - Interlocked.Exchange(ref _position, n); + _position = n; } } } diff --git a/src/libraries/System.Private.CoreLib/src/System/Math.cs b/src/libraries/System.Private.CoreLib/src/System/Math.cs index 7d20cc72202dd1..266e49fc39dd94 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Math.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Math.cs @@ -1033,7 +1033,7 @@ public static double Min(double val1, double val2) // // It propagates NaN inputs back to the caller and // otherwise returns the lesser of the inputs. It - // treats +0 as lesser than -0 as per the specification. + // treats +0 as greater than -0 as per the specification. if (val1 != val2) { @@ -1091,7 +1091,7 @@ public static float Min(float val1, float val2) // // It propagates NaN inputs back to the caller and // otherwise returns the lesser of the inputs. It - // treats +0 as lesser than -0 as per the specification. + // treats +0 as greater than -0 as per the specification. if (val1 != val2) { @@ -1145,7 +1145,7 @@ public static double MinMagnitude(double x, double y) // // It propagates NaN inputs back to the caller and // otherwise returns the input with a lesser magnitude. - // It treats +0 as lesser than -0 as per the specification. + // It treats +0 as greater than -0 as per the specification. double ax = Abs(x); double ay = Abs(y); diff --git a/src/libraries/System.Private.CoreLib/src/System/MathF.cs b/src/libraries/System.Private.CoreLib/src/System/MathF.cs index 2726d14492f6ab..de0efc14f0ac4f 100644 --- a/src/libraries/System.Private.CoreLib/src/System/MathF.cs +++ b/src/libraries/System.Private.CoreLib/src/System/MathF.cs @@ -285,7 +285,7 @@ public static float MinMagnitude(float x, float y) // // It propagates NaN inputs back to the caller and // otherwise returns the input with a lesser magnitude. - // It treats +0 as lesser than -0 as per the specification. + // It treats +0 as greater than -0 as per the specification. float ax = Abs(x); float ay = Abs(y); diff --git a/src/libraries/System.Private.CoreLib/src/System/Numerics/INumberBase.cs b/src/libraries/System.Private.CoreLib/src/System/Numerics/INumberBase.cs index 13397687e2a1dc..4368605183d2e9 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Numerics/INumberBase.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Numerics/INumberBase.cs @@ -27,7 +27,7 @@ public interface INumberBase ISubtractionOperators, IUnaryPlusOperators, IUnaryNegationOperators, - // IUtf8SpanFormattable, + IUtf8SpanFormattable, IUtf8SpanParsable where TSelf : INumberBase? { @@ -457,9 +457,7 @@ static virtual bool TryParse(ReadOnlySpan utf8Text, NumberStyles style, IF return succeeded; } - // Workaround devdiv/#1851707: C++/CLI fails to compile when encountering a Default Interface Method implemented in a derived interface - // bool IUtf8SpanFormattable.TryFormat(Span utf8Destination, out int bytesWritten, ReadOnlySpan format, IFormatProvider? provider) - bool TryFormat(Span utf8Destination, out int bytesWritten, ReadOnlySpan format, IFormatProvider? provider) + bool IUtf8SpanFormattable.TryFormat(Span utf8Destination, out int bytesWritten, ReadOnlySpan format, IFormatProvider? provider) { char[]? utf16DestinationArray; scoped Span utf16Destination; diff --git a/src/libraries/System.Private.CoreLib/src/System/Reflection/ConstructorInvoker.cs b/src/libraries/System.Private.CoreLib/src/System/Reflection/ConstructorInvoker.cs index 67fc0d31919759..3fad9dc8dba85f 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Reflection/ConstructorInvoker.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Reflection/ConstructorInvoker.cs @@ -11,6 +11,17 @@ namespace System.Reflection { + /// + /// Invokes the method reflected by the provided . + /// + /// + /// Used for better performance than when compatibility with that method + /// is not necessary and when the caller can cache the ConstructorInvoker instance for additional invoke calls.
+ /// Unlike , the invoke methods do not look up default values for arguments when + /// is specified. In addition, the target constructor may be inlined for performance and not + /// appear in stack traces. + ///
+ /// public sealed partial class ConstructorInvoker { private InvokeFunc_ObjSpanArgs? _invokeFunc_ObjSpanArgs; @@ -24,6 +35,17 @@ public sealed partial class ConstructorInvoker private readonly RuntimeConstructorInfo _method; private readonly bool _needsByRefStrategy; + /// + /// Creates a new instance of ConstructorInvoker. + /// + /// + /// For performance, the resulting instance should be cached for additional calls. + /// + /// The constructor that will be invoked. + /// An instance of a ConstructorInvoker. + /// + /// The is not a runtime-based method. + /// public static ConstructorInvoker Create(ConstructorInfo constructor) { ArgumentNullException.ThrowIfNull(constructor, nameof(constructor)); @@ -46,6 +68,21 @@ private ConstructorInvoker(RuntimeConstructorInfo constructor, RuntimeType[] arg Initialize(argumentTypes, out _strategy, out _invokerArgFlags, out _needsByRefStrategy); } + /// + /// Invokes the constructor. + /// + /// + /// An instance of the class associated with the constructor. + /// + /// + /// The type that declares the method is an open generic type. + /// + /// + /// The correct number of arguments were not provided. + /// + /// + /// The calling convention or signature is not supported. + /// public object Invoke() { if (_argCount != 0) @@ -56,6 +93,14 @@ public object Invoke() return InvokeImpl(null, null, null, null); } + /// + /// Invokes the constructor using the specified parameters. + /// + /// + /// The first argument for the invoked method. + /// + /// The arguments do not match the signature of the invoked constructor. + /// public object Invoke(object? arg1) { if (_argCount != 1) @@ -66,6 +111,9 @@ public object Invoke(object? arg1) return InvokeImpl(arg1, null, null, null); } + /// + /// The first argument for the invoked method. + /// The second argument for the invoked method. public object Invoke(object? arg1, object? arg2) { if (_argCount != 2) @@ -76,6 +124,10 @@ public object Invoke(object? arg1, object? arg2) return InvokeImpl(arg1, arg2, null, null); } + /// + /// The first argument for the invoked method. + /// The second argument for the invoked method. + /// The third argument for the invoked method. public object Invoke(object? arg1, object? arg2, object? arg3) { if (_argCount !=3) @@ -86,6 +138,11 @@ public object Invoke(object? arg1, object? arg2, object? arg3) return InvokeImpl(arg1, arg2, arg3, null); } + /// + /// The first argument for the invoked method. + /// The second argument for the invoked method. + /// The third argument for the invoked method. + /// The fourth argument for the invoked method. public object Invoke(object? arg1, object? arg2, object? arg3, object? arg4) { if (_argCount != 4) @@ -98,7 +155,7 @@ public object Invoke(object? arg1, object? arg2, object? arg3, object? arg4) private object InvokeImpl(object? arg1, object? arg2, object? arg3, object? arg4) { - if ((_invocationFlags & (InvocationFlags.NoInvoke | InvocationFlags.ContainsStackPointers)) != 0) + if ((_invocationFlags & (InvocationFlags.NoInvoke | InvocationFlags.ContainsStackPointers | InvocationFlags.NoConstructorInvoke)) != 0) { _method.ThrowNoInvokeException(); } @@ -137,6 +194,11 @@ private object InvokeImpl(object? arg1, object? arg2, object? arg3, object? arg4 return InvokeDirectByRef(arg1, arg2, arg3, arg4); } + /// + /// The arguments for the invoked constructor. + /// + /// The arguments do not match the signature of the invoked constructor. + /// public object Invoke(Span arguments) { int argLen = arguments.Length; diff --git a/src/libraries/System.Private.CoreLib/src/System/Reflection/MethodInvoker.cs b/src/libraries/System.Private.CoreLib/src/System/Reflection/MethodInvoker.cs index b5496c37c0cc84..0c8d9c59580703 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Reflection/MethodInvoker.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Reflection/MethodInvoker.cs @@ -12,6 +12,17 @@ namespace System.Reflection { + /// + /// Invokes the method reflected by the provided . + /// + /// + /// Used for better performance than when compatibility with that method + /// is not necessary and when the caller can cache the MethodInvoker instance for additional invoke calls.
+ /// Unlike , the invoke methods do not look up default values for arguments when + /// is specified. In addition, the target method may be inlined for performance and not + /// appear in stack traces. + ///
+ /// public sealed partial class MethodInvoker { private InvokeFunc_ObjSpanArgs? _invokeFunc_ObjSpanArgs; @@ -26,6 +37,17 @@ public sealed partial class MethodInvoker private readonly bool _needsByRefStrategy; private readonly bool _isStatic; + /// + /// Creates a new instance of MethodInvoker. + /// + /// + /// For performance, the resulting instance should be cached for additional calls. + /// + /// The method that will be invoked. + /// An instance of a MethodInvoker. + /// + /// The is not a runtime-based method. + /// public static MethodInvoker Create(MethodBase method) { ArgumentNullException.ThrowIfNull(method, nameof(method)); @@ -44,7 +66,12 @@ public static MethodInvoker Create(MethodBase method) { // This is useful for calling a constructor on an already-initialized object // such as created from RuntimeHelpers.GetUninitializedObject(Type). - return new MethodInvoker(rci); + MethodInvoker invoker = new MethodInvoker(rci); + + // Use the interpreted version to avoid having to generate a new method that doesn't allocate. + invoker._strategy = GetStrategyForUsingInterpreted(); + + return invoker; } throw new ArgumentException(SR.Argument_MustBeRuntimeMethod, nameof(method)); @@ -60,6 +87,32 @@ private MethodInvoker(MethodBase method, RuntimeType[] argumentTypes) Initialize(argumentTypes, out _strategy, out _invokerArgFlags, out _needsByRefStrategy); } + /// + /// Invokes the method using the specified parameters. + /// + /// + /// The object on which to invoke the method. If the method is static, this argument is ignored. + /// + /// + /// An object containing the return value of the invoked method, + /// or null if the invoked method does not have a return value. + /// + /// + /// The obj parameter is null and the method is not static. + /// + /// -or- + /// + /// The method is not declared or inherited by the class of obj. + /// + /// + /// The type that declares the method is an open generic type. + /// + /// + /// The correct number of arguments were not provided. + /// + /// + /// The calling convention or signature is not supported. + /// public object? Invoke(object? obj) { if (_argCount != 0) @@ -70,6 +123,12 @@ private MethodInvoker(MethodBase method, RuntimeType[] argumentTypes) return InvokeImpl(obj, null, null, null, null); } + /// + /// The object on which to invoke the method. If the method is static, this argument is ignored. + /// The first argument for the invoked method. + /// + /// The arguments do not match the signature of the invoked method. + /// public object? Invoke(object? obj, object? arg1) { if (_argCount != 1) @@ -80,6 +139,10 @@ private MethodInvoker(MethodBase method, RuntimeType[] argumentTypes) return InvokeImpl(obj, arg1, null, null, null); } + /// + /// The object on which to invoke the method. If the method is static, this argument is ignored. + /// The first argument for the invoked method. + /// The second argument for the invoked method. public object? Invoke(object? obj, object? arg1, object? arg2) { if (_argCount != 2) @@ -90,6 +153,11 @@ private MethodInvoker(MethodBase method, RuntimeType[] argumentTypes) return InvokeImpl(obj, arg1, arg2, null, null); } + /// + /// The object on which to invoke the method. If the method is static, this argument is ignored. + /// The first argument for the invoked method. + /// The second argument for the invoked method. + /// The third argument for the invoked method. public object? Invoke(object? obj, object? arg1, object? arg2, object? arg3) { if (_argCount != 3) @@ -100,6 +168,12 @@ private MethodInvoker(MethodBase method, RuntimeType[] argumentTypes) return InvokeImpl(obj, arg1, arg2, arg3, null); } + /// + /// The object on which to invoke the method. If the method is static, this argument is ignored. + /// The first argument for the invoked method. + /// The second argument for the invoked method. + /// The third argument for the invoked method. + /// The fourth argument for the invoked method. public object? Invoke(object? obj, object? arg1, object? arg2, object? arg3, object? arg4) { if (_argCount != 4) @@ -112,7 +186,7 @@ private MethodInvoker(MethodBase method, RuntimeType[] argumentTypes) private object? InvokeImpl(object? obj, object? arg1, object? arg2, object? arg3, object? arg4) { - if ((_invocationFlags & (InvocationFlags.NoInvoke | InvocationFlags.ContainsStackPointers)) != 0) + if ((_invocationFlags & (InvocationFlags.NoInvoke | InvocationFlags.ContainsStackPointers | InvocationFlags.NoConstructorInvoke)) != 0) { ThrowForBadInvocationFlags(); } @@ -156,6 +230,12 @@ private MethodInvoker(MethodBase method, RuntimeType[] argumentTypes) return InvokeDirectByRef(obj, arg1, arg2, arg3, arg4); } + /// + /// The object on which to invoke the method. If the method is static, this argument is ignored. + /// The arguments for the invoked method. + /// + /// The arguments do not match the signature of the invoked method. + /// public object? Invoke(object? obj, Span arguments) { int argLen = arguments.Length; diff --git a/src/libraries/System.Private.CoreLib/src/System/Reflection/MethodInvokerCommon.cs b/src/libraries/System.Private.CoreLib/src/System/Reflection/MethodInvokerCommon.cs index 191228b688faa1..813e89f80f8349 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Reflection/MethodInvokerCommon.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Reflection/MethodInvokerCommon.cs @@ -18,13 +18,14 @@ internal static void Initialize( { if (LocalAppContextSwitches.ForceInterpretedInvoke && !LocalAppContextSwitches.ForceEmitInvoke) { - // Always use the native invoke; useful for testing. - strategy = InvokerStrategy.StrategyDetermined_Obj4Args | InvokerStrategy.StrategyDetermined_ObjSpanArgs | InvokerStrategy.StrategyDetermined_RefArgs; + // Always use the native interpreted invoke. + // Useful for testing, to avoid startup overhead of emit, or for calling a ctor on already initialized object. + strategy = GetStrategyForUsingInterpreted(); } else if (LocalAppContextSwitches.ForceEmitInvoke && !LocalAppContextSwitches.ForceInterpretedInvoke) { // Always use emit invoke (if IsDynamicCodeSupported == true); useful for testing. - strategy = InvokerStrategy.HasBeenInvoked_Obj4Args | InvokerStrategy.HasBeenInvoked_ObjSpanArgs | InvokerStrategy.HasBeenInvoked_RefArgs; + strategy = GetStrategyForUsingEmit(); } else { @@ -69,6 +70,18 @@ internal static void Initialize( } } + internal static InvokerStrategy GetStrategyForUsingInterpreted() + { + // This causes the default strategy, which is interpreted, to always be used. + return InvokerStrategy.StrategyDetermined_Obj4Args | InvokerStrategy.StrategyDetermined_ObjSpanArgs | InvokerStrategy.StrategyDetermined_RefArgs; + } + + private static InvokerStrategy GetStrategyForUsingEmit() + { + // This causes the emit strategy, if supported, to be used on the first call as well as subsequent calls. + return InvokerStrategy.HasBeenInvoked_Obj4Args | InvokerStrategy.HasBeenInvoked_ObjSpanArgs | InvokerStrategy.HasBeenInvoked_RefArgs; + } + /// /// Confirm member invocation has an instance and is of the correct type /// diff --git a/src/libraries/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncTaskMethodBuilderT.cs b/src/libraries/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncTaskMethodBuilderT.cs index 02a60f50c3867f..c3064ad22114bd 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncTaskMethodBuilderT.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncTaskMethodBuilderT.cs @@ -93,6 +93,9 @@ internal static void AwaitUnsafeOnCompleted( AwaitUnsafeOnCompleted(ref awaiter, box); } + // Tier0 codegen for this function may still allocate (while FullOpts won't). + // TODO: remove once https://github.com/dotnet/runtime/issues/90965 is implemented + [MethodImpl(MethodImplOptions.AggressiveOptimization)] internal static void AwaitUnsafeOnCompleted( ref TAwaiter awaiter, IAsyncStateMachineBox box) where TAwaiter : ICriticalNotifyCompletion diff --git a/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.PlatformNotSupported.cs b/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.PlatformNotSupported.cs index 2d2a812cb41a07..37c25f851538a7 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.PlatformNotSupported.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.PlatformNotSupported.cs @@ -11,12 +11,14 @@ public abstract partial class ComWrappers { public static unsafe bool TryGetComInstance(object obj, out IntPtr unknown) { - throw new PlatformNotSupportedException(); + unknown = default; + return false; } public static unsafe bool TryGetObject(IntPtr unknown, [NotNullWhen(true)] out object? obj) { - throw new PlatformNotSupportedException(); + obj = default; + return false; } public partial struct ComInterfaceDispatch diff --git a/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshal.cs b/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshal.cs index 98dde70ae362db..11a3319196ad94 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshal.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshal.cs @@ -130,7 +130,7 @@ public static int SizeOf() return SizeOfHelper(t, throwIfNotMarshalable: true); } - public static unsafe int QueryInterface(IntPtr pUnk, in Guid iid, out IntPtr ppv) + public static unsafe int QueryInterface(IntPtr pUnk, ref readonly Guid iid, out IntPtr ppv) { ArgumentNullException.ThrowIfNull(pUnk); diff --git a/src/libraries/System.Private.CoreLib/src/System/Runtime/Loader/AssemblyLoadContext.cs b/src/libraries/System.Private.CoreLib/src/System/Runtime/Loader/AssemblyLoadContext.cs index b8b4ba086ad69b..59123e42fb52a2 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Runtime/Loader/AssemblyLoadContext.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Runtime/Loader/AssemblyLoadContext.cs @@ -135,25 +135,32 @@ private void InitiateUnload() { RaiseUnloadEvent(); + InternalState previousState; + // When in Unloading state, we are not supposed to be called on the finalizer // as the native side is holding a strong reference after calling Unload lock (_unloadLock) { - Debug.Assert(_state == InternalState.Alive); - - var thisStrongHandle = GCHandle.Alloc(this, GCHandleType.Normal); - var thisStrongHandlePtr = GCHandle.ToIntPtr(thisStrongHandle); - // The underlying code will transform the original weak handle - // created by InitializeLoadContext to a strong handle - PrepareForAssemblyLoadContextRelease(_nativeAssemblyLoadContext, thisStrongHandlePtr); + previousState = _state; + if (previousState == InternalState.Alive) + { + var thisStrongHandle = GCHandle.Alloc(this, GCHandleType.Normal); + var thisStrongHandlePtr = GCHandle.ToIntPtr(thisStrongHandle); + // The underlying code will transform the original weak handle + // created by InitializeLoadContext to a strong handle + PrepareForAssemblyLoadContextRelease(_nativeAssemblyLoadContext, thisStrongHandlePtr); - _state = InternalState.Unloading; + _state = InternalState.Unloading; + } } - Dictionary> allContexts = AllContexts; - lock (allContexts) + if (previousState == InternalState.Alive) { - allContexts.Remove(_id); + Dictionary> allContexts = AllContexts; + lock (allContexts) + { + allContexts.Remove(_id); + } } } diff --git a/src/libraries/System.Private.CoreLib/src/System/Threading/PortableThreadPool.WorkerThread.NonBrowser.cs b/src/libraries/System.Private.CoreLib/src/System/Threading/PortableThreadPool.WorkerThread.NonBrowser.cs index 2e634fd469d9d8..c3b278019f6d92 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Threading/PortableThreadPool.WorkerThread.NonBrowser.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Threading/PortableThreadPool.WorkerThread.NonBrowser.cs @@ -7,11 +7,29 @@ namespace System.Threading { internal sealed partial class PortableThreadPool { + private int _numThreadsBeingKeptAlive; + /// /// The worker thread infastructure for the CLR thread pool. /// private static partial class WorkerThread { + private static readonly short ThreadsToKeepAlive = DetermineThreadsToKeepAlive(); + + private static short DetermineThreadsToKeepAlive() + { + const short DefaultThreadsToKeepAlive = 0; + + // The number of worker threads to keep alive after they are created. Set to -1 to keep all created worker + // threads alive. When the ThreadTimeoutMs config value is also set, for worker threads the timeout applies to + // worker threads that are in excess of the number configured for ThreadsToKeepAlive. + short threadsToKeepAlive = + AppContextConfigHelper.GetInt16Config( + "System.Threading.ThreadPool.ThreadsToKeepAlive", + "DOTNET_ThreadPool_ThreadsToKeepAlive", + DefaultThreadsToKeepAlive); + return threadsToKeepAlive >= -1 ? threadsToKeepAlive : DefaultThreadsToKeepAlive; + } /// /// Semaphore for controlling how many threads are currently working. @@ -50,10 +68,36 @@ private static void WorkerThreadStart() LowLevelLock threadAdjustmentLock = threadPoolInstance._threadAdjustmentLock; LowLevelLifoSemaphore semaphore = s_semaphore; + // Determine the idle timeout to use for this thread. Some threads may always be kept alive based on config. + int timeoutMs = ThreadPoolThreadTimeoutMs; + if (ThreadsToKeepAlive != 0) + { + if (ThreadsToKeepAlive < 0) + { + timeoutMs = Timeout.Infinite; + } + else + { + int count = threadPoolInstance._numThreadsBeingKeptAlive; + while (count < ThreadsToKeepAlive) + { + int countBeforeUpdate = + Interlocked.CompareExchange(ref threadPoolInstance._numThreadsBeingKeptAlive, count + 1, count); + if (countBeforeUpdate == count) + { + timeoutMs = Timeout.Infinite; + break; + } + + count = countBeforeUpdate; + } + } + } + while (true) { bool spinWait = true; - while (semaphore.Wait(ThreadPoolThreadTimeoutMs, spinWait)) + while (semaphore.Wait(timeoutMs, spinWait)) { WorkerDoWork(threadPoolInstance, ref spinWait); } @@ -65,7 +109,6 @@ private static void WorkerThreadStart() } } - private static void CreateWorkerThread() { // Thread pool threads must start in the default execution context without transferring the context, so diff --git a/src/libraries/System.Private.CoreLib/src/System/Threading/PortableThreadPool.cs b/src/libraries/System.Private.CoreLib/src/System/Threading/PortableThreadPool.cs index 9ada201e134c5a..db49ea2c5a5a92 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Threading/PortableThreadPool.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Threading/PortableThreadPool.cs @@ -13,7 +13,6 @@ namespace System.Threading /// internal sealed partial class PortableThreadPool { - private const int ThreadPoolThreadTimeoutMs = 20 * 1000; // If you change this make sure to change the timeout times in the tests. private const int SmallStackSizeBytes = 256 * 1024; private const short MaxPossibleThreadCount = short.MaxValue; @@ -40,6 +39,23 @@ internal sealed partial class PortableThreadPool private static readonly short ForcedMaxWorkerThreads = AppContextConfigHelper.GetInt16Config("System.Threading.ThreadPool.MaxThreads", 0, false); + private static readonly int ThreadPoolThreadTimeoutMs = DetermineThreadPoolThreadTimeoutMs(); + + private static int DetermineThreadPoolThreadTimeoutMs() + { + const int DefaultThreadPoolThreadTimeoutMs = 20 * 1000; // If you change this make sure to change the timeout times in the tests. + + // The amount of time in milliseconds a thread pool thread waits without having done any work before timing out and + // exiting. Set to -1 to disable the timeout. Applies to worker threads and wait threads. Also see the + // ThreadsToKeepAlive config value for relevant information. + int timeoutMs = + AppContextConfigHelper.GetInt32Config( + "System.Threading.ThreadPool.ThreadTimeoutMs", + "DOTNET_ThreadPool_ThreadTimeoutMs", + DefaultThreadPoolThreadTimeoutMs); + return timeoutMs >= -1 ? timeoutMs : DefaultThreadPoolThreadTimeoutMs; + } + [ThreadStatic] private static object? t_completionCountObject; diff --git a/src/libraries/System.Private.CoreLib/src/System/Threading/ThreadPool.Windows.cs b/src/libraries/System.Private.CoreLib/src/System/Threading/ThreadPool.Windows.cs index 6882b0482c017a..0da875498afc18 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Threading/ThreadPool.Windows.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Threading/ThreadPool.Windows.cs @@ -13,6 +13,12 @@ public static partial class ThreadPool internal static bool UseWindowsThreadPool { get; } = AppContextConfigHelper.GetBooleanConfig("System.Threading.ThreadPool.UseWindowsThreadPool", "DOTNET_ThreadPool_UseWindowsThreadPool"); +#pragma warning disable CA1823 + // The field should reflect what the property returns because the property can be stubbed by trimming, + // such that sos reflects the actual state of what thread pool is being used and not just the config value. + private static readonly bool s_useWindowsThreadPool = UseWindowsThreadPool; // Name relied on by sos +#pragma warning restore CA1823 + #if NATIVEAOT private const bool IsWorkerTrackingEnabledInConfig = false; #else diff --git a/src/libraries/System.Private.CoreLib/src/System/ValueTuple.cs b/src/libraries/System.Private.CoreLib/src/System/ValueTuple.cs index 240ad862ffcf24..aa919ea294f78d 100644 --- a/src/libraries/System.Private.CoreLib/src/System/ValueTuple.cs +++ b/src/libraries/System.Private.CoreLib/src/System/ValueTuple.cs @@ -1125,6 +1125,7 @@ other is ValueTuple vt && comparer.Equals(Item1, vt.Item1) && comparer.Equals(Item2, vt.Item2) && comparer.Equals(Item3, vt.Item3) && + comparer.Equals(Item4, vt.Item4) && comparer.Equals(Item5, vt.Item5); int IComparable.CompareTo(object? other) @@ -1366,6 +1367,7 @@ other is ValueTuple vt && comparer.Equals(Item1, vt.Item1) && comparer.Equals(Item2, vt.Item2) && comparer.Equals(Item3, vt.Item3) && + comparer.Equals(Item4, vt.Item4) && comparer.Equals(Item5, vt.Item5) && comparer.Equals(Item6, vt.Item6); @@ -1625,6 +1627,7 @@ other is ValueTuple vt && comparer.Equals(Item1, vt.Item1) && comparer.Equals(Item2, vt.Item2) && comparer.Equals(Item3, vt.Item3) && + comparer.Equals(Item4, vt.Item4) && comparer.Equals(Item5, vt.Item5) && comparer.Equals(Item6, vt.Item6) && comparer.Equals(Item7, vt.Item7); @@ -1908,6 +1911,7 @@ other is ValueTuple vt && comparer.Equals(Item1, vt.Item1) && comparer.Equals(Item2, vt.Item2) && comparer.Equals(Item3, vt.Item3) && + comparer.Equals(Item4, vt.Item4) && comparer.Equals(Item5, vt.Item5) && comparer.Equals(Item6, vt.Item6) && comparer.Equals(Item7, vt.Item7) && diff --git a/src/libraries/System.Reflection.Metadata/src/PACKAGE.md b/src/libraries/System.Reflection.Metadata/src/PACKAGE.md index 43543a703a9123..921506138ee1b4 100644 --- a/src/libraries/System.Reflection.Metadata/src/PACKAGE.md +++ b/src/libraries/System.Reflection.Metadata/src/PACKAGE.md @@ -1,18 +1,14 @@ ## About + + This package provides a low-level .NET (ECMA-335) metadata reader and writer. It's geared for performance and is the ideal choice for building higher-level libraries that intend to provide their own object model, such as compilers. The metadata format is defined by the [ECMA-335 - Common Language Infrastructure (CLI)](http://www.ecma-international.org/publications/standards/Ecma-335.htm) specification and [its amendments](https://github.com/dotnet/runtime/blob/main/docs/design/specs/Ecma-335-Augments.md). The `System.Reflection.Metadata` library is included in the .NET Runtime shared framework. The package can be installed when you need to use it in other target frameworks. -For more information, see the documentation: - -- [System.Reflection.Metadata.MetadataReader](https://docs.microsoft.com/dotnet/api/system.reflection.metadata.metadatareader) -- [System.Reflection.PortableExecutable.PEReader](https://docs.microsoft.com/dotnet/api/system.reflection.portableexecutable.pereader) -- [System.Reflection.Metadata.Ecma335.MetadataBuilder](https://docs.microsoft.com/dotnet/api/system.reflection.metadata.ecma335.metadatabuilder) -- [System.Reflection.PortableExecutable.PEBuilder](https://docs.microsoft.com/dotnet/api/system.reflection.portableexecutable.pebuilder) -- [System.Reflection.PortableExecutable.ManagedPEBuilder](https://docs.microsoft.com/dotnet/api/system.reflection.portableexecutable.managedpebuilder) +## How to Use -## Example + The following example shows how to read assembly information using PEReader and MetadataReader. @@ -80,3 +76,31 @@ class Program } ``` + +## Main Types + + + +The main types provided by this library are: + +* `System.Reflection.Metadata.MetadataReader` +* `System.Reflection.PortableExecutable.PEReader` +* `System.Reflection.Metadata.Ecma335.MetadataBuilder` +* `System.Reflection.PortableExecutable.PEBuilder` +* `System.Reflection.PortableExecutable.ManagedPEBuilder` + +## Additional Documentation + + + +* [System.Reflection.Metadata.MetadataReader](https://docs.microsoft.com/dotnet/api/system.reflection.metadata.metadatareader) +* [System.Reflection.PortableExecutable.PEReader](https://docs.microsoft.com/dotnet/api/system.reflection.portableexecutable.pereader) +* [System.Reflection.Metadata.Ecma335.MetadataBuilder](https://docs.microsoft.com/dotnet/api/system.reflection.metadata.ecma335.metadatabuilder) +* [System.Reflection.PortableExecutable.PEBuilder](https://docs.microsoft.com/dotnet/api/system.reflection.portableexecutable.pebuilder) +* [System.Reflection.PortableExecutable.ManagedPEBuilder](https://docs.microsoft.com/dotnet/api/system.reflection.portableexecutable.managedpebuilder) + +## Feedback & Contributing + + + +System.Reflection.Metadata is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). \ No newline at end of file diff --git a/src/libraries/System.Reflection.MetadataLoadContext/src/PACKAGE.md b/src/libraries/System.Reflection.MetadataLoadContext/src/PACKAGE.md index 7e351acc94fe75..f16805d7d3f14a 100644 --- a/src/libraries/System.Reflection.MetadataLoadContext/src/PACKAGE.md +++ b/src/libraries/System.Reflection.MetadataLoadContext/src/PACKAGE.md @@ -1,14 +1,12 @@ ## About -Provides read-only reflection on assemblies in an isolated context with support for assemblies that target different processor architectures and runtimes. Using MetadataLoadContext enables you to inspect assemblies without loading them into the main execution context. Assemblies in MetadataLoadContext are treated only as metadata, that is, you can read information about their members, but cannot execute any code contained in them. + -For more information, see the documentation: +Provides read-only reflection on assemblies in an isolated context with support for assemblies that target different processor architectures and runtimes. Using MetadataLoadContext enables you to inspect assemblies without loading them into the main execution context. Assemblies in MetadataLoadContext are treated only as metadata, that is, you can read information about their members, but cannot execute any code contained in them. -- [How to: Inspect assembly contents using MetadataLoadContext](https://docs.microsoft.com/dotnet/standard/assembly/inspect-contents-using-metadataloadcontext) -- [System.Reflection.MetadataLoadContext](https://docs.microsoft.com/dotnet/api/system.reflection.metadataloadcontext) -- [System.Reflection.MetadataAssemblyResolver](https://docs.microsoft.com/dotnet/api/system.reflection.metadataassemblyresolver) +## How to Use -## Example + The following example shows how to print the list of types defined in an assembly. @@ -38,3 +36,26 @@ class Program } } ``` + +## Main Types + + + +The main types provided by this library are: + +* `System.Reflection.MetadataLoadContext` +* `System.Reflection.MetadataAssemblyResolver` + +## Additional Documentation + + + +* [How to: Inspect assembly contents using MetadataLoadContext](https://docs.microsoft.com/dotnet/standard/assembly/inspect-contents-using-metadataloadcontext) +* [System.Reflection.MetadataLoadContext](https://docs.microsoft.com/dotnet/api/system.reflection.metadataloadcontext) +* [System.Reflection.MetadataAssemblyResolver](https://docs.microsoft.com/dotnet/api/system.reflection.metadataassemblyresolver) + +## Feedback & Contributing + + + +System.Reflection.MetadataLoadContext is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). \ No newline at end of file diff --git a/src/libraries/System.Reflection.MetadataLoadContext/src/System/Reflection/TypeLoading/Types/RoFunctionPointerType.cs b/src/libraries/System.Reflection.MetadataLoadContext/src/System/Reflection/TypeLoading/Types/RoFunctionPointerType.cs index 73cd87437551f6..426ab7085d9fdd 100644 --- a/src/libraries/System.Reflection.MetadataLoadContext/src/System/Reflection/TypeLoading/Types/RoFunctionPointerType.cs +++ b/src/libraries/System.Reflection.MetadataLoadContext/src/System/Reflection/TypeLoading/Types/RoFunctionPointerType.cs @@ -147,7 +147,23 @@ public sealed override bool Equals([NotNullWhen(true)] object? obj) public sealed override bool IsGenericParameter => false; public sealed override bool IsGenericTypeParameter => false; public sealed override bool IsGenericMethodParameter => false; - public sealed override bool ContainsGenericParameters => IsGenericTypeDefinition; + + public sealed override bool ContainsGenericParameters + { + get + { + if (_returnType.ContainsGenericParameters) + return true; + + foreach (Type parameterType in _parameterTypes) + { + if (parameterType.ContainsGenericParameters) + return true; + } + + return false; + } + } protected sealed override TypeCode GetTypeCodeImpl() => TypeCode.Object; diff --git a/src/libraries/System.Reflection/tests/ConstructorCommonTests.cs b/src/libraries/System.Reflection/tests/ConstructorCommonTests.cs new file mode 100644 index 00000000000000..5c71f79da014c1 --- /dev/null +++ b/src/libraries/System.Reflection/tests/ConstructorCommonTests.cs @@ -0,0 +1,167 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Linq; +using Xunit; + +namespace System.Reflection.Tests +{ + /// + /// These tests are shared with ConstructorInfo.Invoke and ConstructorInvoker.Invoke by using + /// the abstract Invoke(...) methods below. + /// + public abstract class ConstructorCommonTests + { + public abstract object Invoke(ConstructorInfo constructorInfo, object?[]? parameters); + + protected abstract bool IsExceptionWrapped { get; } + + /// + /// Invoke constructor on an existing instance. Should return null. + /// + public abstract object? Invoke(ConstructorInfo constructorInfo, object obj, object?[]? parameters); + + public static ConstructorInfo[] GetConstructors(Type type) + { + return type.GetTypeInfo().DeclaredConstructors.ToArray(); + } + + [Fact] + public void SimpleInvoke() + { + ConstructorInfo[] constructors = GetConstructors(typeof(ClassWith3Constructors)); + Assert.Equal(3, constructors.Length); + ClassWith3Constructors obj = (ClassWith3Constructors)Invoke(constructors[0], null); + Assert.NotNull(obj); + } + + [Fact] + [ActiveIssue("https://github.com/mono/mono/issues/15024", TestRuntimes.Mono)] + public void Invoke_StaticConstructor_ThrowsMemberAccessException() + { + ConstructorInfo[] constructors = GetConstructors(typeof(ClassWithStaticConstructor)); + Assert.Equal(1, constructors.Length); + Assert.Throws(() => Invoke(constructors[0], new object[0])); + } + + [Fact] + public void Invoke_OneDimensionalArray() + { + ConstructorInfo[] constructors = GetConstructors(typeof(object[])); + int[] arraylength = { 1, 2, 99, 65535 }; + + // Try to invoke Array ctors with different lengths + foreach (int length in arraylength) + { + // Create big Array with elements + object[] arr = (object[])Invoke(constructors[0], new object[] { length }); + Assert.Equal(arr.Length, length); + } + } + + [Fact] + public void Invoke_OneDimensionalArray_NegativeLengths_ThrowsOverflowException() + { + ConstructorInfo[] constructors = GetConstructors(typeof(object[])); + int[] arraylength = new int[] { -1, -2, -99 }; + // Try to invoke Array ctors with different lengths + foreach (int length in arraylength) + { + // Create big Array with elements + if (IsExceptionWrapped) + { + Exception ex = Assert.Throws(() => Invoke(constructors[0], new object[] { length })); + Assert.IsType(ex.InnerException); + } + else + { + Assert.Throws(() => Invoke(constructors[0], new object[] { length })); + } + } + } + + [Fact] + public void Invoke_OneParameter() + { + ConstructorInfo[] constructors = GetConstructors(typeof(ClassWith3Constructors)); + ClassWith3Constructors obj = (ClassWith3Constructors)Invoke(constructors[1], new object[] { 100 }); + Assert.Equal(100, obj.intValue); + } + + [Fact] + public void Invoke_TwoParameters() + { + ConstructorInfo[] constructors = GetConstructors(typeof(ClassWith3Constructors)); + ClassWith3Constructors obj = (ClassWith3Constructors)Invoke(constructors[2], new object[] { 101, "hello" }); + Assert.Equal(101, obj.intValue); + Assert.Equal("hello", obj.stringValue); + } + + [Fact] + public void Invoke_NoParameters_ThowsTargetParameterCountException() + { + ConstructorInfo[] constructors = GetConstructors(typeof(ClassWith3Constructors)); + Assert.Throws(() => Invoke(constructors[2], new object[0])); + } + + [Fact] + public void Invoke_ParameterMismatch_ThrowsTargetParameterCountException() + { + ConstructorInfo[] constructors = GetConstructors(typeof(ClassWith3Constructors)); + Assert.Throws(() => (ClassWith3Constructors)Invoke(constructors[2], new object[] { 121 })); + } + + [Fact] + public void Invoke_ParameterWrongType_ThrowsArgumentException() + { + ConstructorInfo[] constructors = GetConstructors(typeof(ClassWith3Constructors)); + AssertExtensions.Throws(null, () => (ClassWith3Constructors)Invoke(constructors[1], new object[] { "hello" })); + } + + [Fact] + public void Invoke_ExistingInstance() + { + // Should not produce a second object. + ConstructorInfo[] constructors = GetConstructors(typeof(ClassWith3Constructors)); + ClassWith3Constructors obj1 = new ClassWith3Constructors(100, "hello"); + ClassWith3Constructors obj2 = (ClassWith3Constructors)Invoke(constructors[2], obj1, new object[] { 999, "initialized" }); + Assert.Null(obj2); + Assert.Equal(999, obj1.intValue); + Assert.Equal("initialized", obj1.stringValue); + } + + [Fact] + public void Invoke_NullForObj() + { + ConstructorInfo[] constructors = GetConstructors(typeof(ClassWith3Constructors)); + Assert.Throws(() => Invoke(constructors[2], obj: null, new object[] { 999, "initialized" })); + } + + [Fact] + [ActiveIssue("https://github.com/mono/mono/issues/15026", TestRuntimes.Mono)] + public void Invoke_AbstractClass_ThrowsMemberAccessException() + { + ConstructorInfo[] constructors = GetConstructors(typeof(ConstructorInfoAbstractBase)); + Assert.Throws(() => (ConstructorInfoAbstractBase)Invoke(constructors[0], new object[0])); + } + + [Fact] + public void Invoke_SubClass() + { + ConstructorInfo[] constructors = GetConstructors(typeof(ConstructorInfoDerived)); + ConstructorInfoDerived obj = null; + obj = (ConstructorInfoDerived)Invoke(constructors[0], new object[] { }); + Assert.NotNull(obj); + } + + [Fact] + public void Invoke_Struct() + { + ConstructorInfo[] constructors = GetConstructors(typeof(StructWith1Constructor)); + StructWith1Constructor obj; + obj = (StructWith1Constructor)Invoke(constructors[0], new object[] { 1, 2 }); + Assert.Equal(1, obj.x); + Assert.Equal(2, obj.y); + } + } +} diff --git a/src/libraries/System.Reflection/tests/ConstructorInfoTests.cs b/src/libraries/System.Reflection/tests/ConstructorInfoTests.cs index 34d04906a8055b..37498347ac48fe 100644 --- a/src/libraries/System.Reflection/tests/ConstructorInfoTests.cs +++ b/src/libraries/System.Reflection/tests/ConstructorInfoTests.cs @@ -2,15 +2,29 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Collections.Generic; -using System.Linq; using Xunit; #pragma warning disable 0414 namespace System.Reflection.Tests { - public class ConstructorInfoTests + /// + /// These tests use the shared tests from the base class with ConstructorInfo.Invoke. + /// + public sealed class ConstructorInfoTests : ConstructorCommonTests { + public override object Invoke(ConstructorInfo constructorInfo, object?[]? parameters) + { + return constructorInfo.Invoke(parameters); + } + + public override object? Invoke(ConstructorInfo constructorInfo, object obj, object?[]? parameters) + { + return constructorInfo.Invoke(obj, parameters); + } + + protected override bool IsExceptionWrapped => true; + [Fact] public void ConstructorName() { @@ -50,15 +64,6 @@ public void GetHashCodeTest() } } - [Fact] - public void Invoke() - { - ConstructorInfo[] constructors = GetConstructors(typeof(ClassWith3Constructors)); - Assert.Equal(3, constructors.Length); - ClassWith3Constructors obj = (ClassWith3Constructors)constructors[0].Invoke(null); - Assert.NotNull(obj); - } - [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsInvokingStaticConstructorsSupported))] public void Invoke_StaticConstructor_NullObject_NullParameters() { @@ -88,44 +93,6 @@ public void Invoke_StaticConstructorMultipleTimes() Assert.Equal(1, ClassWithStaticConstructorThatIsCalledMultipleTimesViaReflection.VisibleStatics.s_cctorCallCount); } - [Fact] - [ActiveIssue("https://github.com/mono/mono/issues/15024", TestRuntimes.Mono)] - public void Invoke_StaticConstructor_ThrowsMemberAccessException() - { - ConstructorInfo[] constructors = GetConstructors(typeof(ClassWithStaticConstructor)); - Assert.Equal(1, constructors.Length); - Assert.Throws(() => constructors[0].Invoke(new object[0])); - } - - [Fact] - public void Invoke_OneDimensionalArray() - { - ConstructorInfo[] constructors = GetConstructors(typeof(object[])); - int[] arraylength = { 1, 2, 99, 65535 }; - - // Try to invoke Array ctors with different lengths - foreach (int length in arraylength) - { - // Create big Array with elements - object[] arr = (object[])constructors[0].Invoke(new object[] { length }); - Assert.Equal(arr.Length, length); - } - } - - [Fact] - public void Invoke_OneDimensionalArray_NegativeLengths_ThrowsOverflowException() - { - ConstructorInfo[] constructors = GetConstructors(typeof(object[])); - int[] arraylength = new int[] { -1, -2, -99 }; - // Try to invoke Array ctors with different lengths - foreach (int length in arraylength) - { - // Create big Array with elements - Exception ex = Assert.Throws(() => constructors[0].Invoke(new object[] { length })); - Assert.IsType(ex.InnerException); - } - } - [Fact] [ActiveIssue("https://github.com/dotnet/runtime/issues/67531", typeof(PlatformDetection), nameof(PlatformDetection.IsNativeAot))] public void Invoke_TwoDimensionalArray_CustomBinder_IncorrectTypeArguments() @@ -138,23 +105,6 @@ public void Invoke_TwoDimensionalArray_CustomBinder_IncorrectTypeArguments() Assert.True(args[1] is int); } - [Fact] - public void Invoke_OneParameter() - { - ConstructorInfo[] constructors = GetConstructors(typeof(ClassWith3Constructors)); - ClassWith3Constructors obj = (ClassWith3Constructors)constructors[1].Invoke(new object[] { 100 }); - Assert.Equal(100, obj.intValue); - } - - [Fact] - public void Invoke_TwoParameters() - { - ConstructorInfo[] constructors = GetConstructors(typeof(ClassWith3Constructors)); - ClassWith3Constructors obj = (ClassWith3Constructors)constructors[2].Invoke(new object[] { 101, "hello" }); - Assert.Equal(101, obj.intValue); - Assert.Equal("hello", obj.stringValue); - } - [Fact] [ActiveIssue("https://github.com/dotnet/runtime/issues/67531", typeof(PlatformDetection), nameof(PlatformDetection.IsNativeAot))] public void Invoke_TwoParameters_CustomBinder_IncorrectTypeArgument() @@ -169,66 +119,6 @@ public void Invoke_TwoParameters_CustomBinder_IncorrectTypeArgument() Assert.True(args[1] is string); } - [Fact] - public void Invoke_NoParameters_ThowsTargetParameterCountException() - { - ConstructorInfo[] constructors = GetConstructors(typeof(ClassWith3Constructors)); - Assert.Throws(() => constructors[2].Invoke(new object[0])); - } - - [Fact] - public void Invoke_ParameterMismatch_ThrowsTargetParameterCountException() - { - ConstructorInfo[] constructors = GetConstructors(typeof(ClassWith3Constructors)); - Assert.Throws(() => (ClassWith3Constructors)constructors[2].Invoke(new object[] { 121 })); - } - - [Fact] - public void Invoke_ParameterWrongType_ThrowsArgumentException() - { - ConstructorInfo[] constructors = GetConstructors(typeof(ClassWith3Constructors)); - AssertExtensions.Throws(null, () => (ClassWith3Constructors)constructors[1].Invoke(new object[] { "hello" })); - } - - [Fact] - public void Invoke_ExistingInstance() - { - // Should not produce a second object. - ConstructorInfo[] constructors = GetConstructors(typeof(ClassWith3Constructors)); - ClassWith3Constructors obj1 = new ClassWith3Constructors(100, "hello"); - ClassWith3Constructors obj2 = (ClassWith3Constructors)constructors[2].Invoke(obj1, new object[] { 999, "initialized" }); - Assert.Null(obj2); - Assert.Equal(999, obj1.intValue); - Assert.Equal("initialized", obj1.stringValue); - } - - [Fact] - [ActiveIssue("https://github.com/mono/mono/issues/15026", TestRuntimes.Mono)] - public void Invoke_AbstractClass_ThrowsMemberAccessException() - { - ConstructorInfo[] constructors = GetConstructors(typeof(ConstructorInfoAbstractBase)); - Assert.Throws(() => (ConstructorInfoAbstractBase)constructors[0].Invoke(new object[0])); - } - - [Fact] - public void Invoke_SubClass() - { - ConstructorInfo[] constructors = GetConstructors(typeof(ConstructorInfoDerived)); - ConstructorInfoDerived obj = null; - obj = (ConstructorInfoDerived)constructors[0].Invoke(new object[] { }); - Assert.NotNull(obj); - } - - [Fact] - public void Invoke_Struct() - { - ConstructorInfo[] constructors = GetConstructors(typeof(StructWith1Constructor)); - StructWith1Constructor obj; - obj = (StructWith1Constructor)constructors[0].Invoke(new object[] { 1, 2 }); - Assert.Equal(1, obj.x); - Assert.Equal(2, obj.y); - } - [Fact] public void IsConstructor_ReturnsTrue() { @@ -243,9 +133,18 @@ public void IsPublic() Assert.True(constructors[0].IsPublic); } - public static ConstructorInfo[] GetConstructors(Type type) + // Use this class only from the Invoke_StaticConstructorMultipleTimes method + public static class ClassWithStaticConstructorThatIsCalledMultipleTimesViaReflection { - return type.GetTypeInfo().DeclaredConstructors.ToArray(); + public static class VisibleStatics + { + public static int s_cctorCallCount; + } + + static ClassWithStaticConstructorThatIsCalledMultipleTimesViaReflection() + { + VisibleStatics.s_cctorCallCount++; + } } } @@ -281,20 +180,6 @@ public static class ClassWithStaticConstructor static ClassWithStaticConstructor() { } } - // Use this class only from the Invoke_StaticConstructorMultipleTimes method - public static class ClassWithStaticConstructorThatIsCalledMultipleTimesViaReflection - { - public static class VisibleStatics - { - public static int s_cctorCallCount; - } - - static ClassWithStaticConstructorThatIsCalledMultipleTimesViaReflection() - { - VisibleStatics.s_cctorCallCount++; - } - } - public struct StructWith1Constructor { public int x; diff --git a/src/libraries/System.Reflection/tests/ConstructorInvokerTests.cs b/src/libraries/System.Reflection/tests/ConstructorInvokerTests.cs index 86fa1c4012c270..f5313bde579844 100644 --- a/src/libraries/System.Reflection/tests/ConstructorInvokerTests.cs +++ b/src/libraries/System.Reflection/tests/ConstructorInvokerTests.cs @@ -1,13 +1,26 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System.Runtime.CompilerServices; using Xunit; namespace System.Reflection.Tests { - public class ConstructorInvokerTests + /// + /// These tests use the shared tests from the base class with ConstructorInvoker.Invoke. + /// + public sealed class ConstructorInvokerTests : ConstructorCommonTests { + public override object Invoke(ConstructorInfo constructorInfo, object?[]? parameters) + { + return ConstructorInvoker.Create(constructorInfo).Invoke(new Span(parameters)); + } + + public override object? Invoke(ConstructorInfo constructorInfo, object obj, object?[]? parameters) + { + return MethodInvoker.Create(constructorInfo).Invoke(obj, new Span(parameters)); + } + + protected override bool IsExceptionWrapped => false; [Fact] public void Args_0() @@ -162,16 +175,13 @@ public void ThrowsNonWrappedException_5() } [Fact] - public void ExistingInstance() + public void Invoke_StaticConstructor_NullObject_NullParameters() { - ConstructorInfo ci = typeof(TestClass).GetConstructor(BindingFlags.Public | BindingFlags.Instance, Type.EmptyTypes); - TestClass tc = (TestClass)RuntimeHelpers.GetUninitializedObject(typeof(TestClass)); - Assert.Null(tc._args); + ConstructorInfo[] constructors = GetConstructors(typeof(ClassWithStaticConstructor)); + Assert.Equal(1, constructors.Length); - MethodInvoker invoker = MethodInvoker.Create(ci); - object? obj = invoker.Invoke(tc); - Assert.Equal("0", tc._args); - Assert.Null(obj); + // Invoker classes do not support calling class constructors; use standard reflection for that. + Assert.Throws(() => Invoke(constructors[0], null, new object[] { })); } private class TestClass diff --git a/src/libraries/System.Reflection/tests/InvokeEmit/System.Reflection.InvokeEmit.Tests.csproj b/src/libraries/System.Reflection/tests/InvokeEmit/System.Reflection.InvokeEmit.Tests.csproj index 7b1e61038ccb5d..757b631140f2d6 100644 --- a/src/libraries/System.Reflection/tests/InvokeEmit/System.Reflection.InvokeEmit.Tests.csproj +++ b/src/libraries/System.Reflection/tests/InvokeEmit/System.Reflection.InvokeEmit.Tests.csproj @@ -7,8 +7,12 @@ + + + + diff --git a/src/libraries/System.Reflection/tests/InvokeInterpreted/System.Reflection.InvokeInterpreted.Tests.csproj b/src/libraries/System.Reflection/tests/InvokeInterpreted/System.Reflection.InvokeInterpreted.Tests.csproj index 64dc87d71c0864..5d71379e2d5ca9 100644 --- a/src/libraries/System.Reflection/tests/InvokeInterpreted/System.Reflection.InvokeInterpreted.Tests.csproj +++ b/src/libraries/System.Reflection/tests/InvokeInterpreted/System.Reflection.InvokeInterpreted.Tests.csproj @@ -7,8 +7,12 @@ + + + + diff --git a/src/libraries/System.Reflection/tests/MethodCommonTests.cs b/src/libraries/System.Reflection/tests/MethodCommonTests.cs new file mode 100644 index 00000000000000..b2e8e2f84fbc27 --- /dev/null +++ b/src/libraries/System.Reflection/tests/MethodCommonTests.cs @@ -0,0 +1,297 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Linq; +using Xunit; + +namespace System.Reflection.Tests +{ + /// + /// These tests are shared with MethodInfo.Invoke and MethodInvoker.Invoke by using + /// the abstract Invoke(...) method below. + /// + public abstract class MethodCommonTests + { + public abstract object? Invoke(MethodInfo methodInfo, object? obj, object?[]? parameters); + + protected abstract bool SupportsMissing { get; } + + protected static MethodInfo GetMethod(Type type, string name) + { + return type.GetTypeInfo().GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance).First(method => method.Name.Equals(name)); + } + + [Fact] + public void InvokeNullableRefs() + { + object?[] args; + + int? iNull = null; + args = new object[] { iNull }; + Assert.True((bool)Invoke(GetMethod(nameof(NullableRefMethods.Null)), null, args)); + Assert.Null(args[0]); + Assert.False(((int?)args[0]).HasValue); + + args = new object[] { iNull }; + Assert.True((bool)Invoke(GetMethod(nameof(NullableRefMethods.NullBoxed)), null, args)); + Assert.Null(args[0]); + + args = new object[] { iNull, 10 }; + Assert.True((bool)Invoke(GetMethod(nameof(NullableRefMethods.NullToValue)), null, args)); + Assert.IsType(args[0]); + Assert.Equal(10, (int)args[0]); + + iNull = 42; + args = new object[] { iNull, 42 }; + Assert.True((bool)Invoke(GetMethod(nameof(NullableRefMethods.ValueToNull)), null, args)); + Assert.Null(args[0]); + + iNull = null; + args = new object[] { iNull, 10 }; + Assert.True((bool)Invoke(GetMethod(nameof(NullableRefMethods.NullToValueBoxed)), null, args)); + Assert.IsType(args[0]); + Assert.Equal(10, (int)args[0]); + + static MethodInfo GetMethod(string name) => typeof(NullableRefMethods).GetMethod( + name, BindingFlags.Public | BindingFlags.Static)!; + } + + [Fact] + public void InvokeBoxedNullableRefs() + { + object?[] args; + + object? iNull = null; + args = new object[] { iNull }; + Assert.True((bool)Invoke(GetMethod(nameof(NullableRefMethods.Null)), null, args)); + Assert.Null(args[0]); + + args = new object[] { iNull }; + Assert.True((bool)Invoke(GetMethod(nameof(NullableRefMethods.NullBoxed)), null, args)); + Assert.Null(args[0]); + + args = new object[] { iNull, 10 }; + Assert.True((bool)Invoke(GetMethod(nameof(NullableRefMethods.NullToValue)), null, args)); + Assert.IsType(args[0]); + Assert.Equal(10, (int)args[0]); + + iNull = 42; + args = new object[] { iNull, 42 }; + Assert.True((bool)Invoke(GetMethod(nameof(NullableRefMethods.ValueToNull)), null, args)); + Assert.Null(args[0]); + + iNull = null; + args = new object[] { iNull, 10 }; + Assert.True((bool)Invoke(GetMethod(nameof(NullableRefMethods.NullToValueBoxed)), null, args)); + Assert.IsType(args[0]); + Assert.Equal(10, (int)args[0]); + + static MethodInfo GetMethod(string name) => typeof(NullableRefMethods).GetMethod( + name, BindingFlags.Public | BindingFlags.Static)!; + } + + [Fact] + public void InvokeEnum() + { + // Enums only need to match by primitive type. + Assert.True((bool)GetMethod(nameof(EnumMethods.PassColorsInt)). + Invoke(null, new object[] { OtherColorsInt.Red })); + + // Widening allowed + Assert.True((bool)GetMethod(nameof(EnumMethods.PassColorsInt)). + Invoke(null, new object[] { ColorsShort.Red })); + + // Narrowing not allowed + Assert.Throws(() => GetMethod(nameof(EnumMethods.PassColorsShort)). + Invoke(null, new object[] { OtherColorsInt.Red })); + + static MethodInfo GetMethod(string name) => typeof(EnumMethods).GetMethod( + name, BindingFlags.Public | BindingFlags.Static)!; + } + + [Fact] + public void InvokeNullableEnumParameterDefaultNo() + { + MethodInfo method = typeof(EnumMethods).GetMethod("NullableEnumDefaultNo", BindingFlags.Static | BindingFlags.NonPublic); + + Assert.Null(Invoke(method, null, new object?[] { default(object) })); + Assert.Equal(YesNo.No, Invoke(method, null, new object?[] { YesNo.No })); + Assert.Equal(YesNo.Yes, Invoke(method, null, new object?[] { YesNo.Yes })); + + if (SupportsMissing) + { + Assert.Equal(YesNo.No, Invoke(method, null, new object?[] { Type.Missing })); + } + } + + [Fact] + public void InvokeNullableEnumParameterDefaultYes() + { + MethodInfo method = typeof(EnumMethods).GetMethod("NullableEnumDefaultYes", BindingFlags.Static | BindingFlags.NonPublic); + + Assert.Null(Invoke(method, null, new object?[] { default(object) })); + Assert.Equal(YesNo.No, Invoke(method, null, new object?[] { YesNo.No })); + Assert.Equal(YesNo.Yes, Invoke(method, null, new object?[] { YesNo.Yes })); + + if (SupportsMissing) + { + Assert.Equal(YesNo.Yes, Invoke(method, null, new object?[] { Type.Missing })); + } + } + + [Fact] + public void InvokeNonNullableEnumParameterDefaultYes() + { + MethodInfo method = typeof(EnumMethods).GetMethod("NonNullableEnumDefaultYes", BindingFlags.Static | BindingFlags.NonPublic); + + Assert.Equal(YesNo.No, Invoke(method, null, new object[] { default(object) })); + Assert.Equal(YesNo.No, Invoke(method, null, new object[] { YesNo.No })); + Assert.Equal(YesNo.Yes, Invoke(method, null, new object[] { YesNo.Yes })); + + if (SupportsMissing) + { + Assert.Equal(YesNo.Yes, Invoke(method, null, new object[] { Type.Missing })); + } + } + + [Fact] + public void InvokeNullableEnumParameterDefaultNull() + { + MethodInfo method = typeof(EnumMethods).GetMethod("NullableEnumDefaultNull", BindingFlags.Static | BindingFlags.NonPublic); + + Assert.Null(Invoke(method, null, new object?[] { default(object) })); + Assert.Equal(YesNo.No, Invoke(method, null, new object?[] { YesNo.No })); + Assert.Equal(YesNo.Yes, Invoke(method, null, new object?[] { YesNo.Yes })); + + if (SupportsMissing) + { + Assert.Null(Invoke(method, null, new object?[] { Type.Missing })); + } + } + + [Fact] + public void ValueTypeMembers_WithOverrides() + { + ValueTypeWithOverrides obj = new() { Id = 1 }; + + // ToString is overridden. + Assert.Equal("Hello", (string)Invoke(GetMethod(typeof(ValueTypeWithOverrides), nameof(ValueTypeWithOverrides.ToString)), + obj, null)); + + // Ensure a normal method works. + Assert.Equal(1, (int)Invoke(GetMethod(typeof(ValueTypeWithOverrides), nameof(ValueTypeWithOverrides.GetId)), + obj, null)); + } + + [Fact] + public void ValueTypeMembers_WithoutOverrides() + { + ValueTypeWithoutOverrides obj = new() { Id = 1 }; + + // ToString is not overridden. + Assert.Equal(typeof(ValueTypeWithoutOverrides).ToString(), (string) Invoke(GetMethod(typeof(ValueTypeWithoutOverrides), nameof(ValueTypeWithoutOverrides.ToString)), + obj, null)); + + // Ensure a normal method works. + Assert.Equal(1, (int)Invoke(GetMethod(typeof(ValueTypeWithoutOverrides), nameof(ValueTypeWithoutOverrides.GetId)), + obj, null)); + } + + [Fact] + public void NullableOfTMembers() + { + // Ensure calling a method on Nullable works. + MethodInfo mi = GetMethod(typeof(int?), nameof(Nullable.GetValueOrDefault)); + Assert.Equal(42, Invoke(mi, 42, null)); + } + + [Fact] + public void CopyBackWithByRefArgs() + { + object i = 42; + object[] args = new object[] { i }; + Invoke(GetMethod(typeof(CopyBackMethods), nameof(CopyBackMethods.IncrementByRef)), null, args); + Assert.Equal(43, (int)args[0]); + Assert.NotSame(i, args[0]); // A copy should be made; a boxed instance should never be directly updated. + + i = 42; + args = new object[] { i }; + Invoke(GetMethod(typeof(CopyBackMethods), nameof(CopyBackMethods.IncrementByNullableRef)), null, args); + Assert.Equal(43, (int)args[0]); + Assert.NotSame(i, args[0]); + + object o = null; + args = new object[] { o }; + Invoke(GetMethod(typeof(CopyBackMethods), nameof(CopyBackMethods.SetToNonNullByRef)), null, args); + Assert.NotNull(args[0]); + + o = new object(); + args = new object[] { o }; + Invoke(GetMethod(typeof(CopyBackMethods), nameof(CopyBackMethods.SetToNullByRef)), null, args); + Assert.Null(args[0]); + } + + [Fact] + public unsafe void TestFunctionPointerDirect() + { + // Sanity checks for direct invocation. + void* fn = FunctionPointerMethods.GetFunctionPointer(); + Assert.True(FunctionPointerMethods.GetFunctionPointer()(42)); + Assert.True(FunctionPointerMethods.CallFcnPtr_IntPtr((IntPtr)fn, 42)); + Assert.True(FunctionPointerMethods.CallFcnPtr_Void(fn, 42)); + Assert.False(FunctionPointerMethods.GetFunctionPointer()(41)); + Assert.False(FunctionPointerMethods.CallFcnPtr_IntPtr((IntPtr)fn, 41)); + Assert.False(FunctionPointerMethods.CallFcnPtr_Void(fn, 41)); + } + + [Fact] + public unsafe void TestFunctionPointerAsIntPtrArgType() + { + void* fn = FunctionPointerMethods.GetFunctionPointer(); + + MethodInfo m; + + m = GetMethod(typeof(FunctionPointerMethods), nameof(FunctionPointerMethods.CallFcnPtr_IntPtr)); + Assert.True((bool)Invoke(m, null, new object[] { (IntPtr)fn, 42 })); + Assert.False((bool)Invoke(m, null, new object[] { (IntPtr)fn, 41 })); + + m = GetMethod(typeof(FunctionPointerMethods), nameof(FunctionPointerMethods.CallFcnPtr_Void)); + Assert.True((bool)Invoke(m, null, new object[] { (IntPtr)fn, 42 })); + Assert.False((bool)Invoke(m, null, new object[] { (IntPtr)fn, 41 })); + } + + [Fact] + public unsafe void TestFunctionPointerAsUIntPtrArgType() + { + void* fn = FunctionPointerMethods.GetFunctionPointer(); + + MethodInfo m; + + m = GetMethod(typeof(FunctionPointerMethods), nameof(FunctionPointerMethods.CallFcnPtr_UIntPtr)); + Assert.True((bool)Invoke(m, null, new object[] { (UIntPtr)fn, 42 })); + Assert.False((bool)Invoke(m, null, new object[] { (UIntPtr)fn, 41 })); + + m = GetMethod(typeof(FunctionPointerMethods), nameof(FunctionPointerMethods.CallFcnPtr_Void)); + Assert.True((bool)Invoke(m, null, new object[] { (UIntPtr)fn, 42 })); + Assert.False((bool)Invoke(m, null, new object[] { (UIntPtr)fn, 41 })); + } + + [Fact] + public unsafe void TestFunctionPointerAsArgType() + { + void* fn = FunctionPointerMethods.GetFunctionPointer(); + MethodInfo m = GetMethod(typeof(FunctionPointerMethods), nameof(FunctionPointerMethods.CallFcnPtr_FP)); + Assert.True((bool)Invoke(m, null, new object[] { (IntPtr)fn, 42 })); + Assert.False((bool)Invoke(m, null, new object[] { (IntPtr)fn, 41 })); + } + + [Fact] + public unsafe void TestFunctionPointerAsReturnType() + { + MethodInfo m = GetMethod(typeof(FunctionPointerMethods), nameof(FunctionPointerMethods.GetFunctionPointer)); + object ret = Invoke(m, null, null); + Assert.IsType(ret); + Assert.True((IntPtr)ret != 0); + } + } +} diff --git a/src/libraries/System.Reflection/tests/MethodInfoTests.cs b/src/libraries/System.Reflection/tests/MethodInfoTests.cs index 769fc94266a083..341fb7fcb9de11 100644 --- a/src/libraries/System.Reflection/tests/MethodInfoTests.cs +++ b/src/libraries/System.Reflection/tests/MethodInfoTests.cs @@ -6,12 +6,21 @@ using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using Xunit; -using Xunit.Sdk; namespace System.Reflection.Tests { - public class MethodInfoTests + /// + /// These tests use the shared tests from the base class with MethodInfo.Invoke. + /// + public sealed class MethodInfoTests : MethodCommonTests { + public override object? Invoke(MethodInfo methodInfo, object? obj, object?[]? parameters) + { + return methodInfo.Invoke(obj, parameters); + } + + protected override bool SupportsMissing => false; + [Fact] public void CreateDelegate_PublicMethod() { @@ -361,7 +370,7 @@ public static IEnumerable Invoke_TestData() [Theory] [MemberData(nameof(Invoke_TestData))] - public void Invoke(Type methodDeclaringType, string methodName, object obj, object[] parameters, object result) + public void InvokeWithTestData(Type methodDeclaringType, string methodName, object obj, object[] parameters, object result) { MethodInfo method = GetMethod(methodDeclaringType, methodName); Assert.Equal(result, method.Invoke(obj, parameters)); @@ -370,8 +379,8 @@ public void Invoke(Type methodDeclaringType, string methodName, object obj, obje [Fact] public void Invoke_ParameterSpecification_ArrayOfMissing() { - Invoke(typeof(MethodInfoDefaultParameters), "OptionalObjectParameter", new MethodInfoDefaultParameters(), new object[] { Type.Missing }, Type.Missing); - Invoke(typeof(MethodInfoDefaultParameters), "OptionalObjectParameter", new MethodInfoDefaultParameters(), new Missing[] { Missing.Value }, Missing.Value); + InvokeWithTestData(typeof(MethodInfoDefaultParameters), "OptionalObjectParameter", new MethodInfoDefaultParameters(), new object[] { Type.Missing }, Type.Missing); + InvokeWithTestData(typeof(MethodInfoDefaultParameters), "OptionalObjectParameter", new MethodInfoDefaultParameters(), new Missing[] { Missing.Value }, Missing.Value); } [Fact] @@ -620,149 +629,6 @@ public void ToStringTest_ByMethodInfo(MethodInfo methodInfo, string expected) Assert.Equal(expected, methodInfo.ToString()); } - [Fact] - public void InvokeNullableRefs() - { - object?[] args; - - int? iNull = null; - args = new object[] { iNull }; - Assert.True((bool)GetMethod(nameof(NullableRefMethods.Null)).Invoke(null, args)); - Assert.Null(args[0]); - Assert.False(((int?)args[0]).HasValue); - - args = new object[] { iNull }; - Assert.True((bool)GetMethod(nameof(NullableRefMethods.NullBoxed)).Invoke(null, args)); - Assert.Null(args[0]); - - args = new object[] { iNull, 10 }; - Assert.True((bool)GetMethod(nameof(NullableRefMethods.NullToValue)).Invoke(null, args)); - Assert.IsType(args[0]); - Assert.Equal(10, (int)args[0]); - - iNull = 42; - args = new object[] { iNull, 42 }; - Assert.True((bool)GetMethod(nameof(NullableRefMethods.ValueToNull)).Invoke(null, args)); - Assert.Null(args[0]); - - iNull = null; - args = new object[] { iNull, 10 }; - Assert.True((bool)GetMethod(nameof(NullableRefMethods.NullToValueBoxed)).Invoke(null, args)); - Assert.IsType(args[0]); - Assert.Equal(10, (int)args[0]); - - static MethodInfo GetMethod(string name) => typeof(NullableRefMethods).GetMethod( - name, BindingFlags.Public | BindingFlags.Static)!; - } - - [Fact] - public void InvokeBoxedNullableRefs() - { - object?[] args; - - object? iNull = null; - args = new object[] { iNull }; - Assert.True((bool)GetMethod(nameof(NullableRefMethods.Null)).Invoke(null, args)); - Assert.Null(args[0]); - - args = new object[] { iNull }; - Assert.True((bool)GetMethod(nameof(NullableRefMethods.NullBoxed)).Invoke(null, args)); - Assert.Null(args[0]); - - args = new object[] { iNull, 10 }; - Assert.True((bool)GetMethod(nameof(NullableRefMethods.NullToValue)).Invoke(null, args)); - Assert.IsType(args[0]); - Assert.Equal(10, (int)args[0]); - - iNull = 42; - args = new object[] { iNull, 42 }; - Assert.True((bool)GetMethod(nameof(NullableRefMethods.ValueToNull)).Invoke(null, args)); - Assert.Null(args[0]); - - iNull = null; - args = new object[] { iNull, 10 }; - Assert.True((bool)GetMethod(nameof(NullableRefMethods.NullToValueBoxed)).Invoke(null, args)); - Assert.IsType(args[0]); - Assert.Equal(10, (int)args[0]); - - static MethodInfo GetMethod(string name) => typeof(NullableRefMethods).GetMethod( - name, BindingFlags.Public | BindingFlags.Static)!; - } - - [Fact] - public void InvokeEnum() - { - // Enums only need to match by primitive type. - Assert.True((bool)GetMethod(nameof(EnumMethods.PassColorsInt)). - Invoke(null, new object[] { OtherColorsInt.Red })); - - // Widening allowed - Assert.True((bool)GetMethod(nameof(EnumMethods.PassColorsInt)). - Invoke(null, new object[] { ColorsShort.Red })); - - // Narrowing not allowed - Assert.Throws(() => GetMethod(nameof(EnumMethods.PassColorsShort)). - Invoke(null, new object[] { OtherColorsInt.Red })); - - static MethodInfo GetMethod(string name) => typeof(EnumMethods).GetMethod( - name, BindingFlags.Public | BindingFlags.Static)!; - } - - [Fact] - public static void InvokeNullableEnumParameterDefaultNo() - { - MethodInfo method = typeof(EnumMethods).GetMethod("NullableEnumDefaultNo", BindingFlags.Static | BindingFlags.NonPublic); - - Assert.Null(method.Invoke(null, new object?[] { default(object) })); - Assert.Equal(YesNo.No, method.Invoke(null, new object?[] { YesNo.No })); - Assert.Equal(YesNo.Yes, method.Invoke(null, new object?[] { YesNo.Yes })); - Assert.Equal(YesNo.No, method.Invoke(null, new object?[] { Type.Missing })); - } - - [Fact] - public static void InvokeNullableEnumParameterDefaultYes() - { - MethodInfo method = typeof(EnumMethods).GetMethod("NullableEnumDefaultYes", BindingFlags.Static | BindingFlags.NonPublic); - - Assert.Null(method.Invoke(null, new object?[] { default(object) })); - Assert.Equal(YesNo.No, method.Invoke(null, new object?[] { YesNo.No })); - Assert.Equal(YesNo.Yes, method.Invoke(null, new object?[] { YesNo.Yes })); - Assert.Equal(YesNo.Yes, method.Invoke(null, new object?[] { Type.Missing })); - } - - [Fact] - public static void InvokeNonNullableEnumParameterDefaultYes() - { - MethodInfo method = typeof(EnumMethods).GetMethod("NonNullableEnumDefaultYes", BindingFlags.Static | BindingFlags.NonPublic); - - Assert.Equal(YesNo.No, method.Invoke(null, new object[] { default(object) })); - Assert.Equal(YesNo.No, method.Invoke(null, new object[] { YesNo.No })); - Assert.Equal(YesNo.Yes, method.Invoke(null, new object[] { YesNo.Yes })); - Assert.Equal(YesNo.Yes, method.Invoke(null, new object[] { Type.Missing })); - } - - [Fact] - public static void InvokeNullableEnumParameterDefaultNull() - { - MethodInfo method = typeof(EnumMethods).GetMethod("NullableEnumDefaultNull", BindingFlags.Static | BindingFlags.NonPublic); - - Assert.Null(method.Invoke(null, new object?[] { default(object) })); - Assert.Equal(YesNo.No, method.Invoke(null, new object?[] { YesNo.No })); - Assert.Equal(YesNo.Yes, method.Invoke(null, new object?[] { YesNo.Yes })); - Assert.Null(method.Invoke(null, new object?[] { Type.Missing })); - } - - [Fact] - public static void InvokeNullableEnumParameterNoDefault() - { - MethodInfo method = typeof(EnumMethods).GetMethod("NullableEnumNoDefault", BindingFlags.Static | BindingFlags.NonPublic); - - Assert.Null(method.Invoke(null, new object?[] { default(object) })); - Assert.Equal(YesNo.No, method.Invoke(null, new object?[] { YesNo.No })); - Assert.Equal(YesNo.Yes, method.Invoke(null, new object?[] { YesNo.Yes })); - Assert.Throws(() => method.Invoke(null, new object?[] { Type.Missing })); - } - public static IEnumerable MethodNameAndArguments() { yield return new object[] { nameof(Sample.DefaultString), "Hello", "Hi" }; @@ -805,68 +671,6 @@ public static void InvokeCopiesBackMissingParameterAndArgument() Assert.Null(args[0]); } - [Fact] - public void ValueTypeMembers_WithOverrides() - { - ValueTypeWithOverrides obj = new() { Id = 1 }; - - // ToString is overridden. - Assert.Equal("Hello", (string)GetMethod(typeof(ValueTypeWithOverrides), nameof(ValueTypeWithOverrides.ToString)). - Invoke(obj, null)); - - // Ensure a normal method works. - Assert.Equal(1, (int)GetMethod(typeof(ValueTypeWithOverrides), nameof(ValueTypeWithOverrides.GetId)). - Invoke(obj, null)); - } - - [Fact] - public void ValueTypeMembers_WithoutOverrides() - { - ValueTypeWithoutOverrides obj = new() { Id = 1 }; - - // ToString is not overridden. - Assert.Equal(typeof(ValueTypeWithoutOverrides).ToString(), (string)GetMethod(typeof(ValueTypeWithoutOverrides), nameof(ValueTypeWithoutOverrides.ToString)). - Invoke(obj, null)); - - // Ensure a normal method works. - Assert.Equal(1, (int)GetMethod(typeof(ValueTypeWithoutOverrides), nameof(ValueTypeWithoutOverrides.GetId)). - Invoke(obj, null)); - } - - [Fact] - public void NullableOfTMembers() - { - // Ensure calling a method on Nullable works. - MethodInfo mi = GetMethod(typeof(int?), nameof(Nullable.GetValueOrDefault)); - Assert.Equal(42, mi.Invoke(42, null)); - } - - [Fact] - public void CopyBackWithByRefArgs() - { - object i = 42; - object[] args = new object[] { i }; - GetMethod(typeof(CopyBackMethods), nameof(CopyBackMethods.IncrementByRef)).Invoke(null, args); - Assert.Equal(43, (int)args[0]); - Assert.NotSame(i, args[0]); // A copy should be made; a boxed instance should never be directly updated. - - i = 42; - args = new object[] { i }; - GetMethod(typeof(CopyBackMethods), nameof(CopyBackMethods.IncrementByNullableRef)).Invoke(null, args); - Assert.Equal(43, (int)args[0]); - Assert.NotSame(i, args[0]); - - object o = null; - args = new object[] { o }; - GetMethod(typeof(CopyBackMethods), nameof(CopyBackMethods.SetToNonNullByRef)).Invoke(null, args); - Assert.NotNull(args[0]); - - o = new object(); - args = new object[] { o }; - GetMethod(typeof(CopyBackMethods), nameof(CopyBackMethods.SetToNullByRef)).Invoke(null, args); - Assert.Null(args[0]); - } - [Fact] [ActiveIssue("https://github.com/dotnet/runtime/issues/50957", typeof(PlatformDetection), nameof(PlatformDetection.IsMonoInterpreter))] [ActiveIssue("https://github.com/dotnet/runtime/issues/69919", typeof(PlatformDetection), nameof(PlatformDetection.IsNativeAot))] @@ -894,69 +698,6 @@ private static void SecondCall(MethodInfo mi) Assert.Contains("TestAssembly", asm.ToString()); } - [Fact] - private static unsafe void TestFunctionPointerDirect() - { - // Sanity checks for direct invocation. - void* fn = FunctionPointerMethods.GetFunctionPointer(); - Assert.True(FunctionPointerMethods.GetFunctionPointer()(42)); - Assert.True(FunctionPointerMethods.CallFcnPtr_IntPtr((IntPtr)fn, 42)); - Assert.True(FunctionPointerMethods.CallFcnPtr_Void(fn, 42)); - Assert.False(FunctionPointerMethods.GetFunctionPointer()(41)); - Assert.False(FunctionPointerMethods.CallFcnPtr_IntPtr((IntPtr)fn, 41)); - Assert.False(FunctionPointerMethods.CallFcnPtr_Void(fn, 41)); - } - - [Fact] - private static unsafe void TestFunctionPointerAsIntPtrArgType() - { - void* fn = FunctionPointerMethods.GetFunctionPointer(); - - MethodInfo m; - - m = GetMethod(typeof(FunctionPointerMethods), nameof(FunctionPointerMethods.CallFcnPtr_IntPtr)); - Assert.True((bool)m.Invoke(null, new object[] { (IntPtr)fn, 42 })); - Assert.False((bool)m.Invoke(null, new object[] { (IntPtr)fn, 41 })); - - m = GetMethod(typeof(FunctionPointerMethods), nameof(FunctionPointerMethods.CallFcnPtr_Void)); - Assert.True((bool)m.Invoke(null, new object[] { (IntPtr)fn, 42 })); - Assert.False((bool)m.Invoke(null, new object[] { (IntPtr)fn, 41 })); - } - - [Fact] - private static unsafe void TestFunctionPointerAsUIntPtrArgType() - { - void* fn = FunctionPointerMethods.GetFunctionPointer(); - - MethodInfo m; - - m = GetMethod(typeof(FunctionPointerMethods), nameof(FunctionPointerMethods.CallFcnPtr_UIntPtr)); - Assert.True((bool)m.Invoke(null, new object[] { (UIntPtr)fn, 42 })); - Assert.False((bool)m.Invoke(null, new object[] { (UIntPtr)fn, 41 })); - - m = GetMethod(typeof(FunctionPointerMethods), nameof(FunctionPointerMethods.CallFcnPtr_Void)); - Assert.True((bool)m.Invoke(null, new object[] { (UIntPtr)fn, 42 })); - Assert.False((bool)m.Invoke(null, new object[] { (UIntPtr)fn, 41 })); - } - - [Fact] - private static unsafe void TestFunctionPointerAsArgType() - { - void* fn = FunctionPointerMethods.GetFunctionPointer(); - MethodInfo m = GetMethod(typeof(FunctionPointerMethods), nameof(FunctionPointerMethods.CallFcnPtr_FP)); - Assert.True((bool)m.Invoke(null, new object[] { (IntPtr)fn, 42 })); - Assert.False((bool)m.Invoke(null, new object[] { (IntPtr)fn, 41 })); - } - - [Fact] - private static unsafe void TestFunctionPointerAsReturnType() - { - MethodInfo m = GetMethod(typeof(FunctionPointerMethods), nameof(FunctionPointerMethods.GetFunctionPointer)); - object ret = m.Invoke(null, null); - Assert.IsType(ret); - Assert.True((IntPtr)ret != 0); - } - //Methods for Reflection Metadata private void DummyMethod1(string str, int iValue, long lValue) { @@ -965,11 +706,6 @@ private void DummyMethod1(string str, int iValue, long lValue) private void DummyMethod2() { } - - private static MethodInfo GetMethod(Type type, string name) - { - return type.GetTypeInfo().GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance).First(method => method.Name.Equals(name)); - } } #pragma warning disable 0414 diff --git a/src/libraries/System.Reflection/tests/MethodInvokerTests.cs b/src/libraries/System.Reflection/tests/MethodInvokerTests.cs index 97c9865a64fdaf..94d701507eac1d 100644 --- a/src/libraries/System.Reflection/tests/MethodInvokerTests.cs +++ b/src/libraries/System.Reflection/tests/MethodInvokerTests.cs @@ -2,13 +2,22 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Collections.Generic; -using System.Linq; using Xunit; namespace System.Reflection.Tests { - public class MethodInvokerTests + /// + /// These tests use the shared tests from the base class with MethodInvoker.Invoke. + /// + public class MethodInvokerTests : MethodCommonTests { + public override object? Invoke(MethodInfo methodInfo, object? obj, object?[]? parameters) + { + return MethodInvoker.Create(methodInfo).Invoke(obj, new Span(parameters)); + } + + protected override bool SupportsMissing => false; + [Fact] public void NullTypeValidation() { @@ -286,11 +295,6 @@ public void VerifyThisObj_Null() Assert.Throws(() => invoker.Invoke(obj: null)); } - private static MethodInfo GetMethod(Type type, string name) - { - return type.GetTypeInfo().GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance).First(method => method.Name.Equals(name)); - } - public static IEnumerable Invoke_TestData() => MethodInfoTests.Invoke_TestData(); private class TestClass diff --git a/src/libraries/System.Reflection/tests/System.Reflection.Tests.csproj b/src/libraries/System.Reflection/tests/System.Reflection.Tests.csproj index e6adafcbaac1d5..9553a7dacc4828 100644 --- a/src/libraries/System.Reflection/tests/System.Reflection.Tests.csproj +++ b/src/libraries/System.Reflection/tests/System.Reflection.Tests.csproj @@ -14,14 +14,12 @@ - - - + + + + @@ -31,6 +29,7 @@ + @@ -42,8 +41,7 @@ - + diff --git a/src/libraries/System.Runtime.Caching/src/PACKAGE.md b/src/libraries/System.Runtime.Caching/src/PACKAGE.md new file mode 100644 index 00000000000000..5b79f3f7cfb01c --- /dev/null +++ b/src/libraries/System.Runtime.Caching/src/PACKAGE.md @@ -0,0 +1,46 @@ +## About + + + +Packaged set of simple caching API's derived from those of the same namespace available in .NET Framework since 4.0. This package is intended for use as a bridge when porting .NET Framework applications to .NET. + +[Microsoft.Extensions.Caching.Memory](https://www.nuget.org/packages/Microsoft.Extensions.Caching.Memory/)/[IMemoryCache](https://learn.microsoft.com/aspnet/core/performance/caching/memory?view=aspnetcore-7.0) is recommended over `System.Runtime.Caching`/`MemoryCache` because it's better integrated into ASP.NET Core. For example, `IMemoryCache` works natively with ASP.NET Core [dependency injection](https://learn.microsoft.com/aspnet/core/fundamentals/dependency-injection?view=aspnetcore-7.0). + +**Use `System.Runtime.Caching`/`MemoryCache` as a compatibility bridge when porting code from .NET 4.x to .NET Core.** + +## Key Features + + + +* Use caching facilities like in ASP.NET, but without a dependency on the System.Web assembly. +* Extensible caching mechanism +* Possible to create custom caching providers + +## Main Types + + + +The main types provided by this library are: + +* `System.Runtime.Caching.MemoryCache` + +## Additional Documentation + + + +[MemoryCache.PhysicalMemoryLimit](https://learn.microsoft.com/dotnet/api/system.runtime.caching.memorycache.physicalmemorylimit?view=dotnet-plat-ext-7.0) property is only supported on windows. + +* [Caching in .NET](https://learn.microsoft.com/dotnet/core/extensions/caching) +* [Cache in-memory in ASP.NET Core](https://learn.microsoft.com/aspnet/core/performance/caching/memory?view=aspnetcore-7.0 ) + +## Related Packages + + + +* [Microsoft.Extensions.Caching.Memory](https://www.nuget.org/packages/Microsoft.Extensions.Caching.Memory/) + +## Feedback & Contributing + + + +System.Runtime.Caching is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). \ No newline at end of file diff --git a/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSExportCodeGenerator.cs b/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSExportCodeGenerator.cs index 0a18241ef7c767..a00fc9e4024f76 100644 --- a/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSExportCodeGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSExportCodeGenerator.cs @@ -59,7 +59,7 @@ public JSExportCodeGenerator( public BlockSyntax GenerateJSExportBody() { - StatementSyntax invoke = InvokeSyntax(); + List invoke = InvokeSyntax(); GeneratedStatements statements = GeneratedStatements.Create(_marshallers, _context); bool shouldInitializeVariables = !statements.GuaranteedUnmarshal.IsEmpty || !statements.CleanupCallerAllocated.IsEmpty || !statements.CleanupCalleeAllocated.IsEmpty; VariableDeclarations declarations = VariableDeclarations.GenerateDeclarationsForUnmanagedToManaged(_marshallers, _context, shouldInitializeVariables); @@ -79,7 +79,7 @@ public BlockSyntax GenerateJSExportBody() var tryStatements = new List(); tryStatements.AddRange(statements.Unmarshal); - tryStatements.Add(invoke); + tryStatements.AddRange(invoke); if (!(statements.GuaranteedUnmarshal.IsEmpty && statements.CleanupCalleeAllocated.IsEmpty)) { @@ -93,6 +93,18 @@ public BlockSyntax GenerateJSExportBody() tryStatements.AddRange(statements.Marshal); List allStatements = setupStatements; + + // Wrap unmarshall, invocation and return value marshalling in try-catch. + // In case of exception, marshal exception instead of return value. + var tryInvokeAndMarshal = TryStatement(SingletonList(CatchClause() + .WithDeclaration(CatchDeclaration(IdentifierName(Constants.ExceptionGlobal)).WithIdentifier(Identifier("ex"))) + .WithBlock(Block(SingletonList( + ExpressionStatement(InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(Constants.ArgumentException), IdentifierName(Constants.ToJSMethod))) + .WithArgumentList(ArgumentList(SingletonSeparatedList(Argument(IdentifierName("ex"))))))))))) + .WithBlock(Block(tryStatements)); + List finallyStatements = new List(); if (!(statements.GuaranteedUnmarshal.IsEmpty && statements.CleanupCalleeAllocated.IsEmpty)) { @@ -100,16 +112,14 @@ public BlockSyntax GenerateJSExportBody() } finallyStatements.AddRange(statements.CleanupCallerAllocated); + if (finallyStatements.Count > 0) { - allStatements.Add( - TryStatement(Block(tryStatements), default, FinallyClause(Block(finallyStatements)))); - } - else - { - allStatements.AddRange(tryStatements); + tryInvokeAndMarshal = TryStatement(Block(tryInvokeAndMarshal), default, FinallyClause(Block(finallyStatements))); } + allStatements.Add(tryInvokeAndMarshal); + return Block(allStatements); } @@ -175,7 +185,7 @@ private void SetupSyntax(List statementsToUpdate) Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(1))))))))))))); } - private TryStatementSyntax InvokeSyntax() + private List InvokeSyntax() { var statements = new List(); var arguments = new List(); @@ -205,16 +215,8 @@ private TryStatementSyntax InvokeSyntax() IdentifierName(nativeIdentifier), invocation)); statements.Add(statement); - statements.AddRange(_marshallers.ManagedReturnMarshaller.Generator.Generate(_marshallers.ManagedReturnMarshaller.TypeInfo, _context with { CurrentStage = StubCodeContext.Stage.Marshal })); } - return TryStatement(SingletonList(CatchClause() - .WithDeclaration(CatchDeclaration(IdentifierName(Constants.ExceptionGlobal)).WithIdentifier(Identifier("ex"))) - .WithBlock(Block(SingletonList( - ExpressionStatement(InvocationExpression( - MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, - IdentifierName(Constants.ArgumentException), IdentifierName(Constants.ToJSMethod))) - .WithArgumentList(ArgumentList(SingletonSeparatedList(Argument(IdentifierName("ex"))))))))))) - .WithBlock(Block(statements)); + return statements; } diff --git a/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/Resources/Strings.resx b/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/Resources/Strings.resx index b06ebce2260316..1c6e47ef214f15 100644 --- a/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/Resources/Strings.resx +++ b/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/Resources/Strings.resx @@ -136,7 +136,7 @@ Specified type is not supported by source-generated JavaScript interop. - {0} The generated source will not handle marshalling of the return value of method '{1}'. + {0} The generated source will not handle marshalling of the return value of method '{1}'. For more information see https://aka.ms/dotnet-wasm-jsinterop {0} is a message containing additional details about what is not supported {1} is the name of the method @@ -144,21 +144,21 @@ Type is not supported by source-generated JavaScript interop. - The type '{0}' is not supported by source-generated JavaScript interop. The generated source will not handle marshalling of parameter '{1}'. + The type '{0}' is not supported by source-generated JavaScript interop. The generated source will not handle marshalling of parameter '{1}'. For more information see https://aka.ms/dotnet-wasm-jsinterop - The type '{0}' is not supported by source-generated JavaScript interop. The generated source will not handle marshalling of the return value of method '{1}'. + The type '{0}' is not supported by source-generated JavaScript interop. The generated source will not handle marshalling of the return value of method '{1}'. For more information see https://aka.ms/dotnet-wasm-jsinterop - {0} The generated source will not handle marshalling of parameter '{1}'. + {0} The generated source will not handle marshalling of parameter '{1}'. For more information see https://aka.ms/dotnet-wasm-jsinterop {0} is a message containing additional details about what is not supported {1} is the name of the parameter - The specified '{0}' configuration for the return value of method '{1}' is not supported by source-generated JavaScript interop. + The specified '{0}' configuration for the return value of method '{1}' is not supported by source-generated JavaScript interop. For more information see https://aka.ms/dotnet-wasm-jsinterop - The specified '{0}' configuration for parameter '{1}' is not supported by source-generated JavaScript interop. + The specified '{0}' configuration for parameter '{1}' is not supported by source-generated JavaScript interop. For more information see https://aka.ms/dotnet-wasm-jsinterop Invalid 'JSImportAttribute' usage @@ -167,10 +167,10 @@ Invalid 'JSExportAttribute' usage - Method '{0}' should be 'static', 'partial', and non-generic when marked with 'JSImportAttribute'. JavaScript interop source generation will ignore method '{0}'. + Method '{0}' should be 'static', 'partial', and non-generic when marked with 'JSImportAttribute'. JavaScript interop source generation will ignore method '{0}'. For more information see https://aka.ms/dotnet-wasm-jsinterop - Method '{0}' should be 'static', non-partial and non-generic when marked with 'JSExportAttribute'. JavaScript interop source generation will ignore method '{0}'. + Method '{0}' should be 'static', non-partial and non-generic when marked with 'JSExportAttribute'. JavaScript interop source generation will ignore method '{0}'. For more information see https://aka.ms/dotnet-wasm-jsinterop Methods marked with 'JSImportAttribute' should be 'static', 'partial', and non-generic. JavaScript interop source generation will ignore methods that are non-'static', non-'partial', or generic. @@ -179,7 +179,7 @@ Methods marked with 'JSImportAttribute' should be 'static', non-partial, and non-generic. JavaScript interop source generation will ignore methods that are non-'static', 'partial', or generic. - Method '{0}' is contained in a type '{1}' that is not marked 'partial'. JavaScript interop source generation will ignore method '{0}'. + Method '{0}' is contained in a type '{1}' that is not marked 'partial'. JavaScript interop source generation will ignore method '{0}'. For more information see https://aka.ms/dotnet-wasm-jsinterop 'JSType.Discard' could be only used with void return argument. @@ -212,18 +212,18 @@ JSImportAttribute requires unsafe code. Project must be updated with '<AllowUnsafeBlocks>true</AllowUnsafeBlocks>'. - JSImportAttribute requires unsafe code. Project must be updated with '<AllowUnsafeBlocks>true</AllowUnsafeBlocks>'. + JSImportAttribute requires unsafe code. Project must be updated with '<AllowUnsafeBlocks>true</AllowUnsafeBlocks>'. For more information see https://aka.ms/dotnet-wasm-jsinterop - JSImportAttribute requires unsafe code. + JSImportAttribute requires unsafe code. For more information see https://aka.ms/dotnet-wasm-jsinterop JSExportAttribute requires unsafe code. Project must be updated with '<AllowUnsafeBlocks>true</AllowUnsafeBlocks>'. - JSExportAttribute requires unsafe code. Project must be updated with '<AllowUnsafeBlocks>true</AllowUnsafeBlocks>'. + JSExportAttribute requires unsafe code. Project must be updated with '<AllowUnsafeBlocks>true</AllowUnsafeBlocks>'. For more information see https://aka.ms/dotnet-wasm-jsinterop - JSExportAttribute requires unsafe code. + JSExportAttribute requires unsafe code. For more information see https://aka.ms/dotnet-wasm-jsinterop \ No newline at end of file diff --git a/src/libraries/System.Runtime.InteropServices.JavaScript/tests/JSImportGenerator.UnitTest/Fails.cs b/src/libraries/System.Runtime.InteropServices.JavaScript/tests/JSImportGenerator.UnitTest/Fails.cs index 513af4747a4e92..f5ffe77afb7949 100644 --- a/src/libraries/System.Runtime.InteropServices.JavaScript/tests/JSImportGenerator.UnitTest/Fails.cs +++ b/src/libraries/System.Runtime.InteropServices.JavaScript/tests/JSImportGenerator.UnitTest/Fails.cs @@ -15,39 +15,39 @@ public class Fails public static IEnumerable CodeSnippetsToFail() { yield return new object?[] { CodeSnippets.DefaultReturnMarshaler(), new string[] { - "Please annotate the argument with 'JSMarshalAsAttribute' to specify marshaling of long. The generated source will not handle marshalling of the return value of method 'Import1'.", - "Please annotate the argument with 'JSMarshalAsAttribute' to specify marshaling of long. The generated source will not handle marshalling of the return value of method 'Export1'.", + "Please annotate the argument with 'JSMarshalAsAttribute' to specify marshaling of long. The generated source will not handle marshalling of the return value of method 'Import1'. For more information see https://aka.ms/dotnet-wasm-jsinterop", + "Please annotate the argument with 'JSMarshalAsAttribute' to specify marshaling of long. The generated source will not handle marshalling of the return value of method 'Export1'. For more information see https://aka.ms/dotnet-wasm-jsinterop", },null }; yield return new object?[] { CodeSnippets.DefaultReturnMarshaler(), null, null }; yield return new object?[] { CodeSnippets.DefaultReturnMarshaler("System.Func"), null, null }; yield return new object?[] { CodeSnippets.DefaultReturnMarshaler("System.Action"), new string[] { - "Please annotate the argument with 'JSMarshalAsAttribute' to specify marshaling of global::System.Action. The generated source will not handle marshalling of the return value of method 'Import1'.", - "Please annotate the argument with 'JSMarshalAsAttribute' to specify marshaling of global::System.Action. The generated source will not handle marshalling of the return value of method 'Export1'.", + "Please annotate the argument with 'JSMarshalAsAttribute' to specify marshaling of global::System.Action. The generated source will not handle marshalling of the return value of method 'Import1'. For more information see https://aka.ms/dotnet-wasm-jsinterop", + "Please annotate the argument with 'JSMarshalAsAttribute' to specify marshaling of global::System.Action. The generated source will not handle marshalling of the return value of method 'Export1'. For more information see https://aka.ms/dotnet-wasm-jsinterop", },null }; yield return new object?[] { CodeSnippets.DefaultReturnMarshaler("System.Span"), null, null }; yield return new object?[] { CodeSnippets.DefaultReturnMarshaler("System.Span"), null, null }; yield return new object?[] { CodeSnippets.DefaultReturnMarshaler("System.ArraySegment"), null, null }; yield return new object?[] { CodeSnippets.AllMissing, new string[] { - "Please annotate the argument with 'JSMarshalAsAttribute' to specify marshaling of object. The generated source will not handle marshalling of parameter 'a1'.", - "Please annotate the argument with 'JSMarshalAsAttribute' to specify marshaling of long. The generated source will not handle marshalling of parameter 'a2'.", - "Please annotate the argument with 'JSMarshalAsAttribute' to specify marshaling of long. The generated source will not handle marshalling of parameter 'a3'.", - "Please annotate the argument with 'JSMarshalAsAttribute' to specify marshaling of global::System.Action. The generated source will not handle marshalling of parameter 'a4'.", - "Please annotate the argument with 'JSMarshalAsAttribute' to specify marshaling of global::System.Func. The generated source will not handle marshalling of parameter 'a5'.", - "Please annotate the argument with 'JSMarshalAsAttribute' to specify marshaling of global::System.Span. The generated source will not handle marshalling of parameter 'a6'.", - "Please annotate the argument with 'JSMarshalAsAttribute' to specify marshaling of global::System.ArraySegment. The generated source will not handle marshalling of parameter 'a7'.", - "Please annotate the argument with 'JSMarshalAsAttribute' to specify marshaling of global::System.Threading.Tasks.Task. The generated source will not handle marshalling of parameter 'a8'.", - "Please annotate the argument with 'JSMarshalAsAttribute' to specify marshaling of object[]. The generated source will not handle marshalling of parameter 'a9'.", - "Please annotate the argument with 'JSMarshalAsAttribute' to specify marshaling of global::System.DateTime. The generated source will not handle marshalling of parameter 'a10'.", - "Please annotate the argument with 'JSMarshalAsAttribute' to specify marshaling of global::System.DateTimeOffset. The generated source will not handle marshalling of parameter 'a11'.", - "Please annotate the argument with 'JSMarshalAsAttribute' to specify marshaling of global::System.Threading.Tasks.Task. The generated source will not handle marshalling of parameter 'a12'.", - "Please annotate the argument with 'JSMarshalAsAttribute' to specify marshaling of global::System.Threading.Tasks.Task. The generated source will not handle marshalling of parameter 'a13'.", - "Please annotate the argument with 'JSMarshalAsAttribute' to specify marshaling of global::System.Threading.Tasks.Task. The generated source will not handle marshalling of parameter 'a14'.", - "Please annotate the argument with 'JSMarshalAsAttribute' to specify marshaling of global::System.Threading.Tasks.Task. The generated source will not handle marshalling of parameter 'a15'.", + "Please annotate the argument with 'JSMarshalAsAttribute' to specify marshaling of object. The generated source will not handle marshalling of parameter 'a1'. For more information see https://aka.ms/dotnet-wasm-jsinterop", + "Please annotate the argument with 'JSMarshalAsAttribute' to specify marshaling of long. The generated source will not handle marshalling of parameter 'a2'. For more information see https://aka.ms/dotnet-wasm-jsinterop", + "Please annotate the argument with 'JSMarshalAsAttribute' to specify marshaling of long. The generated source will not handle marshalling of parameter 'a3'. For more information see https://aka.ms/dotnet-wasm-jsinterop", + "Please annotate the argument with 'JSMarshalAsAttribute' to specify marshaling of global::System.Action. The generated source will not handle marshalling of parameter 'a4'. For more information see https://aka.ms/dotnet-wasm-jsinterop", + "Please annotate the argument with 'JSMarshalAsAttribute' to specify marshaling of global::System.Func. The generated source will not handle marshalling of parameter 'a5'. For more information see https://aka.ms/dotnet-wasm-jsinterop", + "Please annotate the argument with 'JSMarshalAsAttribute' to specify marshaling of global::System.Span. The generated source will not handle marshalling of parameter 'a6'. For more information see https://aka.ms/dotnet-wasm-jsinterop", + "Please annotate the argument with 'JSMarshalAsAttribute' to specify marshaling of global::System.ArraySegment. The generated source will not handle marshalling of parameter 'a7'. For more information see https://aka.ms/dotnet-wasm-jsinterop", + "Please annotate the argument with 'JSMarshalAsAttribute' to specify marshaling of global::System.Threading.Tasks.Task. The generated source will not handle marshalling of parameter 'a8'. For more information see https://aka.ms/dotnet-wasm-jsinterop", + "Please annotate the argument with 'JSMarshalAsAttribute' to specify marshaling of object[]. The generated source will not handle marshalling of parameter 'a9'. For more information see https://aka.ms/dotnet-wasm-jsinterop", + "Please annotate the argument with 'JSMarshalAsAttribute' to specify marshaling of global::System.DateTime. The generated source will not handle marshalling of parameter 'a10'. For more information see https://aka.ms/dotnet-wasm-jsinterop", + "Please annotate the argument with 'JSMarshalAsAttribute' to specify marshaling of global::System.DateTimeOffset. The generated source will not handle marshalling of parameter 'a11'. For more information see https://aka.ms/dotnet-wasm-jsinterop", + "Please annotate the argument with 'JSMarshalAsAttribute' to specify marshaling of global::System.Threading.Tasks.Task. The generated source will not handle marshalling of parameter 'a12'. For more information see https://aka.ms/dotnet-wasm-jsinterop", + "Please annotate the argument with 'JSMarshalAsAttribute' to specify marshaling of global::System.Threading.Tasks.Task. The generated source will not handle marshalling of parameter 'a13'. For more information see https://aka.ms/dotnet-wasm-jsinterop", + "Please annotate the argument with 'JSMarshalAsAttribute' to specify marshaling of global::System.Threading.Tasks.Task. The generated source will not handle marshalling of parameter 'a14'. For more information see https://aka.ms/dotnet-wasm-jsinterop", + "Please annotate the argument with 'JSMarshalAsAttribute' to specify marshaling of global::System.Threading.Tasks.Task. The generated source will not handle marshalling of parameter 'a15'. For more information see https://aka.ms/dotnet-wasm-jsinterop", },null }; yield return new object?[] { CodeSnippets.InOutRef, new string[] { - "Parameters with 'in', 'out' and 'ref' modifiers are not supported by source-generated JavaScript interop. The generated source will not handle marshalling of parameter 'a1'.", - "Parameters with 'in', 'out' and 'ref' modifiers are not supported by source-generated JavaScript interop. The generated source will not handle marshalling of parameter 'a2'.", - "Parameters with 'in', 'out' and 'ref' modifiers are not supported by source-generated JavaScript interop. The generated source will not handle marshalling of parameter 'a3'.", + "Parameters with 'in', 'out' and 'ref' modifiers are not supported by source-generated JavaScript interop. The generated source will not handle marshalling of parameter 'a1'. For more information see https://aka.ms/dotnet-wasm-jsinterop", + "Parameters with 'in', 'out' and 'ref' modifiers are not supported by source-generated JavaScript interop. The generated source will not handle marshalling of parameter 'a2'. For more information see https://aka.ms/dotnet-wasm-jsinterop", + "Parameters with 'in', 'out' and 'ref' modifiers are not supported by source-generated JavaScript interop. The generated source will not handle marshalling of parameter 'a3'. For more information see https://aka.ms/dotnet-wasm-jsinterop", }, null }; } diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/DiagnosticDescriptorProvider.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/DiagnosticDescriptorProvider.cs index c676f27cb9f401..fab2d39182c8b7 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/DiagnosticDescriptorProvider.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/DiagnosticDescriptorProvider.cs @@ -26,6 +26,7 @@ internal sealed class DiagnosticDescriptorProvider : IDiagnosticDescriptorProvid GeneratorDiagnostic.NotSupported { NotSupportedDetails: not null, TypePositionInfo.IsManagedReturnPosition: false } => GeneratorDiagnostics.ParameterTypeNotSupportedWithDetails, GeneratorDiagnostic.UnnecessaryData { TypePositionInfo.IsManagedReturnPosition: false } => GeneratorDiagnostics.UnnecessaryParameterMarshallingInfo, GeneratorDiagnostic.UnnecessaryData { TypePositionInfo.IsManagedReturnPosition: true } => GeneratorDiagnostics.UnnecessaryReturnMarshallingInfo, + GeneratorDiagnostic.NotRecommended => GeneratorDiagnostics.GeneratedComInterfaceUsageDoesNotFollowBestPractices, { IsFatal: false } => null, { TypePositionInfo.IsManagedReturnPosition: true } => GeneratorDiagnostics.ReturnTypeNotSupported, { TypePositionInfo.IsManagedReturnPosition: false } => GeneratorDiagnostics.ParameterTypeNotSupported, diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/GeneratorDiagnostics.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/GeneratorDiagnostics.cs index f0b98c3c535ac3..1e1849592f7ba4 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/GeneratorDiagnostics.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/GeneratorDiagnostics.cs @@ -477,7 +477,7 @@ public class Ids /// public static readonly DiagnosticDescriptor HResultTypeWillBeTreatedAsStruct = - new DiagnosticDescriptor( + DiagnosticDescriptorHelper.Create( Ids.NotRecommendedGeneratedComInterfaceUsage, GetResourceString(nameof(SR.HResultTypeWillBeTreatedAsStructTitle)), GetResourceString(nameof(SR.HResultTypeWillBeTreatedAsStructMessage)), @@ -485,6 +485,17 @@ public class Ids DiagnosticSeverity.Info, isEnabledByDefault: true); + /// + public static readonly DiagnosticDescriptor GeneratedComInterfaceUsageDoesNotFollowBestPractices = + new DiagnosticDescriptor( + Ids.NotRecommendedGeneratedComInterfaceUsage, + GetResourceString(nameof(SR.ComInterfaceUsageDoesNotFollowBestPracticesTitle)), + GetResourceString(nameof(SR.ComInterfaceUsageDoesNotFollowBestPracticesMessageWithDetails)), + Category, + DiagnosticSeverity.Info, + isEnabledByDefault: true, + helpLinkUri: "aka.ms/GeneratedComInterfaceUsage"); + /// /// Report diagnostic for invalid configuration for string marshalling. /// diff --git a/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/Strings.resx b/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/Strings.resx index ebf0170e328752..2da298e8a7b0df 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/Strings.resx +++ b/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/Strings.resx @@ -875,7 +875,13 @@ [In] and [Out] attributes - The `[Out]` attribute is only supported on array parameters. Consider using 'out' or 'ref' keywords to make the parameter mutable. + The '[Out]' attribute is only supported on array parameters. Consider using 'out' or 'ref' keywords to make the parameter mutable. + + + The '[In]' attribute is only supported on array parameters. By-value parameters are considered read-only by default. + + + The '[In]' and '[Out]' attributes are only supported on array parameters. Consider using the 'ref' keyword to make the parameter mutable. The return value in the managed definition will be converted to an 'out' parameter when calling the unmanaged COM method. If the return value is intended to be the HRESULT code returned by the unmanaged COM method, use '[PreserveSig]' on the method. @@ -889,4 +895,19 @@ This type will be treated as a struct in the native signature, not as a native HRESULT - \ No newline at end of file + + It is recommended to use explicit '[In]' and '[Out]' attributes on array parameters. + + + The usage of 'GeneratedComInterfaceAttribute' does not follow recommendations. + + + The usage of 'GeneratedComInterfaceAttribute' does not follow recommendations. {0} + + + The usage of 'LibraryImportAttribute' does not follow recommendations. + + + The usage of 'LibraryImportAttribute' does not follow recommendations. {0} + + diff --git a/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/xlf/Strings.cs.xlf b/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/xlf/Strings.cs.xlf index aba6380486a3b7..e55ae1b901475e 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/xlf/Strings.cs.xlf +++ b/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/xlf/Strings.cs.xlf @@ -147,6 +147,16 @@ Hostování .NET COM s EnableComHosting nepodporuje rozhraní s generatedComInterfaceAttribute + + The usage of 'GeneratedComInterfaceAttribute' does not follow recommendations. {0} + Použití „GeneratedComInterfaceAttribute“ není v souladu s doporučeními. {0} + + + + The usage of 'GeneratedComInterfaceAttribute' does not follow recommendations. + Použití „GeneratedComInterfaceAttribute“ není v souladu s doporučeními. + + The return value in the managed definition will be converted to an 'out' parameter when calling the unmanaged COM method. If the return value is intended to be the HRESULT code returned by the unmanaged COM method, use '[PreserveSig]' on the method. Vrácená hodnota ve spravované definici se při volání nespravované metody COM převede na parametr out. Pokud má být návratovou hodnotou kód HRESULT vrácený nespravovanou metodou COM, použijte u metody [PreserveSig]. @@ -439,12 +449,17 @@ The type '{0}' will be treated as a struct in the native signature, not as a native HRESULT. To treat this as an HRESULT, add '[return:MarshalAs(UnmanagedType.Error)]' to the method. - The type '{0}' will be treated as a struct in the native signature, not as a native HRESULT. To treat this as an HRESULT, add '[return:MarshalAs(UnmanagedType.Error)]' to the method. + Typ {0} bude v nativním podpisu považován za strukturu, nikoli za nativní HRESULT. Pokud jej chcete považovat za HRESULT, přidejte do metody [return:MarshalAs(UnmanagedType.Error)]. This type will be treated as a struct in the native signature, not as a native HRESULT - This type will be treated as a struct in the native signature, not as a native HRESULT + Tento typ bude v nativním podpisu považován za strukturu, nikoli za nativní HRESULT. + + + + The '[In]' attribute is only supported on array parameters. By-value parameters are considered read-only by default. + Atribut „[In]“ se podporuje pouze u parametrů pole. Parametry podle hodnoty jsou ve výchozím nastavení považovány za parametry jen pro čtení. @@ -467,6 +482,11 @@ Poskytnuté atributy „[In]“ a „[Out]“ u tohoto parametru se na tomto parametru nepodporují. + + The '[In]' and '[Out]' attributes are only supported on array parameters. Consider using the 'ref' keyword to make the parameter mutable. + Atributy „[In]“ a „[Out]“ jsou podporovány pouze u parametrů pole. Zvažte použití klíčového slova „ref“ k nastavení měnitelného parametru. + + [In] and [Out] attributes atributy [In] a [Out] @@ -702,6 +722,16 @@ Neplatné použití atributu VirtualMethodIndexAttribute + + The usage of 'LibraryImportAttribute' does not follow recommendations. {0} + Použití „LibraryImportAttribute“ není v souladu s doporučeními. {0} + + + + The usage of 'LibraryImportAttribute' does not follow recommendations. + Použití LibraryImportAttribute není v souladu s doporučeními. + + The element type of the 'ReadOnlySpan' returned by 'GetManagedValuesSource' must be the same as the element type returned by 'GetManagedValuesDestination'. Typ prvku ReadOnlySpan vrácený GetManagedValuesSource musí být stejný, jako typ prvku vrácený GetManagedValuesDestination. @@ -903,8 +933,8 @@ - The `[Out]` attribute is only supported on array parameters. Consider using 'out' or 'ref' keywords to make the parameter mutable. - The `[Out]` attribute is only supported on array parameters. Consider using 'out' or 'ref' keywords to make the parameter mutable. + The '[Out]' attribute is only supported on array parameters. Consider using 'out' or 'ref' keywords to make the parameter mutable. + Atribut „[Out]“ se podporuje jen u parametrů pole. Zvažte použití klíčových slov „out“ nebo „ref“, aby se parametr dalo měnit. @@ -917,6 +947,11 @@ Typ {0}určuje, že podporuje zařazování ve směru „Out“, ale neposkytuje metodu ToManaged, která vrací spravovaný typ + + It is recommended to use explicit '[In]' and '[Out]' attributes on array parameters. + U parametrů pole se doporučuje použít explicitní atributy „[In]“ a „[Out]“. + + 'GeneratedComInterfaceAttribute' and 'GeneratedComClassAttribute' require unsafe code. Project must be updated with '<AllowUnsafeBlocks>true</AllowUnsafeBlocks>'. GeneratedComInterfaceAttribute a GeneratedComClassAttribute vyžadují nebezpečný kód. Projekt se musí aktualizovat na <AllowUnsafeBlocks>true</AllowUnsafeBlocks>. diff --git a/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/xlf/Strings.de.xlf b/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/xlf/Strings.de.xlf index 6b978f9aff5a9f..72df69b6fbfe8e 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/xlf/Strings.de.xlf +++ b/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/xlf/Strings.de.xlf @@ -147,6 +147,16 @@ Das .NET COM-Hosting mit "EnableComHosting" unterstützt keine Schnittstellen mit "GeneratedComInterfaceAttribute". + + The usage of 'GeneratedComInterfaceAttribute' does not follow recommendations. {0} + Die Verwendung von "GeneratedComInterfaceAttribute" entspricht nicht den Empfehlungen. {0} + + + + The usage of 'GeneratedComInterfaceAttribute' does not follow recommendations. + Die Verwendung von "GeneratedComInterfaceAttribute" entspricht nicht den Empfehlungen. + + The return value in the managed definition will be converted to an 'out' parameter when calling the unmanaged COM method. If the return value is intended to be the HRESULT code returned by the unmanaged COM method, use '[PreserveSig]' on the method. Der Rückgabewert in der verwalteten Definition wird beim Aufrufen der nicht verwalteten COM-Methode in einen out-Parameter konvertiert. Wenn als Rückgabewert der von der nicht verwalteten COM-Methode zurückgegebene HRESULT-Code eingesetzt werden soll, verwenden Sie "[PreserveSig]" für die Methode. @@ -439,12 +449,17 @@ The type '{0}' will be treated as a struct in the native signature, not as a native HRESULT. To treat this as an HRESULT, add '[return:MarshalAs(UnmanagedType.Error)]' to the method. - The type '{0}' will be treated as a struct in the native signature, not as a native HRESULT. To treat this as an HRESULT, add '[return:MarshalAs(UnmanagedType.Error)]' to the method. + Der Typ '{0}' wird als Struktur in der nativen Signatur und nicht als natives HRESULT behandelt. Um ihn als HRESULT zu behandeln, fügen Sie der Methode "[return:MarshalAs(UnmanagedType.Error)]" hinzu. This type will be treated as a struct in the native signature, not as a native HRESULT - This type will be treated as a struct in the native signature, not as a native HRESULT + Dieser Typ wird als Struktur in der nativen Signatur und nicht als natives HRESULT behandelt + + + + The '[In]' attribute is only supported on array parameters. By-value parameters are considered read-only by default. + Das [In]-Attribut wird nur für Arrayparameter unterstützt. Wertbezogene Parameter werden standardmäßig als schreibgeschützt betrachtet. @@ -467,6 +482,11 @@ Die angegebenen Attribute \"[In]\" und \"[Out]\" für diesen Parameter werden für diesen Parameter nicht unterstützt. + + The '[In]' and '[Out]' attributes are only supported on array parameters. Consider using the 'ref' keyword to make the parameter mutable. + Die [In]- und [Out]-Attribute werden nur für Arrayparameter unterstützt. Erwägen Sie die Verwendung des Schlüsselworts "ref", damit der Parameter geändert werden kann. + + [In] and [Out] attributes [In]- und [Out]-Attribute @@ -702,6 +722,16 @@ Ungültige Verwendung von "VirtualMethodIndexAttribute" + + The usage of 'LibraryImportAttribute' does not follow recommendations. {0} + Die Verwendung von "LibraryImportAttribute" entspricht nicht den Empfehlungen. {0} + + + + The usage of 'LibraryImportAttribute' does not follow recommendations. + Die Verwendung von "LibraryImportAttribute" entspricht nicht den Empfehlungen. + + The element type of the 'ReadOnlySpan' returned by 'GetManagedValuesSource' must be the same as the element type returned by 'GetManagedValuesDestination'. Der von \"GetManagedValuesSource\" zurückgegebene Elementtyp \"ReadOnlySpan\" muss mit dem Elementtyp identisch sein, der von \"GetManagedValuesDestination\" zurückgegeben wird. @@ -903,8 +933,8 @@ - The `[Out]` attribute is only supported on array parameters. Consider using 'out' or 'ref' keywords to make the parameter mutable. - The `[Out]` attribute is only supported on array parameters. Consider using 'out' or 'ref' keywords to make the parameter mutable. + The '[Out]' attribute is only supported on array parameters. Consider using 'out' or 'ref' keywords to make the parameter mutable. + Das [Out]-Attribut wird nur für Arrayparameter unterstützt. Erwägen Sie die Verwendung der Schlüsselwörter "out" oder "ref", damit der Parameter geändert werden kann. @@ -917,6 +947,11 @@ Der Typ \"{0}\" gibt an, dass das Marshalling in der Out-Richtung unterstützt wird. Er stellt jedoch keine ToManaged-Methode bereit, die den verwalteten Typ zurückgibt. + + It is recommended to use explicit '[In]' and '[Out]' attributes on array parameters. + Es wird empfohlen, explizite [In]- und [Out]-Attribute für Arrayparameter zu verwenden. + + 'GeneratedComInterfaceAttribute' and 'GeneratedComClassAttribute' require unsafe code. Project must be updated with '<AllowUnsafeBlocks>true</AllowUnsafeBlocks>'. 'GeneratedComInterfaceAttribute' und 'GeneratedComClassAttribute' erfordern unsicheren Code. Das Projekt muss mit '<AllowUnsafeBlocks>wahr</AllowUnsafeBlocks>' aktualisiert werden. diff --git a/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/xlf/Strings.es.xlf b/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/xlf/Strings.es.xlf index 077827dbf7ef9b..6e7027caded806 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/xlf/Strings.es.xlf +++ b/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/xlf/Strings.es.xlf @@ -147,6 +147,16 @@ El hospedaje COM de .NET con “EnableComHosting” no admite interfaces con “GeneratedComInterfaceAttribute” + + The usage of 'GeneratedComInterfaceAttribute' does not follow recommendations. {0} + El uso de "GeneratedComInterfaceAttribute" no sigue las recomendaciones. {0} + + + + The usage of 'GeneratedComInterfaceAttribute' does not follow recommendations. + El uso de "GeneratedComInterfaceAttribute" no sigue las recomendaciones. + + The return value in the managed definition will be converted to an 'out' parameter when calling the unmanaged COM method. If the return value is intended to be the HRESULT code returned by the unmanaged COM method, use '[PreserveSig]' on the method. El valor devuelto en la definición administrada se convertirá en un parámetro “out” al llamar al método COM no administrado. Si el valor devuelto debe ser el código HRESULT devuelto por el método COM no administrado, use “[PreserveSig]” en el método. @@ -439,12 +449,17 @@ The type '{0}' will be treated as a struct in the native signature, not as a native HRESULT. To treat this as an HRESULT, add '[return:MarshalAs(UnmanagedType.Error)]' to the method. - The type '{0}' will be treated as a struct in the native signature, not as a native HRESULT. To treat this as an HRESULT, add '[return:MarshalAs(UnmanagedType.Error)]' to the method. + El tipo '{0}' se tratará como un struct en la firma nativa, no como un HRESULT nativo. Para tratarlo como un HRESULT, añade '[return:MarshalAs(UnmanagedType.Error)]' al método. This type will be treated as a struct in the native signature, not as a native HRESULT - This type will be treated as a struct in the native signature, not as a native HRESULT + Este tipo se tratará como una estructura en la firma nativa, no como un HRESULT nativo. + + + + The '[In]' attribute is only supported on array parameters. By-value parameters are considered read-only by default. + El atributo "[In]" solo se admite en parámetros de matriz. Los parámetros por valor se consideran de solo lectura de forma predeterminada. @@ -467,6 +482,11 @@ En este parámetro, los atributos “[In]” y “[Out]” proporcionados no se admiten. + + The '[In]' and '[Out]' attributes are only supported on array parameters. Consider using the 'ref' keyword to make the parameter mutable. + Los atributos "[In]" y "[Out]" solo se admiten en parámetros de matriz. Considere la posibilidad de usar la palabra clave "ref" para hacer que el parámetro sea mutable. + + [In] and [Out] attributes Atributos [In] y [Out] @@ -702,6 +722,16 @@ Uso de ”VirtualMethodIndexAttribute” no válido + + The usage of 'LibraryImportAttribute' does not follow recommendations. {0} + El uso de "LibraryImportAttribute" no sigue las recomendaciones. {0} + + + + The usage of 'LibraryImportAttribute' does not follow recommendations. + El uso de "LibraryImportAttribute" no sigue las recomendaciones. + + The element type of the 'ReadOnlySpan' returned by 'GetManagedValuesSource' must be the same as the element type returned by 'GetManagedValuesDestination'. El tipo de elemento del “ReadOnlySpan” devuelto por “GetManagedValuesSource” debe ser el mismo que el tipo de elemento devuelto por “GetManagedValuesDestination”. @@ -903,8 +933,8 @@ - The `[Out]` attribute is only supported on array parameters. Consider using 'out' or 'ref' keywords to make the parameter mutable. - The `[Out]` attribute is only supported on array parameters. Consider using 'out' or 'ref' keywords to make the parameter mutable. + The '[Out]' attribute is only supported on array parameters. Consider using 'out' or 'ref' keywords to make the parameter mutable. + El atributo "[Out]" solo se admite en parámetros de matriz. Considere la posibilidad de usar palabras clave "out" o "ref" para hacer que el parámetro sea mutable. @@ -917,6 +947,11 @@ El tipo “{0}” especifica que admite la serialización en la dirección “Out”, pero no proporciona un método “ToManaged” que devuelva el tipo administrado + + It is recommended to use explicit '[In]' and '[Out]' attributes on array parameters. + Se recomienda usar los atributos explícitos "[In]" y "[Out]" en los parámetros de matriz. + + 'GeneratedComInterfaceAttribute' and 'GeneratedComClassAttribute' require unsafe code. Project must be updated with '<AllowUnsafeBlocks>true</AllowUnsafeBlocks>'. "GeneratedComInterfaceAttribute" y "GeneratedComClassAttribute" requieren código no seguro. El proyecto debe actualizarse con "<AllowUnsafeBlocks>true</AllowUnsafeBlocks>". diff --git a/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/xlf/Strings.fr.xlf b/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/xlf/Strings.fr.xlf index 7b06fbb9966ce1..50bffbbb83c75f 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/xlf/Strings.fr.xlf +++ b/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/xlf/Strings.fr.xlf @@ -147,6 +147,16 @@ L'hébergement .NET COM avec 'EnableComHosting' ne prend pas en charge les interfaces avec 'GeneratedComInterfaceAttribute' + + The usage of 'GeneratedComInterfaceAttribute' does not follow recommendations. {0} + L'utilisation de « GeneratedComInterfaceAttribute » ne suit pas les recommandations. {0} + + + + The usage of 'GeneratedComInterfaceAttribute' does not follow recommendations. + L'utilisation de « GeneratedComInterfaceAttribute » ne suit pas les recommandations. + + The return value in the managed definition will be converted to an 'out' parameter when calling the unmanaged COM method. If the return value is intended to be the HRESULT code returned by the unmanaged COM method, use '[PreserveSig]' on the method. La valeur de retour dans la définition managée est convertie en paramètre 'out' lors de l’appel de la méthode COM non managée. Si la valeur de retour doit être le code HRESULT retourné par la méthode COM non managée, utilisez '[PreserveSig]' sur la méthode. @@ -439,12 +449,17 @@ The type '{0}' will be treated as a struct in the native signature, not as a native HRESULT. To treat this as an HRESULT, add '[return:MarshalAs(UnmanagedType.Error)]' to the method. - The type '{0}' will be treated as a struct in the native signature, not as a native HRESULT. To treat this as an HRESULT, add '[return:MarshalAs(UnmanagedType.Error)]' to the method. + Le type « {0} » sera traité en tant que struct dans la signature native, et non en tant que HRESULT natif. Pour le traiter en tant que HRESULT, ajoutez « [return:MarshalAs(UnmanagedType.Error)] » à la méthode. This type will be treated as a struct in the native signature, not as a native HRESULT - This type will be treated as a struct in the native signature, not as a native HRESULT + Ce type sera traité en tant que struct dans la signature native, et non en tant que HRESULT natif + + + + The '[In]' attribute is only supported on array parameters. By-value parameters are considered read-only by default. + L'attribut '[In]' n'est pris en charge que sur les paramètres de tableau. Les paramètres par valeur sont considérés comme en lecture seule par défaut. @@ -467,6 +482,11 @@ Les attributs « [In] » et « [Out] » fournis sur ce paramètre ne sont pas pris en charge sur ce paramètre. + + The '[In]' and '[Out]' attributes are only supported on array parameters. Consider using the 'ref' keyword to make the parameter mutable. + Les attributs '[In]' et '[Out]' ne sont pris en charge que sur les paramètres de tableau. Pensez à utiliser le mot-clé 'ref' pour rendre le paramètre mutable. + + [In] and [Out] attributes Attributs [In] et [Out] @@ -702,6 +722,16 @@ Utilisation de « VirtualMethodIndexAttribute » non valide + + The usage of 'LibraryImportAttribute' does not follow recommendations. {0} + L'utilisation de « LibraryImport Attribute » ne suit pas les recommandations. {0} + + + + The usage of 'LibraryImportAttribute' does not follow recommendations. + L'utilisation de « LibraryImport Attribute » ne suit pas les recommandations. + + The element type of the 'ReadOnlySpan' returned by 'GetManagedValuesSource' must be the same as the element type returned by 'GetManagedValuesDestination'. Le type d’élément de « ReadOnlySpan » retourné par « GetManagedValuesSource » doit être identique au type d’élément retourné par « GetManagedValuesDestination ». @@ -903,8 +933,8 @@ - The `[Out]` attribute is only supported on array parameters. Consider using 'out' or 'ref' keywords to make the parameter mutable. - L’attribut '[Out]' est uniquement pris en charge sur les paramètres de tableau. Envisagez d’utiliser des mots clés 'out' ou 'ref' pour rendre le paramètre mutable. + The '[Out]' attribute is only supported on array parameters. Consider using 'out' or 'ref' keywords to make the parameter mutable. + L'attribut '[Out]' n'est pris en charge que sur les paramètres de tableau. Pensez à utiliser les mots-clés « out » ou « ref » pour rendre le paramètre mutable. @@ -917,6 +947,11 @@ Le type « {0} » spécifie qu’il prend en charge le marshaling dans la direction « Out », mais il ne fournit pas de méthode « ToManaged » qui retourne le type managé + + It is recommended to use explicit '[In]' and '[Out]' attributes on array parameters. + Il est recommandé d'utiliser les attributs explicites '[In]' et '[Out]' sur les paramètres du tableau. + + 'GeneratedComInterfaceAttribute' and 'GeneratedComClassAttribute' require unsafe code. Project must be updated with '<AllowUnsafeBlocks>true</AllowUnsafeBlocks>'. « GeneratedComInterfaceAttribute » et « GeneratedComClassAttribute » nécessitent du code non sécurisé. Le projet doit être mis à jour avec « <AllowUnsafeBlocks>true</AllowUnsafeBlocks> ». diff --git a/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/xlf/Strings.it.xlf b/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/xlf/Strings.it.xlf index ee4ead36b3e114..a0129cfbe9495e 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/xlf/Strings.it.xlf +++ b/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/xlf/Strings.it.xlf @@ -147,6 +147,16 @@ L'hosting COM .NET con 'EnableComHosting' non supporta le interfacce con 'GeneratedComInterfaceAttribute'. + + The usage of 'GeneratedComInterfaceAttribute' does not follow recommendations. {0} + L'utilizzo di 'GeneratedComInterfaceAttribute' non segue le raccomandazioni.{0} + + + + The usage of 'GeneratedComInterfaceAttribute' does not follow recommendations. + L'utilizzo di 'GeneratedComInterfaceAttribute' non segue le raccomandazioni. + + The return value in the managed definition will be converted to an 'out' parameter when calling the unmanaged COM method. If the return value is intended to be the HRESULT code returned by the unmanaged COM method, use '[PreserveSig]' on the method. Il valore restituito nella definizione gestita verrà convertito in un parametro 'out' quando si chiama il metodo COM non gestito. Se il valore restituito deve essere il codice HRESULT restituito dal metodo COM non gestito, utilizzare '[PreserveSig]' sul metodo. @@ -439,12 +449,17 @@ The type '{0}' will be treated as a struct in the native signature, not as a native HRESULT. To treat this as an HRESULT, add '[return:MarshalAs(UnmanagedType.Error)]' to the method. - The type '{0}' will be treated as a struct in the native signature, not as a native HRESULT. To treat this as an HRESULT, add '[return:MarshalAs(UnmanagedType.Error)]' to the method. + Il tipo '{0}' verrà considerato come uno struct nella firma nativa, non come HRESULT nativo. Per considerare questo valore come HRESULT, aggiungere '[return:MarshalAs(UnmanagedType.Error)]' al metodo. This type will be treated as a struct in the native signature, not as a native HRESULT - This type will be treated as a struct in the native signature, not as a native HRESULT + Questo tipo verrà considerato come uno struct nella firma nativa, non come HRESULT nativo + + + + The '[In]' attribute is only supported on array parameters. By-value parameters are considered read-only by default. + L'attributo '[In]' è supportato solo sui parametri di matrice. I parametri per valore sono considerati di sola lettura per impostazione predefinita. @@ -467,6 +482,11 @@ Gli attributi '[In]' e '[Out]' specificati per questo parametro non sono supportati in questo parametro. + + The '[In]' and '[Out]' attributes are only supported on array parameters. Consider using the 'ref' keyword to make the parameter mutable. + Gli attributi [In]' e '[Out]' sono supportati solo nei parametri di matrice. Provare a usare le parole chiave 'ref' per rendere modificabile il parametro. + + [In] and [Out] attributes Attributi [In] e [Out] @@ -702,6 +722,16 @@ Utilizzo di 'VirtualMethodIndexAttribute' non valido + + The usage of 'LibraryImportAttribute' does not follow recommendations. {0} + L'utilizzo di 'LibraryImportAttribute' non segue le raccomandazioni. {0} + + + + The usage of 'LibraryImportAttribute' does not follow recommendations. + L'utilizzo di 'LibraryImportAttribute' non segue le raccomandazioni. + + The element type of the 'ReadOnlySpan' returned by 'GetManagedValuesSource' must be the same as the element type returned by 'GetManagedValuesDestination'. Il tipo di elemento di 'ReadOnlySpan' restituito da 'GetManagedValuesSource' deve essere uguale al tipo di elemento restituito da 'GetManagedValuesDestination'. @@ -903,7 +933,7 @@ - The `[Out]` attribute is only supported on array parameters. Consider using 'out' or 'ref' keywords to make the parameter mutable. + The '[Out]' attribute is only supported on array parameters. Consider using 'out' or 'ref' keywords to make the parameter mutable. L'attributo '[Out]' è supportato solo nei parametri di matrice. Provare a usare le parole chiave 'out' o 'ref' per rendere modificabile il parametro. @@ -917,6 +947,11 @@ Il tipo '{0}' specifica che supporta il marshalling nella direzione 'Out', ma non fornisce un metodo 'ToManaged' che restituisce il tipo gestito + + It is recommended to use explicit '[In]' and '[Out]' attributes on array parameters. + È consigliabile usare attributi '[In]' e '[Out]' espliciti nei parametri di matrice. + + 'GeneratedComInterfaceAttribute' and 'GeneratedComClassAttribute' require unsafe code. Project must be updated with '<AllowUnsafeBlocks>true</AllowUnsafeBlocks>'. GeneratedComInterfaceAttribute e 'GeneratedComClassAttribute' richiedono codice non gestito. Il progetto deve essere aggiornato con '<AllowUnsafeBlocks>true</AllowUnsafeBlocks>'. diff --git a/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/xlf/Strings.ja.xlf b/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/xlf/Strings.ja.xlf index 9c12755cb80b44..7fa5dab29eafcf 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/xlf/Strings.ja.xlf +++ b/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/xlf/Strings.ja.xlf @@ -147,6 +147,16 @@ 'EnableComHosting' を使用した .NET COM ホスティングでは、'GeneratedComInterfaceAttribute' のインターフェイスはサポートされていません + + The usage of 'GeneratedComInterfaceAttribute' does not follow recommendations. {0} + 'GeneratedComInterfaceAttribute' の使用は推奨事項に従っていません。{0} + + + + The usage of 'GeneratedComInterfaceAttribute' does not follow recommendations. + 'GeneratedComInterfaceAttribute' の使用は推奨事項に従っていません。 + + The return value in the managed definition will be converted to an 'out' parameter when calling the unmanaged COM method. If the return value is intended to be the HRESULT code returned by the unmanaged COM method, use '[PreserveSig]' on the method. マネージド定義の戻り値は、アンマネージド COM メソッドを呼び出すときに 'out' パラメーターに変換されます。戻り値を、アンマネージド COM メソッドによって返される HRESULT コードにする場合は、メソッドで '[PreserveSig]' を使用してください。 @@ -439,12 +449,17 @@ The type '{0}' will be treated as a struct in the native signature, not as a native HRESULT. To treat this as an HRESULT, add '[return:MarshalAs(UnmanagedType.Error)]' to the method. - The type '{0}' will be treated as a struct in the native signature, not as a native HRESULT. To treat this as an HRESULT, add '[return:MarshalAs(UnmanagedType.Error)]' to the method. + 型 '{0}' は、ネイティブ HRESULT としてではなく、ネイティブ シグネチャ内の構造体として扱われます。これを HRESULT として扱うには、'[return:MarshalAs(UnmanagedType.Error)]' をメソッドに追加します。 This type will be treated as a struct in the native signature, not as a native HRESULT - This type will be treated as a struct in the native signature, not as a native HRESULT + この型はネイティブの HRESULT ではなく、ネイティブ シグネチャの構造体として扱われます + + + + The '[In]' attribute is only supported on array parameters. By-value parameters are considered read-only by default. + '[In]' 属性は配列パラメーターでのみサポートされています。既定では、値によるパラメーターは読み取り専用と見なされます。 @@ -467,6 +482,11 @@ このパラメーターに指定された '[In]' 属性と '[Out]' 属性は、このパラメーターではサポートされていません。 + + The '[In]' and '[Out]' attributes are only supported on array parameters. Consider using the 'ref' keyword to make the parameter mutable. + '[In]' および '[Out]' 属性は、配列パラメーターでのみサポートされます。パラメーターを変更可能にするには、'ref' キーワードを使用することを検討してください。 + + [In] and [Out] attributes 属性の[In]と[Out] @@ -702,6 +722,16 @@ 'VirtualMethodIndexAttribute' の使用法が無効です + + The usage of 'LibraryImportAttribute' does not follow recommendations. {0} + 'LibraryImportAttribute' の使用は推奨事項に従っていません。{0} + + + + The usage of 'LibraryImportAttribute' does not follow recommendations. + 'LibraryImportAttribute' の使用は推奨事項に従っていません。 + + The element type of the 'ReadOnlySpan' returned by 'GetManagedValuesSource' must be the same as the element type returned by 'GetManagedValuesDestination'. 'GetManagedValuesSource' によって返される 'ReadOnlySpan' の要素型は、'GetManagedValuesDestination' によって返される要素型と同じである必要があります。 @@ -903,8 +933,8 @@ - The `[Out]` attribute is only supported on array parameters. Consider using 'out' or 'ref' keywords to make the parameter mutable. - The `[Out]` attribute is only supported on array parameters. Consider using 'out' or 'ref' keywords to make the parameter mutable. + The '[Out]' attribute is only supported on array parameters. Consider using 'out' or 'ref' keywords to make the parameter mutable. + '[Out]' 属性は、配列パラメーターでのみサポートされます。パラメーターを変更可能にするには、'out' または 'ref' キーワードを使用することを検討してください。 @@ -917,6 +947,11 @@ 型 '{0}' は、'Out' 方向のマーシャリングをサポートしますが、マネージド型を返す 'ToManaged' メソッドは指定されません + + It is recommended to use explicit '[In]' and '[Out]' attributes on array parameters. + 配列パラメーターに明示的な '[In]' および '[Out]' 属性を使用することをお勧めします。 + + 'GeneratedComInterfaceAttribute' and 'GeneratedComClassAttribute' require unsafe code. Project must be updated with '<AllowUnsafeBlocks>true</AllowUnsafeBlocks>'. 'GeneratedComInterfaceAttribute' および 'GeneratedComClassAttribute' にはアンセーフ コードが必要です。プロジェクトは '<AllowUnsafeBlocks>true</AllowUnsafeBlocks>' で更新する必要があります。 diff --git a/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/xlf/Strings.ko.xlf b/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/xlf/Strings.ko.xlf index 53b72d47f60fa3..d3d130445bab26 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/xlf/Strings.ko.xlf +++ b/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/xlf/Strings.ko.xlf @@ -147,6 +147,16 @@ 'EnableComHosting'을 사용한 .NET COM 호스팅은 'GeneratedComInterfaceAttribute'를 사용한 인터페이스를 지원하지 않습니다. + + The usage of 'GeneratedComInterfaceAttribute' does not follow recommendations. {0} + 'GeneratedComInterfaceAttribute' 사용법이 권장 사항을 따르지 않습니다. {0} + + + + The usage of 'GeneratedComInterfaceAttribute' does not follow recommendations. + 'GeneratedComInterfaceAttribute' 사용법이 권장 사항을 따르지 않습니다. + + The return value in the managed definition will be converted to an 'out' parameter when calling the unmanaged COM method. If the return value is intended to be the HRESULT code returned by the unmanaged COM method, use '[PreserveSig]' on the method. 관리 정의의 반환 값은 관리되지 않는 COM 메서드를 호출할 때 'out' 매개 변수로 변환됩니다. 반환 값이 관리되지 않는 COM 메서드에서 반환된 HRESULT 코드인 경우 메서드에서 '[PreserveSig]'를 사용하세요. @@ -439,12 +449,17 @@ The type '{0}' will be treated as a struct in the native signature, not as a native HRESULT. To treat this as an HRESULT, add '[return:MarshalAs(UnmanagedType.Error)]' to the method. - The type '{0}' will be treated as a struct in the native signature, not as a native HRESULT. To treat this as an HRESULT, add '[return:MarshalAs(UnmanagedType.Error)]' to the method. + '{0}' 형식은 네이티브 HRESULT가 아니라 네이티브 서명에서 구조체로 처리됩니다. HRESULT로 처리하려면 메서드에 '[return:MarshalAs(UnmanagedType.Error)]'를 추가하세요. This type will be treated as a struct in the native signature, not as a native HRESULT - This type will be treated as a struct in the native signature, not as a native HRESULT + 이 형식은 네이티브 HRESULT가 아니라 네이티브 서명의 구조체로 처리됩니다. + + + + The '[In]' attribute is only supported on array parameters. By-value parameters are considered read-only by default. + '[In]' 특성은 배열 매개 변수에서만 지원됩니다. 값별 매개 변수는 기본적으로 읽기 전용으로 간주됩니다. @@ -467,6 +482,11 @@ 이 매개 변수에 제공된 '[In]' 및 '[Out]' 특성은 이 매개 변수에서 지원되지 않습니다. + + The '[In]' and '[Out]' attributes are only supported on array parameters. Consider using the 'ref' keyword to make the parameter mutable. + '[In]' 및 '[Out]' 특성은 배열 매개 변수에서만 지원됩니다. 매개 변수를 변경 가능하게 만들려면 'ref' 키워드를 사용하는 것이 좋습니다. + + [In] and [Out] attributes [In] 및 [Out] 속성 @@ -702,6 +722,16 @@ 잘못된 'VirtualMethodIndexAttribute' 사용 + + The usage of 'LibraryImportAttribute' does not follow recommendations. {0} + 'LibraryImportAttribute' 사용법이 권장 사항을 따르지 않습니다. {0} + + + + The usage of 'LibraryImportAttribute' does not follow recommendations. + 'LibraryImportAttribute' 사용법이 권장 사항을 따르지 않습니다. + + The element type of the 'ReadOnlySpan' returned by 'GetManagedValuesSource' must be the same as the element type returned by 'GetManagedValuesDestination'. 'GetManagedValuesSource'에서 반환된 'ReadOnlySpan'의 요소 형식은 'GetManagedValuesDestination'에서 반환된 요소 형식과 동일해야 합니다. @@ -903,8 +933,8 @@ - The `[Out]` attribute is only supported on array parameters. Consider using 'out' or 'ref' keywords to make the parameter mutable. - The `[Out]` attribute is only supported on array parameters. Consider using 'out' or 'ref' keywords to make the parameter mutable. + The '[Out]' attribute is only supported on array parameters. Consider using 'out' or 'ref' keywords to make the parameter mutable. + '[Out]' 특성은 배열 매개 변수에서만 지원됩니다. 매개 변수를 변경할 수 있도록 'out' 또는 'ref' 키워드를 사용하는 것이 좋습니다. @@ -917,6 +947,11 @@ 형식 '{0}'은(는) 'Out' 방향으로 마샬링을 지원하도록 지정하지만 관리 형식을 반환하는 'ToManaged' 메서드를 제공하지 않습니다. + + It is recommended to use explicit '[In]' and '[Out]' attributes on array parameters. + 배열 매개 변수에는 명시적인 '[In]' 및 '[Out]' 특성을 사용하는 것이 좋습니다. + + 'GeneratedComInterfaceAttribute' and 'GeneratedComClassAttribute' require unsafe code. Project must be updated with '<AllowUnsafeBlocks>true</AllowUnsafeBlocks>'. 'GeneratedComInterfaceAttribute' 및 'GeneratedComClassAttribute'에는 안전하지 않은 코드가 필요합니다. 프로젝트를 '<AllowUnsafeBlocks>true</AllowUnsafeBlocks>'로 업데이트해야 합니다. diff --git a/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/xlf/Strings.pl.xlf b/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/xlf/Strings.pl.xlf index fb6cf9f172c21b..c5dbc075109417 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/xlf/Strings.pl.xlf +++ b/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/xlf/Strings.pl.xlf @@ -147,6 +147,16 @@ Hosting modelu COM platformy .NET z elementem „EnableComHosting” nie obsługuje interfejsów z atrybutem „GeneratedComInterfaceAttribute” + + The usage of 'GeneratedComInterfaceAttribute' does not follow recommendations. {0} + Użycie atrybutu „GeneratedComInterfaceAttribute” nie jest zgodne z zaleceniami. {0} + + + + The usage of 'GeneratedComInterfaceAttribute' does not follow recommendations. + Użycie atrybutu „GeneratedComInterfaceAttribute” nie jest zgodne z zaleceniami. + + The return value in the managed definition will be converted to an 'out' parameter when calling the unmanaged COM method. If the return value is intended to be the HRESULT code returned by the unmanaged COM method, use '[PreserveSig]' on the method. Wartość zwracana w definicji zarządzanej zostanie przekonwertowana na parametr „out” podczas wywoływania niezarządzanej metody COM. Jeśli wartość zwracana ma być kodem HRESULT zwracanym przez niezarządzaną metodę COM, należy użyć „[PreserveSig]” w metodzie. @@ -439,12 +449,17 @@ The type '{0}' will be treated as a struct in the native signature, not as a native HRESULT. To treat this as an HRESULT, add '[return:MarshalAs(UnmanagedType.Error)]' to the method. - The type '{0}' will be treated as a struct in the native signature, not as a native HRESULT. To treat this as an HRESULT, add '[return:MarshalAs(UnmanagedType.Error)]' to the method. + Typ „{0}” będzie traktowany jako struktura w podpisie natywnym, a nie jako natywny wynik HRESULT. Aby traktować to jako HRESULT, dodaj element „[return:MarshalAs(UnmanagedType.Error)]” do metody. This type will be treated as a struct in the native signature, not as a native HRESULT - This type will be treated as a struct in the native signature, not as a native HRESULT + Ten typ będzie traktowany jako struktura w podpisie natywnym, a nie jako natywny wynik HRESULT + + + + The '[In]' attribute is only supported on array parameters. By-value parameters are considered read-only by default. + Atrybut „[In]” jest obsługiwany tylko w przypadku parametrów tablicy. Parametry według wartości są domyślnie uznawane za tylko do odczytu. @@ -467,6 +482,11 @@ Podane atrybuty „[In]” i „[Out]” w tym parametrze nie są obsługiwane w tym parametrze. + + The '[In]' and '[Out]' attributes are only supported on array parameters. Consider using the 'ref' keyword to make the parameter mutable. + Atrybuty „[In]” i „[Out]” są obsługiwane tylko w przypadku parametrów tablicy. Rozważ użycie słowa kluczowego „ref”, aby umożliwić modyfikowanie parametru. + + [In] and [Out] attributes Atrybuty [In] i [Out] @@ -702,6 +722,16 @@ Nieprawidłowe użycie atrybutu „VirtualMethodIndexAttribute” + + The usage of 'LibraryImportAttribute' does not follow recommendations. {0} + Użycie atrybutu „LibraryImportAttribute” nie jest zgodne z zaleceniami. {0} + + + + The usage of 'LibraryImportAttribute' does not follow recommendations. + Użycie atrybutu „LibraryImportAttribute” nie jest zgodne z zaleceniami. + + The element type of the 'ReadOnlySpan' returned by 'GetManagedValuesSource' must be the same as the element type returned by 'GetManagedValuesDestination'. Typ elementu „ReadOnlySpan” zwracany przez element „GetManagedValuesSource” musi być taki sam jak typ elementu zwracany przez element „GetManagedValuesDestination”. @@ -903,7 +933,7 @@ - The `[Out]` attribute is only supported on array parameters. Consider using 'out' or 'ref' keywords to make the parameter mutable. + The '[Out]' attribute is only supported on array parameters. Consider using 'out' or 'ref' keywords to make the parameter mutable. Atrybut „[Out]” jest obsługiwany tylko w przypadku parametrów tablicy. Rozważ użycie słów kluczowych „out” lub „ref”, aby umożliwić modyfikowanie parametru. @@ -917,6 +947,11 @@ Typ „{0}” określa, że obsługuje skierowanie w kierunku „Out”, ale nie zapewnia metody „ToManaged”, która zwraca typ zarządzany + + It is recommended to use explicit '[In]' and '[Out]' attributes on array parameters. + Zaleca się używanie jawnych atrybutów „[In]” i „[Out]” w parametrach tablicy. + + 'GeneratedComInterfaceAttribute' and 'GeneratedComClassAttribute' require unsafe code. Project must be updated with '<AllowUnsafeBlocks>true</AllowUnsafeBlocks>'. Atrybut „GeneratedComInterfaceAttribute” i „GeneratedComClassAttribute” wymagają niebezpiecznego kodu. Projekt musi zostać zaktualizowany za pomocą polecenia „<AllowUnsafeBlocks>true</AllowUnsafeBlocks>”. diff --git a/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/xlf/Strings.pt-BR.xlf b/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/xlf/Strings.pt-BR.xlf index 8d2e606d2d0307..80bc83669b4547 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/xlf/Strings.pt-BR.xlf +++ b/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/xlf/Strings.pt-BR.xlf @@ -147,6 +147,16 @@ A hospedagem .NET COM com 'EnableComHosting' não dá suporte a interfaces com 'GeneratedComInterfaceAttribute' + + The usage of 'GeneratedComInterfaceAttribute' does not follow recommendations. {0} + O uso de 'GeneratedComInterfaceAttribute' não segue as recomendações. {0} + + + + The usage of 'GeneratedComInterfaceAttribute' does not follow recommendations. + O uso de 'GeneratedComInterfaceAttribute' não segue as recomendações. + + The return value in the managed definition will be converted to an 'out' parameter when calling the unmanaged COM method. If the return value is intended to be the HRESULT code returned by the unmanaged COM method, use '[PreserveSig]' on the method. O valor de retorno na definição gerenciada será convertido em um parâmetro 'out' ao chamar o método COM não gerenciado. Se o valor de retorno for o código HRESULT retornado pelo método COM não gerenciado, use '[PreserveSig]' no método. @@ -439,12 +449,17 @@ The type '{0}' will be treated as a struct in the native signature, not as a native HRESULT. To treat this as an HRESULT, add '[return:MarshalAs(UnmanagedType.Error)]' to the method. - The type '{0}' will be treated as a struct in the native signature, not as a native HRESULT. To treat this as an HRESULT, add '[return:MarshalAs(UnmanagedType.Error)]' to the method. + O tipo '{0}' será tratado como uma estrutura na assinatura nativa, não como um HRESULT nativo. Para tratar isso como um HRESULT, adicione '[return:MarshalAs(UnmanagedType.Error)]' ao método. This type will be treated as a struct in the native signature, not as a native HRESULT - This type will be treated as a struct in the native signature, not as a native HRESULT + Este tipo será tratado como uma estrutura na assinatura nativa, não como um HRESULT nativo + + + + The '[In]' attribute is only supported on array parameters. By-value parameters are considered read-only by default. + O '[In]' atributo só tem suporte nos parâmetros de matriz. Parâmetros por valor são considerados somente leitura por padrão. @@ -467,6 +482,11 @@ Os atributos '[In]' e '[Out]' neste parâmetro não têm suporte neste parâmetro. + + The '[In]' and '[Out]' attributes are only supported on array parameters. Consider using the 'ref' keyword to make the parameter mutable. + Os atributos '[In]' e '[Out]' só têm suporte nos parâmetros de matriz. Considere usar a palavra-chave 'ref' para tornar o parâmetro mutável. + + [In] and [Out] attributes Atributos [In] e [Out] @@ -702,6 +722,16 @@ Uso de 'VirtualMethodIndexAttribute' inválido + + The usage of 'LibraryImportAttribute' does not follow recommendations. {0} + O uso de 'LibraryImportAttribute' não segue as recomendações. {0} + + + + The usage of 'LibraryImportAttribute' does not follow recommendations. + O uso de 'LibraryImportAttribute' não segue as recomendações. + + The element type of the 'ReadOnlySpan' returned by 'GetManagedValuesSource' must be the same as the element type returned by 'GetManagedValuesDestination'. O tipo de elemento de 'ReadOnlySpan' retornado por 'GetManagedValuesSource' deve ser igual ao tipo de elemento retornado por 'GetManagedValuesDestination'. @@ -903,8 +933,8 @@ - The `[Out]` attribute is only supported on array parameters. Consider using 'out' or 'ref' keywords to make the parameter mutable. - The `[Out]` attribute is only supported on array parameters. Consider using 'out' or 'ref' keywords to make the parameter mutable. + The '[Out]' attribute is only supported on array parameters. Consider using 'out' or 'ref' keywords to make the parameter mutable. + O atributo "[Out]" só tem suporte nos parâmetros de matriz. Considere usar palavras-chave "out" ou "ref" para tornar o parâmetro mutável. @@ -917,6 +947,11 @@ O tipo '{0}' especifica que ele dá suporte a marshalling na direção 'Out', mas não fornece um método 'ToManaged' que retorna o tipo gerenciado + + It is recommended to use explicit '[In]' and '[Out]' attributes on array parameters. + É recomendável usar atributos '[In]' e '[Out]' explícitos nos parâmetros de matriz. + + 'GeneratedComInterfaceAttribute' and 'GeneratedComClassAttribute' require unsafe code. Project must be updated with '<AllowUnsafeBlocks>true</AllowUnsafeBlocks>'. "GeneratedComInterfaceAttribute" e "GeneratedComClassAttribute" exigem código não seguro. O projeto deve ser atualizado com "<AllowUnsafeBlocks>true</AllowUnsafeBlocks>". diff --git a/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/xlf/Strings.ru.xlf b/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/xlf/Strings.ru.xlf index fca13d2b25e7f6..d820e6ee7f3e76 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/xlf/Strings.ru.xlf +++ b/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/xlf/Strings.ru.xlf @@ -147,6 +147,16 @@ Размещение .NET COM с "EnableComHosting" не поддерживает интерфейсы с "GeneratedComInterfaceAttribute" + + The usage of 'GeneratedComInterfaceAttribute' does not follow recommendations. {0} + Использование "GeneratedComInterfaceAttribute" не соответствует рекомендациям. {0} + + + + The usage of 'GeneratedComInterfaceAttribute' does not follow recommendations. + Использование "GeneratedComInterfaceAttribute" не соответствует рекомендациям. + + The return value in the managed definition will be converted to an 'out' parameter when calling the unmanaged COM method. If the return value is intended to be the HRESULT code returned by the unmanaged COM method, use '[PreserveSig]' on the method. Возвращаемое значение в управляемом определении будет преобразовано в параметр "out" при вызове неуправляемого метода COM. Если возвращаемое значение должно быть кодом HRESULT, возвращаемым неуправляемым COM-методом, используйте "[PreserveSig]" в методе. @@ -439,12 +449,17 @@ The type '{0}' will be treated as a struct in the native signature, not as a native HRESULT. To treat this as an HRESULT, add '[return:MarshalAs(UnmanagedType.Error)]' to the method. - The type '{0}' will be treated as a struct in the native signature, not as a native HRESULT. To treat this as an HRESULT, add '[return:MarshalAs(UnmanagedType.Error)]' to the method. + Тип "{0}" будет рассматриваться как структура в собственной подписи, а не как собственный HRESULT. Чтобы обработать это как HRESULT, добавьте в метод "[return:MarshalAs(UnmanagedType.Error)]". This type will be treated as a struct in the native signature, not as a native HRESULT - This type will be treated as a struct in the native signature, not as a native HRESULT + Этот тип будет рассматриваться как структура в собственной подписи, а не как собственный HRESULT. + + + + The '[In]' attribute is only supported on array parameters. By-value parameters are considered read-only by default. + Атрибут "[In]" поддерживается только для параметров массива. Параметры по значению по умолчанию считаются доступными только для чтения. @@ -467,6 +482,11 @@ Указанные атрибуты \"[In]\" и \"[Out]\" для этого параметра не поддерживаются. + + The '[In]' and '[Out]' attributes are only supported on array parameters. Consider using the 'ref' keyword to make the parameter mutable. + Атрибуты "[In]" и "[Out]" поддерживаются только для параметров массива. Рассмотрите возможность использования ключевого слова "ref", чтобы сделать параметр изменяемым. + + [In] and [Out] attributes Атрибуты [In] и [Out] @@ -702,6 +722,16 @@ Недопустимое использование VirtualMethodIndexAttribute + + The usage of 'LibraryImportAttribute' does not follow recommendations. {0} + Использование "LibraryImportAttribute" не соответствует рекомендациям. {0} + + + + The usage of 'LibraryImportAttribute' does not follow recommendations. + Использование "LibraryImportAttribute" не соответствует рекомендациям. + + The element type of the 'ReadOnlySpan' returned by 'GetManagedValuesSource' must be the same as the element type returned by 'GetManagedValuesDestination'. Тип элемента \"ReadOnlySpan\", возвращенный методом \"GetManagedValuesSource\", должен совпадать с типом элемента, возвращаемым методом \"GetManagedValuesDestination\". @@ -903,7 +933,7 @@ - The `[Out]` attribute is only supported on array parameters. Consider using 'out' or 'ref' keywords to make the parameter mutable. + The '[Out]' attribute is only supported on array parameters. Consider using 'out' or 'ref' keywords to make the parameter mutable. Атрибут "[Out]" поддерживается только для параметров массива. Рассмотрите возможность использования ключевых слов "out" или "ref", чтобы сделать параметр изменяемым. @@ -917,6 +947,11 @@ Тип \"{0}\" указывает, что поддерживает маршализацию в направлении \"наружу\", но не предоставляет метод \"ToManaged\", который возвращает управляемый тип + + It is recommended to use explicit '[In]' and '[Out]' attributes on array parameters. + Рекомендуется использовать явные атрибуты "[In]" и "[Out]" для параметров массива. + + 'GeneratedComInterfaceAttribute' and 'GeneratedComClassAttribute' require unsafe code. Project must be updated with '<AllowUnsafeBlocks>true</AllowUnsafeBlocks>'. Для "GeneratedComInterfaceAttribute" и "GeneratedComClassAttribute" требуется небезопасный код. Проект необходимо обновить с использованием значения "<AllowUnsafeBlocks>true</AllowUnsafeBlocks>". diff --git a/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/xlf/Strings.tr.xlf b/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/xlf/Strings.tr.xlf index 44d2a4a4cbadd3..614ee9fd9a31b8 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/xlf/Strings.tr.xlf +++ b/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/xlf/Strings.tr.xlf @@ -147,6 +147,16 @@ 'EnableComHosting' ile barındırma .NET COM, 'GeneratedComInterfaceAttribute' ile arabirimleri desteklemez + + The usage of 'GeneratedComInterfaceAttribute' does not follow recommendations. {0} + 'GeneratedComInterfaceAttribute' kullanımı önerilere uygun değil. {0} + + + + The usage of 'GeneratedComInterfaceAttribute' does not follow recommendations. + 'GeneratedComInterfaceAttribute' kullanımı önerilere uygun değil. + + The return value in the managed definition will be converted to an 'out' parameter when calling the unmanaged COM method. If the return value is intended to be the HRESULT code returned by the unmanaged COM method, use '[PreserveSig]' on the method. Yönetilen tanımdaki dönüş değeri, yönetilmeyen COM yöntemi çağrılırken 'out' parametresine dönüştürülür. Dönüş değerinin yönetilmeyen COM yöntemi tarafından döndürülen HRESULT kodu olması amaçlanmışsa, yöntemde '[PreserveSig]' kullanın. @@ -439,12 +449,17 @@ The type '{0}' will be treated as a struct in the native signature, not as a native HRESULT. To treat this as an HRESULT, add '[return:MarshalAs(UnmanagedType.Error)]' to the method. - The type '{0}' will be treated as a struct in the native signature, not as a native HRESULT. To treat this as an HRESULT, add '[return:MarshalAs(UnmanagedType.Error)]' to the method. + '{0}' türü, yerel HRESULT olarak değil, yerel imzada bir yapı olarak değerlendirilir. Bunu bir HRESULT olarak değerlendirmek için yönteme '[return:MarshalAs(UnmanagedType.Error)]' ekleyin. This type will be treated as a struct in the native signature, not as a native HRESULT - This type will be treated as a struct in the native signature, not as a native HRESULT + Bu tür, yerel HRESULT olarak değil, yerel imzada bir yapı olarak değerlendirilir + + + + The '[In]' attribute is only supported on array parameters. By-value parameters are considered read-only by default. + '[In]' özniteliği yalnızca dizi parametrelerinde desteklenir. Değere göre parametreleri, varsayılan olarak salt okunur kabul edilir. @@ -467,6 +482,11 @@ Bu parametrede sağlanan '[In]' ve '[Out]' öznitelikleri bu parametrede desteklenmiyor. + + The '[In]' and '[Out]' attributes are only supported on array parameters. Consider using the 'ref' keyword to make the parameter mutable. + '[In]' ve '[Out]' öznitelikleri yalnızca dizi parametrelerinde desteklenir. Parametreyi değiştirilebilir yapmak için 'ref' anahtar sözcüğünü kullanmayı düşünün. + + [In] and [Out] attributes [In] ve [Out] öznitelikleri @@ -702,6 +722,16 @@ Geçersiz 'VirtualMethodIndexAttribute' kullanımı + + The usage of 'LibraryImportAttribute' does not follow recommendations. {0} + 'LibraryImportAttribute' kullanımı önerilere uygun değil. {0} + + + + The usage of 'LibraryImportAttribute' does not follow recommendations. + 'LibraryImportAttribute' kullanımı önerilere uygun değil. + + The element type of the 'ReadOnlySpan' returned by 'GetManagedValuesSource' must be the same as the element type returned by 'GetManagedValuesDestination'. 'GetManagedValuesSource' tarafından döndürülen 'ReadOnlySpan' öğe türü, 'GetManagedValuesDestination' tarafından döndürülen öğe türüyle aynı olmalıdır. @@ -903,7 +933,7 @@ - The `[Out]` attribute is only supported on array parameters. Consider using 'out' or 'ref' keywords to make the parameter mutable. + The '[Out]' attribute is only supported on array parameters. Consider using 'out' or 'ref' keywords to make the parameter mutable. '[Out]' özniteliği yalnızca dizi parametrelerinde desteklenir. Parametreyi değiştirilebilir yapmak için 'out' veya 'ref' anahtar sözcükleri kullanmayı düşünün. @@ -917,6 +947,11 @@ '{0}' türü, 'Out' yönünde sıralamayı desteklediğini belirtiyor, ancak yönetilen türü döndüren bir 'ToManaged' metodu sağlamıyor + + It is recommended to use explicit '[In]' and '[Out]' attributes on array parameters. + Dizi parametrelerinde açık '[In]' ve '[Out]' özniteliklerinin kullanılması önerilir. + + 'GeneratedComInterfaceAttribute' and 'GeneratedComClassAttribute' require unsafe code. Project must be updated with '<AllowUnsafeBlocks>true</AllowUnsafeBlocks>'. 'GeneratedComInterfaceAttribute' ve 'GeneratedComClassAttribute' güvenli olmayan kod gerektiriyor. Projenin '<AllowUnsafeBlocks>true</AllowUnsafeBlocks>' ile güncelleştirilmiş olması gerekiyor. diff --git a/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/xlf/Strings.zh-Hans.xlf b/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/xlf/Strings.zh-Hans.xlf index e6fbf9115aec2d..626d372a9d7fcf 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/xlf/Strings.zh-Hans.xlf +++ b/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/xlf/Strings.zh-Hans.xlf @@ -147,6 +147,16 @@ 具有“EnableComHosting”的 .NET COM 托管不支持具有“GeneratedComInterfaceAttribute”的接口 + + The usage of 'GeneratedComInterfaceAttribute' does not follow recommendations. {0} + “GeneratedComInterfaceAttribute”的使用未遵循建议。{0} + + + + The usage of 'GeneratedComInterfaceAttribute' does not follow recommendations. + “GeneratedComInterfaceAttribute”的使用未遵循建议。 + + The return value in the managed definition will be converted to an 'out' parameter when calling the unmanaged COM method. If the return value is intended to be the HRESULT code returned by the unmanaged COM method, use '[PreserveSig]' on the method. 调用非托管 COM 方法时,托管定义中的返回值将转换为 "out" 参数。如果返回值是非托管 COM 方法返回的 HRESULT 代码,请对方法使用 "[PreserveSig]"。 @@ -439,12 +449,17 @@ The type '{0}' will be treated as a struct in the native signature, not as a native HRESULT. To treat this as an HRESULT, add '[return:MarshalAs(UnmanagedType.Error)]' to the method. - The type '{0}' will be treated as a struct in the native signature, not as a native HRESULT. To treat this as an HRESULT, add '[return:MarshalAs(UnmanagedType.Error)]' to the method. + 类型“{0}”将被视为本机签名中的结构,而不是本机 HRESULT。若要将其视为 HRESULT,请将“[return:MarshalAs(UnmanagedType.Error)]”添加到方法。 This type will be treated as a struct in the native signature, not as a native HRESULT - This type will be treated as a struct in the native signature, not as a native HRESULT + 此类型将被视为本机签名中的结构,而不是本机 HRESULT + + + + The '[In]' attribute is only supported on array parameters. By-value parameters are considered read-only by default. + “[In]”特性仅在数组参数上受支持。默认情况下,按值参数视为只读。 @@ -467,6 +482,11 @@ 此参数上提供的 “[In]” 和 “[Out]” 属性在此参数上不受支持。 + + The '[In]' and '[Out]' attributes are only supported on array parameters. Consider using the 'ref' keyword to make the parameter mutable. + “[In]”和“[Out]”属性仅在数组参数上受支持。请考虑使用“ref”关键字使参数可变。 + + [In] and [Out] attributes [In] 和 [Out] 属性 @@ -702,6 +722,16 @@ “VirtualMethodIndexAttribute” 使用情况无效 + + The usage of 'LibraryImportAttribute' does not follow recommendations. {0} + “LibraryImportAttribute”的使用未遵循建议。{0} + + + + The usage of 'LibraryImportAttribute' does not follow recommendations. + “LibraryImportAttribute”的使用未遵循建议。 + + The element type of the 'ReadOnlySpan' returned by 'GetManagedValuesSource' must be the same as the element type returned by 'GetManagedValuesDestination'. “GetManagedValuesSource” 返回的 “ReadOnlySpan” 的元素类型必须与 “GetManagedValuesDestination” 返回的元素类型相同。 @@ -903,8 +933,8 @@ - The `[Out]` attribute is only supported on array parameters. Consider using 'out' or 'ref' keywords to make the parameter mutable. - The `[Out]` attribute is only supported on array parameters. Consider using 'out' or 'ref' keywords to make the parameter mutable. + The '[Out]' attribute is only supported on array parameters. Consider using 'out' or 'ref' keywords to make the parameter mutable. + "[Out]" 属性仅在数组参数上受支持。请考虑使用“out”或“ref”关键字使参数可变。 @@ -917,6 +947,11 @@ 类型“{0}”指定它支持按 “Out” 方向进行封送,但不提供返回托管类型的 “ToManaged” 方法 + + It is recommended to use explicit '[In]' and '[Out]' attributes on array parameters. + 建议对数组参数使用显式“[In]”和“[Out]”属性。 + + 'GeneratedComInterfaceAttribute' and 'GeneratedComClassAttribute' require unsafe code. Project must be updated with '<AllowUnsafeBlocks>true</AllowUnsafeBlocks>'. “GeneratedComInterfaceAttribute”和“GeneratedComClassAttribute”需要不安全代码。必须将项目更新为“<AllowUnsafeBlocks>true</AllowUnsafeBlocks>”。 diff --git a/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/xlf/Strings.zh-Hant.xlf b/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/xlf/Strings.zh-Hant.xlf index 528339678cefdb..5ba36b10ca276d 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/xlf/Strings.zh-Hant.xlf +++ b/src/libraries/System.Runtime.InteropServices/gen/Common/Resources/xlf/Strings.zh-Hant.xlf @@ -147,6 +147,16 @@ 以 'EnableComHosting' 裝載的 .NET COM 不支援具有 'GeneratedComInterfaceAttribute' 的介面 + + The usage of 'GeneratedComInterfaceAttribute' does not follow recommendations. {0} + 'GeneratedComInterfaceAttribute' 的使用方式未遵循建議。{0} + + + + The usage of 'GeneratedComInterfaceAttribute' does not follow recommendations. + 'GeneratedComInterfaceAttribute' 的使用方式未遵循建議。 + + The return value in the managed definition will be converted to an 'out' parameter when calling the unmanaged COM method. If the return value is intended to be the HRESULT code returned by the unmanaged COM method, use '[PreserveSig]' on the method. 呼叫未受控 COM 方法時,受控定義中的傳回值將轉換為 'out' 參數。如果傳回值預期是未受控 COM 方法傳回的 HRESULT 代碼,請在方法上使用 '[PreserveSig]'。 @@ -439,12 +449,17 @@ The type '{0}' will be treated as a struct in the native signature, not as a native HRESULT. To treat this as an HRESULT, add '[return:MarshalAs(UnmanagedType.Error)]' to the method. - The type '{0}' will be treated as a struct in the native signature, not as a native HRESULT. To treat this as an HRESULT, add '[return:MarshalAs(UnmanagedType.Error)]' to the method. + 類型 '{0}' 會被視為原生簽章中的結構,而非原生 HRESULT。若要將此視為 HRESULT,請將 '[return:MarshalAs(UnmanagedType.Error)]' 新增至方法。 This type will be treated as a struct in the native signature, not as a native HRESULT - This type will be treated as a struct in the native signature, not as a native HRESULT + 此類型會被視為原生簽章中的結構,而非原生 HRESULT + + + + The '[In]' attribute is only supported on array parameters. By-value parameters are considered read-only by default. + 只有在陣列參數上才支援 '[In]' 屬性。預設會將 By-value 參數視為唯讀。 @@ -467,6 +482,11 @@ 此參數不支援在此參數上提供的 '[In]' 和 '[Out]' 屬性。 + + The '[In]' and '[Out]' attributes are only supported on array parameters. Consider using the 'ref' keyword to make the parameter mutable. + 只有在陣列參數上才支援 '[In]' 和 '[Out]' 屬性。請考慮使用 'ref' 關鍵字,讓參數成為可變動。 + + [In] and [Out] attributes [In] 與 [Out] 屬性 @@ -702,6 +722,16 @@ 'VirtualMethodIndexAttribute' 使用方式無效 + + The usage of 'LibraryImportAttribute' does not follow recommendations. {0} + 'LibraryImportAttribute' 的使用方式未遵循建議。{0} + + + + The usage of 'LibraryImportAttribute' does not follow recommendations. + 'LibraryImportAttribute' 的使用方式未遵循建議。 + + The element type of the 'ReadOnlySpan' returned by 'GetManagedValuesSource' must be the same as the element type returned by 'GetManagedValuesDestination'. 'GetManagedValuesSource' 傳回的 'ReadOnlySpan' 元素類型必須與 'GetManagedValuesDestination' 傳回的元素類型相同。 @@ -903,8 +933,8 @@ - The `[Out]` attribute is only supported on array parameters. Consider using 'out' or 'ref' keywords to make the parameter mutable. - 只有陣列參數才支援 '[Out]' 屬性。請考慮使用 'out' 或 'ref' 關鍵字將參數設為可變。 + The '[Out]' attribute is only supported on array parameters. Consider using 'out' or 'ref' keywords to make the parameter mutable. + 只有在陣列參數上才支援 '[Out]' 屬性。請考慮使用 'out' 或 'ref' 關鍵字,讓參數成為可變動。 @@ -917,6 +947,11 @@ 類型 '{0}' 指定它支援以 'Out' 方向排列,但未提供傳回受管理類型的 'ToManaged' 方法 + + It is recommended to use explicit '[In]' and '[Out]' attributes on array parameters. + 建議在陣列參數上使用明確的 '[In]' 和 '[Out]' 屬性。 + + 'GeneratedComInterfaceAttribute' and 'GeneratedComClassAttribute' require unsafe code. Project must be updated with '<AllowUnsafeBlocks>true</AllowUnsafeBlocks>'. 'GeneratedComInterfaceAttribute' 和 'GeneratedComClassAttribute' 需要不安全的程式碼。專案必須以 '<AllowUnsafeBlocks>true</AllowUnsafeBlocks>' 更新。 diff --git a/src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/DiagnosticDescriptorProvider.cs b/src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/DiagnosticDescriptorProvider.cs index 2bef64490f9a4e..eb9a0237307bc3 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/DiagnosticDescriptorProvider.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/DiagnosticDescriptorProvider.cs @@ -24,6 +24,7 @@ internal sealed class DiagnosticDescriptorProvider : IDiagnosticDescriptorProvid GeneratorDiagnostic.NotSupported { NotSupportedDetails: not null, TypePositionInfo.IsManagedReturnPosition: false } => GeneratorDiagnostics.ParameterTypeNotSupportedWithDetails, GeneratorDiagnostic.UnnecessaryData { TypePositionInfo.IsManagedReturnPosition: false } => GeneratorDiagnostics.UnnecessaryParameterMarshallingInfo, GeneratorDiagnostic.UnnecessaryData { TypePositionInfo.IsManagedReturnPosition: true } => GeneratorDiagnostics.UnnecessaryReturnMarshallingInfo, + GeneratorDiagnostic.NotRecommended => GeneratorDiagnostics.LibraryImportUsageDoesNotFollowBestPractices, { IsFatal: false } => null, { TypePositionInfo.IsManagedReturnPosition: true } => GeneratorDiagnostics.ReturnTypeNotSupported, { TypePositionInfo.IsManagedReturnPosition: false } => GeneratorDiagnostics.ParameterTypeNotSupported, diff --git a/src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/GeneratorDiagnostics.cs b/src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/GeneratorDiagnostics.cs index 2af2abadcdd607..f5fd78d48e587c 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/GeneratorDiagnostics.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/GeneratorDiagnostics.cs @@ -254,6 +254,16 @@ public class Ids DiagnosticSeverity.Warning, isEnabledByDefault: true); + /// + public static readonly DiagnosticDescriptor LibraryImportUsageDoesNotFollowBestPractices = + new DiagnosticDescriptor( + Ids.NotRecommendedGeneratedComInterfaceUsage, + GetResourceString(nameof(SR.LibraryImportUsageDoesNotFollowBestPracticesTitle)), + GetResourceString(nameof(SR.LibraryImportUsageDoesNotFollowBestPracticesMessageWithDetails)), + Category, + DiagnosticSeverity.Info, + isEnabledByDefault: true); + /// /// Report diagnostic for invalid configuration for string marshalling. /// diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/AnalyzerConfigOptionsExtensions.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/AnalyzerConfigOptionsExtensions.cs index c7cfc69972d337..65f188ae1f7175 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/AnalyzerConfigOptionsExtensions.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/AnalyzerConfigOptionsExtensions.cs @@ -20,8 +20,8 @@ public static class AnalyzerConfigOptionsExtensions // Parse from the informational version as that is the only version that always matches the TFM version // even in debug builds. private static readonly Version ThisAssemblyVersion = Version.Parse( - typeof(IncrementalGeneratorInitializationContextExtensions).Assembly - .GetCustomAttribute().InformationalVersion.Split('-')[0]); + typeof(AnalyzerConfigOptionsExtensions).Assembly + .GetCustomAttribute().InformationalVersion.Split('-', '+')[0]); public static TargetFrameworkSettings GetTargetFrameworkSettings(this AnalyzerConfigOptions options) { diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/IncrementalValuesProviderExtensions.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/IncrementalValuesProviderExtensions.cs index 123d649c913cd5..5455510ac356c0 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/IncrementalValuesProviderExtensions.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/IncrementalValuesProviderExtensions.cs @@ -47,7 +47,7 @@ public static IncrementalValuesProvider SelectNormalized(this Incr return provider.Select((node, ct) => node.NormalizeWhitespace()); } - public static (IncrementalValuesProvider, IncrementalValuesProvider) Split(this IncrementalValuesProvider<(T, T2)> provider) + public static (IncrementalValuesProvider, IncrementalValuesProvider) Split(this IncrementalValuesProvider<(T, T2)> provider) { return (provider.Select(static (data, ct) => data.Item1), provider.Select(static (data, ct) => data.Item2)); } diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs index 73b8cd8eb901a9..69180eca11c190 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs @@ -431,17 +431,10 @@ private ResolvedGenerator CreateNativeCollectionMarshaller( { byValueMarshalKindSupport = ByValueMarshalKindSupportDescriptor.Default; } - else if (!elementIsBlittable || ElementTypeIsSometimesNonBlittable(elementInfo)) - { - // If the type is not blittable or is sometimes not blittable, we will generate different code when the attributes are provided. - byValueMarshalKindSupport = ByValueMarshalKindSupportDescriptor.ArrayParameter; - } else { - // If the type is always blittable, we'll generate the same code regardless of the attributes, - // but we'll allow them to make it easier to transition to source-generated code and allow users to be clear about expectations - // for values in pre-allocated buffers. - byValueMarshalKindSupport = ByValueMarshalKindSupportDescriptor.PinnedParameter; + // If we have an array, we will use the Array [In, Out] support descriptor + byValueMarshalKindSupport = ByValueMarshalKindSupportDescriptor.ArrayParameter; } // Elements in the collection must be blittable to use the pinnable marshaller. diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/ByValueContentsMarshalKindValidator.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/ByValueContentsMarshalKindValidator.cs index 5cc2a8f9a9b261..dd7a40fb8c90e2 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/ByValueContentsMarshalKindValidator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/ByValueContentsMarshalKindValidator.cs @@ -27,7 +27,7 @@ public ResolvedGenerator Create(TypePositionInfo info, StubCodeContext context) private static ResolvedGenerator ValidateByValueMarshalKind(TypePositionInfo info, StubCodeContext context, ResolvedGenerator generator) { - if (generator.Generator is Forwarder || info.ByValueContentsMarshalKind == ByValueContentsMarshalKind.Default) + if (generator.Generator is Forwarder) { // Forwarder allows everything since it just forwards to a P/Invoke. // The Default marshal kind is always valid. @@ -41,6 +41,7 @@ private static ResolvedGenerator ValidateByValueMarshalKind(TypePositionInfo inf ByValueMarshalKindSupport.Supported => generator, ByValueMarshalKindSupport.NotSupported => ResolvedGenerator.ResolvedWithDiagnostics(s_forwarder, generator.Diagnostics.Add(diagnostic!)), ByValueMarshalKindSupport.Unnecessary => generator with { Diagnostics = generator.Diagnostics.Add(diagnostic!) }, + ByValueMarshalKindSupport.NotRecommended => generator with { Diagnostics = generator.Diagnostics.Add(diagnostic!) }, _ => throw new UnreachableException() }; } diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/ByValueMarshalKindSupportDescriptor.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/ByValueMarshalKindSupportDescriptor.cs index 1f859d34ff919b..7754e177d66063 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/ByValueMarshalKindSupportDescriptor.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/ByValueMarshalKindSupportDescriptor.cs @@ -3,127 +3,105 @@ using System; using System.Collections.Immutable; +using System.Diagnostics; namespace Microsoft.Interop { + public record struct ByValueMarshalKindSupportInfo(ByValueMarshalKindSupport Support, string? details) + { + public ByValueMarshalKindSupport GetSupport(TypePositionInfo info, StubCodeContext context, out GeneratorDiagnostic? diagnostic) + { + diagnostic = Support switch + { + ByValueMarshalKindSupport.Supported => null, + ByValueMarshalKindSupport.NotRecommended => + new GeneratorDiagnostic.NotRecommended(info, context) + { + Details = details + }, + ByValueMarshalKindSupport.Unnecessary => + new GeneratorDiagnostic.UnnecessaryData( + info, + context, + ImmutableArray.Create(info.ByValueMarshalAttributeLocations.OutLocation)) + { + UnnecessaryDataName = SR.InOutAttributes, + UnnecessaryDataDetails = details + }, + ByValueMarshalKindSupport.NotSupported => + new GeneratorDiagnostic.NotSupported(info, context) + { + NotSupportedDetails = details + }, + _ => throw new UnreachableException() + }; + return Support; + } + } + /// /// Provides an implementation of through /// public record ByValueMarshalKindSupportDescriptor( - ByValueMarshalKindSupport InSupport, string? InSupportDetails, - ByValueMarshalKindSupport OutSupport, string? OutSupportDetails, - ByValueMarshalKindSupport InOutSupport, string? InOutSupportDetails) + ByValueMarshalKindSupportInfo DefaultSupport, + ByValueMarshalKindSupportInfo InSupport, + ByValueMarshalKindSupportInfo OutSupport, + ByValueMarshalKindSupportInfo InOutSupport) { /// /// A default for by value parameters. [In] is allowed, but unnecessary. Out is not allowed. /// public static readonly ByValueMarshalKindSupportDescriptor Default = new ByValueMarshalKindSupportDescriptor( - InSupport: ByValueMarshalKindSupport.Unnecessary, InSupportDetails: SR.InAttributeOnlyIsDefault, - OutSupport: ByValueMarshalKindSupport.NotSupported, OutSupportDetails: SR.OutAttributeNotSupportedOnByValueParameters, - InOutSupport: ByValueMarshalKindSupport.NotSupported, InOutSupportDetails: SR.OutAttributeNotSupportedOnByValueParameters); + DefaultSupport: new(ByValueMarshalKindSupport.Supported, null), + InSupport: new(ByValueMarshalKindSupport.NotSupported, SR.InAttributeNotSupportedOnByValueParameters), + OutSupport: new(ByValueMarshalKindSupport.NotSupported, SR.OutAttributeNotSupportedOnByValueParameters), + InOutSupport: new(ByValueMarshalKindSupport.NotSupported, SR.InOutAttributeNotSupportedOnByValueParameters)); /// - /// A default for by value array parameters. [In] is allowed, but unnecessary. Out is allowed. + /// A default for by value array parameters. Default is allowed, but Not Recommended. [In], [Out], and [In, Out] are allowed /// public static readonly ByValueMarshalKindSupportDescriptor ArrayParameter = new ByValueMarshalKindSupportDescriptor( - InSupport: ByValueMarshalKindSupport.Unnecessary, InSupportDetails: SR.InAttributeOnlyIsDefault, - OutSupport: ByValueMarshalKindSupport.Supported, OutSupportDetails: null, - InOutSupport: ByValueMarshalKindSupport.Supported, InOutSupportDetails: null); - - /// - /// A default for pinned parameters. [In] is allowed, but unnecessary. Out is allowed. - /// - public static readonly ByValueMarshalKindSupportDescriptor PinnedParameter = new ByValueMarshalKindSupportDescriptor( - InSupport: ByValueMarshalKindSupport.Unnecessary, InSupportDetails: SR.InAttributeOnlyIsDefault, - OutSupport: ByValueMarshalKindSupport.Supported, OutSupportDetails: null, - InOutSupport: ByValueMarshalKindSupport.Supported, InOutSupportDetails: null); + DefaultSupport: new(ByValueMarshalKindSupport.NotRecommended, SR.PreferExplicitInOutAttributesOnArrays), + InSupport: new(ByValueMarshalKindSupport.Supported, null), + OutSupport: new(ByValueMarshalKindSupport.Supported, null), + InOutSupport: new(ByValueMarshalKindSupport.Supported, null)); /// /// Returns the support for the ByValueContentsMarshalKind, and if it is not , diagnostic is not null /// public ByValueMarshalKindSupport GetSupport(ByValueContentsMarshalKind marshalKind, TypePositionInfo info, StubCodeContext context, out GeneratorDiagnostic? diagnostic) { - if (info.IsByRef && marshalKind != ByValueContentsMarshalKind.Default) + if (info.IsByRef) { - diagnostic = new GeneratorDiagnostic.NotSupported(info, context) + // ByRef with ByValue attributes is not allowed + if (marshalKind != ByValueContentsMarshalKind.Default) { - NotSupportedDetails = SR.InOutAttributeByRefNotSupported - }; - return ByValueMarshalKindSupport.NotSupported; - } - switch (marshalKind) - { - case ByValueContentsMarshalKind.Default: - diagnostic = null; - return ByValueMarshalKindSupport.Supported; - case ByValueContentsMarshalKind.Out: - diagnostic = OutSupport switch + diagnostic = new GeneratorDiagnostic.NotSupported(info, context) { - ByValueMarshalKindSupport.Supported => null, - ByValueMarshalKindSupport.Unnecessary - => new GeneratorDiagnostic.UnnecessaryData( - info, - context, - ImmutableArray.Create(info.ByValueMarshalAttributeLocations.OutLocation)) - { - UnnecessaryDataName = SR.InOutAttributes, - UnnecessaryDataDetails = OutSupportDetails - }, - ByValueMarshalKindSupport.NotSupported - => new GeneratorDiagnostic.NotSupported( - info, - context) - { NotSupportedDetails = OutSupportDetails }, - _ => throw new UnreachableException($"Unexpected {nameof(ByValueMarshalKindSupport)} Variant: {InOutSupport}") + NotSupportedDetails = SR.InOutAttributeByRefNotSupported }; - return OutSupport; - case ByValueContentsMarshalKind.In: - diagnostic = InSupport switch - { - ByValueMarshalKindSupport.Supported => null, - ByValueMarshalKindSupport.Unnecessary - => new GeneratorDiagnostic.UnnecessaryData( - info, - context, - ImmutableArray.Create(info.ByValueMarshalAttributeLocations.InLocation)) - { - UnnecessaryDataName = SR.InOutAttributes, - UnnecessaryDataDetails = InSupportDetails - }, - ByValueMarshalKindSupport.NotSupported - => new GeneratorDiagnostic.NotSupported( - info, - context) - { NotSupportedDetails = InSupportDetails }, - _ => throw new UnreachableException($"Unexpected {nameof(ByValueMarshalKindSupport)} Variant: {InOutSupport}") - }; - return InSupport; - case ByValueContentsMarshalKind.InOut: - diagnostic = InOutSupport switch - { - ByValueMarshalKindSupport.Supported => null, - ByValueMarshalKindSupport.Unnecessary - => new GeneratorDiagnostic.UnnecessaryData( - info, - context, - ImmutableArray.Create( - info.ByValueMarshalAttributeLocations.InLocation, - info.ByValueMarshalAttributeLocations.OutLocation)) - { - UnnecessaryDataName = SR.InOutAttributes, - UnnecessaryDataDetails = InOutSupportDetails - }, - ByValueMarshalKindSupport.NotSupported - => new GeneratorDiagnostic.NotSupported( - info, - context) - { NotSupportedDetails = InOutSupportDetails }, - _ => throw new UnreachableException($"Unexpected {nameof(ByValueMarshalKindSupport)} Variant: {InOutSupport}") - }; - return InOutSupport; - default: - throw new UnreachableException($"Unexpected {nameof(ByValueContentsMarshalKind)} variant: {marshalKind}"); + return ByValueMarshalKindSupport.NotSupported; + } + // ByRef with no ByValue attributes is supported + diagnostic = null; + return ByValueMarshalKindSupport.Supported; + } + // Return can never have In or Out attributes, so can assume valid ByValue attributes + if (info.ManagedIndex < 0) + { + Debug.Assert(marshalKind is ByValueContentsMarshalKind.Default); + diagnostic = null; + return ByValueMarshalKindSupport.Supported; } + + return marshalKind switch + { + ByValueContentsMarshalKind.Default => DefaultSupport.GetSupport(info, context, out diagnostic), + ByValueContentsMarshalKind.In => InSupport.GetSupport(info, context, out diagnostic), + ByValueContentsMarshalKind.Out => OutSupport.GetSupport(info, context, out diagnostic), + ByValueContentsMarshalKind.InOut => InOutSupport.GetSupport(info, context, out diagnostic), + _ => throw new UnreachableException() + }; } } } diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/GeneratorDiagnostic.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/GeneratorDiagnostic.cs index 22444e79824c9b..02820335ec176e 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/GeneratorDiagnostic.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/GeneratorDiagnostic.cs @@ -62,5 +62,18 @@ public override DiagnosticInfo ToDiagnosticInfo(DiagnosticDescriptor descriptor, UnnecessaryDataDetails ?? ""); } } + + public sealed record NotRecommended(TypePositionInfo TypePositionInfo, StubCodeContext StubCodeContext) : GeneratorDiagnostic(TypePositionInfo, StubCodeContext, isFatal: false) + { + public string? Details { get; init; } + + public override DiagnosticInfo ToDiagnosticInfo(DiagnosticDescriptor descriptor, Location location, string elementName) + { + return DiagnosticInfo.Create( + descriptor, + location, + Details); + } + } } } diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/MarshallingGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/MarshallingGenerator.cs index f5e8aacdaa3c15..598ba380028946 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/MarshallingGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/MarshallingGenerator.cs @@ -83,6 +83,10 @@ public enum ByValueMarshalKindSupport /// The provided is supported but does not change behavior from the default in this scenario. /// Unnecessary, + /// + /// The provided is supported but does not follow best practices. + /// + NotRecommended, } /// diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StaticPinnableManagedValueMarshaller.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StaticPinnableManagedValueMarshaller.cs index 45ef61f85dbc58..039d78ed350068 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StaticPinnableManagedValueMarshaller.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StaticPinnableManagedValueMarshaller.cs @@ -105,7 +105,7 @@ private IEnumerable GeneratePinningPath(TypePositionInfo info, public ByValueMarshalKindSupport SupportsByValueMarshalKind(ByValueContentsMarshalKind marshalKind, TypePositionInfo info, StubCodeContext context, out GeneratorDiagnostic? diagnostic) { - return ByValueMarshalKindSupportDescriptor.PinnedParameter.GetSupport(marshalKind, info, context, out diagnostic); + return _innerMarshallingGenerator.SupportsByValueMarshalKind(marshalKind, info, context, out diagnostic); } } } diff --git a/src/libraries/System.Runtime.InteropServices/ref/System.Runtime.InteropServices.cs b/src/libraries/System.Runtime.InteropServices/ref/System.Runtime.InteropServices.cs index 8fe3d526ab6e96..59dfcf2713620d 100644 --- a/src/libraries/System.Runtime.InteropServices/ref/System.Runtime.InteropServices.cs +++ b/src/libraries/System.Runtime.InteropServices/ref/System.Runtime.InteropServices.cs @@ -1104,7 +1104,7 @@ public static void PtrToStructure(System.IntPtr ptr, object structure) { } public static object? PtrToStructure(System.IntPtr ptr, [System.Diagnostics.CodeAnalysis.DynamicallyAccessedMembersAttribute(System.Diagnostics.CodeAnalysis.DynamicallyAccessedMemberTypes.NonPublicConstructors| System.Diagnostics.CodeAnalysis.DynamicallyAccessedMemberTypes.PublicConstructors)] System.Type structureType) { throw null; } public static T? PtrToStructure<[System.Diagnostics.CodeAnalysis.DynamicallyAccessedMembersAttribute(System.Diagnostics.CodeAnalysis.DynamicallyAccessedMemberTypes.NonPublicConstructors | System.Diagnostics.CodeAnalysis.DynamicallyAccessedMemberTypes.PublicConstructors)]T>(System.IntPtr ptr) { throw null; } public static void PtrToStructure(System.IntPtr ptr, [System.Diagnostics.CodeAnalysis.DisallowNullAttribute] T structure) { } - public static int QueryInterface(System.IntPtr pUnk, in System.Guid iid, out System.IntPtr ppv) { throw null; } + public static int QueryInterface(System.IntPtr pUnk, ref readonly System.Guid iid, out System.IntPtr ppv) { throw null; } public static byte ReadByte(System.IntPtr ptr) { throw null; } public static byte ReadByte(System.IntPtr ptr, int ofs) { throw null; } [System.Diagnostics.CodeAnalysis.RequiresDynamicCode("Marshalling code for the object might not be available")] diff --git a/src/libraries/System.Runtime.InteropServices/src/System/Runtime/InteropServices/Marshalling/IIUnknownStrategy.cs b/src/libraries/System.Runtime.InteropServices/src/System/Runtime/InteropServices/Marshalling/IIUnknownStrategy.cs index 771b88b04e7f4d..aa9742526f3631 100644 --- a/src/libraries/System.Runtime.InteropServices/src/System/Runtime/InteropServices/Marshalling/IIUnknownStrategy.cs +++ b/src/libraries/System.Runtime.InteropServices/src/System/Runtime/InteropServices/Marshalling/IIUnknownStrategy.cs @@ -32,7 +32,7 @@ public unsafe interface IIUnknownStrategy /// The IID (Interface ID) to query for. /// The resulting interface. /// Returns an HRESULT represents the success of the operation. - /// + /// public int QueryInterface(void* instancePtr, in Guid iid, out void* ppObj); /// diff --git a/src/libraries/System.Runtime.InteropServices/src/System/Runtime/InteropServices/Marshalling/StrategyBasedComWrappers.cs b/src/libraries/System.Runtime.InteropServices/src/System/Runtime/InteropServices/Marshalling/StrategyBasedComWrappers.cs index 14ced03d3a7fd0..9741b482633c24 100644 --- a/src/libraries/System.Runtime.InteropServices/src/System/Runtime/InteropServices/Marshalling/StrategyBasedComWrappers.cs +++ b/src/libraries/System.Runtime.InteropServices/src/System/Runtime/InteropServices/Marshalling/StrategyBasedComWrappers.cs @@ -76,7 +76,7 @@ static IIUnknownInterfaceDetailsStrategy GetInteropStrategy() /// protected sealed override unsafe ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) { - if (obj.GetType().GetCustomAttribute(typeof(ComExposedClassAttribute<>)) is IComExposedDetails details) + if (GetOrCreateInterfaceDetailsStrategy().GetComExposedTypeDetails(obj.GetType().TypeHandle) is { } details) { return details.GetComInterfaceEntries(out count); } diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ByValueContentsMarshalling.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ByValueContentsMarshalling.cs new file mode 100644 index 00000000000000..fdd2d33d6e7d27 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ByValueContentsMarshalling.cs @@ -0,0 +1,301 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Threading.Tasks; +using Microsoft.CodeAnalysis.Testing; +using Microsoft.Interop; +using Xunit; +using static Microsoft.Interop.UnitTests.TestUtils; +using StringMarshalling = System.Runtime.InteropServices.StringMarshalling; +using VerifyComInterfaceGenerator = Microsoft.Interop.UnitTests.Verifiers.CSharpSourceGeneratorVerifier; + +namespace ComInterfaceGenerator.Unit.Tests +{ + public class ByValueContentsMarshalling + { + private static IComInterfaceAttributeProvider GetAttributeProvider(GeneratorKind generator) + => generator switch + { + GeneratorKind.VTableIndexStubGenerator => new VirtualMethodIndexAttributeProvider(), + GeneratorKind.ComInterfaceGeneratorManagedObjectWrapper => new GeneratedComInterfaceAttributeProvider(System.Runtime.InteropServices.Marshalling.ComInterfaceOptions.ManagedObjectWrapper), + GeneratorKind.ComInterfaceGeneratorComObjectWrapper => new GeneratedComInterfaceAttributeProvider(System.Runtime.InteropServices.Marshalling.ComInterfaceOptions.ComObjectWrapper), + GeneratorKind.ComInterfaceGenerator => new GeneratedComInterfaceAttributeProvider(), + _ => throw new UnreachableException(), + }; + + public static IEnumerable ByValueMarshalAttributeOnValueTypes() + { + var codeSnippets = new CodeSnippets(GetAttributeProvider(GeneratorKind.ComInterfaceGenerator)); + const string In = "[{|#1:InAttribute|}]"; + const string Out = "[{|#2:OutAttribute|}]"; + const string paramName = "p"; + const string MarshalAsU4 = "[MarshalAs(UnmanagedType.U4)]"; + const string MarshalAsU2 = "[MarshalAs(UnmanagedType.U2)]"; + + string p = $$"""{|#0:{{paramName}}|}"""; + var diagnostic = new DiagnosticResult(GeneratorDiagnostics.ParameterTypeNotSupportedWithDetails); + var outAttributeNotSupported = diagnostic + .WithLocation(0) + .WithArguments(SR.OutAttributeNotSupportedOnByValueParameters, paramName); + var inAttributeIsNotSupported = diagnostic + .WithLocation(0) + .WithArguments(SR.InAttributeNotSupportedOnByValueParameters, paramName); + var inOutAttributeNotSupported = diagnostic + .WithLocation(0) + .WithArguments(SR.InOutAttributeNotSupportedOnByValueParameters, paramName); + + DiagnosticResult[] InIsNotSupported = [inAttributeIsNotSupported]; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(In, "int", p), InIsNotSupported }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(In, "byte", p), InIsNotSupported }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(In + MarshalAsU4, "bool", p), InIsNotSupported }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(In + MarshalAsU2, "char", p), InIsNotSupported }; + + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(In, "string", p, (StringMarshalling.Utf8, null)), InIsNotSupported }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(In, "IntClass", p) + CodeSnippets.IntClassAndMarshaller, InIsNotSupported }; + + DiagnosticResult[] OutIsNotSupported = [outAttributeNotSupported]; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(Out, "int", p), OutIsNotSupported }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(Out, "IntStruct", p) + CodeSnippets.IntStructAndMarshaller, OutIsNotSupported }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(Out + MarshalAsU4, "bool", p), OutIsNotSupported }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(Out + MarshalAsU2, "char", p), OutIsNotSupported }; + + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(Out, "string", p, (StringMarshalling.Utf8, null)), OutIsNotSupported }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(Out, "IntClass", p) + CodeSnippets.IntClassAndMarshaller, OutIsNotSupported }; + + DiagnosticResult[] InOutIsNotSupported = [inOutAttributeNotSupported]; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(In + Out, "int", p), InOutIsNotSupported }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(In + Out, "IntStruct", p) + CodeSnippets.IntStructAndMarshaller, InOutIsNotSupported }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(In + Out + MarshalAsU4, "bool", p), InOutIsNotSupported }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(In + Out + MarshalAsU2, "char", p), InOutIsNotSupported }; + + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(In + Out, "string", p, (StringMarshalling.Utf8, null)), InOutIsNotSupported }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(In + Out, "IntClass", p) + CodeSnippets.IntClassAndMarshaller, InOutIsNotSupported }; + + // Any ref keyword is okay for non-collection types + DiagnosticResult[] None = []; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType("out", "IntStruct", p) + CodeSnippets.IntStructAndMarshaller, None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType("out", "byte", p), None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(MarshalAsU4 + "out", "bool", p), None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(MarshalAsU2 + "out", "char", p), None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType("out", "string", p, (StringMarshalling.Utf8, null)), None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType("out", "IntClass", p) + CodeSnippets.IntClassAndMarshaller, None }; + + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType("in", "IntStruct", p) + CodeSnippets.IntStructAndMarshaller, None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType("in", "byte", p), None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(MarshalAsU4 + "in", "bool", p), None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(MarshalAsU2 + "in", "char", p), None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType("in", "string", p, (StringMarshalling.Utf8, null)), None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType("in", "IntClass", p) + CodeSnippets.IntClassAndMarshaller, None }; + + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType("ref", "IntStruct", p) + CodeSnippets.IntStructAndMarshaller, None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType("ref", "byte", p), None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(MarshalAsU4 + "ref", "bool", p), None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(MarshalAsU2 + "ref", "char", p), None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType("ref", "string", p, (StringMarshalling.Utf8, null)), None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType("ref", "IntClass", p) + CodeSnippets.IntClassAndMarshaller, None }; + } + + public static IEnumerable ByValueMarshalAttributeOnPinnedMarshalledTypes() + { + var codeSnippets = new CodeSnippets(GetAttributeProvider(GeneratorKind.ComInterfaceGenerator)); + const string In = "[{|#1:InAttribute|}]"; + const string Out = "[{|#2:OutAttribute|}]"; + const string paramName = "p"; + string p = $$"""{|#0:{{paramName}}|}"""; + const string Count = @"[MarshalUsing(ConstantElementCount = 10)]"; + const string MarshalAsBoolArray = "[MarshalAs(UnmanagedType.LPArray, ArraySubType = UnmanagedType.U1, SizeConst = 10)]"; + const string MarshalUsingIntMarshaller = "[MarshalUsing(typeof(IntMarshaller), ElementIndirectionDepth = 1)]"; + const string MarshalUsingIntStructMarshaller = "[MarshalUsing(typeof(IntStructMarshaller), ElementIndirectionDepth = 1)]"; + const string MarshalUsingIntClassMarshaller = "[MarshalUsing(typeof(IntClassMarshaller), ElementIndirectionDepth = 1)]"; + + // Any explicit [In] or [Out] on an array is preferred and should not warn + DiagnosticResult[] None = []; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(In + Count, "int[]", p), None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(In + Count, "char[]", p, (StringMarshalling.Utf16, null)), None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(In + MarshalAsBoolArray, "bool[]", p, (StringMarshalling.Utf16, null)), None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(In + MarshalUsingIntMarshaller + Count, "int[]", p) + CodeSnippets.IntMarshaller, None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(In + Count, "string[]", p, (StringMarshalling.Utf16, null)), None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(In + Count, "string[]", p, (StringMarshalling.Utf8, null)), None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(In + Count + MarshalUsingIntStructMarshaller, "IntStruct[]", p) + CodeSnippets.IntStructAndMarshaller, None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(In + Count + MarshalUsingIntClassMarshaller, "IntClass[]", p) + CodeSnippets.IntClassAndMarshaller, None }; + + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(In + Out + Count, "int[]", p), None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(In + Out + Count, "char[]", p, (StringMarshalling.Utf16, null)), None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(In + Out + MarshalAsBoolArray, "bool[]", p, (StringMarshalling.Utf16, null)), None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(In + Out + MarshalUsingIntMarshaller + Count, "int[]", p) + CodeSnippets.IntMarshaller, None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(In + Out + Count, "string[]", p, (StringMarshalling.Utf16, null)), None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(In + Out + Count, "string[]", p, (StringMarshalling.Utf8, null)), None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(In + Out + Count + MarshalUsingIntStructMarshaller, "IntStruct[]", p) + CodeSnippets.IntStructAndMarshaller, None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(In + Out + Count + MarshalUsingIntClassMarshaller, "IntClass[]", p) + CodeSnippets.IntClassAndMarshaller, None }; + + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(Out + Count, "int[]", p), None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(Out + Count, "char[]", p, (StringMarshalling.Utf16, null)), None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(Out + MarshalAsBoolArray, "bool[]", p, (StringMarshalling.Utf16, null)), None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(Out + MarshalUsingIntMarshaller + Count, "int[]", p) + CodeSnippets.IntMarshaller, None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(Out + Count, "string[]", p, (StringMarshalling.Utf16, null)), None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(Out + Count, "string[]", p, (StringMarshalling.Utf8, null)), None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(Out + Count + MarshalUsingIntStructMarshaller, "IntStruct[]", p) + CodeSnippets.IntStructAndMarshaller, None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(Out + Count + MarshalUsingIntClassMarshaller, "IntClass[]", p) + CodeSnippets.IntClassAndMarshaller, None }; + + // Array parameters without [In] or [Out] should provide an Info diagnostic + var preferExplicitAttributesDiagnostic = new DiagnosticResult(GeneratorDiagnostics.GeneratedComInterfaceUsageDoesNotFollowBestPractices) + .WithLocation(0) + .WithArguments(SR.PreferExplicitInOutAttributesOnArrays); + DiagnosticResult[] PreferInOutAttributes = [preferExplicitAttributesDiagnostic]; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(Count, "int[]", p), PreferInOutAttributes }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(Count, "char[]", p, (StringMarshalling.Utf16, null)), PreferInOutAttributes }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(MarshalAsBoolArray, "bool[]", p, (StringMarshalling.Utf16, null)), PreferInOutAttributes }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(MarshalUsingIntMarshaller + Count, "int[]", p) + CodeSnippets.IntMarshaller, PreferInOutAttributes }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(Count, "string[]", p, (StringMarshalling.Utf16, null)), PreferInOutAttributes }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(Count, "string[]", p, (StringMarshalling.Utf8, null)), PreferInOutAttributes }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(Count + MarshalUsingIntStructMarshaller, "IntStruct[]", p) + CodeSnippets.IntStructAndMarshaller, PreferInOutAttributes }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(Count + MarshalUsingIntClassMarshaller, "IntClass[]", p) + CodeSnippets.IntClassAndMarshaller, PreferInOutAttributes }; + + // Ref Kinds shouldn't warn + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(Count, "in int[]", p), None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(Count, "in char[]", p, (StringMarshalling.Utf16, null)), None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(MarshalAsBoolArray, "in bool[]", p, (StringMarshalling.Utf16, null)), None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(MarshalUsingIntMarshaller + Count, "in int[]", p) + CodeSnippets.IntMarshaller, None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(Count, "in string[]", p, (StringMarshalling.Utf16, null)), None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(Count, "in string[]", p, (StringMarshalling.Utf8, null)), None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(Count + MarshalUsingIntStructMarshaller, "in IntStruct[]", p) + CodeSnippets.IntStructAndMarshaller, None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(Count + MarshalUsingIntClassMarshaller, "in IntClass[]", p) + CodeSnippets.IntClassAndMarshaller, None }; + + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(Count, "out int[]", p), None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(Count, "out char[]", p, (StringMarshalling.Utf16, null)), None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(MarshalAsBoolArray, "out bool[]", p, (StringMarshalling.Utf16, null)), None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(MarshalUsingIntMarshaller + Count, "out int[]", p) + CodeSnippets.IntMarshaller, None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(Count, "out string[]", p, (StringMarshalling.Utf16, null)), None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(Count, "out string[]", p, (StringMarshalling.Utf8, null)), None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(Count + MarshalUsingIntStructMarshaller, "out IntStruct[]", p) + CodeSnippets.IntStructAndMarshaller, None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(Count + MarshalUsingIntClassMarshaller, "out IntClass[]", p) + CodeSnippets.IntClassAndMarshaller, None }; + + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(Count, "ref int[]", p), None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(Count, "ref char[]", p, (StringMarshalling.Utf16, null)), None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(MarshalAsBoolArray, "ref bool[]", p, (StringMarshalling.Utf16, null)), None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(MarshalUsingIntMarshaller + Count, "ref int[]", p) + CodeSnippets.IntMarshaller, None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(Count, "ref string[]", p, (StringMarshalling.Utf16, null)), None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(Count, "ref string[]", p, (StringMarshalling.Utf8, null)), None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(Count + MarshalUsingIntStructMarshaller, "ref IntStruct[]", p) + CodeSnippets.IntStructAndMarshaller, None }; + yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(Count + MarshalUsingIntClassMarshaller, "ref IntClass[]", p) + CodeSnippets.IntClassAndMarshaller, None }; + } + + public static IEnumerable ByValueMarshalAttributeOnCustomCollections() + { + var codeSnippets = new CodeSnippets(GetAttributeProvider(GeneratorKind.ComInterfaceGenerator)); + const string In = "[{|#1:InAttribute|}]"; + const string Out = "[{|#2:OutAttribute|}]"; + const string paramName = "p"; + string p = $$"""{|#0:{{paramName}}|}"""; + const string CollectionMarshaller = "StatelessCollectionAllShapesMarshaller<,>"; + + var diagnostic = new DiagnosticResult(GeneratorDiagnostics.ParameterTypeNotSupportedWithDetails); + var outAttributeNotSupported = diagnostic + .WithLocation(0) + .WithArguments(SR.OutAttributeNotSupportedOnByValueParameters, paramName); + var inAttributeIsNotSupported = diagnostic + .WithLocation(0) + .WithArguments(SR.InAttributeNotSupportedOnByValueParameters, paramName); + var inOutAttributeNotSupported = diagnostic + .WithLocation(0) + .WithArguments(SR.InOutAttributeNotSupportedOnByValueParameters, paramName); + + DiagnosticResult[] InIsNotSupported = [inAttributeIsNotSupported]; + DiagnosticResult[] OutIsNotSupported = [outAttributeNotSupported]; + DiagnosticResult[] InOutIsNotSupported = [inOutAttributeNotSupported]; + DiagnosticResult[] None = []; + + yield return new object[] { ID(), Source("", "int"), None }; + yield return new object[] { ID(), Source("", "byte"), None }; + yield return new object[] { ID(), Source(MarshalUsing("IntClassMarshaller", 1), "IntClass") + CodeSnippets.IntClassAndMarshaller, None }; + yield return new object[] { ID(), Source(MarshalUsing("IntStructMarshaller", 1), "IntStruct") + CodeSnippets.IntStructAndMarshaller, None }; + yield return new object[] { ID(), Source("", "string", (StringMarshalling.Utf16, null)), None }; + yield return new object[] { ID(), Source("", "string", (StringMarshalling.Utf8, null)), None }; + yield return new object[] { ID(), Source(MarshalCollection(1), CodeSnippets.GetCustomCollectionType("int")), None }; + + // [In] and [Out] are not allowed on custom collections + yield return new object[] { ID(), Source(In, "int"), InIsNotSupported }; + yield return new object[] { ID(), Source(In, "byte"), InIsNotSupported }; + yield return new object[] { ID(), Source(In + MarshalUsing("IntClassMarshaller", 1), "IntClass") + CodeSnippets.IntClassAndMarshaller, InIsNotSupported }; + yield return new object[] { ID(), Source(In + MarshalUsing("IntStructMarshaller", 1), "IntStruct") + CodeSnippets.IntStructAndMarshaller, InIsNotSupported }; + yield return new object[] { ID(), Source(In, "string", (StringMarshalling.Utf16, null)), InIsNotSupported }; + yield return new object[] { ID(), Source(In, "string", (StringMarshalling.Utf8, null)), InIsNotSupported }; + yield return new object[] { ID(), Source(In + MarshalCollection(1), CodeSnippets.GetCustomCollectionType("int")), InIsNotSupported }; + + yield return new object[] { ID(), Source(Out, "int"), OutIsNotSupported }; + yield return new object[] { ID(), Source(Out, "byte"), OutIsNotSupported }; + yield return new object[] { ID(), Source(Out + MarshalUsing("IntClassMarshaller", 1), "IntClass") + CodeSnippets.IntClassAndMarshaller, OutIsNotSupported }; + yield return new object[] { ID(), Source(Out + MarshalUsing("IntStructMarshaller", 1), "IntStruct") + CodeSnippets.IntStructAndMarshaller, OutIsNotSupported }; + yield return new object[] { ID(), Source(Out, "string", (StringMarshalling.Utf16, null)), OutIsNotSupported }; + yield return new object[] { ID(), Source(Out, "string", (StringMarshalling.Utf8, null)), OutIsNotSupported }; + yield return new object[] { ID(), Source(Out + MarshalCollection(1), CodeSnippets.GetCustomCollectionType("int")), OutIsNotSupported }; + + yield return new object[] { ID(), Source(In + Out, "int"), InOutIsNotSupported }; + yield return new object[] { ID(), Source(In + Out, "byte"), InOutIsNotSupported }; + yield return new object[] { ID(), Source(In + Out + MarshalUsing("IntClassMarshaller", 1), "IntClass") + CodeSnippets.IntClassAndMarshaller, InOutIsNotSupported }; + yield return new object[] { ID(), Source(In + Out + MarshalUsing("IntStructMarshaller", 1), "IntStruct") + CodeSnippets.IntStructAndMarshaller, InOutIsNotSupported }; + yield return new object[] { ID(), Source(In + Out, "string", (StringMarshalling.Utf16, null)), InOutIsNotSupported }; + yield return new object[] { ID(), Source(In + Out, "string", (StringMarshalling.Utf8, null)), InOutIsNotSupported }; + yield return new object[] { ID(), Source(In + Out + MarshalCollection(1), CodeSnippets.GetCustomCollectionType("int")), InOutIsNotSupported }; + + // RefKind modifiers are okay + yield return new object[] { ID(), SourceWithRefKind("in", "", "int"), None }; + yield return new object[] { ID(), SourceWithRefKind("in", "", "byte"), None }; + yield return new object[] { ID(), SourceWithRefKind("in", MarshalUsing("IntClassMarshaller", 1), "IntClass") + CodeSnippets.IntClassAndMarshaller, None }; + yield return new object[] { ID(), SourceWithRefKind("in", MarshalUsing("IntStructMarshaller", 1), "IntStruct") + CodeSnippets.IntStructAndMarshaller, None }; + yield return new object[] { ID(), SourceWithRefKind("in", "", "string", (StringMarshalling.Utf16, null)), None }; + yield return new object[] { ID(), SourceWithRefKind("in", "", "string", (StringMarshalling.Utf8, null)), None }; + yield return new object[] { ID(), SourceWithRefKind("in", MarshalCollection(1), CodeSnippets.GetCustomCollectionType("int")), None }; + + yield return new object[] { ID(), SourceWithRefKind("out", "", "int"), None }; + yield return new object[] { ID(), SourceWithRefKind("out", "", "byte"), None }; + yield return new object[] { ID(), SourceWithRefKind("out", MarshalUsing("IntClassMarshaller", 1), "IntClass") + CodeSnippets.IntClassAndMarshaller, None }; + yield return new object[] { ID(), SourceWithRefKind("out", MarshalUsing("IntStructMarshaller", 1), "IntStruct") + CodeSnippets.IntStructAndMarshaller, None }; + yield return new object[] { ID(), SourceWithRefKind("out", "", "string", (StringMarshalling.Utf16, null)), None }; + yield return new object[] { ID(), SourceWithRefKind("out", "", "string", (StringMarshalling.Utf8, null)), None }; + yield return new object[] { ID(), SourceWithRefKind("out", MarshalCollection(1), CodeSnippets.GetCustomCollectionType("int")), None }; + + yield return new object[] { ID(), SourceWithRefKind("ref", "", "int"), None }; + yield return new object[] { ID(), SourceWithRefKind("ref", "", "byte"), None }; + yield return new object[] { ID(), SourceWithRefKind("ref", MarshalUsing("IntClassMarshaller", 1), "IntClass") + CodeSnippets.IntClassAndMarshaller, None }; + yield return new object[] { ID(), SourceWithRefKind("ref", MarshalUsing("IntStructMarshaller", 1), "IntStruct") + CodeSnippets.IntStructAndMarshaller, None }; + yield return new object[] { ID(), SourceWithRefKind("ref", "", "string", (StringMarshalling.Utf16, null)), None }; + yield return new object[] { ID(), SourceWithRefKind("ref", "", "string", (StringMarshalling.Utf8, null)), None }; + yield return new object[] { ID(), SourceWithRefKind("ref", MarshalCollection(1), CodeSnippets.GetCustomCollectionType("int")), None }; + + string Source(string Attributes, string type, (StringMarshalling StringMarshalling, Type? StringMarshallingCustomType)? stringMarshalling = null) + => SourceWithRefKind("", Attributes, type, stringMarshalling); + + string SourceWithRefKind(string refKind, string Attributes, string type, (StringMarshalling StringMarshalling, Type? StringMarshallingCustomType)? stringMarshalling = null) + { + return codeSnippets.ByValueMarshallingOfType(Attributes + MarshalCollection(), CodeSnippets.GetCustomCollectionType(type), p, stringMarshalling) + CodeSnippets.CustomCollectionAndMarshaller; + } + static string MarshalUsing(string marshaller = CollectionMarshaller, int depth = 0) + => $"[MarshalUsing(typeof({marshaller}), ElementIndirectionDepth = {depth})]"; + static string MarshalCollection(int depth = 0) + => $"[MarshalUsing(typeof({CollectionMarshaller}), ElementIndirectionDepth = {depth}, ConstantElementCount = 10)]"; + } + + + + [Theory] + [MemberData(nameof(ByValueMarshalAttributeOnPinnedMarshalledTypes))] + [MemberData(nameof(ByValueMarshalAttributeOnValueTypes))] + [MemberData(nameof(ByValueMarshalAttributeOnCustomCollections))] + public async Task VerifyByValueMarshallingAttributeUsageInfoMessages(string id, string source, DiagnosticResult[] diagnostics) + { + _ = id; + VerifyComInterfaceGenerator.Test test = new(referenceAncillaryInterop: false) + { + TestCode = source, + TestBehaviors = TestBehaviors.SkipGeneratedSourcesCheck, + }; + test.DisabledDiagnostics.Remove(GeneratorDiagnostics.Ids.NotRecommendedGeneratedComInterfaceUsage); + test.ExpectedDiagnostics.AddRange(diagnostics); + await test.RunAsync(); + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/CodeSnippets.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/CodeSnippets.cs index b013d8c26c4599..700797fcb3e888 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/CodeSnippets.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/CodeSnippets.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; @@ -12,12 +13,26 @@ namespace ComInterfaceGenerator.Unit.Tests { internal partial class CodeSnippets { + internal static IComInterfaceAttributeProvider GetAttributeProvider(GeneratorKind generator) + => generator switch + { + GeneratorKind.VTableIndexStubGenerator => new VirtualMethodIndexAttributeProvider(), + GeneratorKind.ComInterfaceGeneratorManagedObjectWrapper => new GeneratedComInterfaceAttributeProvider(System.Runtime.InteropServices.Marshalling.ComInterfaceOptions.ManagedObjectWrapper), + GeneratorKind.ComInterfaceGeneratorComObjectWrapper => new GeneratedComInterfaceAttributeProvider(System.Runtime.InteropServices.Marshalling.ComInterfaceOptions.ComObjectWrapper), + GeneratorKind.ComInterfaceGenerator => new GeneratedComInterfaceAttributeProvider(), + _ => throw new UnreachableException(), + }; + private readonly IComInterfaceAttributeProvider _attributeProvider; public CodeSnippets(IComInterfaceAttributeProvider attributeProvider) { _attributeProvider = attributeProvider; } + public CodeSnippets(GeneratorKind generator) : this(GetAttributeProvider(generator)) + { + } + private string VirtualMethodIndex( int index, bool? ImplicitThisParameter = null, @@ -49,6 +64,51 @@ private string UnmanagedCallConv(Type[]? CallConvs = null) + arguments + "]"; } + public static string GetCustomCollectionType(string elementName) => $"StatelessCollectionAllShapes<{elementName}>"; + + public const string CustomCollectionAndMarshaller = CustomCollectionDefinition + CustomCollectionAllShapesMarshaller; + public const string CustomCollectionDefinition = """ + internal class StatelessCollectionAllShapes + { + public T _field; + } + """; + public const string CustomCollectionAllShapesMarshaller = """ + [ContiguousCollectionMarshaller] + [CustomMarshaller(typeof(StatelessCollectionAllShapes<>), MarshalMode.Default, typeof(StatelessCollectionAllShapesMarshaller<,>))] + internal unsafe static class StatelessCollectionAllShapesMarshaller where TUnmanagedElement : unmanaged + { + public static void Free(TUnmanagedElement* unmanaged) { } + + // ToUnmanaged + public static TUnmanagedElement* AllocateContainerForUnmanagedElements(StatelessCollectionAllShapes managed, out int numElements) + => throw null; + public static System.ReadOnlySpan GetManagedValuesSource(StatelessCollectionAllShapes managed) // Can throw exceptions + => throw null; + public static System.Span GetUnmanagedValuesDestination(TUnmanagedElement* unmanaged, int numElements) // Can throw exceptions + => throw null; + public static ref TUnmanagedElement* GetPinnableReference(StatelessCollectionAllShapes managed) + => throw null; + + // Caller Allocated buffer ToUnmanaged + public static int BufferSize { get; } = 10; + public static TUnmanagedElement* AllocateContainerForUnmanagedElements(StatelessCollectionAllShapes managed, System.Span buffer, out int numElements) + => throw null; + + // ToManaged + public static StatelessCollectionAllShapes AllocateContainerForManagedElements(TUnmanagedElement* unmanaged, int numElements) + => throw null; + public static System.Span GetManagedValuesDestination(StatelessCollectionAllShapes managed) + => throw null; + public static System.ReadOnlySpan GetUnmanagedValuesSource(TUnmanagedElement* unmanaged, int numElements) + => throw null; + + //ToManaged Guaranteed marshalling + public static StatelessCollectionAllShapes AllocateContainerForManagedElementsFinally(TUnmanagedElement* unmanaged, int numElements) + => throw null; + } + """; + public static readonly string DisableRuntimeMarshalling = "[assembly:System.Runtime.CompilerServices.DisableRuntimeMarshalling]"; public static readonly string UsingSystemRuntimeInteropServicesMarshalling = "using System.Runtime.InteropServices.Marshalling;"; public const string IntMarshaller = """ diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComInterfaceGenerator.Unit.Tests.csproj b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComInterfaceGenerator.Unit.Tests.csproj index 08ca1a56fa43c8..beda4af3a72283 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComInterfaceGenerator.Unit.Tests.csproj +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComInterfaceGenerator.Unit.Tests.csproj @@ -33,6 +33,7 @@ + diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/CompileFails.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/CompileFails.cs index 2e4d25da3837e5..331fa54117fb04 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/CompileFails.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/CompileFails.cs @@ -3,9 +3,7 @@ using System; using System.Collections.Generic; -using System.Diagnostics; using System.Linq; -using System.Runtime.CompilerServices; using System.Runtime.InteropServices.Marshalling; using System.Threading.Tasks; using Microsoft.CodeAnalysis; @@ -14,6 +12,7 @@ using Microsoft.Interop; using Microsoft.Interop.UnitTests; using Xunit; +using static Microsoft.Interop.UnitTests.TestUtils; using StringMarshalling = System.Runtime.InteropServices.StringMarshalling; using VerifyComInterfaceGenerator = Microsoft.Interop.UnitTests.Verifiers.CSharpSourceGeneratorVerifier; @@ -21,11 +20,6 @@ namespace ComInterfaceGenerator.Unit.Tests { public class CompileFails { - private static string ID( - [CallerLineNumber] int lineNumber = 0, - [CallerFilePath] string? filePath = null) - => TestUtils.GetFileLineName(lineNumber, filePath); - public static IEnumerable ComInterfaceGeneratorSnippetsToCompile() { CodeSnippets codeSnippets = new(new GeneratedComInterfaceAttributeProvider()); @@ -54,19 +48,9 @@ public async Task ValidateComInterfaceGeneratorSnippets(string id, string source await VerifyComInterfaceGenerator.VerifySourceGeneratorAsync(source, expectedDiagnostics); } - private static IComInterfaceAttributeProvider GetAttributeProvider(GeneratorKind generator) - => generator switch - { - GeneratorKind.VTableIndexStubGenerator => new VirtualMethodIndexAttributeProvider(), - GeneratorKind.ComInterfaceGeneratorManagedObjectWrapper => new GeneratedComInterfaceAttributeProvider(System.Runtime.InteropServices.Marshalling.ComInterfaceOptions.ManagedObjectWrapper), - GeneratorKind.ComInterfaceGeneratorComObjectWrapper => new GeneratedComInterfaceAttributeProvider(System.Runtime.InteropServices.Marshalling.ComInterfaceOptions.ComObjectWrapper), - GeneratorKind.ComInterfaceGenerator => new GeneratedComInterfaceAttributeProvider(), - _ => throw new UnreachableException(), - }; - public static IEnumerable InvalidUnmanagedToManagedCodeSnippetsToCompile(GeneratorKind generator) { - CodeSnippets codeSnippets = new(GetAttributeProvider(generator)); + CodeSnippets codeSnippets = new(generator); string safeHandleMarshallerDoesNotSupportManagedToUnmanaged = string.Format(SR.ManagedToUnmanagedMissingRequiredMarshaller, "global::System.Runtime.InteropServices.Marshalling.SafeHandleMarshaller"); string safeHandleMarshallerDoesNotSupportUnmanagedToManaged = string.Format(SR.UnmanagedToManagedMissingRequiredMarshaller, "global::System.Runtime.InteropServices.Marshalling.SafeHandleMarshaller"); @@ -98,7 +82,7 @@ public static IEnumerable InvalidUnmanagedToManagedCodeSnippetsToCompi DiagnosticResult invalidReturnTypeDiagnostic = VerifyComInterfaceGenerator.Diagnostic(GeneratorDiagnostics.ReturnTypeNotSupportedWithDetails) .WithLocation(0) .WithArguments(marshallerDoesNotSupportManagedToUnmanaged, "Method"); - CustomStructMarshallingCodeSnippets customStructMarshallingCodeSnippets = new(new CodeSnippets.Bidirectional(GetAttributeProvider(generator))); + CustomStructMarshallingCodeSnippets customStructMarshallingCodeSnippets = new(new CodeSnippets.Bidirectional(CodeSnippets.GetAttributeProvider(generator))); yield return new object[] { ID(), customStructMarshallingCodeSnippets.Stateless.NativeToManagedOnlyOutParameter, new[] { invalidManagedToUnmanagedParameterDiagnostic } }; yield return new object[] { ID(), customStructMarshallingCodeSnippets.Stateless.NativeToManagedOnlyReturnValue, new[] { invalidReturnTypeDiagnostic } }; yield return new object[] { ID(), customStructMarshallingCodeSnippets.Stateless.ByValueInParameter, new[] { invalidUnmanagedToManagedParameterDiagnostic } }; @@ -113,7 +97,7 @@ public static IEnumerable StringMarshallingCodeSnippets(GeneratorKind string CustomTypeSpecifiedWithNoStringMarshallingCustom = SR.InvalidStringMarshallingConfigurationNotCustom; string StringMarshallingMustMatchBase = SR.GeneratedComInterfaceStringMarshallingMustMatchBase; - CodeSnippets codeSnippets = new(GetAttributeProvider(generator)); + CodeSnippets codeSnippets = new(generator); (StringMarshalling, Type?) utf8Marshalling = (StringMarshalling.Utf8, null); (StringMarshalling, Type?) utf16Marshalling = (StringMarshalling.Utf16, null); (StringMarshalling, Type?) customUtf16Marshalling = (StringMarshalling.Custom, typeof(Utf16StringMarshaller)); @@ -342,7 +326,7 @@ public static IEnumerable InvalidManagedToUnmanagedCodeSnippetsToCompi { // Marshallers with only support for their expected places in the signatures in // UnmanagedToManaged marshal modes. - CustomStructMarshallingCodeSnippets customStructMarshallingCodeSnippets = new(new CodeSnippets.Bidirectional(GetAttributeProvider(generator))); + CustomStructMarshallingCodeSnippets customStructMarshallingCodeSnippets = new(new CodeSnippets.Bidirectional(CodeSnippets.GetAttributeProvider(generator))); yield return new[] { ID(), customStructMarshallingCodeSnippets.Stateless.NativeToManagedOnlyInParameter }; yield return new[] { ID(), customStructMarshallingCodeSnippets.Stateless.ByValueOutParameter }; @@ -543,272 +527,6 @@ public async Task VerifyInterfaceWithLessVisibilityThanInterfaceWarns(string id, await VerifyComInterfaceGenerator.VerifySourceGeneratorAsync(source, diagnostics); } - public static IEnumerable ByValueMarshalAttributeOnValueTypes() - { - var codeSnippets = new CodeSnippets(GetAttributeProvider(GeneratorKind.ComInterfaceGenerator)); - const string inAttribute = "[{|#1:InAttribute|}]"; - const string outAttribute = "[{|#2:OutAttribute|}]"; - const string paramName = "p"; - string paramNameWithLocation = $$"""{|#0:{{paramName}}|}"""; - var inAttributeIsDefaultDiagnostic = new DiagnosticResult(GeneratorDiagnostics.UnnecessaryParameterMarshallingInfo) - .WithLocation(0) - .WithLocation(1) - .WithArguments(SR.InOutAttributes, paramName, SR.InAttributeOnlyIsDefault); - - - // [In] is default for all non-pinned marshalled types - yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(inAttribute, "int", paramNameWithLocation), new DiagnosticResult[] { - inAttributeIsDefaultDiagnostic } }; - yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(inAttribute, "byte", paramNameWithLocation), new DiagnosticResult[] { - inAttributeIsDefaultDiagnostic } }; - yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(inAttribute + "[MarshalAs(UnmanagedType.U4)]", "bool", paramNameWithLocation), new DiagnosticResult[] { - inAttributeIsDefaultDiagnostic } }; - yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(inAttribute + "[MarshalAs(UnmanagedType.U2)]", "char", paramNameWithLocation), new DiagnosticResult[] { - inAttributeIsDefaultDiagnostic } }; - - // [Out] is not allowed on value types passed by value - there is no indirection for the callee to make visible modifications. - var outAttributeNotSupportedOnValueParameters = new DiagnosticResult(GeneratorDiagnostics.ParameterTypeNotSupportedWithDetails) - .WithLocation(0) - .WithArguments(SR.OutAttributeNotSupportedOnByValueParameters, paramName); - yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(outAttribute, "int", paramNameWithLocation), new DiagnosticResult[] { - outAttributeNotSupportedOnValueParameters } }; - yield return new object[] { - ID(), - codeSnippets.ByValueMarshallingOfType(outAttribute, "IntStruct", paramNameWithLocation) + CodeSnippets.IntStructAndMarshaller, - new DiagnosticResult[] { - outAttributeNotSupportedOnValueParameters - } }; - yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(outAttribute + "[MarshalAs(UnmanagedType.U4)]", "bool", paramNameWithLocation), new DiagnosticResult[] { - outAttributeNotSupportedOnValueParameters - } }; - yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(outAttribute, "[MarshalAs(UnmanagedType.U2)] char", paramNameWithLocation), new DiagnosticResult[] { - outAttributeNotSupportedOnValueParameters - } }; - // [In,Out] should only warn for Out attribute - yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(inAttribute+outAttribute, "int", paramNameWithLocation), new DiagnosticResult[] { - outAttributeNotSupportedOnValueParameters } }; - yield return new object[] { - ID(), - codeSnippets.ByValueMarshallingOfType(inAttribute+outAttribute, "IntStruct", paramNameWithLocation) + CodeSnippets.IntStructAndMarshaller, - new DiagnosticResult[] { - outAttributeNotSupportedOnValueParameters - } }; - yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(inAttribute + outAttribute + "[MarshalAs(UnmanagedType.U4)]", "bool", paramNameWithLocation), new DiagnosticResult[] { - outAttributeNotSupportedOnValueParameters - } }; - yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(inAttribute + outAttribute, "[MarshalAs(UnmanagedType.U2)] char", paramNameWithLocation), new DiagnosticResult[] { - outAttributeNotSupportedOnValueParameters - } }; - - // Any ref keyword is okay for value types - yield return new object[] { - ID(), - codeSnippets.ByValueMarshallingOfType("out", "IntStruct", paramNameWithLocation) + CodeSnippets.IntStructAndMarshaller, - new DiagnosticResult[] {} - }; - yield return new object[] { - ID(), - codeSnippets.ByValueMarshallingOfType("out", "byte", paramNameWithLocation), - new DiagnosticResult[] {} - }; - yield return new object[] { - ID(), - codeSnippets.ByValueMarshallingOfType("", "[MarshalAs(UnmanagedType.U2)] out char", paramNameWithLocation), - new DiagnosticResult[] {} - }; - yield return new object[] { - ID(), - codeSnippets.ByValueMarshallingOfType("in", "IntStruct", paramNameWithLocation) + CodeSnippets.IntStructAndMarshaller, - new DiagnosticResult[] {} - }; - yield return new object[] { - ID(), - codeSnippets.ByValueMarshallingOfType("in", "byte", paramNameWithLocation), - new DiagnosticResult[] {} - }; - yield return new object[] { - ID(), - codeSnippets.ByValueMarshallingOfType("", "[MarshalAs(UnmanagedType.U2)] in char", paramNameWithLocation), - new DiagnosticResult[] {} - }; - yield return new object[] { - ID(), - codeSnippets.ByValueMarshallingOfType("ref", "IntStruct", paramNameWithLocation) + CodeSnippets.IntStructAndMarshaller, - new DiagnosticResult[] {} - }; - yield return new object[] { - ID(), - codeSnippets.ByValueMarshallingOfType("ref", "byte", paramNameWithLocation), - new DiagnosticResult[] {} - }; - yield return new object[] { - ID(), - codeSnippets.ByValueMarshallingOfType("", "[MarshalAs(UnmanagedType.U2)] ref char", paramNameWithLocation), - new DiagnosticResult[] {} - }; - } - - public static IEnumerable ByValueMarshalAttributeOnReferenceTypes() - { - var codeSnippets = new CodeSnippets(GetAttributeProvider(GeneratorKind.ComInterfaceGenerator)); - const string inAttribute = "[{|#1:InAttribute|}]"; - const string outAttribute = "[{|#2:OutAttribute|}]"; - const string paramName = "p"; - string paramNameWithLocation = $$"""{|#0:{{paramName}}|}"""; - var inAttributeIsDefaultDiagnostic = new DiagnosticResult(GeneratorDiagnostics.UnnecessaryParameterMarshallingInfo) - .WithLocation(0) - .WithLocation(1) - .WithArguments(SR.InOutAttributes, paramName, SR.InAttributeOnlyIsDefault); - - // [In] is default - yield return new object[] { - ID(), - codeSnippets.ByValueMarshallingOfType(inAttribute, "string", paramNameWithLocation, (StringMarshalling.Utf8, null)), - new DiagnosticResult[] { inAttributeIsDefaultDiagnostic } - }; - yield return new object[] { - ID(), - codeSnippets.ByValueMarshallingOfType(inAttribute, "IntClass", paramNameWithLocation) + CodeSnippets.IntClassAndMarshaller, - new DiagnosticResult[] { inAttributeIsDefaultDiagnostic } - }; - - var outNotAllowedOnRefTypes = new DiagnosticResult(GeneratorDiagnostics.ParameterTypeNotSupportedWithDetails) - .WithLocation(0) - .WithArguments(SR.OutAttributeNotSupportedOnByValueParameters, paramName); - - // [Out] is not allowed on strings - yield return new object[] { - ID(), - codeSnippets.ByValueMarshallingOfType(outAttribute, "string", paramNameWithLocation, (StringMarshalling.Utf8, null)), - new DiagnosticResult[] { outNotAllowedOnRefTypes } - }; - - // [Out] warns on by value reference types - yield return new object[] { - ID(), - codeSnippets.ByValueMarshallingOfType(outAttribute, "IntClass", paramNameWithLocation) + CodeSnippets.IntClassAndMarshaller, - new DiagnosticResult[] { outNotAllowedOnRefTypes } - }; - - // [In,Out] is fine on classes - yield return new object[] { - ID(), - codeSnippets.ByValueMarshallingOfType(inAttribute + outAttribute, "IntClass", paramNameWithLocation) + CodeSnippets.IntClassAndMarshaller, - new DiagnosticResult[] { outNotAllowedOnRefTypes } - }; - - // All refkinds are okay on classes and strings - yield return new object[] { - ID(), - codeSnippets.ByValueMarshallingOfType("in", "string", paramNameWithLocation, (StringMarshalling.Utf8, null)), - new DiagnosticResult[] { } - }; - yield return new object[] { - ID(), - codeSnippets.ByValueMarshallingOfType("in", "IntClass", paramNameWithLocation) + CodeSnippets.IntClassAndMarshaller, - new DiagnosticResult[] { } - }; - yield return new object[] { - ID(), - codeSnippets.ByValueMarshallingOfType("out", "string", paramNameWithLocation, (StringMarshalling.Utf8, null)), - new DiagnosticResult[] { } - }; - yield return new object[] { - ID(), - codeSnippets.ByValueMarshallingOfType("out", "IntClass", paramNameWithLocation) + CodeSnippets.IntClassAndMarshaller, - new DiagnosticResult[] { } - }; - yield return new object[] { - ID(), - codeSnippets.ByValueMarshallingOfType("ref", "string", paramNameWithLocation, (StringMarshalling.Utf8, null)), - new DiagnosticResult[] { } - }; - yield return new object[] { - ID(), - codeSnippets.ByValueMarshallingOfType("ref", "IntClass", paramNameWithLocation) + CodeSnippets.IntClassAndMarshaller, - new DiagnosticResult[] { } - }; - } - - public static IEnumerable ByValueMarshalAttributeOnPinnedMarshalledTypes() - { - var codeSnippets = new CodeSnippets(GetAttributeProvider(GeneratorKind.ComInterfaceGenerator)); - const string inAttribute = "[{|#1:InAttribute|}]"; - const string outAttribute = "[{|#2:OutAttribute|}]"; - const string paramName = "p"; - string paramNameWithLocation = $$"""{|#0:{{paramName}}|}"""; - const string constElementCount = @"[MarshalUsing(ConstantElementCount = 10)]"; - var inAttributeIsDefaultDiagnostic = new DiagnosticResult(GeneratorDiagnostics.UnnecessaryParameterMarshallingInfo) - .WithLocation(0) - .WithLocation(1) - .WithArguments(SR.InOutAttributes, paramName, SR.InAttributeOnlyIsDefault); - - yield return new object[] { ID(), codeSnippets.ByValueMarshallingOfType(inAttribute + constElementCount, "int[]", paramNameWithLocation), new DiagnosticResult[] { - inAttributeIsDefaultDiagnostic - }}; - yield return new object[] { - ID(), - codeSnippets.ByValueMarshallingOfType(inAttribute + constElementCount, "char[]", paramNameWithLocation, (StringMarshalling.Utf16, null)), - new DiagnosticResult[] { inAttributeIsDefaultDiagnostic } - }; - - // bools that are marshalled into a new array are in by default - yield return new object[] { - ID(), - codeSnippets.ByValueMarshallingOfType( - inAttribute + "[MarshalAs(UnmanagedType.LPArray, ArraySubType = UnmanagedType.U1, SizeConst = 10)]", - "bool[]", - paramNameWithLocation, - (StringMarshalling.Utf16, null)), - new DiagnosticResult[] { inAttributeIsDefaultDiagnostic } - }; - // Overriding marshalling with a custom marshaller makes it not pinned - yield return new object[] { - ID(), - codeSnippets.ByValueMarshallingOfType(inAttribute, "[MarshalUsing(typeof(IntMarshaller), ElementIndirectionDepth = 1), MarshalUsing(ConstantElementCount = 10)]int[]", paramNameWithLocation) + CodeSnippets.IntMarshaller, - new DiagnosticResult[] { inAttributeIsDefaultDiagnostic } - }; - - yield return new object[] { - ID(), - codeSnippets.ByValueMarshallingOfType(inAttribute + outAttribute + constElementCount, "int[]", paramNameWithLocation), - new DiagnosticResult[] { } - }; - yield return new object[] { - ID(), - codeSnippets.ByValueMarshallingOfType(inAttribute + outAttribute + constElementCount, "char[]", paramNameWithLocation, (StringMarshalling.Utf16, null)), - new DiagnosticResult[] { } - }; - - yield return new object[] { - ID(), - codeSnippets.ByValueMarshallingOfType(outAttribute + constElementCount, "int[]", paramNameWithLocation), - new DiagnosticResult[] { } - }; - - yield return new object[] { - ID(), - codeSnippets.ByValueMarshallingOfType(outAttribute + constElementCount, "char[]", paramNameWithLocation, (StringMarshalling.Utf16, null)), - new DiagnosticResult[] { } - }; - } - - [Theory] - [MemberData(nameof(ByValueMarshalAttributeOnValueTypes))] - [MemberData(nameof(ByValueMarshalAttributeOnReferenceTypes))] - [MemberData(nameof(ByValueMarshalAttributeOnPinnedMarshalledTypes))] - public async Task VerifyByValueMarshallingAttributeUsage(string id, string source, DiagnosticResult[] diagnostics) - { - _ = id; - VerifyComInterfaceGenerator.Test test = new(referenceAncillaryInterop: false) - { - TestCode = source, - TestBehaviors = TestBehaviors.SkipGeneratedSourcesCheck, - }; - test.ExpectedDiagnostics.AddRange(diagnostics); - await test.RunAsync(); - } - [Fact] public async Task VerifyNonPartialInterfaceWarns() { @@ -967,7 +685,7 @@ partial interface {|#0:J|} public static IEnumerable CountParameterIsOutSnippets() { - var g = GetAttributeProvider(GeneratorKind.ComInterfaceGenerator); + var g = CodeSnippets.GetAttributeProvider(GeneratorKind.ComInterfaceGenerator); CodeSnippets a = new(g); DiagnosticResult returnValueDiag = new DiagnosticResult(GeneratorDiagnostics.SizeOfInCollectionMustBeDefinedAtCallReturnValue) .WithLocation(1) diff --git a/src/libraries/System.Runtime.InteropServices/tests/Common/TestUtils.cs b/src/libraries/System.Runtime.InteropServices/tests/Common/TestUtils.cs index 4d1ec3a5d2c7dd..cd065f54a48c0b 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/Common/TestUtils.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/Common/TestUtils.cs @@ -1,10 +1,6 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp; -using Microsoft.CodeAnalysis.Diagnostics; -using Microsoft.CodeAnalysis.Testing; using System; using System.Collections.Generic; using System.Collections.Immutable; @@ -15,6 +11,10 @@ using System.Runtime.InteropServices; using System.Threading; using System.Threading.Tasks; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.Diagnostics; +using Microsoft.CodeAnalysis.Testing; using Xunit; namespace Microsoft.Interop.UnitTests @@ -51,6 +51,10 @@ public enum TestTargetFramework public static class TestUtils { + public static string ID( + [CallerLineNumber] int lineNumber = 0, + [CallerFilePath] string? filePath = null) + => TestUtils.GetFileLineName(lineNumber, filePath); internal static string GetFileLineName( [CallerLineNumber] int lineNumber = 0, [CallerFilePath] string? filePath = null) @@ -298,7 +302,7 @@ public void Stop() int count = Interlocked.Decrement(ref _count); if (count == 0) { - Environment.SetEnvironmentVariable(EnvVarName, null); + Environment.SetEnvironmentVariable(EnvVarName, null); } } } diff --git a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/ByValueContentsMarshalling.cs b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/ByValueContentsMarshalling.cs new file mode 100644 index 00000000000000..e632481d52109f --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/ByValueContentsMarshalling.cs @@ -0,0 +1,135 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Threading.Tasks; +using Microsoft.CodeAnalysis.Testing; +using Microsoft.Interop; +using Xunit; +using static Microsoft.Interop.UnitTests.TestUtils; +using VerifyCS = Microsoft.Interop.UnitTests.Verifiers.CSharpSourceGeneratorVerifier; + +namespace LibraryImportGenerator.UnitTests +{ + public class ByValueContentsMarshalling + { + public static IEnumerable ByValueMarshalAttributeOnValueTypes() + { + CodeSnippets.ByValueParameterWithModifier("int[]", "In"); + + const string In = "InAttribute"; + const string Out = "OutAttribute"; + const string InOut = "InAttribute, OutAttribute"; + const string paramName = "p"; + const string MarshalUsingUtf16 = "MarshalUsing(typeof(Utf16StringMarshaller))"; + + var diagnostic = new DiagnosticResult(GeneratorDiagnostics.ParameterTypeNotSupportedWithDetails); + var outAttributeNotSupported = diagnostic + .WithLocation(0) + .WithArguments(SR.OutAttributeNotSupportedOnByValueParameters, paramName); + var inAttributeIsNotSupported = diagnostic + .WithLocation(0) + .WithArguments(SR.InAttributeNotSupportedOnByValueParameters, paramName); + var inOutAttributeNotSupported = diagnostic + .WithLocation(0) + .WithArguments(SR.InOutAttributeNotSupportedOnByValueParameters, paramName); + + DiagnosticResult[] InIsNotSupported = [inAttributeIsNotSupported]; + yield return new object[] { ID(), CodeSnippets.ByValueParameterWithModifier("int", In), InIsNotSupported }; + yield return new object[] { ID(), CodeSnippets.ByValueParameterWithModifier("byte", In), InIsNotSupported }; + yield return new object[] { ID(), CodeSnippets.ByValueParameterWithModifier("string", In + " , " + MarshalUsingUtf16), InIsNotSupported }; + yield return new object[] { ID(), CodeSnippets.ByValueParameterWithModifier("IntStruct", In, CodeSnippets.IntStructAndMarshaller), InIsNotSupported }; + yield return new object[] { ID(), CodeSnippets.ByValueParameterWithModifier("IntClass", In, CodeSnippets.IntClassAndMarshaller), InIsNotSupported }; + + DiagnosticResult[] OutIsNotSupported = [outAttributeNotSupported]; + yield return new object[] { ID(), CodeSnippets.ByValueParameterWithModifier("int", Out), OutIsNotSupported }; + yield return new object[] { ID(), CodeSnippets.ByValueParameterWithModifier("byte", Out), OutIsNotSupported }; + yield return new object[] { ID(), CodeSnippets.ByValueParameterWithModifier("string", Out + " , " + MarshalUsingUtf16), OutIsNotSupported }; + yield return new object[] { ID(), CodeSnippets.ByValueParameterWithModifier("IntStruct", Out, CodeSnippets.IntStructAndMarshaller), OutIsNotSupported }; + yield return new object[] { ID(), CodeSnippets.ByValueParameterWithModifier("IntClass", Out, CodeSnippets.IntClassAndMarshaller), OutIsNotSupported }; + + DiagnosticResult[] InOutIsNotSupported = [inOutAttributeNotSupported]; + yield return new object[] { ID(), CodeSnippets.ByValueParameterWithModifier("int", InOut), InOutIsNotSupported }; + yield return new object[] { ID(), CodeSnippets.ByValueParameterWithModifier("byte", InOut), InOutIsNotSupported }; + yield return new object[] { ID(), CodeSnippets.ByValueParameterWithModifier("string", InOut + " , " + MarshalUsingUtf16), InOutIsNotSupported }; + yield return new object[] { ID(), CodeSnippets.ByValueParameterWithModifier("IntStruct", InOut, CodeSnippets.IntStructAndMarshaller), InOutIsNotSupported }; + yield return new object[] { ID(), CodeSnippets.ByValueParameterWithModifier("IntClass", InOut, CodeSnippets.IntClassAndMarshaller), InOutIsNotSupported }; + + var inAndOutNotAllowedWithRefKind = new DiagnosticResult(GeneratorDiagnostics.ParameterTypeNotSupportedWithDetails) + .WithLocation(0) + .WithArguments("The '[In]' and '[Out]' attributes are unsupported on parameters passed by reference. Use the 'in', 'ref', or 'out' keywords instead.", "p"); + DiagnosticResult[] InAndOutNotAllowedWithRefKind = [inAndOutNotAllowedWithRefKind]; + yield return new object[] { ID(), CodeSnippets.ByValueParameterWithModifier("in int", In), InAndOutNotAllowedWithRefKind }; + yield return new object[] { ID(), CodeSnippets.ByValueParameterWithModifier("ref int", In), InAndOutNotAllowedWithRefKind }; + yield return new object[] { ID(), CodeSnippets.ByValueParameterWithModifier("ref int", InOut), InAndOutNotAllowedWithRefKind }; + yield return new object[] { ID(), CodeSnippets.ByValueParameterWithModifier("out int", Out), InAndOutNotAllowedWithRefKind }; + } + + public static IEnumerable ByValueMarshalAttributeOnArrays() + { + CodeSnippets.ByValueParameterWithModifier("int[]", "In"); + + const string In = "InAttribute"; + const string Out = "OutAttribute"; + const string InOut = "InAttribute, OutAttribute"; + const string MarshalUsingUtf16 = "MarshalUsing(typeof(Utf16StringMarshaller), ElementIndirectionDepth = 1)"; + const string Count = "MarshalUsing(ConstantElementCount = 10)"; + DiagnosticResult[] None = []; + + yield return new object[] { ID(), CodeSnippets.ByValueParameterWithModifier("int[]", string.Join(',', Count, In)), None }; + yield return new object[] { ID(), CodeSnippets.ByValueParameterWithModifier("byte[]", string.Join(',', Count, In)), None }; + yield return new object[] { ID(), CodeSnippets.ByValueParameterWithModifier("string[]", string.Join(',', Count, In, MarshalUsingUtf16)), None }; + yield return new object[] { ID(), CodeSnippets.ByValueParameterWithModifier("IntStruct[]", string.Join(',', Count, In), CodeSnippets.IntStructAndMarshaller), None }; + yield return new object[] { ID(), CodeSnippets.ByValueParameterWithModifier("IntClass[]", string.Join(',', Count, In), CodeSnippets.IntClassAndMarshaller), None }; + + yield return new object[] { ID(), CodeSnippets.ByValueParameterWithModifier("int[]", string.Join(',', Count, Out)), None }; + yield return new object[] { ID(), CodeSnippets.ByValueParameterWithModifier("byte[]", string.Join(',', Count, Out)), None }; + yield return new object[] { ID(), CodeSnippets.ByValueParameterWithModifier("string[]", string.Join(',', Count, Out, MarshalUsingUtf16)), None }; + yield return new object[] { ID(), CodeSnippets.ByValueParameterWithModifier("IntStruct[]", string.Join(',', Count, Out), CodeSnippets.IntStructAndMarshaller), None }; + yield return new object[] { ID(), CodeSnippets.ByValueParameterWithModifier("IntClass[]", string.Join(',', Count, Out), CodeSnippets.IntClassAndMarshaller), None }; + + yield return new object[] { ID(), CodeSnippets.ByValueParameterWithModifier("int[]", string.Join(',', Count, InOut)), None }; + yield return new object[] { ID(), CodeSnippets.ByValueParameterWithModifier("byte[]", string.Join(',', Count, InOut)), None }; + yield return new object[] { ID(), CodeSnippets.ByValueParameterWithModifier("string[]", string.Join(',', Count, InOut, MarshalUsingUtf16)), None }; + yield return new object[] { ID(), CodeSnippets.ByValueParameterWithModifier("IntStruct[]", string.Join(',', Count, InOut), CodeSnippets.IntStructAndMarshaller), None }; + yield return new object[] { ID(), CodeSnippets.ByValueParameterWithModifier("IntClass[]", string.Join(',', Count, InOut), CodeSnippets.IntClassAndMarshaller), None }; + + DiagnosticResult preferAttributes = new DiagnosticResult(GeneratorDiagnostics.LibraryImportUsageDoesNotFollowBestPractices) + .WithArguments(SR.PreferExplicitInOutAttributesOnArrays) + .WithLocation(0); + DiagnosticResult[] PreferAttributes = [preferAttributes]; + yield return new object[] { ID(), CodeSnippets.ByValueParameterWithModifier("int[]", Count), PreferAttributes }; + yield return new object[] { ID(), CodeSnippets.ByValueParameterWithModifier("byte[]", Count), PreferAttributes }; + yield return new object[] { ID(), CodeSnippets.ByValueParameterWithModifier("string[]", string.Join(',', Count, MarshalUsingUtf16)), PreferAttributes }; + yield return new object[] { ID(), CodeSnippets.ByValueParameterWithModifier("IntStruct[]", Count, CodeSnippets.IntStructAndMarshaller), PreferAttributes }; + yield return new object[] { ID(), CodeSnippets.ByValueParameterWithModifier("IntClass[]", Count, CodeSnippets.IntClassAndMarshaller), PreferAttributes }; + + var inAndOutNotAllowedWithRefKind = new DiagnosticResult(GeneratorDiagnostics.ParameterTypeNotSupportedWithDetails) + .WithLocation(0) + .WithArguments("The '[In]' and '[Out]' attributes are unsupported on parameters passed by reference. Use the 'in', 'ref', or 'out' keywords instead.", "p"); + DiagnosticResult[] InAndOutNotAllowedWithRefKind = [inAndOutNotAllowedWithRefKind]; + yield return new object[] { ID(), CodeSnippets.ByValueParameterWithModifier("in int[]", string.Join(',', Count, In)), InAndOutNotAllowedWithRefKind }; + yield return new object[] { ID(), CodeSnippets.ByValueParameterWithModifier("ref int[]", string.Join(',', Count, In)), InAndOutNotAllowedWithRefKind }; + yield return new object[] { ID(), CodeSnippets.ByValueParameterWithModifier("ref int[]", string.Join(',', Count, In, Out)), InAndOutNotAllowedWithRefKind }; + yield return new object[] { ID(), CodeSnippets.ByValueParameterWithModifier("out int[]", string.Join(',', Count, Out)), InAndOutNotAllowedWithRefKind }; + } + + + [Theory] + [MemberData(nameof(ByValueMarshalAttributeOnValueTypes))] + [MemberData(nameof(ByValueMarshalAttributeOnArrays))] + public async Task VerifyByValueMarshallingAttributeUsageInfoMessages(string id, string source, DiagnosticResult[] diagnostics) + { + _ = id; + VerifyCS.Test test = new(referenceAncillaryInterop: false) + { + TestCode = source, + TestBehaviors = TestBehaviors.SkipGeneratedSourcesCheck, + }; + test.DisabledDiagnostics.Remove(GeneratorDiagnostics.Ids.NotRecommendedGeneratedComInterfaceUsage); + test.ExpectedDiagnostics.AddRange(diagnostics); + await test.RunAsync(); + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/CodeSnippets.cs b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/CodeSnippets.cs index e1039f70a90ef9..19832861a38aaa 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/CodeSnippets.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/CodeSnippets.cs @@ -986,6 +986,22 @@ out int pOutSize } """; + public const string IntClassAndMarshaller = IntClassDefinition + IntClassMarshallerDefinition; + public const string IntClassDefinition = """ + internal struct IntClass + { + public int Field; + } + """; + public const string IntClassMarshallerDefinition = """ + [CustomMarshaller(typeof(IntClass), MarshalMode.Default, typeof(IntClassMarshaller))] + internal static class IntClassMarshaller + { + public static nint ConvertToUnmanaged(IntClass managed) => (nint)0; + public static IntClass ConvertToManaged(nint unmanaged) => default; + } + """; + public const string IntStructAndMarshaller = IntStructDefinition + IntStructMarshallerDefinition; public const string IntStructDefinition = """ internal struct IntStruct diff --git a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/CompileFails.cs b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/CompileFails.cs index ffb3fb5c6a14ba..2d278dc6237de2 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/CompileFails.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/CompileFails.cs @@ -350,56 +350,6 @@ public static IEnumerable CodeSnippetsToCompile() .WithArguments("MarshalAsAttribute", "t") }}; - // Unsupported [In, Out] attributes usage - - // By ref with [In, Out] attributes - yield return new object[] { ID(), CodeSnippets.ByValueParameterWithModifier("in int", "In"), new[] - { - VerifyCS.Diagnostic(GeneratorDiagnostics.ParameterTypeNotSupportedWithDetails) - .WithLocation(0) - .WithArguments("The '[In]' and '[Out]' attributes are unsupported on parameters passed by reference. Use the 'in', 'ref', or 'out' keywords instead.", "p") - } }; - yield return new object[] { ID(), CodeSnippets.ByValueParameterWithModifier("ref int", "In"), new[] - { - VerifyCS.Diagnostic(GeneratorDiagnostics.ParameterTypeNotSupportedWithDetails) - .WithLocation(0) - .WithArguments("The '[In]' and '[Out]' attributes are unsupported on parameters passed by reference. Use the 'in', 'ref', or 'out' keywords instead.", "p") - } }; - yield return new object[] { ID(), CodeSnippets.ByValueParameterWithModifier("ref int", "In, Out"), new[] - { - VerifyCS.Diagnostic(GeneratorDiagnostics.ParameterTypeNotSupportedWithDetails) - .WithLocation(0) - .WithArguments("The '[In]' and '[Out]' attributes are unsupported on parameters passed by reference. Use the 'in', 'ref', or 'out' keywords instead.", "p") - } }; - yield return new object[] { ID(), CodeSnippets.ByValueParameterWithModifier("out int", "Out"), new[] - { - VerifyCS.Diagnostic(GeneratorDiagnostics.ParameterTypeNotSupportedWithDetails) - .WithLocation(0) - .WithArguments("The '[In]' and '[Out]' attributes are unsupported on parameters passed by reference. Use the 'in', 'ref', or 'out' keywords instead.", "p") - } }; - - // By value non-array with [In, Out] attributes - yield return new object[] { ID(), CodeSnippets.ByValueParameterWithModifier("{|#1:In|}"), new [] { - VerifyCS.Diagnostic(GeneratorDiagnostics.UnnecessaryParameterMarshallingInfo) - .WithLocation(0) - .WithLocation(1) - .WithArguments(SR.InOutAttributes, "p", SR.InAttributeOnlyIsDefault) - .WithSeverity(DiagnosticSeverity.Info) - } }; - yield return new object[] { ID(), CodeSnippets.ByValueParameterWithModifier("Out"), new[] - { - VerifyCS.Diagnostic(GeneratorDiagnostics.ParameterTypeNotSupportedWithDetails) - .WithLocation(0) - .WithArguments(SR.OutAttributeNotSupportedOnByValueParameters, "p") - } }; - - yield return new object[] { ID(), CodeSnippets.ByValueParameterWithModifier("In, Out"), new[] - { - VerifyCS.Diagnostic(GeneratorDiagnostics.ParameterTypeNotSupportedWithDetails) - .WithLocation(0) - .WithArguments(SR.OutAttributeNotSupportedOnByValueParameters, "p") - } }; - // LCIDConversion yield return new object[] { ID(), CodeSnippets.LCIDConversionAttribute, new[] { VerifyCS.Diagnostic(GeneratorDiagnostics.ConfigurationNotSupported) @@ -832,13 +782,6 @@ public static IEnumerable CodeSnippetsToCompile() .WithLocation(1) .WithArguments("ref return", "Basic.RefReadonlyReturn()"), } }; - yield return new object[] { ID(), CodeSnippets.ByValueParameterWithModifier("{|#10:In|}"), new[] - { - VerifyCS.Diagnostic(GeneratorDiagnostics.UnnecessaryParameterMarshallingInfo) - .WithLocation(0) - .WithLocation(10) - .WithArguments("[In] and [Out] attributes", "p", SR.InAttributeOnlyIsDefault) - } }; } [Theory] diff --git a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/Compiles.cs b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/Compiles.cs index 438d589b2c29d0..5d37fffd52a5b6 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/Compiles.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/Compiles.cs @@ -17,7 +17,6 @@ using Microsoft.CodeAnalysis.Text; using Microsoft.Interop.UnitTests; using Xunit; -using GeneratorDiagnostics = Microsoft.Interop.GeneratorDiagnostics; using VerifyCS = Microsoft.Interop.UnitTests.Verifiers.CSharpSourceGeneratorVerifier; namespace LibraryImportGenerator.UnitTests @@ -744,28 +743,5 @@ public NoChangeTest(TestTargetFramework framework) return (newCompilation, diagnostics); } } - - public static IEnumerable ByValueMarshalKindSnippets() - { - // Blittable array - yield return new object[] { ID(), CodeSnippets.ByValueParameterWithModifier("{|#10:Out|}"), new DiagnosticResult[] { } }; - - yield return new object[] { ID(), CodeSnippets.ByValueParameterWithModifier("{|#10:In|}, {|#11:Out|}"), new DiagnosticResult[] { } }; - - yield return new object[] { ID(), CodeSnippets.ByValueParameterWithModifier("{|#10:In|}"), new DiagnosticResult[] { - VerifyCS.Diagnostic(GeneratorDiagnostics.UnnecessaryParameterMarshallingInfo) - .WithLocation(0) - .WithLocation(10) - .WithArguments("[In] and [Out] attributes", "p", SR.InAttributeOnlyIsDefault) - } }; - } - - [MemberData(nameof(ByValueMarshalKindSnippets))] - [Theory] - public async Task ValidateDiagnosticsForUnnecessaryByValueMarshalKindAttributes(string id, string source, DiagnosticResult[] diagnostics) - { - _ = id; - await VerifyCS.VerifySourceGeneratorAsync(source, diagnostics); - } } } diff --git a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/SharedTypes.csproj b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/SharedTypes.csproj index c68eadac7ca4b4..cb767939315c84 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/SharedTypes.csproj +++ b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/SharedTypes.csproj @@ -11,6 +11,8 @@ + + diff --git a/src/libraries/System.Runtime.Loader/tests/CollectibleAssemblyLoadContextTest.cs b/src/libraries/System.Runtime.Loader/tests/CollectibleAssemblyLoadContextTest.cs index 90527c3d0be9de..83d1c9aebf61f7 100644 --- a/src/libraries/System.Runtime.Loader/tests/CollectibleAssemblyLoadContextTest.cs +++ b/src/libraries/System.Runtime.Loader/tests/CollectibleAssemblyLoadContextTest.cs @@ -16,12 +16,16 @@ public partial class AssemblyLoadContextTest // Tests related to Collectible assemblies [MethodImpl(MethodImplOptions.NoInlining)] - static void CreateAndLoadContext(CollectibleChecker checker) + static void CreateAndLoadContext(CollectibleChecker checker, bool unloadTwice = false) { var alc = new ResourceAssemblyLoadContext(true); checker.SetAssemblyLoadContext(0, alc); alc.Unload(); + if (unloadTwice) + { + alc.Unload(); + } // Check that any attempt to load an assembly after an explicit Unload will fail Assert.Throws(() => alc.LoadFromAssemblyPath(Path.GetFullPath("none.dll"))); @@ -39,6 +43,19 @@ public static void Unload_CollectibleWithNoAssemblyLoaded() checker.GcAndCheck(); } + [Fact] + [ActiveIssue("https://github.com/mono/mono/issues/15142", TestRuntimes.Mono)] + public static void DoubleUnload_CollectibleWithNoAssemblyLoaded() + { + // Use a collectible ALC + Unload + // Check that we receive the Unloading event + + var checker = new CollectibleChecker(1); + CreateAndLoadContext(checker, unloadTwice: true); + checker.GcAndCheck(); + } + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsPreciseGcSupported))] public static void Finalizer_CollectibleWithNoAssemblyLoaded() { diff --git a/src/libraries/System.Runtime/ref/System.Runtime.cs b/src/libraries/System.Runtime/ref/System.Runtime.cs index 900040db349efe..bc7e04a54f96b3 100644 --- a/src/libraries/System.Runtime/ref/System.Runtime.cs +++ b/src/libraries/System.Runtime/ref/System.Runtime.cs @@ -2722,7 +2722,6 @@ public static void SuppressFinalize(object obj) { } public static void WaitForPendingFinalizers() { } public static TimeSpan GetTotalPauseDuration() { throw null; } public static System.Collections.Generic.IReadOnlyDictionary GetConfigurationVariables() { throw null; } - [System.Runtime.Versioning.RequiresPreviewFeaturesAttribute("RefreshMemoryLimit is in preview.")] public static void RefreshMemoryLimit() { throw null; } } @@ -10802,7 +10801,7 @@ public partial interface IMultiplyOperators where TSelf static virtual TResult operator checked *(TSelf left, TOther right) { throw null; } static abstract TResult operator *(TSelf left, TOther right); } - public partial interface INumberBase : System.IEquatable, System.IFormattable, System.IParsable, System.ISpanFormattable, System.ISpanParsable, System.Numerics.IAdditionOperators, System.Numerics.IAdditiveIdentity, System.Numerics.IDecrementOperators, System.Numerics.IDivisionOperators, System.Numerics.IEqualityOperators, System.Numerics.IIncrementOperators, System.Numerics.IMultiplicativeIdentity, System.Numerics.IMultiplyOperators, System.Numerics.ISubtractionOperators, System.Numerics.IUnaryNegationOperators, System.Numerics.IUnaryPlusOperators, /* System.IUtf8SpanFormattable, */ System.IUtf8SpanParsable where TSelf : System.Numerics.INumberBase? + public partial interface INumberBase : System.IEquatable, System.IFormattable, System.IParsable, System.ISpanFormattable, System.ISpanParsable, System.Numerics.IAdditionOperators, System.Numerics.IAdditiveIdentity, System.Numerics.IDecrementOperators, System.Numerics.IDivisionOperators, System.Numerics.IEqualityOperators, System.Numerics.IIncrementOperators, System.Numerics.IMultiplicativeIdentity, System.Numerics.IMultiplyOperators, System.Numerics.ISubtractionOperators, System.Numerics.IUnaryNegationOperators, System.Numerics.IUnaryPlusOperators, System.IUtf8SpanFormattable, System.IUtf8SpanParsable where TSelf : System.Numerics.INumberBase? { static abstract TSelf One { get; } static abstract int Radix { get; } @@ -10844,9 +10843,7 @@ static virtual TSelf CreateTruncating(TOther value) static virtual TSelf Parse(System.ReadOnlySpan utf8Text, System.Globalization.NumberStyles style, System.IFormatProvider? provider) { throw null; } static abstract TSelf Parse(System.ReadOnlySpan s, System.Globalization.NumberStyles style, System.IFormatProvider? provider); static abstract TSelf Parse(string s, System.Globalization.NumberStyles style, System.IFormatProvider? provider); - // Workaround devdiv/#1851707: C++/CLI fails to compile when encountering a Default Interface Method implemented in a derived interface - // bool System.IUtf8SpanFormattable.TryFormat(System.Span utf8Destination, out int bytesWritten, System.ReadOnlySpan format, System.IFormatProvider? provider) { throw null; } - bool TryFormat(System.Span utf8Destination, out int bytesWritten, System.ReadOnlySpan format, System.IFormatProvider? provider) { throw null; } + bool System.IUtf8SpanFormattable.TryFormat(System.Span utf8Destination, out int bytesWritten, System.ReadOnlySpan format, System.IFormatProvider? provider) { throw null; } static TSelf System.IUtf8SpanParsable.Parse(System.ReadOnlySpan utf8Text, System.IFormatProvider? provider) { throw null; } static bool System.IUtf8SpanParsable.TryParse(System.ReadOnlySpan utf8Text, System.IFormatProvider? provider, [System.Diagnostics.CodeAnalysis.MaybeNullWhen(false)] out TSelf result) { throw null; } protected static abstract bool TryConvertFromChecked(TOther value, [System.Diagnostics.CodeAnalysis.MaybeNullWhen(false)] out TSelf result) diff --git a/src/libraries/System.Runtime/tests/System/Type/TypePropertyTests.cs b/src/libraries/System.Runtime/tests/System/Type/TypePropertyTests.cs index 04dcbafe1f1664..066711e48297e4 100644 --- a/src/libraries/System.Runtime/tests/System/Type/TypePropertyTests.cs +++ b/src/libraries/System.Runtime/tests/System/Type/TypePropertyTests.cs @@ -406,7 +406,7 @@ public abstract class StructTypeTestBase : TypePropertyTestBase public abstract class InterfaceTypeTestBase : TypePropertyTestBase { - public override TypeAttributes Attributes => TypeAttributes.AutoLayout | TypeAttributes.AnsiClass | TypeAttributes.Class | TypeAttributes.Public | TypeAttributes.ClassSemanticsMask | TypeAttributes.Abstract; + public override TypeAttributes Attributes => TypeAttributes.AutoLayout | TypeAttributes.AnsiClass | TypeAttributes.Class | TypeAttributes.Public | TypeAttributes.ClassSemanticsMask | TypeAttributes.Abstract | TypeAttributes.BeforeFieldInit; public override Type BaseType => null; } diff --git a/src/libraries/System.Security.Claims/src/System/Security/Claims/ClaimsIdentity.cs b/src/libraries/System.Security.Claims/src/System/Security/Claims/ClaimsIdentity.cs index e0d82df3b7863f..4408b11cddf135 100644 --- a/src/libraries/System.Security.Claims/src/System/Security/Claims/ClaimsIdentity.cs +++ b/src/libraries/System.Security.Claims/src/System/Security/Claims/ClaimsIdentity.cs @@ -14,7 +14,6 @@ namespace System.Security.Claims /// An Identity that is represented by a set of claims. /// [DebuggerDisplay("{DebuggerToString(),nq}")] - [DebuggerTypeProxy(typeof(ClaimsIdentityDebugProxy))] public class ClaimsIdentity : IIdentity { private enum SerializationMask @@ -962,26 +961,5 @@ internal string DebuggerToString() return debugText; } - - private sealed class ClaimsIdentityDebugProxy - { - private readonly ClaimsIdentity _identity; - - public ClaimsIdentityDebugProxy(ClaimsIdentity identity) - { - _identity = identity; - } - - public ClaimsIdentity? Actor => _identity.Actor; - public string? AuthenticationType => _identity.AuthenticationType; - public object? BootstrapContext => _identity.BootstrapContext; - // List type has a friendly debugger view - public List Claims => new List(_identity.Claims); - public bool IsAuthenticated => _identity.IsAuthenticated; - public string? Label => _identity.Label; - public string? Name => _identity.Name; - public string NameClaimType => _identity.NameClaimType; - public string RoleClaimType => _identity.RoleClaimType; - } } } diff --git a/src/libraries/System.Security.Claims/src/System/Security/Claims/ClaimsPrincipal.cs b/src/libraries/System.Security.Claims/src/System/Security/Claims/ClaimsPrincipal.cs index 13ee10f7f6f4e6..de8f7d89725c68 100644 --- a/src/libraries/System.Security.Claims/src/System/Security/Claims/ClaimsPrincipal.cs +++ b/src/libraries/System.Security.Claims/src/System/Security/Claims/ClaimsPrincipal.cs @@ -15,7 +15,6 @@ namespace System.Security.Claims /// Concrete IPrincipal supporting multiple claims-based identities /// [DebuggerDisplay("{DebuggerToString(),nq}")] - [DebuggerTypeProxy(typeof(ClaimsPrincipalDebugProxy))] public class ClaimsPrincipal : IPrincipal { private enum SerializationMask @@ -594,20 +593,5 @@ private string DebuggerToString() return $"Identities = {identitiesCount}, Claims = {claimsCount}"; } - - private sealed class ClaimsPrincipalDebugProxy - { - private readonly ClaimsPrincipal _principal; - - public ClaimsPrincipalDebugProxy(ClaimsPrincipal principal) - { - _principal = principal; - } - - // List type has a friendly debugger view - public List Claims => new List(_principal.Claims); - public List Identities => new List(_principal.Identities); - public IIdentity? Identity => _principal.Identity; - } } } diff --git a/src/libraries/System.Security.Cryptography.Pkcs/src/System/Security/Cryptography/Pkcs/CmsSignature.ECDsa.cs b/src/libraries/System.Security.Cryptography.Pkcs/src/System/Security/Cryptography/Pkcs/CmsSignature.ECDsa.cs index f9900fed9fb6a1..dd098e46362ea6 100644 --- a/src/libraries/System.Security.Cryptography.Pkcs/src/System/Security/Cryptography/Pkcs/CmsSignature.ECDsa.cs +++ b/src/libraries/System.Security.Cryptography.Pkcs/src/System/Security/Cryptography/Pkcs/CmsSignature.ECDsa.cs @@ -22,17 +22,17 @@ static partial void PrepareRegistrationECDsa(Dictionary lo lookup.Add(Oids.ECDsaWithSha3_384, new ECDsaCmsSignature(Oids.ECDsaWithSha3_384, HashAlgorithmName.SHA3_384)); lookup.Add(Oids.ECDsaWithSha3_512, new ECDsaCmsSignature(Oids.ECDsaWithSha3_512, HashAlgorithmName.SHA3_512)); #endif - lookup.Add(Oids.EcPublicKey, new ECDsaCmsSignature(null, default)); + lookup.Add(Oids.EcPublicKey, new ECDsaCmsSignature(null, null)); } private sealed partial class ECDsaCmsSignature : CmsSignature { - private readonly HashAlgorithmName _expectedDigest; + private readonly HashAlgorithmName? _expectedDigest; private readonly string? _signatureAlgorithm; internal override RSASignaturePadding? SignaturePadding => null; - internal ECDsaCmsSignature(string? signatureAlgorithm, HashAlgorithmName expectedDigest) + internal ECDsaCmsSignature(string? signatureAlgorithm, HashAlgorithmName? expectedDigest) { _signatureAlgorithm = signatureAlgorithm; _expectedDigest = expectedDigest; @@ -56,7 +56,7 @@ internal override bool VerifySignature( ReadOnlyMemory? signatureParameters, X509Certificate2 certificate) { - if (_expectedDigest != digestAlgorithmName) + if (_expectedDigest != null && _expectedDigest != digestAlgorithmName) { throw new CryptographicException( SR.Format( diff --git a/src/libraries/System.Security.Cryptography.Pkcs/tests/Oids.cs b/src/libraries/System.Security.Cryptography.Pkcs/tests/Oids.cs index bcaa73082ebfc7..a9e73c30df756e 100644 --- a/src/libraries/System.Security.Cryptography.Pkcs/tests/Oids.cs +++ b/src/libraries/System.Security.Cryptography.Pkcs/tests/Oids.cs @@ -25,6 +25,7 @@ internal static class Oids public const string RsaPss = "1.2.840.113549.1.1.10"; public const string Esdh = "1.2.840.113549.1.9.16.3.5"; public const string Dh = "1.2.840.10046.2.1"; + public const string EcPublicKey = "1.2.840.10045.2.1"; public const string EcdsaSha256 = "1.2.840.10045.4.3.2"; // Cryptographic Attribute Types diff --git a/src/libraries/System.Security.Cryptography.Pkcs/tests/SignedCms/SignedCmsTests.netcoreapp.cs b/src/libraries/System.Security.Cryptography.Pkcs/tests/SignedCms/SignedCmsTests.netcoreapp.cs index 355d9e5763557d..e5ef61d996a88b 100644 --- a/src/libraries/System.Security.Cryptography.Pkcs/tests/SignedCms/SignedCmsTests.netcoreapp.cs +++ b/src/libraries/System.Security.Cryptography.Pkcs/tests/SignedCms/SignedCmsTests.netcoreapp.cs @@ -823,6 +823,20 @@ public static void ExistingDocument_Rsa_Sha3_512() } } + [Fact] + public static void ExistingDocument_Ecdsa_Sha256_FromNetFX() + { + SignedCms cms = new SignedCms(); + cms.Decode(SignedDocuments.Ecdsa_Sha256_FromNetFX_SignedDocument); + + cms.CheckSignature(true); // Assert.NoThrow + Assert.Single(cms.SignerInfos); + + SignerInfo signerInfo = cms.SignerInfos[0]; + Assert.Equal(Oids.Sha256, signerInfo.DigestAlgorithm.Value); + Assert.Equal(Oids.EcPublicKey, signerInfo.SignatureAlgorithm.Value); + } + private static void VerifyWithExplicitPrivateKey(X509Certificate2 cert, AsymmetricAlgorithm key) { using (var pubCert = new X509Certificate2(cert.RawData)) diff --git a/src/libraries/System.Security.Cryptography.Pkcs/tests/SignedCms/SignedDocuments.cs b/src/libraries/System.Security.Cryptography.Pkcs/tests/SignedCms/SignedDocuments.cs index 7001b325cd4029..06c739ef3ba6c0 100644 --- a/src/libraries/System.Security.Cryptography.Pkcs/tests/SignedCms/SignedDocuments.cs +++ b/src/libraries/System.Security.Cryptography.Pkcs/tests/SignedCms/SignedDocuments.cs @@ -1789,5 +1789,38 @@ internal static class SignedDocuments "F11C632B4F605A41821A3F15B4F537FD5F0EE3426A7A03732AC946C3B435" + "776A873A3DAE93FB8312C681144CF51F05CE37A0DB4C1544E178F88E421C" + "0B5456D18C13B335DA808CE60C4E35F507").HexToByteArray(); + + // produced with the below PowerShell code using the pfx Certificates.ECDsaP256Win + // $cert = [System.Security.Cryptography.X509Certificates.X509Certificate2]::new( + // [System.Convert]::FromBase64String($Certificates_ECDsaP256Win), + // 'Test', + // 'EphemeralKeySet') + // $signer = [System.Security.Cryptography.Pkcs.CmsSigner]::new('IssuerAndSerialNumber', $cert) + // $signer.IncludeOption = 'ExcludeRoot' + // $signer.DigestAlgorithm = '2.16.840.1.101.3.4.2.1' + // $contentInfo = [System.Security.Cryptography.Pkcs.ContentInfo]::new([byte[]]@(0)) + // $signedCms = [System.Security.Cryptography.Pkcs.SignedCms]::new($contentInfo, $false) + // $signedCms.ComputeSignature($signer, $true) + // $signedCms.Encode() + internal static readonly byte[] Ecdsa_Sha256_FromNetFX_SignedDocument = ( + "3082023106092A864886F70D010702A08202223082021E020101310F300D" + + "06096086480165030402010500301006092A864886F70D010701A0030401" + + "00A082015C308201583081FFA003020102021035428F3B3C5107AD49E776" + + "D6E74C4DC8300A06082A8648CE3D04030230153113301106035504030C0A" + + "45434453412054657374301E170D3135303530313030333730335A170D31" + + "36303530313030353730335A30153113301106035504030C0A4543445341" + + "20546573743059301306072A8648CE3D020106082A8648CE3D0301070342" + + "00047590F69CA114E92927E034C997B7C882A8C992AC00CEFB4EB8319015" + + "36F291E1B515263BCD20E1EA32496FDAC84E2D8D1B703266A9088F6EAF65" + + "2549D9BB63D5A331302F300E0603551D0F0101FF040403020388301D0603" + + "551D0E0416041411218A92C5EB12273B3C5CCFB8220CCCFDF387DB300A06" + + "082A8648CE3D040302034800304502201AFE595E19F1AE4B6A4B231E8851" + + "926438C55B5DDE632E6ADF13C1023A65898E022100CBDF434FDD197D8B59" + + "4E8026E44263BADE773C2BEBD060CC4109484A498E7C7E31819530819202" + + "0101302930153113301106035504030C0A45434453412054657374021035" + + "428F3B3C5107AD49E776D6E74C4DC8300D06096086480165030402010500" + + "300B06072A8648CE3D020105000446304402203557687B26E650E4F86F4B" + + "77A5BF5851350C96F01142696CC1391632CB95C3370220017FD4D9329F00" + + "1EC74210CD34CAEE3878B2302602DB7930347E104679734291").HexToByteArray(); } } diff --git a/src/libraries/System.Security.Cryptography.ProtectedData/src/PACKAGE.md b/src/libraries/System.Security.Cryptography.ProtectedData/src/PACKAGE.md new file mode 100644 index 00000000000000..7a0f751326f550 --- /dev/null +++ b/src/libraries/System.Security.Cryptography.ProtectedData/src/PACKAGE.md @@ -0,0 +1,72 @@ +## About + + + +System.Security.Cryptography.ProtectedData offers a simplified interface for utilizing Microsoft Windows DPAPI's [CryptProtectData](https://learn.microsoft.com/windows/win32/api/dpapi/nf-dpapi-cryptprotectdata) and [CryptUnprotectData](https://learn.microsoft.com/windows/win32/api/dpapi/nf-dpapi-cryptunprotectdata) functions. + +**Note**: Since it relies on Windows DPAPI, this package is only supported on Windows platforms. +For more complex cryptographic operations or cross-platform support, consider the [System.Security.Cryptography](https://learn.microsoft.com/dotnet/api/system.security.cryptography) namespace. + +## Key Features + + + +* Built upon the robust and secure Windows Data Protection API (DPAPI). +* Data can be protected either for current process or for any process on the machine. +* Scope of protection can be defined either to the current user or the local machine. + +## How to Use + + + +Utilizing this package is quite simple, and it mainly revolves around two methods: `Protect` and `Unprotect`. + +Here, `originalData` is the data you want to protect, `optionalEntropy` is an additional byte array used to increase encryption complexity, and `DataProtectionScope` specifies whether the data protection should apply to the current user or the machine. + +```csharp +using System.Security.Cryptography; +using System.Text; + +byte[] originalData = Encoding.UTF8.GetBytes("This is a secret"); +byte[] optionalEntropy = new byte[64]; +Random.Shared.NextBytes(optionalEntropy); + +// To protect: +byte[] encryptedData = ProtectedData.Protect( + originalData, + optionalEntropy, + DataProtectionScope.CurrentUser); + +// To unprotect: +byte[] decryptedData = ProtectedData.Unprotect( + encryptedData, + optionalEntropy, + DataProtectionScope.CurrentUser); +``` + +## Main Types + + + +The main type provided by this library is: + +* `System.Security.Cryptography.ProtectedData` + +## Additional Documentation + + + +* [Conceptual documentation](https://learn.microsoft.com/dotnet/standard/security/how-to-use-data-protection) +* [API documentation](https://learn.microsoft.com/dotnet/api/system.security.cryptography.protecteddata) + +## Related Packages + + + +* PKCS and CMS algorithms: [System.Security.Cryptography.Pkcs](https://www.nuget.org/packages/System.Security.Cryptography.Pkcs/) + +## Feedback & Contributing + + + +System.Security.Cryptography.ProtectedData is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). diff --git a/src/libraries/System.Security.Cryptography/src/System/Security/Cryptography/X509Certificates/ChainPal.Apple.cs b/src/libraries/System.Security.Cryptography/src/System/Security/Cryptography/X509Certificates/ChainPal.Apple.cs index 7c3ab590319808..6df9aacc727856 100644 --- a/src/libraries/System.Security.Cryptography/src/System/Security/Cryptography/X509Certificates/ChainPal.Apple.cs +++ b/src/libraries/System.Security.Cryptography/src/System/Security/Cryptography/X509Certificates/ChainPal.Apple.cs @@ -348,16 +348,14 @@ private void BuildAndSetProperties((X509Certificate2, int)[] elementTuples) for (int i = 0; i < elementTuples.Length; i++) { - (X509Certificate2, int) tuple = elementTuples[i]; + (X509Certificate2 cert, int chainStatus) = elementTuples[i]; - elements[i] = BuildElement(tuple.Item1, tuple.Item2); - allStatus |= tuple.Item2; + elements[i] = new X509ChainElement(cert, BuildChainElementStatuses(cert, chainStatus), ""); + allStatus |= chainStatus; } ChainElements = elements; - - X509ChainElement rollupElement = BuildElement(null!, allStatus); - ChainStatus = rollupElement.ChainElementStatus; + ChainStatus = BuildChainElementStatuses(null, allStatus); } private static void FixupRevocationStatus( @@ -457,11 +455,11 @@ private static X509ChainStatusFlags FindUntrustedRootReason(X509Certificate2 cer return X509ChainStatusFlags.UntrustedRoot; } - private X509ChainElement BuildElement(X509Certificate2 cert, int dwStatus) + private X509ChainStatus[] BuildChainElementStatuses(X509Certificate2? cert, int dwStatus) { if (dwStatus == 0) { - return new X509ChainElement(cert, Array.Empty(), ""); + return Array.Empty(); } List statuses = new List(); @@ -499,7 +497,7 @@ private X509ChainElement BuildElement(X509Certificate2 cert, int dwStatus) } } - return new X509ChainElement(cert, statuses.ToArray(), ""); + return statuses.ToArray(); } private readonly struct X509ChainErrorMapping diff --git a/src/libraries/System.Security.Cryptography/src/System/Security/Cryptography/X509Certificates/OpenSslCachedSystemStoreProvider.cs b/src/libraries/System.Security.Cryptography/src/System/Security/Cryptography/X509Certificates/OpenSslCachedSystemStoreProvider.cs index 4c9643c01e2fcb..e66b3d1ad11022 100644 --- a/src/libraries/System.Security.Cryptography/src/System/Security/Cryptography/X509Certificates/OpenSslCachedSystemStoreProvider.cs +++ b/src/libraries/System.Security.Cryptography/src/System/Security/Cryptography/X509Certificates/OpenSslCachedSystemStoreProvider.cs @@ -21,14 +21,14 @@ internal sealed class OpenSslCachedSystemStoreProvider : IStorePal private static readonly TimeSpan s_lastWriteRecheckInterval = TimeSpan.FromSeconds(5); private static readonly TimeSpan s_assumeInvalidInterval = TimeSpan.FromMinutes(5); private static readonly Stopwatch s_recheckStopwatch = new Stopwatch(); - private static DirectoryInfo? s_rootStoreDirectoryInfo = SafeOpenRootDirectoryInfo(); + private static string[]? s_rootStoreDirectories; private static bool s_defaultRootDir; - private static readonly FileInfo? s_rootStoreFileInfo = SafeOpenRootFileInfo(); + private static string? s_rootStoreFile; + private static DateTime[]? s_directoryLastWrite; + private static DateTime s_fileLastWrite; // Use non-Value-Tuple so that it's an atomic update. private static Tuple? s_nativeCollections; - private static DateTime s_directoryCertsLastWrite; - private static DateTime s_fileCertsLastWrite; private readonly bool _isRoot; @@ -93,18 +93,11 @@ private static Tuple GetCollections() { lock (s_recheckStopwatch) { - FileInfo? fileInfo = s_rootStoreFileInfo; - DirectoryInfo? dirInfo = s_rootStoreDirectoryInfo; - - fileInfo?.Refresh(); - dirInfo?.Refresh(); - if (ret == null || elapsed > s_assumeInvalidInterval || - (fileInfo != null && fileInfo.Exists && ContentWriteTime(fileInfo) != s_fileCertsLastWrite) || - (dirInfo != null && dirInfo.Exists && ContentWriteTime(dirInfo) != s_directoryCertsLastWrite)) + LastWriteTimesHaveChanged()) { - ret = LoadMachineStores(dirInfo, fileInfo); + ret = LoadMachineStores(); } } } @@ -113,9 +106,37 @@ private static Tuple GetCollections() return ret; } - private static Tuple LoadMachineStores( - DirectoryInfo? rootStorePath, - FileInfo? rootStoreFile) + private static bool LastWriteTimesHaveChanged() + { + Debug.Assert( + Monitor.IsEntered(s_recheckStopwatch), + "LastWriteTimesHaveChanged assumes a lock(s_recheckStopwatch)"); + + if (s_rootStoreFile != null) + { + _ = TryStatFile(s_rootStoreFile, out DateTime lastModified); + if (lastModified != s_fileLastWrite) + { + return true; + } + } + + if (s_rootStoreDirectories != null && s_directoryLastWrite != null) + { + for (int i = 0; i < s_rootStoreDirectories.Length; i++) + { + _ = TryStatDirectory(s_rootStoreDirectories[i], out DateTime lastModified); + if (lastModified != s_directoryLastWrite[i]) + { + return true; + } + } + } + + return false; + } + + private static Tuple LoadMachineStores() { Debug.Assert( Monitor.IsEntered(s_recheckStopwatch), @@ -126,61 +147,76 @@ private static Tuple LoadMachineStores SafeX509StackHandle intermedStore = Interop.Crypto.NewX509Stack(); Interop.Crypto.CheckValidOpenSslHandle(intermedStore); - DateTime newFileTime = default; - DateTime newDirTime = default; - var uniqueRootCerts = new HashSet(); var uniqueIntermediateCerts = new HashSet(); bool firstLoad = (s_nativeCollections == null); - if (rootStoreFile != null && rootStoreFile.Exists) + if (firstLoad) { - newFileTime = ContentWriteTime(rootStoreFile); - ProcessFile(rootStoreFile); + s_rootStoreDirectories = GetRootStoreDirectories(out s_defaultRootDir); + s_directoryLastWrite = new DateTime[s_rootStoreDirectories.Length]; + s_rootStoreFile = GetRootStoreFile(); + } + else + { + Debug.Assert(s_rootStoreDirectories is not null); + Debug.Assert(s_directoryLastWrite is not null); + } + + if (s_rootStoreFile != null) + { + ProcessFile(s_rootStoreFile, out s_fileLastWrite); } bool hasStoreData = false; - if (rootStorePath != null && rootStorePath.Exists) + for (int i = 0; i < s_rootStoreDirectories.Length; i++) { - newDirTime = ContentWriteTime(rootStorePath); - hasStoreData = ProcessDir(rootStorePath); + hasStoreData = ProcessDir(s_rootStoreDirectories[i], out s_directoryLastWrite[i]); } if (firstLoad && !hasStoreData && s_defaultRootDir) { - DirectoryInfo etcSslCerts = new DirectoryInfo("/etc/ssl/certs"); - - if (etcSslCerts.Exists) + const string DefaultCertDir = "/etc/ssl/certs"; + hasStoreData = ProcessDir(DefaultCertDir, out DateTime lastModified); + if (hasStoreData) { - DateTime tmpTime = ContentWriteTime(etcSslCerts); - hasStoreData = ProcessDir(etcSslCerts); - - if (hasStoreData) - { - newDirTime = tmpTime; - s_rootStoreDirectoryInfo = etcSslCerts; - } + s_rootStoreDirectories = new[] { DefaultCertDir }; + s_directoryLastWrite = new[] { lastModified }; } } - bool ProcessDir(DirectoryInfo dir) + bool ProcessDir(string dir, out DateTime lastModified) { + if (!TryStatDirectory(dir, out lastModified)) + { + return false; + } + bool hasStoreData = false; - foreach (FileInfo file in dir.EnumerateFiles()) + foreach (string file in Directory.EnumerateFiles(dir)) { - hasStoreData |= ProcessFile(file); + hasStoreData |= ProcessFile(file, out _, skipStat: true); } return hasStoreData; } - bool ProcessFile(FileInfo file) + bool ProcessFile(string file, out DateTime lastModified, bool skipStat = false) { bool readData = false; - using (SafeBioHandle fileBio = Interop.Crypto.BioNewFile(file.FullName, "rb")) + if (skipStat) + { + lastModified = default; + } + else if (!TryStatFile(file, out lastModified)) + { + return false; + } + + using (SafeBioHandle fileBio = Interop.Crypto.BioNewFile(file, "rb")) { // The handle may be invalid, for example when we don't have read permission for the file. if (fileBio.IsInvalid) @@ -274,114 +310,78 @@ bool ProcessFile(FileInfo file) // on every call. Volatile.Write(ref s_nativeCollections, newCollections); - s_directoryCertsLastWrite = newDirTime; - s_fileCertsLastWrite = newFileTime; s_recheckStopwatch.Restart(); return newCollections; } - private static FileInfo? SafeOpenRootFileInfo() + private static string? GetRootStoreFile() { string? rootFile = Interop.Crypto.GetX509RootStoreFile(); if (!string.IsNullOrEmpty(rootFile)) { - try - { - return new FileInfo(rootFile); - } - catch (ArgumentException) - { - // If SSL_CERT_FILE is set to the empty string, or anything else which gives - // "The path is not of a legal form", then the GetX509RootStoreFile value is ignored. - } + return Path.GetFullPath(rootFile); } return null; } - private static DirectoryInfo? SafeOpenRootDirectoryInfo() + private static string[] GetRootStoreDirectories(out bool isDefault) { - string? rootDirectory = Interop.Crypto.GetX509RootStorePath(out s_defaultRootDir); + string rootDirectory = Interop.Crypto.GetX509RootStorePath(out isDefault) ?? ""; - if (!string.IsNullOrEmpty(rootDirectory)) - { - try - { - return new DirectoryInfo(rootDirectory); - } - catch (ArgumentException) - { - // If SSL_CERT_DIR is set to the empty string, or anything else which gives - // "The path is not of a legal form", then the GetX509RootStoreFile value is ignored. - } - } - - return null; - } - - private static DateTime ContentWriteTime(FileInfo info) - { - string path = info.FullName; - string? target = Interop.Sys.ReadLink(path); - - if (string.IsNullOrEmpty(target)) - { - return info.LastWriteTimeUtc; - } + string[] directories = rootDirectory.Split(Path.PathSeparator, StringSplitOptions.RemoveEmptyEntries); - if (target[0] != '/') + for (int i = 0; i < directories.Length; i++) { - target = Path.Join(info.Directory?.FullName, target); + directories[i] = Path.GetFullPath(directories[i]); } - try + // Remove duplicates. + if (directories.Length > 1) { - var targetInfo = new FileInfo(target); - - if (targetInfo.Exists) + var set = new HashSet(directories, StringComparer.Ordinal); + if (set.Count != directories.Length) { - return targetInfo.LastWriteTimeUtc; + // Preserve the original order. + string[] directoriesTrimmed = new string[set.Count]; + int j = 0; + for (int i = 0; i < directories.Length; i++) + { + string directory = directories[i]; + if (set.Remove(directory)) + { + directoriesTrimmed[j++] = directory; + } + } + Debug.Assert(set.Count == 0); + directories = directoriesTrimmed; } } - catch (ArgumentException) - { - // If we can't load information about the link path, just treat it as not a link. - } - return info.LastWriteTimeUtc; + return directories; } - private static DateTime ContentWriteTime(DirectoryInfo info) - { - string path = info.FullName; - string? target = Interop.Sys.ReadLink(path); - - if (string.IsNullOrEmpty(target)) - { - return info.LastWriteTimeUtc; - } + private static bool TryStatFile(string path, out DateTime lastModified) + => TryStat(path, Interop.Sys.FileTypes.S_IFREG, out lastModified); - if (target[0] != '/') - { - target = Path.Join(info.Parent?.FullName, target); - } + private static bool TryStatDirectory(string path, out DateTime lastModified) + => TryStat(path, Interop.Sys.FileTypes.S_IFDIR, out lastModified); - try - { - var targetInfo = new DirectoryInfo(target); + private static bool TryStat(string path, int fileType, out DateTime lastModified) + { + lastModified = default; - if (targetInfo.Exists) - { - return targetInfo.LastWriteTimeUtc; - } - } - catch (ArgumentException) + Interop.Sys.FileStatus status; + // Use Stat to follow links. + if (Interop.Sys.Stat(path, out status) < 0 || + (status.Mode & Interop.Sys.FileTypes.S_IFMT) != fileType) { - // If we can't load information about the link path, just treat it as not a link. + return false; } - return info.LastWriteTimeUtc; + lastModified = DateTime.UnixEpoch + TimeSpan.FromTicks(status.MTime * TimeSpan.TicksPerSecond + status.MTimeNsec / TimeSpan.NanosecondsPerTick); + return true; } } } diff --git a/src/libraries/System.Security.Cryptography/tests/DSATests.cs b/src/libraries/System.Security.Cryptography/tests/DSATests.cs index b995a5e0920893..8eca860fe4fb9b 100644 --- a/src/libraries/System.Security.Cryptography/tests/DSATests.cs +++ b/src/libraries/System.Security.Cryptography/tests/DSATests.cs @@ -171,7 +171,7 @@ protected override void Dispose(bool disposing) public override void ImportParameters(DSAParameters parameters) => _dsa.ImportParameters(parameters); public override bool VerifySignature(byte[] rgbHash, byte[] rgbSignature) => _dsa.VerifySignature(rgbHash, rgbSignature); protected override byte[] HashData(Stream data, HashAlgorithmName hashAlgorithm) => - (byte[])_dsa.GetType().GetMethod( + (byte[])typeof(DSA).GetMethod( nameof(HashData), BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance, null, @@ -179,7 +179,7 @@ protected override byte[] HashData(Stream data, HashAlgorithmName hashAlgorithm) null) .Invoke(_dsa, new object[] { data, hashAlgorithm }); protected override byte[] HashData(byte[] data, int offset, int count, HashAlgorithmName hashAlgorithm) => - (byte[])_dsa.GetType().GetMethod( + (byte[])typeof(DSA).GetMethod( nameof(HashData), BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance, null, diff --git a/src/libraries/System.Security.Cryptography/tests/ECDsaTests.cs b/src/libraries/System.Security.Cryptography/tests/ECDsaTests.cs index 5a871f35c2ef08..c858fd0866213d 100644 --- a/src/libraries/System.Security.Cryptography/tests/ECDsaTests.cs +++ b/src/libraries/System.Security.Cryptography/tests/ECDsaTests.cs @@ -169,7 +169,7 @@ public byte[] BaseHashData(byte[] data, int offset, int count, HashAlgorithmName base.HashData(data, offset, count, hashAlgorithm); protected override byte[] HashData(Stream data, HashAlgorithmName hashAlgorithm) => - (byte[])_ecdsa.GetType().GetMethod( + (byte[])typeof(ECDsa).GetMethod( nameof(HashData), BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance, null, @@ -178,7 +178,7 @@ protected override byte[] HashData(Stream data, HashAlgorithmName hashAlgorithm) .Invoke(_ecdsa, new object[] { data, hashAlgorithm }); protected override byte[] HashData(byte[] data, int offset, int count, HashAlgorithmName hashAlgorithm) => - (byte[])_ecdsa.GetType().GetMethod( + (byte[])typeof(ECDsa).GetMethod( nameof(HashData), BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance, null, diff --git a/src/libraries/System.Security.Cryptography/tests/X509Certificates/RevocationTests/AiaTests.cs b/src/libraries/System.Security.Cryptography/tests/X509Certificates/RevocationTests/AiaTests.cs index a1ce84ce5be8e8..5cceafd0e81108 100644 --- a/src/libraries/System.Security.Cryptography/tests/X509Certificates/RevocationTests/AiaTests.cs +++ b/src/libraries/System.Security.Cryptography/tests/X509Certificates/RevocationTests/AiaTests.cs @@ -3,6 +3,7 @@ using System.Linq; using System.Security.Cryptography.X509Certificates.Tests.Common; +using Microsoft.DotNet.RemoteExecutor; using Test.Cryptography; using Xunit; @@ -178,5 +179,44 @@ public static void DisableAiaOptionWorks() }); } } + + [ActiveIssue("https://github.com/dotnet/runtime/issues/57506", typeof(PlatformDetection), nameof(PlatformDetection.IsMonoRuntime), nameof(PlatformDetection.IsMariner))] + [PlatformSpecific(TestPlatforms.Linux)] + [ConditionalFact(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))] + public static void AiaIgnoresCertOverLimit() + { + RemoteExecutor.Invoke(() => + { + AppContext.SetData("System.Security.Cryptography.AiaDownloadLimit", 100); + CertificateAuthority.BuildPrivatePki( + PkiOptions.AllRevocation, + out RevocationResponder responder, + out CertificateAuthority root, + out CertificateAuthority intermediate, + out X509Certificate2 endEntity, + pkiOptionsInSubject: false, + testName: Guid.NewGuid().ToString()); + + using (responder) + using (root) + using (intermediate) + using (endEntity) + using (X509Certificate2 rootCert = root.CloneIssuerCert()) + { + responder.AiaResponseKind = AiaResponseKind.Cert; + + using (ChainHolder holder = new ChainHolder()) + { + X509Chain chain = holder.Chain; + chain.ChainPolicy.CustomTrustStore.Add(rootCert); + chain.ChainPolicy.TrustMode = X509ChainTrustMode.CustomRootTrust; + chain.ChainPolicy.VerificationTime = endEntity.NotBefore.AddMinutes(1); + chain.ChainPolicy.UrlRetrievalTimeout = DynamicRevocationTests.s_urlRetrievalLimit; + + Assert.False(chain.Build(endEntity)); + } + } + }).Dispose(); + } } } diff --git a/src/libraries/System.Security.Cryptography/tests/X509Certificates/TestData.cs b/src/libraries/System.Security.Cryptography/tests/X509Certificates/TestData.cs index 7110c78329778d..e3deb00b354d06 100644 --- a/src/libraries/System.Security.Cryptography/tests/X509Certificates/TestData.cs +++ b/src/libraries/System.Security.Cryptography/tests/X509Certificates/TestData.cs @@ -51,7 +51,8 @@ internal static class TestData // This pfx was generated by new X509Certificate(MsCertificate).Export(X509ContentType.Pfx) // and was choosen when the padding was 01 and caused a false-positive on decryption. - public static byte[] MsCertificateExportedToPfx_NullPassword = Convert.FromBase64String(@" + public static byte[] MsCertificateExportedToPfx_NullPassword = Convert.FromBase64String( + /* [SuppressMessage("Microsoft.Security", "CS002:SecretInNextLine", Justification="This PKCS#12 blob only contains public info.")] */ @" MIIFxAIBAzCCBYoGCSqGSIb3DQEHAaCCBXsEggV3MIIFczCCBW8GCSqGSIb3DQEH BqCCBWAwggVcAgEAMIIFVQYJKoZIhvcNAQcBMBwGCiqGSIb3DQEMAQYwDgQIKpCU u5nlxAACAggAgIIFKG/SLlS1TJmxGUiXBPJ1r4yV+JMehwo6RYPMkCSnpKGaiLyA @@ -3057,7 +3058,8 @@ internal static DSAParameters GetDSA1024Params() "4D7314FCB4041469835268466D1390373566F7034C4736346CD17D020207D0").HexToByteArray(); internal static readonly byte[] Pkcs12NoPasswordRandomCounts = - Convert.FromBase64String(@" + Convert.FromBase64String( + /* [SuppressMessage("Microsoft.Security", "CS002:SecretInNextLine", Justification="Self-signed cert created specifically for inclusion in public-facing unit tests.")] */ @" MIIvdAIBAzCCLtQGCSqGSIb3DQEHAaCCLsUEgi7BMIIuvTCCLrkGCSqGSIb3DQEHAaCCLqoEgi6m MIIuojCCAvkGCyqGSIb3DQEMCgECoIIC6DCCAuQwXgYJKoZIhvcNAQUNMFEwMAYJKoZIhvcNAQUM MCMEELD+7LV5Y9tyUiJnNeZVLwQCASowDAYIKoZIhvcNAgsFADAdBglghkgBZQMEASoEEBzHfelA diff --git a/src/libraries/System.Security.Cryptography/tests/X509Certificates/X509StoreTests.Unix.cs b/src/libraries/System.Security.Cryptography/tests/X509Certificates/X509StoreTests.Unix.cs index 0efb6c12028fb9..f460d6b9bd6c69 100644 --- a/src/libraries/System.Security.Cryptography/tests/X509Certificates/X509StoreTests.Unix.cs +++ b/src/libraries/System.Security.Cryptography/tests/X509Certificates/X509StoreTests.Unix.cs @@ -10,7 +10,6 @@ namespace System.Security.Cryptography.X509Certificates.Tests { public partial class X509StoreTests { - [ConditionalFact(nameof(NotRunningAsRootAndRemoteExecutorSupported))] // root can read '2.pem' [PlatformSpecific(TestPlatforms.Linux)] // Windows/OSX doesn't use SSL_CERT_{DIR,FILE}. private void X509Store_MachineStoreLoadSkipsInvalidFiles() @@ -50,6 +49,47 @@ private void X509Store_MachineStoreLoadSkipsInvalidFiles() }, new RemoteInvokeOptions { StartInfo = psi }).Dispose(); } + [ConditionalFact(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))] + [PlatformSpecific(TestPlatforms.Linux)] // Windows/OSX doesn't use SSL_CERT_{DIR,FILE}. + private void X509Store_MachineStoreLoadsMutipleSslCertDirectories() + { + // Create 3 certificates and place them in two directories that will be passed + // using SSL_CERT_DIR. + string sslCertDir1 = GetTestFilePath(); + Directory.CreateDirectory(sslCertDir1); + File.WriteAllBytes(Path.Combine(sslCertDir1, "1.pem"), TestData.SelfSigned1PemBytes); + File.WriteAllBytes(Path.Combine(sslCertDir1, "2.pem"), TestData.SelfSigned2PemBytes); + string sslCertDir2 = GetTestFilePath(); + Directory.CreateDirectory(sslCertDir2); + File.WriteAllBytes(Path.Combine(sslCertDir2, "3.pem"), TestData.SelfSigned3PemBytes); + + // Add a non-existing directory after each valid directory to verify they are ignored. + string sslCertDir = string.Join(Path.PathSeparator, + new[] { + sslCertDir1, + sslCertDir2, + "", // empty string + sslCertDir2, // duplicate directory + "/invalid2", // path that does not exist + }); + + var psi = new ProcessStartInfo(); + psi.Environment.Add("SSL_CERT_DIR", sslCertDir); + // Set SSL_CERT_FILE to avoid loading the default bundle file. + psi.Environment.Add("SSL_CERT_FILE", "/nonexisting"); + RemoteExecutor.Invoke(() => + { + Assert.NotNull(Environment.GetEnvironmentVariable("SSL_CERT_DIR")); + using (var store = new X509Store(StoreName.Root, StoreLocation.LocalMachine)) + { + store.Open(OpenFlags.OpenExistingOnly); + + // Check nr of certificates in store. + Assert.Equal(3, store.Certificates.Count); + } + }, new RemoteInvokeOptions { StartInfo = psi }).Dispose(); + } + public static bool NotRunningAsRootAndRemoteExecutorSupported => !Environment.IsPrivilegedProcess && RemoteExecutor.IsSupported; } } diff --git a/src/libraries/System.Security.Permissions/src/CompatibilitySuppressions.xml b/src/libraries/System.Security.Permissions/src/CompatibilitySuppressions.xml index 10869bf91f9227..5ff9a82feffe7d 100644 --- a/src/libraries/System.Security.Permissions/src/CompatibilitySuppressions.xml +++ b/src/libraries/System.Security.Permissions/src/CompatibilitySuppressions.xml @@ -218,6 +218,27 @@ lib/netstandard2.0/System.Security.Permissions.dll true + + CP0014 + P:System.Security.Permissions.FileIOPermissionAttribute.All:[T:System.ObsoleteAttribute] + lib/netstandard2.0/System.Security.Permissions.dll + lib/netstandard2.0/System.Security.Permissions.dll + true + + + CP0014 + P:System.Security.Permissions.ReflectionPermissionAttribute.ReflectionEmit:[T:System.ObsoleteAttribute] + lib/netstandard2.0/System.Security.Permissions.dll + lib/netstandard2.0/System.Security.Permissions.dll + true + + + CP0014 + P:System.Security.Permissions.ReflectionPermissionAttribute.TypeInformation:[T:System.ObsoleteAttribute] + lib/netstandard2.0/System.Security.Permissions.dll + lib/netstandard2.0/System.Security.Permissions.dll + true + CP0014 P:System.Security.Permissions.RegistryPermissionAttribute.All:[T:System.ObsoleteAttribute] diff --git a/src/libraries/System.ServiceProcess.ServiceController/src/PACKAGE.md b/src/libraries/System.ServiceProcess.ServiceController/src/PACKAGE.md new file mode 100644 index 00000000000000..902a69d0ae2858 --- /dev/null +++ b/src/libraries/System.ServiceProcess.ServiceController/src/PACKAGE.md @@ -0,0 +1,69 @@ +## About + + +Provides the System.ServiceProcess.ServiceController API, which allows to connect to a Windows service, manipulate it, or get information about it. Not supported on other platforms. + +## Key Features + + + +* Retrieve information from Windows services +* Connect to and manipulate Windows services (start, pause, stop or other operations) + +## How to Use + + + +### Retrieve Windows service information +```C# +using System.ServiceProcess; + +// Loop through all installed Windows services and print the name, status and display name. +foreach (ServiceController serviceController in ServiceController.GetServices()) +{ + Console.WriteLine("Name: " + serviceController.ServiceName); + Console.WriteLine("Status: " + serviceController.Status.ToString()); + Console.WriteLine("Display name: " + serviceController.DisplayName); +} + +// Loop through all installed device driver services +foreach (ServiceController serviceController in ServiceController.GetDevices()) +{ + Console.WriteLine("Name: " + serviceController.ServiceName); + Console.WriteLine("Status: " + serviceController.Status.ToString()); + Console.WriteLine("Display name: " + serviceController.DisplayName); +} +``` + +### Manipulate a Windows service +```C# +using System.ServiceProcess; + +ServiceController service = new("TestServiceName"); +if (service.CanStop && service.Status != ServiceControllerStatus.Stopped && service.Status != ServiceControllerStatus.StopPending) +{ + service.Stop(); +} +``` + +## Main Types + + + +The main types provided by this library are: + +* `System.ServiceProcess.ServiceController` +* `System.ServiceProcess.ServiceControllerStatus` +* `System.ServiceProcess.ServiceType` + +## Additional Documentation + + + +* [System.ServiceController API documentation](https://learn.microsoft.com/dotnet/api/system.serviceprocess.servicecontroller?view=dotnet-plat-ext-7.0) + +## Feedback & Contributing + + + +System.ServiceProcess.ServiceController is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). diff --git a/src/libraries/System.Speech/src/PACKAGE.md b/src/libraries/System.Speech/src/PACKAGE.md new file mode 100644 index 00000000000000..ac57df4ca354d8 --- /dev/null +++ b/src/libraries/System.Speech/src/PACKAGE.md @@ -0,0 +1,102 @@ +## About + + + +Provides APIs for speech recognition and synthesis built on the [Microsoft Speech API](https://learn.microsoft.com/previous-versions/windows/desktop/ms723627(v=vs.85)) in Windows. Not supported on other platforms. + +This package is provided primarily for compatibility with code being ported from .NET Framework and is not accepting new features. + +## Key Features + + + +* Recognize speech as text in a given language and grammar. +* Synthesize text as speech. +* Support for [Speech Recognition Grammar v1.0](https://www.w3.org/TR/speech-grammar/) documents + +## How to Use + + + +### Synthesis example +```C# +using System.Speech.Synthesis; + +// Initialize a new instance of the SpeechSynthesizer. +SpeechSynthesizer synth = new SpeechSynthesizer(); + +// Configure the audio output. +synth.SetOutputToDefaultAudioDevice(); + +// Speak a string, synchronously +synth.Speak("Hello World!"); + +// Speak a string asynchronously +var prompt = synth.SpeakAsync("Goodnight Moon!"); + +while (!prompt.IsCompleted) +{ + Console.WriteLine("speaking..."); + Thread.Sleep(500); +} +``` + +### Recognition example +```C# +// Create a new SpeechRecognitionEngine instance. +using SpeechRecognizer recognizer = new SpeechRecognizer(); +using ManualResetEvent exit = new ManualResetEvent(false); + +// Create a simple grammar that recognizes "red", "green", "blue", or "exit". +Choices choices = new Choices(); +choices.Add(new string[] { "red", "green", "blue", "exit" }); + +// Create a GrammarBuilder object and append the Choices object. +GrammarBuilder gb = new GrammarBuilder(); +gb.Append(choices); + +// Create the Grammar instance and load it into the speech recognition engine. +Grammar g = new Grammar(gb); +recognizer.LoadGrammar(g); + +// Register a handler for the SpeechRecognized event. +recognizer.SpeechRecognized += (s, e) => +{ + Console.WriteLine($"Recognized: {e.Result.Text}, Confidence: {e.Result.Confidence}"); + if (e.Result.Text == "exit") + { + exit.Set(); + } +}; + +// Emulate +Console.WriteLine("Emulating \"red\"."); +recognizer.EmulateRecognize("red"); + +Console.WriteLine("Speak red, green, blue, or exit please..."); + +exit.WaitOne(); +``` + +## Main Types + + + +The main types provided by this library are: + +* `System.Speech.Recognition.SpeechRecognizer` +* `System.Speech.Synthesis.SpeechSynthesizer` + +## Additional Documentation + + + +* [Conceptual documentation](https://learn.microsoft.com/previous-versions/office/developer/speech-technologies/hh361625(v%3doffice.14)) +* [Speech.Recognition API documentation](https://learn.microsoft.com/dotnet/api/system.speech.recognition) +* [Speech.Synthesis API documentation](https://learn.microsoft.com/dotnet/api/system.speech.synthesis) + +## Feedback & Contributing + + + +System.Speech is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). diff --git a/src/libraries/System.Text.Encoding.CodePages/src/PACKAGE.md b/src/libraries/System.Text.Encoding.CodePages/src/PACKAGE.md new file mode 100644 index 00000000000000..7da4fd98cd740b --- /dev/null +++ b/src/libraries/System.Text.Encoding.CodePages/src/PACKAGE.md @@ -0,0 +1,39 @@ +## About + +`System.Text.Encoding.CodePages` enable creating single and double bytes encodings for code pages that otherwise are available only in the desktop .NET Framework. + +## Key Features + +* Support single and double byte encodings for code pages that are not available in .NET Core. + +## How to Use + +```C# +using System.Text; + +// Register the CodePages encoding provider at application startup to enable using single and double byte encodings. +Encoding.RegisterProvider(CodePagesEncodingProvider.Instance); + +// Now can create single and double byte encodings for code pages that are not available in .NET Core. +Encoding windows1252Encoding = Encoding.GetEncoding(1252); // Western European (Windows) +byte[] encodedBytes = windows1252Encoding.GetBytes("String to encode"); + +``` + +## Main Types + +The main types provided by this library are: + +* `CodePagesEncodingProvider` + +## Additional Documentation + +* [API documentation](https://learn.microsoft.com/dotnet/api/system.text.codepagesencodingprovider) + +## Related Packages + +* [System.Text.Encodings.Web](https://www.nuget.org/packages/System.Text.Encodings.Web) + +## Feedback & Contributing + +System.Text.Encoding.CodePages is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). \ No newline at end of file diff --git a/src/libraries/System.Text.Json/gen/Helpers/DiagnosticInfo.cs b/src/libraries/System.Text.Json/gen/Helpers/DiagnosticInfo.cs deleted file mode 100644 index 493f79191d4375..00000000000000 --- a/src/libraries/System.Text.Json/gen/Helpers/DiagnosticInfo.cs +++ /dev/null @@ -1,43 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.Linq; -using System.Numerics.Hashing; -using Microsoft.CodeAnalysis; - -namespace System.Text.Json.SourceGeneration -{ - /// - /// Descriptor for diagnostic instances using structural equality comparison. - /// Provides a work-around for https://github.com/dotnet/roslyn/issues/68291. - /// - public readonly struct DiagnosticInfo : IEquatable - { - public required DiagnosticDescriptor Descriptor { get; init; } - public required object?[] MessageArgs { get; init; } - public required Location? Location { get; init; } - - public Diagnostic CreateDiagnostic() - => Diagnostic.Create(Descriptor, Location, MessageArgs); - - public override readonly bool Equals(object? obj) => obj is DiagnosticInfo info && Equals(info); - public readonly bool Equals(DiagnosticInfo other) - { - return Descriptor.Equals(other.Descriptor) && - MessageArgs.SequenceEqual(other.MessageArgs) && - Location == other.Location; - } - - public override readonly int GetHashCode() - { - int hashCode = Descriptor.GetHashCode(); - foreach (object? messageArg in MessageArgs) - { - hashCode = HashHelpers.Combine(hashCode, messageArg?.GetHashCode() ?? 0); - } - - hashCode = HashHelpers.Combine(hashCode, Location?.GetHashCode() ?? 0); - return hashCode; - } - } -} diff --git a/src/libraries/System.Text.Json/gen/Helpers/RoslynExtensions.cs b/src/libraries/System.Text.Json/gen/Helpers/RoslynExtensions.cs index b0cb90e33d1476..3f3ecb506fd83d 100644 --- a/src/libraries/System.Text.Json/gen/Helpers/RoslynExtensions.cs +++ b/src/libraries/System.Text.Json/gen/Helpers/RoslynExtensions.cs @@ -25,8 +25,6 @@ internal static class RoslynExtensions return compilation.GetBestTypeByMetadataName(type.FullName); } - public static string GetFullyQualifiedName(this ITypeSymbol type) => type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - public static Location? GetLocation(this ISymbol typeSymbol) => typeSymbol.Locations.Length > 0 ? typeSymbol.Locations[0] : null; @@ -36,12 +34,6 @@ internal static class RoslynExtensions return reference?.SyntaxTree.GetLocation(reference.Span); } - /// - /// Creates a copy of the Location instance that does not capture a reference to Compilation. - /// - public static Location GetTrimmedLocation(this Location location) - => Location.Create(location.SourceTree?.FilePath ?? "", location.SourceSpan, location.GetLineSpan().Span); - /// /// Returns true if the specified location is contained in one of the syntax trees in the compilation. /// @@ -209,28 +201,6 @@ public static bool IsNullableValueType(this ITypeSymbol type, [NotNullWhen(true) return false; } - public static ITypeSymbol[] GetAllTypeArgumentsInScope(this INamedTypeSymbol type) - { - if (!type.IsGenericType) - { - return Array.Empty(); - } - - var args = new List(); - TraverseContainingTypes(type); - return args.ToArray(); - - void TraverseContainingTypes(INamedTypeSymbol current) - { - if (current.ContainingType is INamedTypeSymbol parent) - { - TraverseContainingTypes(parent); - } - - args.AddRange(current.TypeArguments); - } - } - public static ITypeSymbol GetMemberType(this ISymbol member) { Debug.Assert(member is IFieldSymbol or IPropertySymbol); diff --git a/src/libraries/System.Text.Json/gen/JsonSourceGenerator.Parser.cs b/src/libraries/System.Text.Json/gen/JsonSourceGenerator.Parser.cs index 3242dcc5faa4c2..594f7ad9770c3c 100644 --- a/src/libraries/System.Text.Json/gen/JsonSourceGenerator.Parser.cs +++ b/src/libraries/System.Text.Json/gen/JsonSourceGenerator.Parser.cs @@ -11,6 +11,7 @@ using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; +using SourceGenerators; namespace System.Text.Json.SourceGeneration { @@ -59,12 +60,7 @@ public void ReportDiagnostic(DiagnosticDescriptor descriptor, Location? location location = _contextClassLocation; } - Diagnostics.Add(new DiagnosticInfo - { - Descriptor = descriptor, - Location = location.GetTrimmedLocation(), - MessageArgs = messageArgs ?? Array.Empty(), - }); + Diagnostics.Add(DiagnosticInfo.Create(descriptor, location, messageArgs)); } public Parser(KnownTypeSymbols knownSymbols) @@ -867,7 +863,7 @@ private List ParsePropertyGenerationSpecs( { Location? typeLocation = typeToGenerate.Location; List properties = new(); - PropertyHierarchyResolutionState state = new(); + PropertyHierarchyResolutionState state = new(options); hasExtensionDataProperty = false; // Walk the type hierarchy starting from the current type up to the base type(s) @@ -974,11 +970,10 @@ bool PropertyIsOverriddenAndIgnored(IPropertySymbol property, Dictionary Properties = new(); - public Dictionary AddedProperties = new(); + public Dictionary AddedProperties = new(options?.PropertyNameCaseInsensitive == true ? StringComparer.OrdinalIgnoreCase : StringComparer.Ordinal); public Dictionary? IgnoredMembers; public bool IsPropertyOrderSpecified; public bool HasInvalidConfigurationForFastPath; @@ -1609,9 +1604,12 @@ private static string GetTypeInfoPropertyName(ITypeSymbol type) sb.Append(name); - foreach (ITypeSymbol genericArg in namedType.GetAllTypeArgumentsInScope()) + if (namedType.GetAllTypeArgumentsInScope() is List typeArgsInScope) { - sb.Append(GetTypeInfoPropertyName(genericArg)); + foreach (ITypeSymbol genericArg in typeArgsInScope) + { + sb.Append(GetTypeInfoPropertyName(genericArg)); + } } return sb.ToString(); diff --git a/src/libraries/System.Text.Json/gen/JsonSourceGenerator.Roslyn3.11.cs b/src/libraries/System.Text.Json/gen/JsonSourceGenerator.Roslyn3.11.cs index 4c58a3d968ac54..7520f9bc75a6f5 100644 --- a/src/libraries/System.Text.Json/gen/JsonSourceGenerator.Roslyn3.11.cs +++ b/src/libraries/System.Text.Json/gen/JsonSourceGenerator.Roslyn3.11.cs @@ -8,6 +8,7 @@ using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Text; +using SourceGenerators; namespace System.Text.Json.SourceGeneration { diff --git a/src/libraries/System.Text.Json/gen/JsonSourceGenerator.Roslyn4.0.cs b/src/libraries/System.Text.Json/gen/JsonSourceGenerator.Roslyn4.0.cs index e3f8b4aacf6c5b..447f54c7f07821 100644 --- a/src/libraries/System.Text.Json/gen/JsonSourceGenerator.Roslyn4.0.cs +++ b/src/libraries/System.Text.Json/gen/JsonSourceGenerator.Roslyn4.0.cs @@ -9,6 +9,7 @@ #if !ROSLYN4_4_OR_GREATER using Microsoft.CodeAnalysis.DotnetRuntime.Extensions; #endif +using SourceGenerators; namespace System.Text.Json.SourceGeneration { diff --git a/src/libraries/System.Text.Json/gen/Model/ContextGenerationSpec.cs b/src/libraries/System.Text.Json/gen/Model/ContextGenerationSpec.cs index 1e2ee2d737e009..00c7192c3ae58c 100644 --- a/src/libraries/System.Text.Json/gen/Model/ContextGenerationSpec.cs +++ b/src/libraries/System.Text.Json/gen/Model/ContextGenerationSpec.cs @@ -2,7 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Diagnostics; -using System.Text.Json.Serialization; +using SourceGenerators; namespace System.Text.Json.SourceGeneration { diff --git a/src/libraries/System.Text.Json/gen/Model/ParameterGenerationSpec.cs b/src/libraries/System.Text.Json/gen/Model/ParameterGenerationSpec.cs index 2945b20b730b15..68e32d01531569 100644 --- a/src/libraries/System.Text.Json/gen/Model/ParameterGenerationSpec.cs +++ b/src/libraries/System.Text.Json/gen/Model/ParameterGenerationSpec.cs @@ -1,6 +1,8 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using SourceGenerators; + namespace System.Text.Json.SourceGeneration { /// diff --git a/src/libraries/System.Text.Json/gen/Model/PropertyGenerationSpec.cs b/src/libraries/System.Text.Json/gen/Model/PropertyGenerationSpec.cs index 56b42970f68893..214c32b4d19e21 100644 --- a/src/libraries/System.Text.Json/gen/Model/PropertyGenerationSpec.cs +++ b/src/libraries/System.Text.Json/gen/Model/PropertyGenerationSpec.cs @@ -3,6 +3,7 @@ using System.Diagnostics; using System.Text.Json.Serialization; +using SourceGenerators; namespace System.Text.Json.SourceGeneration { diff --git a/src/libraries/System.Text.Json/gen/Model/PropertyInitializerGenerationSpec.cs b/src/libraries/System.Text.Json/gen/Model/PropertyInitializerGenerationSpec.cs index 9fc68a11928470..608ce8e887d725 100644 --- a/src/libraries/System.Text.Json/gen/Model/PropertyInitializerGenerationSpec.cs +++ b/src/libraries/System.Text.Json/gen/Model/PropertyInitializerGenerationSpec.cs @@ -1,6 +1,8 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using SourceGenerators; + namespace System.Text.Json.SourceGeneration { /// diff --git a/src/libraries/System.Text.Json/gen/Model/SourceGenerationOptionsSpec.cs b/src/libraries/System.Text.Json/gen/Model/SourceGenerationOptionsSpec.cs index 7e94f824bae8cb..83b587fb962f7e 100644 --- a/src/libraries/System.Text.Json/gen/Model/SourceGenerationOptionsSpec.cs +++ b/src/libraries/System.Text.Json/gen/Model/SourceGenerationOptionsSpec.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Text.Json.Serialization; +using SourceGenerators; namespace System.Text.Json.SourceGeneration { diff --git a/src/libraries/System.Text.Json/gen/Model/TypeGenerationSpec.cs b/src/libraries/System.Text.Json/gen/Model/TypeGenerationSpec.cs index 189295bcb971ca..9b71bf16438b89 100644 --- a/src/libraries/System.Text.Json/gen/Model/TypeGenerationSpec.cs +++ b/src/libraries/System.Text.Json/gen/Model/TypeGenerationSpec.cs @@ -4,6 +4,7 @@ using System.Diagnostics; using System.Text.Json.Serialization; using Microsoft.CodeAnalysis; +using SourceGenerators; namespace System.Text.Json.SourceGeneration { diff --git a/src/libraries/System.Text.Json/gen/Resources/Strings.resx b/src/libraries/System.Text.Json/gen/Resources/Strings.resx index 85d64b685f023d..519cbffa3ec09a 100644 --- a/src/libraries/System.Text.Json/gen/Resources/Strings.resx +++ b/src/libraries/System.Text.Json/gen/Resources/Strings.resx @@ -1,17 +1,17 @@ - @@ -193,7 +193,7 @@ C# language version not supported by the source generator. - The System.Text.Json source generator is not available in C# '{0}'. Please use language version {1} or greater. + The System.Text.Json source generator is not available in C# {0}. Please use language version {1} or greater. Constructor annotated with JsonConstructorAttribute is inaccessible. diff --git a/src/libraries/System.Text.Json/gen/Resources/xlf/Strings.cs.xlf b/src/libraries/System.Text.Json/gen/Resources/xlf/Strings.cs.xlf index cc88af350de1e5..96bb2680ffcb70 100644 --- a/src/libraries/System.Text.Json/gen/Resources/xlf/Strings.cs.xlf +++ b/src/libraries/System.Text.Json/gen/Resources/xlf/Strings.cs.xlf @@ -113,8 +113,8 @@ - The System.Text.Json source generator is not available in C# '{0}'. Please use language version {1} or greater. - Zdrojový generátor System.Text.Json není k dispozici v jazyce C#{0}. Použijte prosím jazykovou verzi {1} nebo vyšší. + The System.Text.Json source generator is not available in C# {0}. Please use language version {1} or greater. + Zdrojový generátor System.Text.Json není k dispozici v jazyce C# {0}. Použijte prosím jazykovou verzi {1} nebo vyšší. diff --git a/src/libraries/System.Text.Json/gen/Resources/xlf/Strings.de.xlf b/src/libraries/System.Text.Json/gen/Resources/xlf/Strings.de.xlf index d23d07843bb212..dbc6ca2ae2b90a 100644 --- a/src/libraries/System.Text.Json/gen/Resources/xlf/Strings.de.xlf +++ b/src/libraries/System.Text.Json/gen/Resources/xlf/Strings.de.xlf @@ -113,8 +113,8 @@ - The System.Text.Json source generator is not available in C# '{0}'. Please use language version {1} or greater. - Der System.Text.Json-Quellgenerator ist in C# „{0}“ nicht verfügbar. Verwenden Sie die Sprachversion {1} oder höher. + The System.Text.Json source generator is not available in C# {0}. Please use language version {1} or greater. + Der System.Text.Json-Quellgenerator ist in C# {0} nicht verfügbar. Verwenden Sie die Sprachversion {1} oder höher. diff --git a/src/libraries/System.Text.Json/gen/Resources/xlf/Strings.es.xlf b/src/libraries/System.Text.Json/gen/Resources/xlf/Strings.es.xlf index 989c5a8a02f263..878fa15bc7efb1 100644 --- a/src/libraries/System.Text.Json/gen/Resources/xlf/Strings.es.xlf +++ b/src/libraries/System.Text.Json/gen/Resources/xlf/Strings.es.xlf @@ -113,8 +113,8 @@ - The System.Text.Json source generator is not available in C# '{0}'. Please use language version {1} or greater. - El generador de origen System.Text.Json no está disponible en C# '{0}'. Use la versión de idioma {1} o superior. + The System.Text.Json source generator is not available in C# {0}. Please use language version {1} or greater. + El generador de origen System.Text.Json no está disponible en C# {0}. Use la versión de idioma {1} o superior. diff --git a/src/libraries/System.Text.Json/gen/Resources/xlf/Strings.fr.xlf b/src/libraries/System.Text.Json/gen/Resources/xlf/Strings.fr.xlf index 4e1998a8f6e376..fbf12b5a733d49 100644 --- a/src/libraries/System.Text.Json/gen/Resources/xlf/Strings.fr.xlf +++ b/src/libraries/System.Text.Json/gen/Resources/xlf/Strings.fr.xlf @@ -113,8 +113,8 @@ - The System.Text.Json source generator is not available in C# '{0}'. Please use language version {1} or greater. - Le générateur de source System.Text.Json n'est pas disponible en C# '{0}'. Veuillez utiliser la version linguistique {1} ou supérieure. + The System.Text.Json source generator is not available in C# {0}. Please use language version {1} or greater. + Le générateur de source System.Text.Json n'est pas disponible en C# « {0} ». Veuillez utiliser la version linguistique {1} ou supérieure. diff --git a/src/libraries/System.Text.Json/gen/Resources/xlf/Strings.it.xlf b/src/libraries/System.Text.Json/gen/Resources/xlf/Strings.it.xlf index 5cb68b207cab19..04d5cf68ef4a94 100644 --- a/src/libraries/System.Text.Json/gen/Resources/xlf/Strings.it.xlf +++ b/src/libraries/System.Text.Json/gen/Resources/xlf/Strings.it.xlf @@ -113,8 +113,8 @@ - The System.Text.Json source generator is not available in C# '{0}'. Please use language version {1} or greater. - Il generatore di origine System.Text.Json non è disponibile in C# '{0}'. Usare la versione del linguaggio {1} o successiva. + The System.Text.Json source generator is not available in C# {0}. Please use language version {1} or greater. + Il generatore di origine System.Text.Json non è disponibile in C# {0}. Usare la versione del linguaggio {1} o successiva. diff --git a/src/libraries/System.Text.Json/gen/Resources/xlf/Strings.ja.xlf b/src/libraries/System.Text.Json/gen/Resources/xlf/Strings.ja.xlf index a25a901571f489..0224cb65fc26e3 100644 --- a/src/libraries/System.Text.Json/gen/Resources/xlf/Strings.ja.xlf +++ b/src/libraries/System.Text.Json/gen/Resources/xlf/Strings.ja.xlf @@ -113,8 +113,8 @@ - The System.Text.Json source generator is not available in C# '{0}'. Please use language version {1} or greater. - System.Text.Json ソース ジェネレーターは C# '{0}' では使用できません。言語バージョン {1} 以上を使用してください。 + The System.Text.Json source generator is not available in C# {0}. Please use language version {1} or greater. + System.Text.Json ソース ジェネレーターは、C# {0}では使用できません。言語バージョン {1} 以上を使用してください。 diff --git a/src/libraries/System.Text.Json/gen/Resources/xlf/Strings.ko.xlf b/src/libraries/System.Text.Json/gen/Resources/xlf/Strings.ko.xlf index dec5ff53a5aa12..9bc14e0cc3d556 100644 --- a/src/libraries/System.Text.Json/gen/Resources/xlf/Strings.ko.xlf +++ b/src/libraries/System.Text.Json/gen/Resources/xlf/Strings.ko.xlf @@ -113,8 +113,8 @@ - The System.Text.Json source generator is not available in C# '{0}'. Please use language version {1} or greater. - System.Text.Json 원본 생성기는 C# '{0}'에서 사용할 수 없습니다. {1} 이상의 언어 버전을 사용하세요. + The System.Text.Json source generator is not available in C# {0}. Please use language version {1} or greater. + System.Text.Json 원본 생성기는 C# {0}에서 사용할 수 없습니다. {1} 이상의 언어 버전을 사용하세요. diff --git a/src/libraries/System.Text.Json/gen/Resources/xlf/Strings.pl.xlf b/src/libraries/System.Text.Json/gen/Resources/xlf/Strings.pl.xlf index bb1a603dbaf458..d8e7a990013fac 100644 --- a/src/libraries/System.Text.Json/gen/Resources/xlf/Strings.pl.xlf +++ b/src/libraries/System.Text.Json/gen/Resources/xlf/Strings.pl.xlf @@ -113,8 +113,8 @@ - The System.Text.Json source generator is not available in C# '{0}'. Please use language version {1} or greater. - Generator źródła System.Text.Json nie jest dostępny w języku C# „{0}”. Użyj wersji językowej lub nowszej {1} . + The System.Text.Json source generator is not available in C# {0}. Please use language version {1} or greater. + Generator źródła System.Text.Json nie jest dostępny w języku C# {0}. Użyj wersji językowej {1} lub nowszej. diff --git a/src/libraries/System.Text.Json/gen/Resources/xlf/Strings.pt-BR.xlf b/src/libraries/System.Text.Json/gen/Resources/xlf/Strings.pt-BR.xlf index a4cb21e691b5dd..d31fa77f16ee47 100644 --- a/src/libraries/System.Text.Json/gen/Resources/xlf/Strings.pt-BR.xlf +++ b/src/libraries/System.Text.Json/gen/Resources/xlf/Strings.pt-BR.xlf @@ -113,8 +113,8 @@ - The System.Text.Json source generator is not available in C# '{0}'. Please use language version {1} or greater. - O gerador de fonte System.Text.Json não está disponível em C# '{0}'. Use a versão do idioma {1} ou superior. + The System.Text.Json source generator is not available in C# {0}. Please use language version {1} or greater. + O gerador de origem System.Text.Json não está disponível em C# {0}. Use a versão do idioma {1} ou superior. diff --git a/src/libraries/System.Text.Json/gen/Resources/xlf/Strings.ru.xlf b/src/libraries/System.Text.Json/gen/Resources/xlf/Strings.ru.xlf index d1c3241bc43f22..de31bb76e6f2d9 100644 --- a/src/libraries/System.Text.Json/gen/Resources/xlf/Strings.ru.xlf +++ b/src/libraries/System.Text.Json/gen/Resources/xlf/Strings.ru.xlf @@ -113,8 +113,8 @@ - The System.Text.Json source generator is not available in C# '{0}'. Please use language version {1} or greater. - Генератор исходного кода System.Text.Json не доступен в C# "{0}". Используйте языковую версию {1} или выше. + The System.Text.Json source generator is not available in C# {0}. Please use language version {1} or greater. + Генератор исходного кода System.Text.Json не доступен в C# {0}. Используйте языковую версию {1} или выше. diff --git a/src/libraries/System.Text.Json/gen/Resources/xlf/Strings.tr.xlf b/src/libraries/System.Text.Json/gen/Resources/xlf/Strings.tr.xlf index 72b706fdc31a64..3ab7559fc82284 100644 --- a/src/libraries/System.Text.Json/gen/Resources/xlf/Strings.tr.xlf +++ b/src/libraries/System.Text.Json/gen/Resources/xlf/Strings.tr.xlf @@ -113,8 +113,8 @@ - The System.Text.Json source generator is not available in C# '{0}'. Please use language version {1} or greater. - System.Text.Json kaynak oluşturucusu C# '{0}' içinde kullanılamıyor. Lütfen dil sürümü {1} veya üstü sürümü kullanın. + The System.Text.Json source generator is not available in C# {0}. Please use language version {1} or greater. + System.Text.Json kaynak oluşturucusu C# {0}'ta mevcut değildir. Lütfen {1} dil sürümünü veya daha üstünü kullanın. diff --git a/src/libraries/System.Text.Json/gen/Resources/xlf/Strings.zh-Hans.xlf b/src/libraries/System.Text.Json/gen/Resources/xlf/Strings.zh-Hans.xlf index 3450d312816e01..03ce3792f5ec8a 100644 --- a/src/libraries/System.Text.Json/gen/Resources/xlf/Strings.zh-Hans.xlf +++ b/src/libraries/System.Text.Json/gen/Resources/xlf/Strings.zh-Hans.xlf @@ -113,8 +113,8 @@ - The System.Text.Json source generator is not available in C# '{0}'. Please use language version {1} or greater. - System.Text.Json 源生成器在 C#“{0}”中不可用。请使用{1}或更高版本的语言版本。 + The System.Text.Json source generator is not available in C# {0}. Please use language version {1} or greater. + System.Text.Json 源生成器在 C# {0} 中不可用。请使用{1}或更高版本的语言版本。 diff --git a/src/libraries/System.Text.Json/gen/Resources/xlf/Strings.zh-Hant.xlf b/src/libraries/System.Text.Json/gen/Resources/xlf/Strings.zh-Hant.xlf index 176989f5fc53a6..ed207add481d94 100644 --- a/src/libraries/System.Text.Json/gen/Resources/xlf/Strings.zh-Hant.xlf +++ b/src/libraries/System.Text.Json/gen/Resources/xlf/Strings.zh-Hant.xlf @@ -113,8 +113,8 @@ - The System.Text.Json source generator is not available in C# '{0}'. Please use language version {1} or greater. - C# '{0}' 中無法使用 System.Text.Json 來源產生器。請使用 {1} 或更新的語言版本。 + The System.Text.Json source generator is not available in C# {0}. Please use language version {1} or greater. + C# {0} 中無法使用 System.Text.Json 來源產生器。請使用 {1} 或更新的語言版本。 diff --git a/src/libraries/System.Text.Json/gen/System.Text.Json.SourceGeneration.targets b/src/libraries/System.Text.Json/gen/System.Text.Json.SourceGeneration.targets index 364f6e1f6682f7..23add6278d7c07 100644 --- a/src/libraries/System.Text.Json/gen/System.Text.Json.SourceGeneration.targets +++ b/src/libraries/System.Text.Json/gen/System.Text.Json.SourceGeneration.targets @@ -30,7 +30,11 @@ + + + + @@ -53,9 +57,7 @@ - - @@ -73,6 +75,5 @@ - diff --git a/src/libraries/System.Text.Json/src/PACKAGE.md b/src/libraries/System.Text.Json/src/PACKAGE.md index 1bfcd1da44e258..1ddd210a0acb28 100644 --- a/src/libraries/System.Text.Json/src/PACKAGE.md +++ b/src/libraries/System.Text.Json/src/PACKAGE.md @@ -237,10 +237,10 @@ The main types provided by this library are: * `System.Text.Json.Nodes.JsonNode` * `System.Text.Json.Serialization.Metadata.JsonTypeInfo` -## Addtional Documentation +## Additional Documentation -* [Conceptual documentation](https://learn.microsoft.com/en-us/dotnet/standard/serialization/system-text-json/overview) -* [API documentation](https://learn.microsoft.com/en-us/dotnet/api/system.text.json) +* [Conceptual documentation](https://learn.microsoft.com/dotnet/standard/serialization/system-text-json/overview) +* [API documentation](https://learn.microsoft.com/dotnet/api/system.text.json) ## Related Packages diff --git a/src/libraries/System.Text.Json/src/Resources/Strings.resx b/src/libraries/System.Text.Json/src/Resources/Strings.resx index f091984783b601..0ebab3e5d27d6f 100644 --- a/src/libraries/System.Text.Json/src/Resources/Strings.resx +++ b/src/libraries/System.Text.Json/src/Resources/Strings.resx @@ -696,6 +696,9 @@ JsonObjectCreationHandling.Populate is incompatible with reference handling. + + JsonObjectCreationHandling.Populate is currently not supported in types with parameterized constructors. + Either the JSON value is not in a supported format, or is out of bounds for an Int128. diff --git a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Converters/Object/ObjectWithParameterizedConstructorConverter.cs b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Converters/Object/ObjectWithParameterizedConstructorConverter.cs index 304ca0a26e409c..7c5d6aae1c4051 100644 --- a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Converters/Object/ObjectWithParameterizedConstructorConverter.cs +++ b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Converters/Object/ObjectWithParameterizedConstructorConverter.cs @@ -25,10 +25,10 @@ internal sealed override bool OnTryRead(ref Utf8JsonReader reader, Type typeToCo { JsonTypeInfo jsonTypeInfo = state.Current.JsonTypeInfo; - if (jsonTypeInfo.CreateObject != null || state.Current.IsPopulating) + if (!jsonTypeInfo.UsesParameterizedConstructor || state.Current.IsPopulating) { // Fall back to default object converter in following cases: - // - if user has set a default constructor delegate with contract customization + // - if user configuration has invalidated the parameterized constructor // - we're continuing populating an object. return base.OnTryRead(ref reader, typeToConvert, options, ref state, out value); } diff --git a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/DefaultJsonTypeInfoResolver.Helpers.cs b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/DefaultJsonTypeInfoResolver.Helpers.cs index 7ec9db9adc7739..3f929d87378485 100644 --- a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/DefaultJsonTypeInfoResolver.Helpers.cs +++ b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/DefaultJsonTypeInfoResolver.Helpers.cs @@ -38,18 +38,26 @@ internal static MemberAccessor MemberAccessor private static JsonTypeInfo CreateTypeInfoCore(Type type, JsonConverter converter, JsonSerializerOptions options) { JsonTypeInfo typeInfo = JsonTypeInfo.CreateJsonTypeInfo(type, converter, options); - typeInfo.NumberHandling = GetNumberHandlingForType(typeInfo.Type); - typeInfo.PreferredPropertyObjectCreationHandling = GetObjectCreationHandlingForType(typeInfo.Type); - if (typeInfo.Kind == JsonTypeInfoKind.Object) + if (GetNumberHandlingForType(typeInfo.Type) is { } numberHandling) { - typeInfo.UnmappedMemberHandling = GetUnmappedMemberHandling(typeInfo.Type); + typeInfo.NumberHandling = numberHandling; + } + + if (GetObjectCreationHandlingForType(typeInfo.Type) is { } creationHandling) + { + typeInfo.PreferredPropertyObjectCreationHandling = creationHandling; + } + + if (GetUnmappedMemberHandling(typeInfo.Type) is { } unmappedMemberHandling) + { + typeInfo.UnmappedMemberHandling = unmappedMemberHandling; } typeInfo.PopulatePolymorphismMetadata(); typeInfo.MapInterfaceTypesToCallbacks(); - Func? createObject = MemberAccessor.CreateConstructor(typeInfo.Type); + Func? createObject = DetermineCreateObjectDelegate(type, converter); typeInfo.SetCreateObjectIfCompatible(createObject); typeInfo.CreateObjectForExtensionDataProperty = createObject; @@ -80,7 +88,7 @@ private static void PopulateProperties(JsonTypeInfo typeInfo) bool constructorHasSetsRequiredMembersAttribute = typeInfo.Converter.ConstructorInfo?.HasSetsRequiredMembersAttribute() ?? false; - JsonTypeInfo.PropertyHierarchyResolutionState state = new(); + JsonTypeInfo.PropertyHierarchyResolutionState state = new(typeInfo.Options); // Walk the type hierarchy starting from the current type up to the base type(s) foreach (Type currentType in typeInfo.Type.GetSortedTypeHierarchy()) @@ -411,5 +419,24 @@ internal static void DeterminePropertyAccessors(JsonPropertyInfo jsonPrope break; } } + + [RequiresUnreferencedCode(JsonSerializer.SerializationUnreferencedCodeMessage)] + [RequiresDynamicCode(JsonSerializer.SerializationRequiresDynamicCodeMessage)] + private static Func? DetermineCreateObjectDelegate(Type type, JsonConverter converter) + { + ConstructorInfo? defaultCtor = null; + + if (converter.ConstructorInfo != null && !converter.ConstructorIsParameterized) + { + // A parameterless constructor has been resolved by the converter + // (e.g. it might be a non-public ctor with JsonConverterAttribute). + defaultCtor = converter.ConstructorInfo; + } + + // Fall back to resolving any public constructors on the type. + defaultCtor ??= type.GetConstructor(BindingFlags.Public | BindingFlags.Instance, binder: null, Type.EmptyTypes, modifiers: null); + + return MemberAccessor.CreateParameterlessConstructor(type, defaultCtor); + } } } diff --git a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/JsonMetadataServices.Helpers.cs b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/JsonMetadataServices.Helpers.cs index 1b7113f9dd758a..965b4cea39570a 100644 --- a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/JsonMetadataServices.Helpers.cs +++ b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/JsonMetadataServices.Helpers.cs @@ -137,7 +137,7 @@ internal static void PopulateProperties(JsonTypeInfo typeInfo, JsonTypeInfo.Json // Regardless of the source generator we need to re-run the naming conflict resolution algorithm // at run time since it is possible that the naming policy or other configs can be different then. - JsonTypeInfo.PropertyHierarchyResolutionState state = new(); + JsonTypeInfo.PropertyHierarchyResolutionState state = new(typeInfo.Options); foreach (JsonPropertyInfo jsonPropertyInfo in properties) { diff --git a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/JsonPropertyInfo.cs b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/JsonPropertyInfo.cs index 0ab5c08d7825b3..e2234093474e0d 100644 --- a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/JsonPropertyInfo.cs +++ b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/JsonPropertyInfo.cs @@ -495,9 +495,17 @@ private void DetermineEffectiveObjectCreationHandlingForProperty() Debug.Assert(ParentTypeInfo != null, "We should have ensured parent is assigned in JsonTypeInfo"); Debug.Assert(!IsConfigured, "Should not be called post-configuration."); + JsonObjectCreationHandling effectiveObjectCreationHandling = JsonObjectCreationHandling.Replace; if (ObjectCreationHandling == null) { - JsonObjectCreationHandling preferredCreationHandling = ParentTypeInfo.PreferredPropertyObjectCreationHandling ?? Options.PreferredObjectCreationHandling; + // Consult type-level configuration, then global configuration. + // Ignore global configuration if we're using a parameterized constructor. + JsonObjectCreationHandling preferredCreationHandling = + ParentTypeInfo.PreferredPropertyObjectCreationHandling + ?? (ParentTypeInfo.DetermineUsesParameterizedConstructor() + ? JsonObjectCreationHandling.Replace + : Options.PreferredObjectCreationHandling); + bool canPopulate = preferredCreationHandling == JsonObjectCreationHandling.Populate && EffectiveConverter.CanPopulate && @@ -506,7 +514,7 @@ private void DetermineEffectiveObjectCreationHandlingForProperty() !ParentTypeInfo.SupportsPolymorphicDeserialization && !(Set == null && IgnoreReadOnlyMember); - EffectiveObjectCreationHandling = canPopulate ? JsonObjectCreationHandling.Populate : JsonObjectCreationHandling.Replace; + effectiveObjectCreationHandling = canPopulate ? JsonObjectCreationHandling.Populate : JsonObjectCreationHandling.Replace; } else if (ObjectCreationHandling == JsonObjectCreationHandling.Populate) { @@ -537,18 +545,24 @@ private void DetermineEffectiveObjectCreationHandlingForProperty() ThrowHelper.ThrowInvalidOperationException_ObjectCreationHandlingPropertyCannotAllowReadOnlyMember(this); } - EffectiveObjectCreationHandling = JsonObjectCreationHandling.Populate; - } - else - { - Debug.Assert(EffectiveObjectCreationHandling == JsonObjectCreationHandling.Replace); + effectiveObjectCreationHandling = JsonObjectCreationHandling.Populate; } - if (EffectiveObjectCreationHandling == JsonObjectCreationHandling.Populate && - Options.ReferenceHandlingStrategy != ReferenceHandlingStrategy.None) + if (effectiveObjectCreationHandling is JsonObjectCreationHandling.Populate) { - ThrowHelper.ThrowInvalidOperationException_ObjectCreationHandlingPropertyCannotAllowReferenceHandling(); + if (ParentTypeInfo.DetermineUsesParameterizedConstructor()) + { + ThrowHelper.ThrowNotSupportedException_ObjectCreationHandlingPropertyDoesNotSupportParameterizedConstructors(); + } + + if (Options.ReferenceHandlingStrategy != ReferenceHandlingStrategy.None) + { + ThrowHelper.ThrowInvalidOperationException_ObjectCreationHandlingPropertyCannotAllowReferenceHandling(); + } } + + // Validation complete, commit configuration. + EffectiveObjectCreationHandling = effectiveObjectCreationHandling; } private bool NumberHandingIsApplicable() diff --git a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/JsonTypeInfo.Cache.cs b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/JsonTypeInfo.Cache.cs index 86a7af256a78a8..5a901fbb80eae6 100644 --- a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/JsonTypeInfo.Cache.cs +++ b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/JsonTypeInfo.Cache.cs @@ -35,6 +35,14 @@ public abstract partial class JsonTypeInfo // All of the serializable parameters on a POCO constructor keyed on parameter name. // Only parameters which bind to properties are cached. internal JsonPropertyDictionary? ParameterCache { get; private set; } + internal bool UsesParameterizedConstructor + { + get + { + Debug.Assert(IsConfigured); + return ParameterCache != null; + } + } // All of the serializable properties on a POCO (except the optional extension property) keyed on property name. internal JsonPropertyDictionary? PropertyCache { get; private set; } diff --git a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/JsonTypeInfo.cs b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/JsonTypeInfo.cs index b9e9fe60d2b23f..668e0c7b15e1a1 100644 --- a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/JsonTypeInfo.cs +++ b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/JsonTypeInfo.cs @@ -552,17 +552,14 @@ public JsonObjectCreationHandling? PreferredPropertyObjectCreationHandling { VerifyMutable(); - if (value is not null) + if (Kind != JsonTypeInfoKind.Object) { - if (Kind != JsonTypeInfoKind.Object) - { - ThrowHelper.ThrowInvalidOperationException_JsonTypeInfoOperationNotPossibleForKind(Kind); - } + ThrowHelper.ThrowInvalidOperationException_JsonTypeInfoOperationNotPossibleForKind(Kind); + } - if (!JsonSerializer.IsValidCreationHandlingValue(value.Value)) - { - throw new ArgumentOutOfRangeException(nameof(value)); - } + if (value is not null && !JsonSerializer.IsValidCreationHandlingValue(value.Value)) + { + throw new ArgumentOutOfRangeException(nameof(value)); } _preferredPropertyObjectCreationHandling = value; @@ -684,7 +681,7 @@ private void Configure() { ConfigureProperties(); - if (Converter.ConstructorIsParameterized) + if (DetermineUsesParameterizedConstructor()) { ConfigureConstructorParameters(); } @@ -808,6 +805,12 @@ bool IsCurrentNodeCompatible() /// private bool IsCompatibleWithCurrentOptions { get; set; } = true; + /// + /// Determine if the current configuration is compatible with using a parameterized constructor. + /// + internal bool DetermineUsesParameterizedConstructor() + => Converter.ConstructorIsParameterized && CreateObject is null; + #if DEBUG internal string GetPropertyDebugInfo(ReadOnlySpan unescapedPropertyName) { @@ -989,10 +992,9 @@ public JsonPropertyInfo CreateJsonPropertyInfo(Type propertyType, string name) internal abstract ValueTask DeserializeAsObjectAsync(Stream utf8Json, CancellationToken cancellationToken); internal abstract object? DeserializeAsObject(Stream utf8Json); - internal ref struct PropertyHierarchyResolutionState + internal ref struct PropertyHierarchyResolutionState(JsonSerializerOptions options) { - public PropertyHierarchyResolutionState() { } - public Dictionary AddedProperties = new(); + public Dictionary AddedProperties = new(options.PropertyNameCaseInsensitive ? StringComparer.OrdinalIgnoreCase : StringComparer.Ordinal); public Dictionary? IgnoredProperties; public bool IsPropertyOrderSpecified; } @@ -1107,7 +1109,7 @@ internal void ConfigureProperties() internal void ConfigureConstructorParameters() { Debug.Assert(Kind == JsonTypeInfoKind.Object); - Debug.Assert(Converter.ConstructorIsParameterized); + Debug.Assert(DetermineUsesParameterizedConstructor()); Debug.Assert(PropertyCache is not null); Debug.Assert(ParameterCache is null); diff --git a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/MemberAccessor.cs b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/MemberAccessor.cs index 326ab657b8899d..ff6c442fa488cb 100644 --- a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/MemberAccessor.cs +++ b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/MemberAccessor.cs @@ -9,8 +9,9 @@ namespace System.Text.Json.Serialization.Metadata { internal abstract class MemberAccessor { - public abstract Func? CreateConstructor( - [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] Type classType); + public abstract Func? CreateParameterlessConstructor( + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] Type type, + ConstructorInfo? constructorInfo); public abstract Func CreateParameterizedConstructor(ConstructorInfo constructor); diff --git a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/ReflectionEmitCachingMemberAccessor.cs b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/ReflectionEmitCachingMemberAccessor.cs index a243f4be4d86a6..e30f87d76da9f1 100644 --- a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/ReflectionEmitCachingMemberAccessor.cs +++ b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/ReflectionEmitCachingMemberAccessor.cs @@ -21,12 +21,12 @@ internal sealed partial class ReflectionEmitCachingMemberAccessor : MemberAccess => s_cache.GetOrAdd((nameof(CreateAddMethodDelegate), typeof(TCollection), null), static (_) => s_sourceAccessor.CreateAddMethodDelegate()); - public override Func? CreateConstructor([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] Type classType) - => s_cache.GetOrAdd((nameof(CreateConstructor), classType, null), + public override Func? CreateParameterlessConstructor([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] Type type, ConstructorInfo? ctorInfo) + => s_cache.GetOrAdd((nameof(CreateParameterlessConstructor), type, ctorInfo), [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2077:UnrecognizedReflectionPattern", Justification = "Cannot apply DynamicallyAccessedMembersAttribute to tuple properties.")] #pragma warning disable IL2077 // The suppression doesn't work for the trim analyzer: https://github.com/dotnet/roslyn/issues/59746 - static (key) => s_sourceAccessor.CreateConstructor(key.declaringType)); + static (key) => s_sourceAccessor.CreateParameterlessConstructor(key.declaringType, (ConstructorInfo?)key.member)); #pragma warning restore IL2077 public override Func CreateFieldGetter(FieldInfo fieldInfo) diff --git a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/ReflectionEmitMemberAccessor.cs b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/ReflectionEmitMemberAccessor.cs index 6e05e5c8057ac7..4b0b426bdaa9ef 100644 --- a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/ReflectionEmitMemberAccessor.cs +++ b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/ReflectionEmitMemberAccessor.cs @@ -13,18 +13,19 @@ namespace System.Text.Json.Serialization.Metadata [RequiresDynamicCode(JsonSerializer.SerializationRequiresDynamicCodeMessage)] internal sealed class ReflectionEmitMemberAccessor : MemberAccessor { - public override Func? CreateConstructor( - [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] Type type) + public override Func? CreateParameterlessConstructor( + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] Type type, + ConstructorInfo? constructorInfo) { Debug.Assert(type != null); - ConstructorInfo? realMethod = type.GetConstructor(BindingFlags.Public | BindingFlags.Instance, binder: null, Type.EmptyTypes, modifiers: null); + Debug.Assert(constructorInfo is null || constructorInfo.GetParameters().Length == 0); if (type.IsAbstract) { return null; } - if (realMethod == null && !type.IsValueType) + if (constructorInfo is null && !type.IsValueType) { return null; } @@ -38,8 +39,10 @@ internal sealed class ReflectionEmitMemberAccessor : MemberAccessor ILGenerator generator = dynamicMethod.GetILGenerator(); - if (realMethod == null) + if (constructorInfo is null) { + Debug.Assert(type.IsValueType); + LocalBuilder local = generator.DeclareLocal(type); generator.Emit(OpCodes.Ldloca_S, local); @@ -49,7 +52,7 @@ internal sealed class ReflectionEmitMemberAccessor : MemberAccessor } else { - generator.Emit(OpCodes.Newobj, realMethod); + generator.Emit(OpCodes.Newobj, constructorInfo); if (type.IsValueType) { // Since C# 10 it's now possible to have parameterless constructors in structs diff --git a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/ReflectionMemberAccessor.cs b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/ReflectionMemberAccessor.cs index 606ad6aba79c26..8627a24f3f4926 100644 --- a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/ReflectionMemberAccessor.cs +++ b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/ReflectionMemberAccessor.cs @@ -10,35 +10,26 @@ namespace System.Text.Json.Serialization.Metadata { internal sealed class ReflectionMemberAccessor : MemberAccessor { - private sealed class ConstructorContext - { - [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] - private readonly Type _type; - - public ConstructorContext([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] Type type) - => _type = type; - - public object? CreateInstance() - => Activator.CreateInstance(_type, nonPublic: false); - } - - public override Func? CreateConstructor( - [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] Type type) + public override Func? CreateParameterlessConstructor( + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] Type type, + ConstructorInfo? ctorInfo) { Debug.Assert(type != null); - ConstructorInfo? realMethod = type.GetConstructor(BindingFlags.Public | BindingFlags.Instance, binder: null, Type.EmptyTypes, modifiers: null); + Debug.Assert(ctorInfo is null || ctorInfo.GetParameters().Length == 0); if (type.IsAbstract) { return null; } - if (realMethod == null && !type.IsValueType) + if (ctorInfo is null) { - return null; + return type.IsValueType + ? () => Activator.CreateInstance(type, nonPublic: false)! + : null; } - return new ConstructorContext(type).CreateInstance!; + return () => ctorInfo.Invoke(null); } public override Func CreateParameterizedConstructor(ConstructorInfo constructor) diff --git a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/ReadStack.cs b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/ReadStack.cs index 59a47bc3ac7bc6..25c067cc10930a 100644 --- a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/ReadStack.cs +++ b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/ReadStack.cs @@ -386,20 +386,20 @@ public JsonTypeInfo GetTopJsonTypeInfoWithParameterizedConstructor() for (int i = 0; i < _count - 1; i++) { - if (_stack[i].JsonTypeInfo.Converter.ConstructorIsParameterized) + if (_stack[i].JsonTypeInfo.UsesParameterizedConstructor) { return _stack[i].JsonTypeInfo; } } - Debug.Assert(Current.JsonTypeInfo.Converter.ConstructorIsParameterized); + Debug.Assert(Current.JsonTypeInfo.UsesParameterizedConstructor); return Current.JsonTypeInfo; } [MethodImpl(MethodImplOptions.AggressiveInlining)] private void SetConstructorArgumentState() { - if (Current.JsonTypeInfo.Converter.ConstructorIsParameterized) + if (Current.JsonTypeInfo.UsesParameterizedConstructor) { Current.CtorArgumentState ??= new(); } diff --git a/src/libraries/System.Text.Json/src/System/Text/Json/ThrowHelper.Serialization.cs b/src/libraries/System.Text.Json/src/System/Text/Json/ThrowHelper.Serialization.cs index 7072d9e3020085..5b05ff243a80a9 100644 --- a/src/libraries/System.Text.Json/src/System/Text/Json/ThrowHelper.Serialization.cs +++ b/src/libraries/System.Text.Json/src/System/Text/Json/ThrowHelper.Serialization.cs @@ -99,6 +99,12 @@ public static void ThrowInvalidOperationException_ObjectCreationHandlingProperty throw new InvalidOperationException(SR.ObjectCreationHandlingPropertyCannotAllowReferenceHandling); } + [DoesNotReturn] + public static void ThrowNotSupportedException_ObjectCreationHandlingPropertyDoesNotSupportParameterizedConstructors() + { + throw new NotSupportedException(SR.ObjectCreationHandlingPropertyDoesNotSupportParameterizedConstructors); + } + [DoesNotReturn] public static void ThrowJsonException_SerializationConverterRead(JsonConverter? converter) { diff --git a/src/libraries/System.Text.Json/tests/Common/ConstructorTests/ConstructorTests.AttributePresence.cs b/src/libraries/System.Text.Json/tests/Common/ConstructorTests/ConstructorTests.AttributePresence.cs index 681d495b8c752c..f54692a5c59a44 100644 --- a/src/libraries/System.Text.Json/tests/Common/ConstructorTests/ConstructorTests.AttributePresence.cs +++ b/src/libraries/System.Text.Json/tests/Common/ConstructorTests/ConstructorTests.AttributePresence.cs @@ -39,6 +39,24 @@ public async Task NonPublicCtors_WithJsonConstructorAttribute_WorksAsExpected(Ty } } + [Theory] + [InlineData(typeof(PrivateParameterlessCtor_WithAttribute), false)] + [InlineData(typeof(InternalParameterlessCtor_WithAttribute), true)] + [InlineData(typeof(ProtectedParameterlessCtor_WithAttribute), false)] + public async Task NonPublicParameterlessCtors_WithJsonConstructorAttribute_WorksAsExpected(Type type, bool isAccessibleBySourceGen) + { + if (!Serializer.IsSourceGeneratedSerializer || isAccessibleBySourceGen) + { + object? result = await Serializer.DeserializeWrapper("{}", type); + Assert.IsType(type, result); + } + else + { + NotSupportedException ex = await Assert.ThrowsAsync(() => Serializer.DeserializeWrapper("{}", type)); + Assert.Contains("JsonConstructorAttribute", ex.ToString()); + } + } + [Fact] public async Task SinglePublicParameterizedCtor_SingleParameterlessCtor_NoAttribute_Supported_UseParameterlessCtor() { diff --git a/src/libraries/System.Text.Json/tests/Common/JsonCreationHandlingTests.Object.cs b/src/libraries/System.Text.Json/tests/Common/JsonCreationHandlingTests.Object.cs index aae9da6b2c628b..e25adffef53fa2 100644 --- a/src/libraries/System.Text.Json/tests/Common/JsonCreationHandlingTests.Object.cs +++ b/src/libraries/System.Text.Json/tests/Common/JsonCreationHandlingTests.Object.cs @@ -1165,4 +1165,52 @@ public class ClassWithInvalidPropertyAnnotation [JsonObjectCreationHandling((JsonObjectCreationHandling)(-1))] public List Property { get; } } + + [Theory] + [InlineData(typeof(ClassWithParameterizedConstructorWithPopulateProperty))] + [InlineData(typeof(ClassWithParameterizedConstructorWithPopulateType))] + public async Task ClassWithParameterizedCtor_UsingPopulateConfiguration_ThrowsNotSupportedException(Type type) + { + object instance = Activator.CreateInstance(type, "Jim"); + string json = """{"Username":"Jim","PhoneNumbers":["123456"]}"""; + + await Assert.ThrowsAsync(() => Serializer.SerializeWrapper(instance, type)); + await Assert.ThrowsAsync(() => Serializer.DeserializeWrapper(json, type)); + Assert.Throws(() => Serializer.GetTypeInfo(type)); + } + + public class ClassWithParameterizedConstructorWithPopulateProperty(string name) + { + public string Name { get; } = name; + + [JsonObjectCreationHandling(JsonObjectCreationHandling.Populate)] + public List PhoneNumbers { get; } = new(); + } + + [JsonObjectCreationHandling(JsonObjectCreationHandling.Populate)] + public class ClassWithParameterizedConstructorWithPopulateType(string name) + { + public string Name { get; } = name; + + public List PhoneNumbers { get; } = new(); + } + + [Fact] + public async Task ClassWithParameterizedCtor_NoPopulateConfiguration_WorksWithGlobalPopulateConfiguration() + { + string json = """{"Username":"Jim","PhoneNumbers":["123456"]}"""; + + JsonSerializerOptions options = Serializer.CreateOptions(makeReadOnly: false); + options.PreferredObjectCreationHandling = JsonObjectCreationHandling.Populate; + + ClassWithParameterizedConstructorNoPopulate result = await Serializer.DeserializeWrapper(json, options); + Assert.Empty(result.PhoneNumbers); + } + + public class ClassWithParameterizedConstructorNoPopulate(string name) + { + public string Name { get; } = name; + + public List PhoneNumbers { get; } = new(); + } } diff --git a/src/libraries/System.Text.Json/tests/Common/PropertyNameTests.cs b/src/libraries/System.Text.Json/tests/Common/PropertyNameTests.cs index 4295359c6f0380..021481ae5a1362 100644 --- a/src/libraries/System.Text.Json/tests/Common/PropertyNameTests.cs +++ b/src/libraries/System.Text.Json/tests/Common/PropertyNameTests.cs @@ -494,5 +494,34 @@ public class ClassWithSpecialCharacters [JsonPropertyName("\uA000_2")] // Valid C# property name: \uA000_2 public int YiIt_2 { get; set; } } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task ClassWithIgnoredCaseInsensitiveConflict_RespectsIgnoredMember(bool propertyNameCaseInsensitive) + { + // Regression test for https://github.com/dotnet/runtime/issues/93903 + // specifically for propertyNameCaseInsensitive := true + + JsonSerializerOptions options = Serializer.CreateOptions(makeReadOnly: false); + options.PropertyNameCaseInsensitive = propertyNameCaseInsensitive; + + var value = new ClassWithIgnoredCaseInsensitiveConflict { name = "lowercase", Name = "uppercase" }; + string json = await Serializer.SerializeWrapper(value, options); + + Assert.Equal("""{"name":"lowercase"}""", json); + + value = await Serializer.DeserializeWrapper(json, options); + Assert.Equal("lowercase", value.name); + Assert.Null(value.Name); + } + + public class ClassWithIgnoredCaseInsensitiveConflict + { + public string name { get; set; } + + [JsonIgnore] + public string Name { get; set; } + } } } diff --git a/src/libraries/System.Text.Json/tests/Common/TestClasses/TestClasses.Constructor.cs b/src/libraries/System.Text.Json/tests/Common/TestClasses/TestClasses.Constructor.cs index 285c5b78934733..f12cf89e41de59 100644 --- a/src/libraries/System.Text.Json/tests/Common/TestClasses/TestClasses.Constructor.cs +++ b/src/libraries/System.Text.Json/tests/Common/TestClasses/TestClasses.Constructor.cs @@ -74,6 +74,33 @@ public class ProtectedParameterizedCtor_WithAttribute protected ProtectedParameterizedCtor_WithAttribute(int x) => X = x; } + public class PrivateParameterlessCtor_WithAttribute + { + public int X { get; } + + [JsonConstructor] + private PrivateParameterlessCtor_WithAttribute() + => X = 42; + } + + public class ProtectedParameterlessCtor_WithAttribute + { + public int X { get; } + + [JsonConstructor] + protected ProtectedParameterlessCtor_WithAttribute() + => X = 42; + } + + public class InternalParameterlessCtor_WithAttribute + { + public int X { get; } + + [JsonConstructor] + internal InternalParameterlessCtor_WithAttribute() + => X = 42; + } + public class PrivateParameterlessCtor_InternalParameterizedCtor_WithMultipleAttributes { [JsonConstructor] diff --git a/src/libraries/System.Text.Json/tests/Common/TestClasses/TestClasses.cs b/src/libraries/System.Text.Json/tests/Common/TestClasses/TestClasses.cs index 79859c1e73cc5f..470d624d3646fb 100644 --- a/src/libraries/System.Text.Json/tests/Common/TestClasses/TestClasses.cs +++ b/src/libraries/System.Text.Json/tests/Common/TestClasses/TestClasses.cs @@ -1913,69 +1913,81 @@ public override string ConvertName(string name) } } + public static class ReflectionExtensions + { +#if NET6_0_OR_GREATER + [return: System.Diagnostics.CodeAnalysis.DynamicallyAccessedMembers(System.Diagnostics.CodeAnalysis.DynamicallyAccessedMemberTypes.PublicConstructors)] + public static Type WithConstructors( + [System.Diagnostics.CodeAnalysis.DynamicallyAccessedMembers(System.Diagnostics.CodeAnalysis.DynamicallyAccessedMemberTypes.PublicConstructors)] + this Type type) => type; +#else + public static Type WithConstructors(this Type type) => type; +#endif + } + public static class CollectionTestTypes { public static IEnumerable EnumerableTypes() { - yield return typeof(TElement[]); // ArrayConverter - yield return typeof(ConcurrentQueue); // ConcurrentQueueOfTConverter - yield return typeof(GenericICollectionWrapper); // ICollectionOfTConverter - yield return typeof(WrapperForIEnumerable); // IEnumerableConverter - yield return typeof(WrapperForIReadOnlyCollectionOfT); // IEnumerableOfTConverter - yield return typeof(Queue); // IEnumerableWithAddMethodConverter - yield return typeof(WrapperForIList); // IListConverter - yield return typeof(Collection); // IListOfTConverter - yield return typeof(ImmutableList); // ImmutableEnumerableOfTConverter - yield return typeof(HashSet); // ISetOfTConverter - yield return typeof(List); // ListOfTConverter - yield return typeof(Queue); // QueueOfTConverter + yield return typeof(TElement[]).WithConstructors(); // ArrayConverter + yield return typeof(ConcurrentQueue).WithConstructors(); // ConcurrentQueueOfTConverter + yield return typeof(GenericICollectionWrapper).WithConstructors(); // ICollectionOfTConverter + yield return typeof(WrapperForIEnumerable).WithConstructors(); // IEnumerableConverter + yield return typeof(WrapperForIReadOnlyCollectionOfT).WithConstructors(); // IEnumerableOfTConverter + yield return typeof(Queue).WithConstructors(); // IEnumerableWithAddMethodConverter + yield return typeof(WrapperForIList).WithConstructors(); // IListConverter + yield return typeof(Collection).WithConstructors(); // IListOfTConverter + yield return typeof(ImmutableList).WithConstructors(); // ImmutableEnumerableOfTConverter + yield return typeof(HashSet).WithConstructors(); // ISetOfTConverter + yield return typeof(List).WithConstructors(); // ListOfTConverter + yield return typeof(Queue).WithConstructors(); // QueueOfTConverter } public static IEnumerable DeserializableGenericEnumerableTypes() { - yield return typeof(TElement[]); // ArrayConverter - yield return typeof(ConcurrentQueue); // ConcurrentQueueOfTConverter - yield return typeof(GenericICollectionWrapper); // ICollectionOfTConverter - yield return typeof(IEnumerable); // IEnumerableConverter - yield return typeof(Collection); // IListOfTConverter - yield return typeof(ImmutableList); // ImmutableEnumerableOfTConverter - yield return typeof(HashSet); // ISetOfTConverter - yield return typeof(List); // ListOfTConverter - yield return typeof(Queue); // QueueOfTConverter + yield return typeof(TElement[]).WithConstructors(); // ArrayConverter + yield return typeof(ConcurrentQueue).WithConstructors(); // ConcurrentQueueOfTConverter + yield return typeof(GenericICollectionWrapper).WithConstructors(); // ICollectionOfTConverter + yield return typeof(IEnumerable).WithConstructors(); // IEnumerableConverter + yield return typeof(Collection).WithConstructors(); // IListOfTConverter + yield return typeof(ImmutableList).WithConstructors(); // ImmutableEnumerableOfTConverter + yield return typeof(HashSet).WithConstructors(); // ISetOfTConverter + yield return typeof(List).WithConstructors(); // ListOfTConverter + yield return typeof(Queue).WithConstructors(); // QueueOfTConverter } public static IEnumerable DeserializableNonGenericEnumerableTypes() { - yield return typeof(Queue); // IEnumerableWithAddMethodConverter - yield return typeof(WrapperForIList); // IListConverter + yield return typeof(Queue).WithConstructors(); // IEnumerableWithAddMethodConverter + yield return typeof(WrapperForIList).WithConstructors(); // IListConverter } public static IEnumerable DictionaryTypes() { - yield return typeof(Dictionary); // DictionaryOfStringTValueConverter - yield return typeof(Hashtable); // IDictionaryConverter - yield return typeof(ConcurrentDictionary); // IDictionaryOfStringTValueConverter - yield return typeof(GenericIDictionaryWrapper); // IDictionaryOfStringTValueConverter - yield return typeof(ImmutableDictionary); // ImmutableDictionaryOfStringTValueConverter - yield return typeof(GenericIReadOnlyDictionaryWrapper); // IReadOnlyDictionaryOfStringTValueConverter + yield return typeof(Dictionary).WithConstructors(); // DictionaryOfStringTValueConverter + yield return typeof(Hashtable).WithConstructors(); // IDictionaryConverter + yield return typeof(ConcurrentDictionary).WithConstructors(); // IDictionaryOfStringTValueConverter + yield return typeof(GenericIDictionaryWrapper).WithConstructors(); // IDictionaryOfStringTValueConverter + yield return typeof(ImmutableDictionary).WithConstructors(); // ImmutableDictionaryOfStringTValueConverter + yield return typeof(GenericIReadOnlyDictionaryWrapper).WithConstructors(); // IReadOnlyDictionaryOfStringTValueConverter } public static IEnumerable DeserializableDictionaryTypes() { - yield return typeof(Dictionary); // DictionaryOfStringTValueConverter - yield return typeof(Hashtable); // IDictionaryConverter - yield return typeof(IDictionary); // IDictionaryConverter - yield return typeof(ConcurrentDictionary); // IDictionaryOfStringTValueConverter - yield return typeof(IDictionary); // IDictionaryOfStringTValueConverter - yield return typeof(GenericIDictionaryWrapper); // IDictionaryOfStringTValueConverter - yield return typeof(ImmutableDictionary); // ImmutableDictionaryOfStringTValueConverter - yield return typeof(IReadOnlyDictionary); // IReadOnlyDictionaryOfStringTValueConverter + yield return typeof(Dictionary).WithConstructors(); // DictionaryOfStringTValueConverter + yield return typeof(Hashtable).WithConstructors(); // IDictionaryConverter + yield return typeof(IDictionary).WithConstructors(); // IDictionaryConverter + yield return typeof(ConcurrentDictionary).WithConstructors(); // IDictionaryOfStringTValueConverter + yield return typeof(IDictionary).WithConstructors(); // IDictionaryOfStringTValueConverter + yield return typeof(GenericIDictionaryWrapper).WithConstructors(); // IDictionaryOfStringTValueConverter + yield return typeof(ImmutableDictionary).WithConstructors(); // ImmutableDictionaryOfStringTValueConverter + yield return typeof(IReadOnlyDictionary).WithConstructors(); // IReadOnlyDictionaryOfStringTValueConverter } public static IEnumerable DeserializableNonGenericDictionaryTypes() { - yield return typeof(Hashtable); // IDictionaryConverter - yield return typeof(SortedList); // IDictionaryConverter + yield return typeof(Hashtable).WithConstructors(); // IDictionaryConverter + yield return typeof(SortedList).WithConstructors(); // IDictionaryConverter } } diff --git a/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Tests/Serialization/ConstructorTests.cs b/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Tests/Serialization/ConstructorTests.cs index b2a5fd465da766..d1769f491104db 100644 --- a/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Tests/Serialization/ConstructorTests.cs +++ b/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Tests/Serialization/ConstructorTests.cs @@ -41,6 +41,9 @@ protected ConstructorTests_Metadata(JsonSerializerWrapper stringWrapper) [JsonSerializable(typeof(PrivateParameterizedCtor_WithAttribute))] [JsonSerializable(typeof(InternalParameterizedCtor_WithAttribute))] [JsonSerializable(typeof(ProtectedParameterizedCtor_WithAttribute))] + [JsonSerializable(typeof(PrivateParameterlessCtor_WithAttribute))] + [JsonSerializable(typeof(InternalParameterlessCtor_WithAttribute))] + [JsonSerializable(typeof(ProtectedParameterlessCtor_WithAttribute))] [JsonSerializable(typeof(SinglePublicParameterizedCtor))] [JsonSerializable(typeof(SingleParameterlessCtor_MultiplePublicParameterizedCtor))] [JsonSerializable(typeof(SingleParameterlessCtor_MultiplePublicParameterizedCtor_Struct))] @@ -186,6 +189,9 @@ public ConstructorTests_Default(JsonSerializerWrapper jsonSerializer) : base(jso [JsonSerializable(typeof(PrivateParameterizedCtor_WithAttribute))] [JsonSerializable(typeof(InternalParameterizedCtor_WithAttribute))] [JsonSerializable(typeof(ProtectedParameterizedCtor_WithAttribute))] + [JsonSerializable(typeof(PrivateParameterlessCtor_WithAttribute))] + [JsonSerializable(typeof(InternalParameterlessCtor_WithAttribute))] + [JsonSerializable(typeof(ProtectedParameterlessCtor_WithAttribute))] [JsonSerializable(typeof(SinglePublicParameterizedCtor))] [JsonSerializable(typeof(SingleParameterlessCtor_MultiplePublicParameterizedCtor))] [JsonSerializable(typeof(SingleParameterlessCtor_MultiplePublicParameterizedCtor_Struct))] diff --git a/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Tests/Serialization/JsonCreationHandlingTests.cs b/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Tests/Serialization/JsonCreationHandlingTests.cs index eab6f939b93674..5862387200f6a4 100644 --- a/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Tests/Serialization/JsonCreationHandlingTests.cs +++ b/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Tests/Serialization/JsonCreationHandlingTests.cs @@ -278,6 +278,9 @@ public sealed class JsonCreationHandlingTests_AsyncStreamWithSmallBuffer() [JsonSerializable(typeof(SimpleClassWitNonPopulatableProperty))] [JsonSerializable(typeof(ClassWithInvalidTypeAnnotation))] [JsonSerializable(typeof(ClassWithInvalidPropertyAnnotation))] + [JsonSerializable(typeof(ClassWithParameterizedConstructorWithPopulateProperty))] + [JsonSerializable(typeof(ClassWithParameterizedConstructorWithPopulateType))] + [JsonSerializable(typeof(ClassWithParameterizedConstructorNoPopulate))] internal partial class CreationHandlingTestContext : JsonSerializerContext { } diff --git a/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Tests/Serialization/PropertyNameTests.cs b/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Tests/Serialization/PropertyNameTests.cs index 82566bf7123ce7..e512451eed72bc 100644 --- a/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Tests/Serialization/PropertyNameTests.cs +++ b/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Tests/Serialization/PropertyNameTests.cs @@ -28,6 +28,7 @@ public PropertyNameTests_Metadata() [JsonSerializable(typeof(ObjectPropertyNamesDifferentByCaseOnly_TestClass))] [JsonSerializable(typeof(OverridePropertyNameDesignTime_TestClass))] [JsonSerializable(typeof(SimpleTestClass))] + [JsonSerializable(typeof(ClassWithIgnoredCaseInsensitiveConflict))] internal sealed partial class PropertyNameTestsContext_Metadata : JsonSerializerContext { } @@ -53,6 +54,7 @@ public PropertyNameTests_Default() [JsonSerializable(typeof(ObjectPropertyNamesDifferentByCaseOnly_TestClass))] [JsonSerializable(typeof(OverridePropertyNameDesignTime_TestClass))] [JsonSerializable(typeof(SimpleTestClass))] + [JsonSerializable(typeof(ClassWithIgnoredCaseInsensitiveConflict))] internal sealed partial class PropertyNameTestsContext_Default : JsonSerializerContext { } diff --git a/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Unit.Tests/JsonSourceGeneratorDiagnosticsTests.cs b/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Unit.Tests/JsonSourceGeneratorDiagnosticsTests.cs index 98f2b71042332b..a554d2681d43d1 100644 --- a/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Unit.Tests/JsonSourceGeneratorDiagnosticsTests.cs +++ b/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Unit.Tests/JsonSourceGeneratorDiagnosticsTests.cs @@ -416,7 +416,7 @@ public void SupportedLanguageVersions_SucceedCompilation(LanguageVersion langVer using System.Text.Json.Serialization; namespace HelloWorld - { + { public class MyClass { public MyClass(int value) @@ -457,7 +457,7 @@ public void SupportedLanguageVersions_Memory_SucceedCompilation(LanguageVersion using System.Text.Json.Serialization; namespace HelloWorld - { + { public class MyClass { public MyClass( @@ -504,7 +504,7 @@ public void UnsupportedLanguageVersions_FailCompilation(LanguageVersion langVers using System.Text.Json.Serialization; namespace HelloWorld - { + { public class MyClass { public MyClass(int value) @@ -531,7 +531,7 @@ public partial class MyJsonContext : JsonSerializerContext var expectedDiagnostics = new DiagnosticData[] { - new(DiagnosticSeverity.Error, contextLocation, $"The System.Text.Json source generator is not available in C# '{langVersion.ToDisplayString()}'. Please use language version 9.0 or greater.") + new(DiagnosticSeverity.Error, contextLocation, $"The System.Text.Json source generator is not available in C# {langVersion.ToDisplayString()}. Please use language version 9.0 or greater.") }; CompilationHelper.AssertEqualDiagnosticMessages(expectedDiagnostics, result.Diagnostics); diff --git a/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Unit.Tests/JsonSourceGeneratorIncrementalTests.cs b/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Unit.Tests/JsonSourceGeneratorIncrementalTests.cs index 7a38a7e5fb5128..daa6498cbc9b2d 100644 --- a/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Unit.Tests/JsonSourceGeneratorIncrementalTests.cs +++ b/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Unit.Tests/JsonSourceGeneratorIncrementalTests.cs @@ -6,6 +6,7 @@ using System.Linq; using System.Reflection; using Microsoft.CodeAnalysis; +using SourceGenerators.Tests; using Xunit; namespace System.Text.Json.SourceGeneration.UnitTests @@ -29,7 +30,7 @@ public static void CompilingTheSameSourceResultsInEqualModels(Func ContextGenerationSpec ctx2 = result2.ContextGenerationSpecs[i]; Assert.NotSame(ctx1, ctx2); - AssertStructurallyEqual(ctx1, ctx2); + GeneratorTestHelpers.AssertStructurallyEqual(ctx1, ctx2); Assert.Equal(ctx1, ctx2); Assert.Equal(ctx1.GetHashCode(), ctx2.GetHashCode()); @@ -86,7 +87,7 @@ public partial class JsonContext : JsonSerializerContext { } ContextGenerationSpec ctx2 = result2.ContextGenerationSpecs[0]; Assert.NotSame(ctx1, ctx2); - AssertStructurallyEqual(ctx1, ctx2); + GeneratorTestHelpers.AssertStructurallyEqual(ctx1, ctx2); Assert.Equal(ctx1, ctx2); Assert.Equal(ctx1.GetHashCode(), ctx2.GetHashCode()); @@ -377,74 +378,5 @@ public static IEnumerable GetCompilationHelperFactories() .Where(m => m.ReturnType == typeof(Compilation) && m.GetParameters().Length == 0) .Select(m => new object[] { Delegate.CreateDelegate(typeof(Func), m) }); } - - /// - /// Asserts for structural equality, returning a path to the mismatching data when not equal. - /// - private static void AssertStructurallyEqual(T expected, T actual) - { - CheckAreEqualCore(expected, actual, new()); - static void CheckAreEqualCore(object expected, object actual, Stack path) - { - if (expected is null || actual is null) - { - if (expected is not null || actual is not null) - { - FailNotEqual(); - } - - return; - } - - Type type = expected.GetType(); - if (type != actual.GetType()) - { - FailNotEqual(); - return; - } - - if (expected is IEnumerable leftCollection) - { - if (actual is not IEnumerable rightCollection) - { - FailNotEqual(); - return; - } - - object?[] expectedValues = leftCollection.Cast().ToArray(); - object?[] actualValues = rightCollection.Cast().ToArray(); - - for (int i = 0; i < Math.Max(expectedValues.Length, actualValues.Length); i++) - { - object? expectedElement = i < expectedValues.Length ? expectedValues[i] : ""; - object? actualElement = i < actualValues.Length ? actualValues[i] : ""; - - path.Push($"[{i}]"); - CheckAreEqualCore(expectedElement, actualElement, path); - path.Pop(); - } - } - - if (type.GetProperty("EqualityContract", BindingFlags.Instance | BindingFlags.NonPublic, null, returnType: typeof(Type), types: Array.Empty(), null) != null) - { - // Type is a C# record, run pointwise equality comparison. - foreach (PropertyInfo property in type.GetProperties(BindingFlags.Public | BindingFlags.Instance)) - { - path.Push("." + property.Name); - CheckAreEqualCore(property.GetValue(expected), property.GetValue(actual), path); - path.Pop(); - } - - return; - } - - if (!expected.Equals(actual)) - { - FailNotEqual(); - } - - void FailNotEqual() => Assert.Fail($"Value not equal in ${string.Join("", path.Reverse())}: expected {expected}, but was {actual}."); - } - } } } diff --git a/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Unit.Tests/System.Text.Json.SourceGeneration.Unit.Tests.targets b/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Unit.Tests/System.Text.Json.SourceGeneration.Unit.Tests.targets index a700b2a9f3a385..56bf105dc1fddf 100644 --- a/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Unit.Tests/System.Text.Json.SourceGeneration.Unit.Tests.targets +++ b/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Unit.Tests/System.Text.Json.SourceGeneration.Unit.Tests.targets @@ -12,6 +12,7 @@ + diff --git a/src/libraries/System.Text.Json/tests/System.Text.Json.Tests/Serialization/MetadataTests/DefaultJsonTypeInfoResolverTests.JsonTypeInfo.cs b/src/libraries/System.Text.Json/tests/System.Text.Json.Tests/Serialization/MetadataTests/DefaultJsonTypeInfoResolverTests.JsonTypeInfo.cs index 06fd59bae037e9..bc6e3a28ff78f7 100644 --- a/src/libraries/System.Text.Json/tests/System.Text.Json.Tests/Serialization/MetadataTests/DefaultJsonTypeInfoResolverTests.JsonTypeInfo.cs +++ b/src/libraries/System.Text.Json/tests/System.Text.Json.Tests/Serialization/MetadataTests/DefaultJsonTypeInfoResolverTests.JsonTypeInfo.cs @@ -1440,11 +1440,10 @@ public static void PreferredPropertyObjectCreationHandling_NonObjectKind_ThrowsI { JsonTypeInfo jsonTypeInfo = JsonTypeInfo.CreateJsonTypeInfo(type, new()); - // Invalid kinds default to null and can be set to null. - Assert.Null(jsonTypeInfo.PreferredPropertyObjectCreationHandling); - jsonTypeInfo.PreferredPropertyObjectCreationHandling = null; + // Invalid kinds default to null. Assert.Null(jsonTypeInfo.PreferredPropertyObjectCreationHandling); + Assert.Throws(() => jsonTypeInfo.PreferredPropertyObjectCreationHandling = null); Assert.Throws(() => jsonTypeInfo.PreferredPropertyObjectCreationHandling = JsonObjectCreationHandling.Populate); Assert.Throws(() => jsonTypeInfo.PreferredPropertyObjectCreationHandling = JsonObjectCreationHandling.Replace); Assert.Null(jsonTypeInfo.PreferredPropertyObjectCreationHandling); diff --git a/src/libraries/System.Threading.Channels/src/PACKAGE.md b/src/libraries/System.Threading.Channels/src/PACKAGE.md new file mode 100644 index 00000000000000..f022aaf5ba32f2 --- /dev/null +++ b/src/libraries/System.Threading.Channels/src/PACKAGE.md @@ -0,0 +1,73 @@ +## About + + + +The `System.Threading.Channels` library provides types for passing data asynchronously between producers and consumers. + +## Key Features + + + +* Abstractions representing channels for one or more producers to publish data to one or more consumers +* APIs focused on asynchronous production and consumption of data +* Factory methods for producing multiple kinds of channels + +## How to Use + + + +```C# +using System; +using System.Threading.Channels; +using System.Threading.Tasks; + +Channel channel = Channel.CreateUnbounded(); + +Task producer = Task.Run(async () => +{ + int i = 0; + while (true) + { + channel.Writer.TryWrite(i++); + await Task.Delay(TimeSpan.FromSeconds(1)); + } +}); + +Task consumer = Task.Run(async () => +{ + await foreach (int value in channel.Reader.ReadAllAsync()) + { + Console.WriteLine(value); + } +}); + +await Task.WhenAll(producer, consumer); +``` + +## Main Types + + + +The main types provided by this library are: + +* `System.Threading.Channel` +* `System.Threading.Channel` + +## Additional Documentation + + + +* [Overview](https://devblogs.microsoft.com/dotnet/an-introduction-to-system-threading-channels/) +* [API documentation](https://learn.microsoft.com/dotnet/api/system.threading.channels) + +## Related Packages + + + +https://www.nuget.org/packages/System.Threading.Tasks.Dataflow/ + +## Feedback & Contributing + + + +System.Threading.Channels is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). \ No newline at end of file diff --git a/src/libraries/System.Threading.RateLimiting/src/System/Threading/RateLimiting/ConcurrencyLimiter.cs b/src/libraries/System.Threading.RateLimiting/src/System/Threading/RateLimiting/ConcurrencyLimiter.cs index 6b5a4014990ef0..7131b4fe1d7999 100644 --- a/src/libraries/System.Threading.RateLimiting/src/System/Threading/RateLimiting/ConcurrencyLimiter.cs +++ b/src/libraries/System.Threading.RateLimiting/src/System/Threading/RateLimiting/ConcurrencyLimiter.cs @@ -156,8 +156,17 @@ protected override ValueTask AcquireAsyncCore(int permitCount, C Debug.Assert(_queueCount >= 0); if (!oldestRequest.TrySetResult(FailedLease)) { - // Updating queue count is handled by the cancellation code - _queueCount += oldestRequest.Count; + if (!oldestRequest.QueueCountModified) + { + // We already updated the queue count, the Cancel code is about to run or running and waiting on our lock, + // tell Cancel not to do anything + oldestRequest.QueueCountModified = true; + } + else + { + // Updating queue count was handled by the cancellation code, don't double count + _queueCount += oldestRequest.Count; + } } else { @@ -277,10 +286,19 @@ private void Release(int releaseCount) // Check if request was canceled if (!nextPendingRequest.TrySetResult(lease)) { - // Queued item was canceled so add count back + // Queued item was canceled so add count back, permits weren't acquired _permitCount += nextPendingRequest.Count; - // Updating queue count is handled by the cancellation code - _queueCount += nextPendingRequest.Count; + if (!nextPendingRequest.QueueCountModified) + { + // We already updated the queue count, the Cancel code is about to run or running and waiting on our lock, + // tell Cancel not to do anything + nextPendingRequest.QueueCountModified = true; + } + else + { + // Updating queue count was handled by the cancellation code, don't double count + _queueCount += nextPendingRequest.Count; + } } else { @@ -399,6 +417,9 @@ private sealed class RequestRegistration : TaskCompletionSource private readonly CancellationToken _cancellationToken; private CancellationTokenRegistration _cancellationTokenRegistration; + // Update under the limiter lock and only if the queue count was updated by the calling code + public bool QueueCountModified { get; set; } + // this field is used only by the disposal mechanics and never shared between threads private RequestRegistration? _next; @@ -429,7 +450,14 @@ private static void Cancel(object? state) var limiter = (ConcurrencyLimiter)registration.Task.AsyncState!; lock (limiter.Lock) { - limiter._queueCount -= registration.Count; + // Queuing and replenishing code might modify the _queueCount, since there is no guarantee of when the cancellation + // code runs and we only want to update the _queueCount once, we set a bool (under a lock) so either method + // can update the count and not double count. + if (!registration.QueueCountModified) + { + limiter._queueCount -= registration.Count; + registration.QueueCountModified = true; + } } } } diff --git a/src/libraries/System.Threading.RateLimiting/src/System/Threading/RateLimiting/FixedWindowRateLimiter.cs b/src/libraries/System.Threading.RateLimiting/src/System/Threading/RateLimiting/FixedWindowRateLimiter.cs index d09c7973b18aa7..daaed9cf5ce422 100644 --- a/src/libraries/System.Threading.RateLimiting/src/System/Threading/RateLimiting/FixedWindowRateLimiter.cs +++ b/src/libraries/System.Threading.RateLimiting/src/System/Threading/RateLimiting/FixedWindowRateLimiter.cs @@ -173,7 +173,17 @@ protected override ValueTask AcquireAsyncCore(int permitCount, C Debug.Assert(_queueCount >= 0); if (!oldestRequest.TrySetResult(FailedLease)) { - _queueCount += oldestRequest.Count; + if (!oldestRequest.QueueCountModified) + { + // We already updated the queue count, the Cancel code is about to run or running and waiting on our lock, + // tell Cancel not to do anything + oldestRequest.QueueCountModified = true; + } + else + { + // Updating queue count was handled by the cancellation code, don't double count + _queueCount += oldestRequest.Count; + } } else { @@ -330,10 +340,19 @@ private void ReplenishInternal(long nowTicks) if (!nextPendingRequest.TrySetResult(SuccessfulLease)) { - // Queued item was canceled so add count back + // Queued item was canceled so add count back, permits weren't acquired _permitCount += nextPendingRequest.Count; - // Updating queue count is handled by the cancellation code - _queueCount += nextPendingRequest.Count; + if (!nextPendingRequest.QueueCountModified) + { + // We already updated the queue count, the Cancel code is about to run or running and waiting on our lock, + // tell Cancel not to do anything + nextPendingRequest.QueueCountModified = true; + } + else + { + // Updating queue count was handled by the cancellation code, don't double count + _queueCount += nextPendingRequest.Count; + } } else { @@ -435,6 +454,9 @@ private sealed class RequestRegistration : TaskCompletionSource private readonly CancellationToken _cancellationToken; private CancellationTokenRegistration _cancellationTokenRegistration; + // Update under the limiter lock and only if the queue count was updated by the calling code + public bool QueueCountModified { get; set; } + // this field is used only by the disposal mechanics and never shared between threads private RequestRegistration? _next; @@ -465,7 +487,14 @@ private static void Cancel(object? state) var limiter = (FixedWindowRateLimiter)registration.Task.AsyncState!; lock (limiter.Lock) { - limiter._queueCount -= registration.Count; + // Queuing and replenishing code might modify the _queueCount, since there is no guarantee of when the cancellation + // code runs and we only want to update the _queueCount once, we set a bool (under a lock) so either method + // can update the count and not double count. + if (!registration.QueueCountModified) + { + limiter._queueCount -= registration.Count; + registration.QueueCountModified = true; + } } } } diff --git a/src/libraries/System.Threading.RateLimiting/src/System/Threading/RateLimiting/SlidingWindowRateLimiter.cs b/src/libraries/System.Threading.RateLimiting/src/System/Threading/RateLimiting/SlidingWindowRateLimiter.cs index a179720ede33fa..23dbf98e0fcdea 100644 --- a/src/libraries/System.Threading.RateLimiting/src/System/Threading/RateLimiting/SlidingWindowRateLimiter.cs +++ b/src/libraries/System.Threading.RateLimiting/src/System/Threading/RateLimiting/SlidingWindowRateLimiter.cs @@ -185,7 +185,17 @@ protected override ValueTask AcquireAsyncCore(int permitCount, C Debug.Assert(_queueCount >= 0); if (!oldestRequest.TrySetResult(FailedLease)) { - _queueCount += oldestRequest.Count; + if (!oldestRequest.QueueCountModified) + { + // We already updated the queue count, the Cancel code is about to run or running and waiting on our lock, + // tell Cancel not to do anything + oldestRequest.QueueCountModified = true; + } + else + { + // Updating queue count was handled by the cancellation code, don't double count + _queueCount += oldestRequest.Count; + } } else { @@ -342,11 +352,20 @@ private void ReplenishInternal(long nowTicks) if (!nextPendingRequest.TrySetResult(SuccessfulLease)) { - // Queued item was canceled so add count back + // Queued item was canceled so add count back, permits weren't acquired _permitCount += nextPendingRequest.Count; _requestsPerSegment[_currentSegmentIndex] -= nextPendingRequest.Count; - // Updating queue count is handled by the cancellation code - _queueCount += nextPendingRequest.Count; + if (!nextPendingRequest.QueueCountModified) + { + // We already updated the queue count, the Cancel code is about to run or running and waiting on our lock, + // tell Cancel not to do anything + nextPendingRequest.QueueCountModified = true; + } + else + { + // Updating queue count was handled by the cancellation code, don't double count + _queueCount += nextPendingRequest.Count; + } } else { @@ -448,6 +467,9 @@ private sealed class RequestRegistration : TaskCompletionSource private readonly CancellationToken _cancellationToken; private CancellationTokenRegistration _cancellationTokenRegistration; + // Update under the limiter lock and only if the queue count was updated by the calling code + public bool QueueCountModified { get; set; } + // this field is used only by the disposal mechanics and never shared between threads private RequestRegistration? _next; @@ -478,7 +500,14 @@ private static void Cancel(object? state) var limiter = (SlidingWindowRateLimiter)registration.Task.AsyncState!; lock (limiter.Lock) { - limiter._queueCount -= registration.Count; + // Queuing and replenishing code might modify the _queueCount, since there is no guarantee of when the cancellation + // code runs and we only want to update the _queueCount once, we set a bool (under a lock) so either method + // can update the count and not double count. + if (!registration.QueueCountModified) + { + limiter._queueCount -= registration.Count; + registration.QueueCountModified = true; + } } } } diff --git a/src/libraries/System.Threading.RateLimiting/src/System/Threading/RateLimiting/TokenBucketRateLimiter.cs b/src/libraries/System.Threading.RateLimiting/src/System/Threading/RateLimiting/TokenBucketRateLimiter.cs index 5ad7859792ff7f..67a3a55a29ad03 100644 --- a/src/libraries/System.Threading.RateLimiting/src/System/Threading/RateLimiting/TokenBucketRateLimiter.cs +++ b/src/libraries/System.Threading.RateLimiting/src/System/Threading/RateLimiting/TokenBucketRateLimiter.cs @@ -178,8 +178,17 @@ protected override ValueTask AcquireAsyncCore(int tokenCount, Ca Debug.Assert(_queueCount >= 0); if (!oldestRequest.TrySetResult(FailedLease)) { - // Updating queue count is handled by the cancellation code - _queueCount += oldestRequest.Count; + if (!oldestRequest.QueueCountModified) + { + // We already updated the queue count, the Cancel code is about to run or running and waiting on our lock, + // tell Cancel not to do anything + oldestRequest.QueueCountModified = true; + } + else + { + // Updating queue count was handled by the cancellation code, don't double count + _queueCount += oldestRequest.Count; + } } else { @@ -345,10 +354,19 @@ private void ReplenishInternal(long nowTicks) if (!nextPendingRequest.TrySetResult(SuccessfulLease)) { - // Queued item was canceled so add count back + // Queued item was canceled so add count back, permits weren't acquired _tokenCount += nextPendingRequest.Count; - // Updating queue count is handled by the cancellation code - _queueCount += nextPendingRequest.Count; + if (!nextPendingRequest.QueueCountModified) + { + // We already updated the queue count, the Cancel code is about to run or running and waiting on our lock, + // tell Cancel not to do anything + nextPendingRequest.QueueCountModified = true; + } + else + { + // Updating queue count was handled by the cancellation code, don't double count + _queueCount += nextPendingRequest.Count; + } } else { @@ -450,6 +468,9 @@ private sealed class RequestRegistration : TaskCompletionSource private readonly CancellationToken _cancellationToken; private CancellationTokenRegistration _cancellationTokenRegistration; + // Update under the limiter lock and only if the queue count was updated by the calling code + public bool QueueCountModified { get; set; } + // this field is used only by the disposal mechanics and never shared between threads private RequestRegistration? _next; @@ -480,7 +501,14 @@ private static void Cancel(object? state) var limiter = (TokenBucketRateLimiter)registration.Task.AsyncState!; lock (limiter.Lock) { - limiter._queueCount -= registration.Count; + // Queuing and replenishing code might modify the _queueCount, since there is no guarantee of when the cancellation + // code runs and we only want to update the _queueCount once, we set a bool (under a lock) so either method + // can update the count and not double count. + if (!registration.QueueCountModified) + { + limiter._queueCount -= registration.Count; + registration.QueueCountModified = true; + } } } } diff --git a/src/libraries/System.Threading.Tasks.Parallel/src/System/Threading/Tasks/Parallel.ForEachAsync.cs b/src/libraries/System.Threading.Tasks.Parallel/src/System/Threading/Tasks/Parallel.ForEachAsync.cs index b3c6d50cafe2e7..48d76dff065eaa 100644 --- a/src/libraries/System.Threading.Tasks.Parallel/src/System/Threading/Tasks/Parallel.ForEachAsync.cs +++ b/src/libraries/System.Threading.Tasks.Parallel/src/System/Threading/Tasks/Parallel.ForEachAsync.cs @@ -186,6 +186,7 @@ static bool CompareExchange(ref T location, T value, T comparand) => // If we're the last worker to complete, complete the operation. if (state.SignalWorkerCompletedIterating()) { + state.Dispose(); state.Complete(); } } @@ -745,7 +746,7 @@ public ValueTask DisposeAsync() /// Stores the state associated with an IAsyncEnumerable ForEachAsync operation, shared between all its workers. /// Specifies the type of data being enumerated. - private sealed class ForEachState : ForEachAsyncState + private sealed class ForEachState : ForEachAsyncState, IDisposable { public T NextAvailable; public readonly T ToExclusive; @@ -759,6 +760,8 @@ public ForEachState( NextAvailable = fromExclusive; ToExclusive = toExclusive; } + + public void Dispose() => _registration.Dispose(); } } } diff --git a/src/libraries/System.ValueTuple/tests/ValueTupleTests.cs b/src/libraries/System.ValueTuple/tests/ValueTupleTests.cs index 9507b31388b1f8..147b4ea9647cbf 100644 --- a/src/libraries/System.ValueTuple/tests/ValueTupleTests.cs +++ b/src/libraries/System.ValueTuple/tests/ValueTupleTests.cs @@ -799,6 +799,12 @@ public static void OneTuples() Assert.False(((IStructuralEquatable)sc).Equals(sc, DummyTestEqualityComparer.Instance)); + Assert.True(ValueTuple.Create(1).Equals(ValueTuple.Create(1))); + Assert.False(ValueTuple.Create(1).Equals(ValueTuple.Create(0))); + + Assert.True(((IStructuralEquatable)ValueTuple.Create(1)).Equals(ValueTuple.Create(1), TestEqualityComparer.Instance)); + Assert.False(((IStructuralEquatable)ValueTuple.Create(1)).Equals(ValueTuple.Create(0), TestEqualityComparer.Instance)); + Assert.Equal("(1, 2, 3, 4, 5, 6, 7, 1)", CreateLong(1, 2, 3, 4, 5, 6, 7, ValueTuple.Create(1)).ToString()); var vtWithNull = new ValueTuple(null); @@ -831,6 +837,14 @@ public static void TwoTuples() Assert.False(((IStructuralEquatable)sc).Equals(sc, DummyTestEqualityComparer.Instance)); + Assert.True(ValueTuple.Create(1, 2).Equals(ValueTuple.Create(1, 2))); + Assert.False(ValueTuple.Create(1, 2).Equals(ValueTuple.Create(0, 2))); + Assert.False(ValueTuple.Create(1, 2).Equals(ValueTuple.Create(1, 0))); + + Assert.True(((IStructuralEquatable)ValueTuple.Create(1, 2)).Equals(ValueTuple.Create(1, 2), TestEqualityComparer.Instance)); + Assert.False(((IStructuralEquatable)ValueTuple.Create(1, 2)).Equals(ValueTuple.Create(0, 2), TestEqualityComparer.Instance)); + Assert.False(((IStructuralEquatable)ValueTuple.Create(1, 2)).Equals(ValueTuple.Create(1, 0), TestEqualityComparer.Instance)); + Assert.Equal("(1, 2, 3, 4, 5, 6, 7, 1, 2)", CreateLong(1, 2, 3, 4, 5, 6, 7, ValueTuple.Create(1, 2)).ToString()); var vtWithNull = new ValueTuple(null, null); @@ -866,6 +880,16 @@ public static void ThreeTuples() Assert.False(((IStructuralEquatable)sc).Equals(sc, DummyTestEqualityComparer.Instance)); + Assert.True(ValueTuple.Create(1, 2, 3).Equals(ValueTuple.Create(1, 2, 3))); + Assert.False(ValueTuple.Create(1, 2, 3).Equals(ValueTuple.Create(0, 2, 3))); + Assert.False(ValueTuple.Create(1, 2, 3).Equals(ValueTuple.Create(1, 0, 3))); + Assert.False(ValueTuple.Create(1, 2, 3).Equals(ValueTuple.Create(1, 2, 0))); + + Assert.True(((IStructuralEquatable)ValueTuple.Create(1, 2, 3)).Equals(ValueTuple.Create(1, 2, 3), TestEqualityComparer.Instance)); + Assert.False(((IStructuralEquatable)ValueTuple.Create(1, 2, 3)).Equals(ValueTuple.Create(0, 2, 3), TestEqualityComparer.Instance)); + Assert.False(((IStructuralEquatable)ValueTuple.Create(1, 2, 3)).Equals(ValueTuple.Create(1, 0, 3), TestEqualityComparer.Instance)); + Assert.False(((IStructuralEquatable)ValueTuple.Create(1, 2, 3)).Equals(ValueTuple.Create(1, 2, 0), TestEqualityComparer.Instance)); + Assert.Equal("(1, 2, 3, 4, 5, 6, 7, 1, 2, 3)", CreateLong(1, 2, 3, 4, 5, 6, 7, ValueTuple.Create(1, 2, 3)).ToString()); var vtWithNull = new ValueTuple(null, null, null); @@ -905,6 +929,18 @@ public static void FourTuples() Assert.False(((IStructuralEquatable)sc).Equals(sc, DummyTestEqualityComparer.Instance)); + Assert.True(ValueTuple.Create(1, 2, 3, 4).Equals(ValueTuple.Create(1, 2, 3, 4))); + Assert.False(ValueTuple.Create(1, 2, 3, 4).Equals(ValueTuple.Create(0, 2, 3, 4))); + Assert.False(ValueTuple.Create(1, 2, 3, 4).Equals(ValueTuple.Create(1, 0, 3, 4))); + Assert.False(ValueTuple.Create(1, 2, 3, 4).Equals(ValueTuple.Create(1, 2, 0, 4))); + Assert.False(ValueTuple.Create(1, 2, 3, 4).Equals(ValueTuple.Create(1, 2, 3, 0))); + + Assert.True(((IStructuralEquatable)ValueTuple.Create(1, 2, 3, 4)).Equals(ValueTuple.Create(1, 2, 3, 4), TestEqualityComparer.Instance)); + Assert.False(((IStructuralEquatable)ValueTuple.Create(1, 2, 3, 4)).Equals(ValueTuple.Create(0, 2, 3, 4), TestEqualityComparer.Instance)); + Assert.False(((IStructuralEquatable)ValueTuple.Create(1, 2, 3, 4)).Equals(ValueTuple.Create(1, 0, 3, 4), TestEqualityComparer.Instance)); + Assert.False(((IStructuralEquatable)ValueTuple.Create(1, 2, 3, 4)).Equals(ValueTuple.Create(1, 2, 0, 4), TestEqualityComparer.Instance)); + Assert.False(((IStructuralEquatable)ValueTuple.Create(1, 2, 3, 4)).Equals(ValueTuple.Create(1, 2, 3, 0), TestEqualityComparer.Instance)); + Assert.Equal("(1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4)", CreateLong(1, 2, 3, 4, 5, 6, 7, ValueTuple.Create(1, 2, 3, 4)).ToString()); var vtWithNull = new ValueTuple(null, null, null, null); @@ -947,6 +983,20 @@ public static void FiveTuples() Assert.False(((IStructuralEquatable)sc).Equals(sc, DummyTestEqualityComparer.Instance)); + Assert.True(ValueTuple.Create(1, 2, 3, 4, 5).Equals(ValueTuple.Create(1, 2, 3, 4, 5))); + Assert.False(ValueTuple.Create(1, 2, 3, 4, 5).Equals(ValueTuple.Create(0, 2, 3, 4, 5))); + Assert.False(ValueTuple.Create(1, 2, 3, 4, 5).Equals(ValueTuple.Create(1, 0, 3, 4, 5))); + Assert.False(ValueTuple.Create(1, 2, 3, 4, 5).Equals(ValueTuple.Create(1, 2, 0, 4, 5))); + Assert.False(ValueTuple.Create(1, 2, 3, 4, 5).Equals(ValueTuple.Create(1, 2, 3, 0, 5))); + Assert.False(ValueTuple.Create(1, 2, 3, 4, 5).Equals(ValueTuple.Create(1, 2, 3, 4, 0))); + + Assert.True(((IStructuralEquatable)ValueTuple.Create(1, 2, 3, 4, 5)).Equals(ValueTuple.Create(1, 2, 3, 4, 5), TestEqualityComparer.Instance)); + Assert.False(((IStructuralEquatable)ValueTuple.Create(1, 2, 3, 4, 5)).Equals(ValueTuple.Create(0, 2, 3, 4, 5), TestEqualityComparer.Instance)); + Assert.False(((IStructuralEquatable)ValueTuple.Create(1, 2, 3, 4, 5)).Equals(ValueTuple.Create(1, 0, 3, 4, 5), TestEqualityComparer.Instance)); + Assert.False(((IStructuralEquatable)ValueTuple.Create(1, 2, 3, 4, 5)).Equals(ValueTuple.Create(1, 2, 0, 4, 5), TestEqualityComparer.Instance)); + Assert.False(((IStructuralEquatable)ValueTuple.Create(1, 2, 3, 4, 5)).Equals(ValueTuple.Create(1, 2, 3, 0, 5), TestEqualityComparer.Instance)); + Assert.False(((IStructuralEquatable)ValueTuple.Create(1, 2, 3, 4, 5)).Equals(ValueTuple.Create(1, 2, 3, 4, 0), TestEqualityComparer.Instance)); + Assert.Equal("(1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5)", CreateLong(1, 2, 3, 4, 5, 6, 7, ValueTuple.Create(1, 2, 3, 4, 5)).ToString()); var vtWithNull = new ValueTuple(null, null, null, null, null); @@ -992,6 +1042,22 @@ public static void SixTuples() Assert.False(((IStructuralEquatable)sc).Equals(sc, DummyTestEqualityComparer.Instance)); + Assert.True(ValueTuple.Create(1, 2, 3, 4, 5, 6).Equals(ValueTuple.Create(1, 2, 3, 4, 5, 6))); + Assert.False(ValueTuple.Create(1, 2, 3, 4, 5, 6).Equals(ValueTuple.Create(0, 2, 3, 4, 5, 6))); + Assert.False(ValueTuple.Create(1, 2, 3, 4, 5, 6).Equals(ValueTuple.Create(1, 0, 3, 4, 5, 6))); + Assert.False(ValueTuple.Create(1, 2, 3, 4, 5, 6).Equals(ValueTuple.Create(1, 2, 0, 4, 5, 6))); + Assert.False(ValueTuple.Create(1, 2, 3, 4, 5, 6).Equals(ValueTuple.Create(1, 2, 3, 0, 5, 6))); + Assert.False(ValueTuple.Create(1, 2, 3, 4, 5, 6).Equals(ValueTuple.Create(1, 2, 3, 4, 0, 6))); + Assert.False(ValueTuple.Create(1, 2, 3, 4, 5, 6).Equals(ValueTuple.Create(1, 2, 3, 4, 5, 0))); + + Assert.True(((IStructuralEquatable)ValueTuple.Create(1, 2, 3, 4, 5, 6)).Equals(ValueTuple.Create(1, 2, 3, 4, 5, 6), TestEqualityComparer.Instance)); + Assert.False(((IStructuralEquatable)ValueTuple.Create(1, 2, 3, 4, 5, 6)).Equals(ValueTuple.Create(0, 2, 3, 4, 5, 6), TestEqualityComparer.Instance)); + Assert.False(((IStructuralEquatable)ValueTuple.Create(1, 2, 3, 4, 5, 6)).Equals(ValueTuple.Create(1, 0, 3, 4, 5, 6), TestEqualityComparer.Instance)); + Assert.False(((IStructuralEquatable)ValueTuple.Create(1, 2, 3, 4, 5, 6)).Equals(ValueTuple.Create(1, 2, 0, 4, 5, 6), TestEqualityComparer.Instance)); + Assert.False(((IStructuralEquatable)ValueTuple.Create(1, 2, 3, 4, 5, 6)).Equals(ValueTuple.Create(1, 2, 3, 0, 5, 6), TestEqualityComparer.Instance)); + Assert.False(((IStructuralEquatable)ValueTuple.Create(1, 2, 3, 4, 5, 6)).Equals(ValueTuple.Create(1, 2, 3, 4, 0, 6), TestEqualityComparer.Instance)); + Assert.False(((IStructuralEquatable)ValueTuple.Create(1, 2, 3, 4, 5, 6)).Equals(ValueTuple.Create(1, 2, 3, 4, 5, 0), TestEqualityComparer.Instance)); + Assert.Equal("(1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6)", CreateLong(1, 2, 3, 4, 5, 6, 7, ValueTuple.Create(1, 2, 3, 4, 5, 6)).ToString()); var vtWithNull = new ValueTuple(null, null, null, null, null, null); @@ -1040,6 +1106,24 @@ public static void SevenTuples() Assert.False(((IStructuralEquatable)sc).Equals(sc, DummyTestEqualityComparer.Instance)); + Assert.True(ValueTuple.Create(1, 2, 3, 4, 5, 6, 7).Equals(ValueTuple.Create(1, 2, 3, 4, 5, 6, 7))); + Assert.False(ValueTuple.Create(1, 2, 3, 4, 5, 6, 7).Equals(ValueTuple.Create(0, 2, 3, 4, 5, 6, 7))); + Assert.False(ValueTuple.Create(1, 2, 3, 4, 5, 6, 7).Equals(ValueTuple.Create(1, 0, 3, 4, 5, 6, 7))); + Assert.False(ValueTuple.Create(1, 2, 3, 4, 5, 6, 7).Equals(ValueTuple.Create(1, 2, 0, 4, 5, 6, 7))); + Assert.False(ValueTuple.Create(1, 2, 3, 4, 5, 6, 7).Equals(ValueTuple.Create(1, 2, 3, 0, 5, 6, 7))); + Assert.False(ValueTuple.Create(1, 2, 3, 4, 5, 6, 7).Equals(ValueTuple.Create(1, 2, 3, 4, 0, 6, 7))); + Assert.False(ValueTuple.Create(1, 2, 3, 4, 5, 6, 7).Equals(ValueTuple.Create(1, 2, 3, 4, 5, 0, 7))); + Assert.False(ValueTuple.Create(1, 2, 3, 4, 5, 6, 7).Equals(ValueTuple.Create(1, 2, 3, 4, 5, 6, 0))); + + Assert.True(((IStructuralEquatable)ValueTuple.Create(1, 2, 3, 4, 5, 6, 7)).Equals(ValueTuple.Create(1, 2, 3, 4, 5, 6, 7), TestEqualityComparer.Instance)); + Assert.False(((IStructuralEquatable)ValueTuple.Create(1, 2, 3, 4, 5, 6, 7)).Equals(ValueTuple.Create(0, 2, 3, 4, 5, 6, 7), TestEqualityComparer.Instance)); + Assert.False(((IStructuralEquatable)ValueTuple.Create(1, 2, 3, 4, 5, 6, 7)).Equals(ValueTuple.Create(1, 0, 3, 4, 5, 6, 7), TestEqualityComparer.Instance)); + Assert.False(((IStructuralEquatable)ValueTuple.Create(1, 2, 3, 4, 5, 6, 7)).Equals(ValueTuple.Create(1, 2, 0, 4, 5, 6, 7), TestEqualityComparer.Instance)); + Assert.False(((IStructuralEquatable)ValueTuple.Create(1, 2, 3, 4, 5, 6, 7)).Equals(ValueTuple.Create(1, 2, 3, 0, 5, 6, 7), TestEqualityComparer.Instance)); + Assert.False(((IStructuralEquatable)ValueTuple.Create(1, 2, 3, 4, 5, 6, 7)).Equals(ValueTuple.Create(1, 2, 3, 4, 0, 6, 7), TestEqualityComparer.Instance)); + Assert.False(((IStructuralEquatable)ValueTuple.Create(1, 2, 3, 4, 5, 6, 7)).Equals(ValueTuple.Create(1, 2, 3, 4, 5, 0, 7), TestEqualityComparer.Instance)); + Assert.False(((IStructuralEquatable)ValueTuple.Create(1, 2, 3, 4, 5, 6, 7)).Equals(ValueTuple.Create(1, 2, 3, 4, 5, 6, 0), TestEqualityComparer.Instance)); + Assert.Equal("(1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7)", CreateLong(1, 2, 3, 4, 5, 6, 7, ValueTuple.Create(1, 2, 3, 4, 5, 6, 7)).ToString()); var vtWithNull = new ValueTuple(null, null, null, null, null, null, null); @@ -1109,6 +1193,26 @@ public static void EightTuples() Assert.False(se.Equals(t, DummyTestEqualityComparer.Instance)); + Assert.True(CreateLong(1, 2, 3, 4, 5, 6, 7, ValueTuple.Create(8)).Equals(CreateLong(1, 2, 3, 4, 5, 6, 7, ValueTuple.Create(8)))); + Assert.False(CreateLong(1, 2, 3, 4, 5, 6, 7, ValueTuple.Create(8)).Equals(CreateLong(0, 2, 3, 4, 5, 6, 7, ValueTuple.Create(8)))); + Assert.False(CreateLong(1, 2, 3, 4, 5, 6, 7, ValueTuple.Create(8)).Equals(CreateLong(1, 0, 3, 4, 5, 6, 7, ValueTuple.Create(8)))); + Assert.False(CreateLong(1, 2, 3, 4, 5, 6, 7, ValueTuple.Create(8)).Equals(CreateLong(1, 2, 0, 4, 5, 6, 7, ValueTuple.Create(8)))); + Assert.False(CreateLong(1, 2, 3, 4, 5, 6, 7, ValueTuple.Create(8)).Equals(CreateLong(1, 2, 3, 0, 5, 6, 7, ValueTuple.Create(8)))); + Assert.False(CreateLong(1, 2, 3, 4, 5, 6, 7, ValueTuple.Create(8)).Equals(CreateLong(1, 2, 3, 4, 0, 6, 7, ValueTuple.Create(8)))); + Assert.False(CreateLong(1, 2, 3, 4, 5, 6, 7, ValueTuple.Create(8)).Equals(CreateLong(1, 2, 3, 4, 5, 0, 7, ValueTuple.Create(8)))); + Assert.False(CreateLong(1, 2, 3, 4, 5, 6, 7, ValueTuple.Create(8)).Equals(CreateLong(1, 2, 3, 4, 5, 6, 0, ValueTuple.Create(8)))); + Assert.False(CreateLong(1, 2, 3, 4, 5, 6, 7, ValueTuple.Create(8)).Equals(CreateLong(1, 2, 3, 4, 5, 6, 7, ValueTuple.Create(0)))); + + Assert.True(((IStructuralEquatable)CreateLong(1, 2, 3, 4, 5, 6, 7, ValueTuple.Create(8))).Equals(CreateLong(1, 2, 3, 4, 5, 6, 7, ValueTuple.Create(8)), TestEqualityComparer.Instance)); + Assert.False(((IStructuralEquatable)CreateLong(1, 2, 3, 4, 5, 6, 7, ValueTuple.Create(8))).Equals(CreateLong(0, 2, 3, 4, 5, 6, 7, ValueTuple.Create(8)), TestEqualityComparer.Instance)); + Assert.False(((IStructuralEquatable)CreateLong(1, 2, 3, 4, 5, 6, 7, ValueTuple.Create(8))).Equals(CreateLong(1, 0, 3, 4, 5, 6, 7, ValueTuple.Create(8)), TestEqualityComparer.Instance)); + Assert.False(((IStructuralEquatable)CreateLong(1, 2, 3, 4, 5, 6, 7, ValueTuple.Create(8))).Equals(CreateLong(1, 2, 0, 4, 5, 6, 7, ValueTuple.Create(8)), TestEqualityComparer.Instance)); + Assert.False(((IStructuralEquatable)CreateLong(1, 2, 3, 4, 5, 6, 7, ValueTuple.Create(8))).Equals(CreateLong(1, 2, 3, 0, 5, 6, 7, ValueTuple.Create(8)), TestEqualityComparer.Instance)); + Assert.False(((IStructuralEquatable)CreateLong(1, 2, 3, 4, 5, 6, 7, ValueTuple.Create(8))).Equals(CreateLong(1, 2, 3, 4, 0, 6, 7, ValueTuple.Create(8)), TestEqualityComparer.Instance)); + Assert.False(((IStructuralEquatable)CreateLong(1, 2, 3, 4, 5, 6, 7, ValueTuple.Create(8))).Equals(CreateLong(1, 2, 3, 4, 5, 0, 7, ValueTuple.Create(8)), TestEqualityComparer.Instance)); + Assert.False(((IStructuralEquatable)CreateLong(1, 2, 3, 4, 5, 6, 7, ValueTuple.Create(8))).Equals(CreateLong(1, 2, 3, 4, 5, 6, 0, ValueTuple.Create(8)), TestEqualityComparer.Instance)); + Assert.False(((IStructuralEquatable)CreateLong(1, 2, 3, 4, 5, 6, 7, ValueTuple.Create(8))).Equals(CreateLong(1, 2, 3, 4, 5, 6, 7, ValueTuple.Create(0)), TestEqualityComparer.Instance)); + // Notice that 0-tuple prints as empty position Assert.Equal("(1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, )", CreateLong(1, 2, 3, 4, 5, 6, 7, CreateLong(1, 2, 3, 4, 5, 6, 7, ValueTuple.Create())).ToString()); diff --git a/src/libraries/apicompat/ApiCompatBaseline.netstandard2.0.xml b/src/libraries/apicompat/ApiCompatBaseline.netstandard2.0.xml index f19da8b94090d0..12fce752d1b3d7 100644 --- a/src/libraries/apicompat/ApiCompatBaseline.netstandard2.0.xml +++ b/src/libraries/apicompat/ApiCompatBaseline.netstandard2.0.xml @@ -2773,10 +2773,22 @@ netstandard2.0/netstandard.dll net8.0/netstandard.dll + + CP0015 + T:System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverageAttribute:[T:System.AttributeUsageAttribute] + netstandard2.0/netstandard.dll + net8.0/netstandard.dll + CP0015 P:System.Timers.Timer.Interval:[T:System.ComponentModel.DefaultValueAttribute] netstandard2.0/System.dll net8.0/System.dll - \ No newline at end of file + + CP0015 + T:System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverageAttribute:[T:System.AttributeUsageAttribute] + netstandard2.0/System.dll + net8.0/System.dll + + diff --git a/src/libraries/apicompat/ApiCompatBaseline.netstandard2.1.xml b/src/libraries/apicompat/ApiCompatBaseline.netstandard2.1.xml index de4ff50c7dc0ed..a6009206387156 100644 --- a/src/libraries/apicompat/ApiCompatBaseline.netstandard2.1.xml +++ b/src/libraries/apicompat/ApiCompatBaseline.netstandard2.1.xml @@ -853,6 +853,12 @@ netstandard2.1/netstandard.dll net8.0/netstandard.dll + + CP0015 + T:System.Runtime.CompilerServices.AsyncMethodBuilderAttribute:[T:System.AttributeUsageAttribute] + netstandard2.1/netstandard.dll + net8.0/netstandard.dll + CP0015 T:System.Runtime.InteropServices.ManagedToNativeComInteropStubAttribute:[T:System.AttributeUsageAttribute] @@ -871,4 +877,4 @@ netstandard2.1/netstandard.dll net8.0/netstandard.dll - \ No newline at end of file + diff --git a/src/libraries/shims/mscorlib/src/mscorlib.cs b/src/libraries/shims/mscorlib/src/mscorlib.cs index 857e2a83bfc48e..eb77227a1130a0 100644 --- a/src/libraries/shims/mscorlib/src/mscorlib.cs +++ b/src/libraries/shims/mscorlib/src/mscorlib.cs @@ -109,3 +109,9 @@ [assembly:System.Runtime.CompilerServices.TypeForwardedTo(typeof(System.Threading.Tasks.Sources.ManualResetValueTaskSourceCore<>))] [assembly:System.Runtime.CompilerServices.TypeForwardedTo(typeof(System.Threading.Tasks.Sources.ValueTaskSourceOnCompletedFlags))] [assembly:System.Runtime.CompilerServices.TypeForwardedTo(typeof(System.Threading.Tasks.Sources.ValueTaskSourceStatus))] +// These types are required for back-compatibility with .NET Framework and previous versions of .NETCoreApp. --> +[assembly:System.Runtime.CompilerServices.TypeForwardedTo(typeof(System.Reflection.Emit.PEFileKinds))] +[assembly:System.Runtime.CompilerServices.TypeForwardedTo(typeof(System.Runtime.InteropServices.AssemblyRegistrationFlags))] +[assembly:System.Runtime.CompilerServices.TypeForwardedTo(typeof(System.Runtime.InteropServices.ExporterEventKind))] +[assembly:System.Runtime.CompilerServices.TypeForwardedTo(typeof(System.Runtime.InteropServices.RegistrationClassContext))] +[assembly:System.Runtime.CompilerServices.TypeForwardedTo(typeof(System.Runtime.InteropServices.RegistrationConnectionType))] diff --git a/src/libraries/testPackages/testPackages.proj b/src/libraries/testPackages/testPackages.proj index 6480f5fc5928bf..9d60aeb8e6e64b 100644 --- a/src/libraries/testPackages/testPackages.proj +++ b/src/libraries/testPackages/testPackages.proj @@ -63,12 +63,24 @@ - + + + + + + + + + + + + + + diff --git a/src/libraries/tests.proj b/src/libraries/tests.proj index 7ef3d7b0235ebb..a550e9cf9cc827 100644 --- a/src/libraries/tests.proj +++ b/src/libraries/tests.proj @@ -266,6 +266,7 @@ + + diff --git a/src/mono/CMakeLists.txt b/src/mono/CMakeLists.txt index 66e92d80564363..05766210ea6be5 100644 --- a/src/mono/CMakeLists.txt +++ b/src/mono/CMakeLists.txt @@ -15,6 +15,9 @@ if (MSVC) if(EXISTS ${CLR_SOURCELINK_FILE_PATH}) add_link_options("/sourcelink:${CLR_SOURCELINK_FILE_PATH}") endif() + + # FIXME: Remove the line below when https://github.com/dotnet/runtime/issues/91249 is fixed. + add_compile_options($<$:/wd4244>) # conversion from 'type1' to 'type2', possible loss of data endif(MSVC) set(CROSS_ROOTFS $ENV{ROOTFS_DIR}) @@ -242,14 +245,6 @@ elseif(CLR_CMAKE_HOST_OS STREQUAL "emscripten") add_compile_options(-Wno-strict-prototypes) add_compile_options(-Wno-unused-but-set-variable) add_compile_options(-Wno-single-bit-bitfield-constant-conversion) - # Allow using WASM simd intrinsics in the interpreter - add_compile_options(-msimd128) - # Disable autovectorization (it is automatically turned on by msimd128) - add_compile_options(-disable-loop-vectorization) - add_compile_options(-disable-vectorization) - add_compile_options(-fno-vectorize) - add_compile_options(-fno-tree-vectorize) - add_compile_options(-fno-slp-vectorize) set(DISABLE_EXECUTABLES 1) # FIXME: Is there a cmake option for this ? set(DISABLE_SHARED_LIBS 1) diff --git a/src/mono/System.Private.CoreLib/src/System/GC.Mono.cs b/src/mono/System.Private.CoreLib/src/System/GC.Mono.cs index 9e87bfe1e167a9..485576645537f0 100644 --- a/src/mono/System.Private.CoreLib/src/System/GC.Mono.cs +++ b/src/mono/System.Private.CoreLib/src/System/GC.Mono.cs @@ -320,7 +320,6 @@ public static System.Collections.Generic.IReadOnlyDictionary Get return new System.Collections.Generic.Dictionary(); } - [System.Runtime.Versioning.RequiresPreviewFeaturesAttribute("RefreshMemoryLimit is in preview.")] public static void RefreshMemoryLimit() { throw new PlatformNotSupportedException(); diff --git a/src/mono/System.Private.CoreLib/src/System/RuntimeType.Mono.cs b/src/mono/System.Private.CoreLib/src/System/RuntimeType.Mono.cs index e7feb145b18846..5924d5c0640df7 100644 --- a/src/mono/System.Private.CoreLib/src/System/RuntimeType.Mono.cs +++ b/src/mono/System.Private.CoreLib/src/System/RuntimeType.Mono.cs @@ -2048,6 +2048,16 @@ public override bool ContainsGenericParameters if (HasElementType) return GetElementType().ContainsGenericParameters; + if (IsFunctionPointer) + { + if (GetFunctionPointerReturnType().ContainsGenericParameters) + return true; + + foreach (Type arg in GetFunctionPointerParameterTypes()) + if (arg.ContainsGenericParameters) + return true; + } + return false; } } diff --git a/src/mono/mono.proj b/src/mono/mono.proj index cf99f325ea5565..6f15a31d94dca4 100644 --- a/src/mono/mono.proj +++ b/src/mono/mono.proj @@ -254,9 +254,10 @@ - - <_MonoLLVMTargetArchitecture Condition="'$(TargetArchitecture)' == 'wasm' and '$(MonoUseLLVMPackage)' == 'true'">$(BuildArchitecture) - <_MonoLLVMTargetArchitecture Condition="'$(TargetArchitecture)' != 'wasm' and '$(MonoUseLLVMPackage)' == 'true'">$(TargetArchitecture) + + <_MonoLLVMTargetArchitecture Condition="'$(TargetArchitecture)' == 'wasm'">$(BuildArchitecture) + <_MonoLLVMTargetArchitecture Condition="'$(TargetArchitecture)' != 'wasm'">$(TargetArchitecture) + <_MonoLLVMHostArchitecture Condition="'$(AotHostArchitecture)' != ''">$(AotHostArchitecture) <_MonoCMakeArgs Condition="'$(_MonoUseNinja)' == 'true'" Include="-G Ninja"/> @@ -698,14 +699,14 @@ $(MonoCrossDir)/usr/lib/gcc/aarch64-linux-gnu/5 - - <_MonoLLVMTargetArchitecture Condition="'$(MonoUseLLVMPackage)' == 'true'">$(BuildArchitecture) - <_MonoLLVMTargetArchitecture Condition="'$(MonoUseLLVMPackage)' == 'true'">$(AotHostArchitecture) + + <_MonoLLVMTargetArchitecture>$(TargetArchitecture) + <_MonoLLVMHostArchitecture>$(AotHostArchitecture) - <_MonoAOTCXXFLAGS Include="-I$(MonoLLVMDir)\$(_MonoLLVMTargetArchitecture)\include\c++\v1" /> - <_MonoAOTCXXFLAGS Include="-L$(MonoLLVMDir)\$(_MonoLLVMTargetArchitecture)\lib" /> + <_MonoAOTCXXFLAGS Include="-I$(MonoLLVMDir)\$(_MonoLLVMHostArchitecture)\include\c++\v1" /> + <_MonoAOTCXXFLAGS Include="-L$(MonoLLVMDir)\$(_MonoLLVMHostArchitecture)\lib" /> <_MonoAOTCXXFLAGS Include="-stdlib=libc++" /> @@ -844,7 +845,7 @@ - + @@ -936,8 +937,7 @@ <_MonoAotCrossPdbFilePath>$(MonoObjCrossDir)out\bin\$(MonoAotCrossPdbFileName) - <_MonoLLVMTargetArchitecture>$(BuildArchitecture) - <_MonoLLVMTargetArchitecture Condition="'$(TargetArchitecture)' != 'wasm'">$(AotHostArchitecture) + <_MonoLLVMHostArchitecture>$(AotHostArchitecture) @@ -977,25 +977,25 @@ <_MonoRuntimeArtifacts Condition="'$(HostOS)' == 'Linux' and ('$(MonoBundleLLVMOptimizer)' == 'true' or '$(MonoEnableLLVM)' == 'true') and '$(TargetArchitecture)' != 'wasm' and '$(MonoUseLibCxx)' == 'true'" Include="$(MonoLLVMDir)\$(_MonoLLVMTargetArchitecture)\lib\libc++abi.so.1"> $(RuntimeBinDir)libc++abi.so.1 - <_MonoRuntimeArtifacts Condition="'$(HostOS)' == 'Linux' and ((('$(MonoAOTBundleLLVMOptimizer)' == 'true' or '$(MonoAOTEnableLLVM)' == 'true') and '$(MonoUseLibCxx)' == 'true') or '$(TargetArchitecture)' == 'wasm')" Include="$(MonoLLVMDir)\$(_MonoLLVMTargetArchitecture)\lib\libc++.so.1"> + <_MonoRuntimeArtifacts Condition="'$(HostOS)' == 'Linux' and ((('$(MonoAOTBundleLLVMOptimizer)' == 'true' or '$(MonoAOTEnableLLVM)' == 'true') and '$(MonoUseLibCxx)' == 'true') or '$(TargetArchitecture)' == 'wasm')" Include="$(MonoLLVMDir)\$(_MonoLLVMHostArchitecture)\lib\libc++.so.1"> $(RuntimeBinDir)cross\$(OutputRID)\libc++.so.1 - <_MonoRuntimeArtifacts Condition="'$(HostOS)' == 'Linux' and ((('$(MonoAOTBundleLLVMOptimizer)' == 'true' or '$(MonoAOTEnableLLVM)' == 'true') and '$(MonoUseLibCxx)' == 'true') or '$(TargetArchitecture)' == 'wasm')" Include="$(MonoLLVMDir)\$(_MonoLLVMTargetArchitecture)\lib\libc++abi.so.1"> + <_MonoRuntimeArtifacts Condition="'$(HostOS)' == 'Linux' and ((('$(MonoAOTBundleLLVMOptimizer)' == 'true' or '$(MonoAOTEnableLLVM)' == 'true') and '$(MonoUseLibCxx)' == 'true') or '$(TargetArchitecture)' == 'wasm')" Include="$(MonoLLVMDir)\$(_MonoLLVMHostArchitecture)\lib\libc++abi.so.1"> $(RuntimeBinDir)cross\$(OutputRID)\libc++abi.so.1 <_MonoRuntimeArtifacts Include="$(_MonoAotCrossPdbFilePath)" Condition="Exists('$(_MonoAotCrossPdbFilePath)')"> $(RuntimeBinDir)cross\$(OutputRID)\$(MonoAotCrossPdbFileName) - <_MonoRuntimeArtifacts Condition="'$(MonoBundleLLVMOptimizer)' == 'true'" Include="$(MonoLLVMDir)\$(_MonoLLVMTargetArchitecture)\bin\llc$(ExeSuffix)"> + <_MonoRuntimeArtifacts Condition="'$(MonoBundleLLVMOptimizer)' == 'true'" Include="$(MonoLLVMDir)\$(_MonoLLVMHostArchitecture)\bin\llc$(ExeSuffix)"> $(RuntimeBinDir)\llc$(ExeSuffix) - <_MonoRuntimeArtifacts Condition="'$(MonoBundleLLVMOptimizer)' == 'true'" Include="$(MonoLLVMDir)\$(_MonoLLVMTargetArchitecture)\bin\opt$(ExeSuffix)"> + <_MonoRuntimeArtifacts Condition="'$(MonoBundleLLVMOptimizer)' == 'true'" Include="$(MonoLLVMDir)\$(_MonoLLVMHostArchitecture)\bin\opt$(ExeSuffix)"> $(RuntimeBinDir)\opt$(ExeSuffix) - <_MonoRuntimeArtifacts Condition="'$(MonoAOTBundleLLVMOptimizer)' == 'true'" Include="$(MonoLLVMDir)\$(_MonoLLVMTargetArchitecture)\bin\llc$(ExeSuffix)"> + <_MonoRuntimeArtifacts Condition="'$(MonoAOTBundleLLVMOptimizer)' == 'true'" Include="$(MonoLLVMDir)\$(_MonoLLVMHostArchitecture)\bin\llc$(ExeSuffix)"> $(RuntimeBinDir)cross\$(OutputRID)\llc$(ExeSuffix) - <_MonoRuntimeArtifacts Condition="'$(MonoAOTBundleLLVMOptimizer)' == 'true'" Include="$(MonoLLVMDir)\$(_MonoLLVMTargetArchitecture)\bin\opt$(ExeSuffix)"> + <_MonoRuntimeArtifacts Condition="'$(MonoAOTBundleLLVMOptimizer)' == 'true'" Include="$(MonoLLVMDir)\$(_MonoLLVMHostArchitecture)\bin\opt$(ExeSuffix)"> $(RuntimeBinDir)cross\$(OutputRID)\opt$(ExeSuffix) <_MonoIncludeArtifacts Include="$(MonoObjDir)out\include\**" /> @@ -1061,6 +1061,12 @@ <_MonoRuntimeArtifacts Condition="'$(TargetsBrowser)' == 'true' and '$(BuildMonoAOTCrossCompilerOnly)' != 'true'" Include="$(MonoObjDir)out\lib\libmono-wasm-eh-wasm.a"> $(RuntimeBinDir)libmono-wasm-eh-wasm.a + <_MonoRuntimeArtifacts Condition="('$(TargetsBrowser)' == 'true' or '$(TargetsWasi)' == 'true') and '$(BuildMonoAOTCrossCompilerOnly)' != 'true'" Include="$(MonoObjDir)out\lib\libmono-wasm-simd.a"> + $(RuntimeBinDir)libmono-wasm-simd.a + + <_MonoRuntimeArtifacts Condition="('$(TargetsBrowser)' == 'true' or '$(TargetsWasi)' == 'true') and '$(BuildMonoAOTCrossCompilerOnly)' != 'true'" Include="$(MonoObjDir)out\lib\libmono-wasm-nosimd.a"> + $(RuntimeBinDir)libmono-wasm-nosimd.a + <_MonoICorDebugArtifacts Condition="'$(MonoMsCorDbi)' == 'true'" Include="$(MonoObjDir)out\lib\$(LibPrefix)mscordbi$(LibSuffix)"> $(RuntimeBinDir)$(LibPrefix)mscordbi$(LibSuffix) diff --git a/src/mono/mono/component/debugger-agent.c b/src/mono/mono/component/debugger-agent.c index a210242c1eb691..9a6906f6e272ef 100644 --- a/src/mono/mono/component/debugger-agent.c +++ b/src/mono/mono/component/debugger-agent.c @@ -5258,6 +5258,13 @@ buffer_add_value_full (Buffer *buf, MonoType *t, void *addr, MonoDomain *domain, nfields = 0; iter = NULL; while ((f = mono_class_get_fields_internal (klass, &iter))) { + if (G_UNLIKELY (!f->type)) { + ERROR_DECL(error); + mono_field_resolve_type (f, error); + mono_error_cleanup (error); + if (!f->type) + continue; + } if (f->type->attrs & FIELD_ATTRIBUTE_STATIC) continue; if (mono_field_is_deleted (f)) @@ -5275,6 +5282,13 @@ buffer_add_value_full (Buffer *buf, MonoType *t, void *addr, MonoDomain *domain, iter = NULL; while ((f = mono_class_get_fields_internal (klass, &iter))) { + if (G_UNLIKELY (!f->type)) { + ERROR_DECL(error); + mono_field_resolve_type (f, error); + mono_error_cleanup (error); + if (!f->type) + continue; + } if (f->type->attrs & FIELD_ATTRIBUTE_STATIC) continue; if (mono_field_is_deleted (f)) @@ -5375,6 +5389,13 @@ decode_vtype (MonoType *t, MonoDomain *domain, gpointer void_addr, gpointer void nfields = decode_int (buf, &buf, limit); while ((f = mono_class_get_fields_internal (klass, &iter))) { + if (G_UNLIKELY (!f->type)) { + ERROR_DECL(error); + mono_field_resolve_type (f, error); + mono_error_cleanup (error); + if (!f->type) + continue; + } if (f->type->attrs & FIELD_ATTRIBUTE_STATIC) continue; if (mono_field_is_deleted (f)) @@ -5476,6 +5497,13 @@ decode_vtype_compute_size (MonoType *t, MonoDomain *domain, gpointer void_buf, g nfields = decode_int (buf, &buf, limit); while ((f = mono_class_get_fields_internal (klass, &iter))) { + if (G_UNLIKELY (!f->type)) { + ERROR_DECL(error); + mono_field_resolve_type (f, error); + mono_error_cleanup (error); + if (!f->type) + continue; + } if (f->type->attrs & FIELD_ATTRIBUTE_STATIC) continue; if (mono_field_is_deleted (f)) @@ -8481,6 +8509,13 @@ type_commands_internal (int command, MonoClass *klass, MonoDomain *domain, guint buffer_add_int (buf, nfields); while ((f = mono_class_get_fields_internal (klass, &iter))) { + if (G_UNLIKELY (!f->type)) { + ERROR_DECL(field_error); + mono_field_resolve_type (f, field_error); + mono_error_cleanup (field_error); + if (!f->type) + continue; + } buffer_add_fieldid (buf, domain, f); buffer_add_string (buf, f->name); buffer_add_typeid (buf, domain, mono_class_from_mono_type_internal (f->type)); @@ -8861,6 +8896,13 @@ type_commands_internal (int command, MonoClass *klass, MonoDomain *domain, guint int nfields = 0; gpointer iter = NULL; while ((f = mono_class_get_fields_internal (klass, &iter))) { + if (G_UNLIKELY (!f->type)) { + ERROR_DECL(field_error); + mono_field_resolve_type (f, field_error); + mono_error_cleanup (field_error); + if (!f->type) + continue; + } if (f->type->attrs & FIELD_ATTRIBUTE_STATIC) continue; if (mono_field_is_deleted (f)) @@ -8871,6 +8913,13 @@ type_commands_internal (int command, MonoClass *klass, MonoDomain *domain, guint iter = NULL; while ((f = mono_class_get_fields_internal (klass, &iter))) { + if (G_UNLIKELY (!f->type)) { + ERROR_DECL(field_error); + mono_field_resolve_type (f, field_error); + mono_error_cleanup (field_error); + if (!f->type) + continue; + } if (f->type->attrs & FIELD_ATTRIBUTE_STATIC) continue; if (mono_field_is_deleted (f)) diff --git a/src/mono/mono/metadata/class.c b/src/mono/mono/metadata/class.c index 284e3153a1a253..c5fcd2a8d7a18d 100644 --- a/src/mono/mono/metadata/class.c +++ b/src/mono/mono/metadata/class.c @@ -4100,7 +4100,7 @@ mono_class_is_assignable_from_general (MonoClass *klass, MonoClass *oklass, gboo return; } - if (m_class_is_array_special_interface (klass) && m_class_get_rank (oklass) == 1) { + if (m_class_is_array_special_interface (klass) && m_class_get_rank (oklass) == 1 && m_class_get_byval_arg (oklass)->type == MONO_TYPE_SZARRAY) { if (mono_class_is_gtd (klass)) { /* klass is an array special gtd like * IList`1<>, and oklass is X[] for some X. diff --git a/src/mono/mono/metadata/metadata.c b/src/mono/mono/metadata/metadata.c index 7c7cd4433eecf6..f9abe652b019a3 100644 --- a/src/mono/mono/metadata/metadata.c +++ b/src/mono/mono/metadata/metadata.c @@ -3434,7 +3434,8 @@ mono_metadata_get_canonical_generic_inst (MonoGenericInst *candidate) MonoMemoryManager *mm = mono_mem_manager_get_generic (data.images, data.nimages); collect_data_free (&data); - mono_mem_manager_lock (mm); + // Hashtable key equal func can take loader lock + mono_loader_lock (); if (!mm->ginst_cache) mm->ginst_cache = g_hash_table_new_full (mono_metadata_generic_inst_hash, mono_metadata_generic_inst_equal, NULL, (GDestroyNotify)free_generic_inst); @@ -3456,7 +3457,7 @@ mono_metadata_get_canonical_generic_inst (MonoGenericInst *candidate) g_hash_table_insert (mm->ginst_cache, ginst, ginst); } - mono_mem_manager_unlock (mm); + mono_loader_unlock (); return ginst; } @@ -3467,7 +3468,8 @@ mono_metadata_get_canonical_aggregate_modifiers (MonoAggregateModContainer *cand g_assert (candidate->count > 0); MonoMemoryManager *mm = mono_metadata_get_mem_manager_for_aggregate_modifiers (candidate); - mono_mem_manager_lock (mm); + // Hashtable key equal func can take loader lock + mono_loader_lock (); if (!mm->aggregate_modifiers_cache) mm->aggregate_modifiers_cache = g_hash_table_new_full (aggregate_modifiers_hash, aggregate_modifiers_equal, NULL, (GDestroyNotify)free_aggregate_modifiers); @@ -3484,7 +3486,7 @@ mono_metadata_get_canonical_aggregate_modifiers (MonoAggregateModContainer *cand g_hash_table_insert (mm->aggregate_modifiers_cache, amods, amods); } - mono_mem_manager_unlock (mm); + mono_loader_unlock (); return amods; } @@ -3543,7 +3545,8 @@ mono_metadata_lookup_generic_class (MonoClass *container_class, MonoGenericInst if (gclass) return gclass; - mono_mem_manager_lock (mm); + // Hashtable key equal func can take loader lock + mono_loader_lock (); gclass = mono_mem_manager_alloc0 (mm, sizeof (MonoGenericClass)); if (is_dynamic) @@ -3563,7 +3566,7 @@ mono_metadata_lookup_generic_class (MonoClass *container_class, MonoGenericInst // g_hash_table_insert (set->gclass_cache, gclass, gclass); - mono_mem_manager_unlock (mm); + mono_loader_unlock (); return gclass2; } diff --git a/src/mono/mono/metadata/sre.c b/src/mono/mono/metadata/sre.c index de94e8c4bf9a3b..3f47f9e84d008e 100644 --- a/src/mono/mono/metadata/sre.c +++ b/src/mono/mono/metadata/sre.c @@ -1289,6 +1289,7 @@ image_module_basic_init (MonoReflectionModuleBuilderHandle moduleb, MonoError *e * determined at assembly save time. */ /*image = (MonoDynamicImage*)ab->dynamic_assembly->assembly.image; */ + MonoAssemblyLoadContext *alc = mono_alc_get_default (); MonoStringHandle abname = MONO_HANDLE_NEW_GET (MonoString, ab, name); char *name = mono_string_handle_to_utf8 (abname, error); return_val_if_nok (error, FALSE); @@ -1300,6 +1301,7 @@ image_module_basic_init (MonoReflectionModuleBuilderHandle moduleb, MonoError *e } MonoDynamicAssembly *dynamic_assembly = MONO_HANDLE_GETVAL (ab, dynamic_assembly); image = mono_dynamic_image_create (dynamic_assembly, name, fqname); + image->image.alc = alc; MONO_HANDLE_SETVAL (MONO_HANDLE_CAST (MonoReflectionModule, moduleb), image, MonoImage*, &image->image); MONO_HANDLE_SETVAL (moduleb, dynamic_image, MonoDynamicImage*, image); diff --git a/src/mono/mono/mini/CMakeLists.txt b/src/mono/mono/mini/CMakeLists.txt index 884b43c1b1eec4..5d6ef3dfa3c31e 100644 --- a/src/mono/mono/mini/CMakeLists.txt +++ b/src/mono/mono/mini/CMakeLists.txt @@ -288,7 +288,6 @@ set(interp_sources interp/interp.h interp/interp-internals.h interp/interp.c - interp/interp-simd.c interp/interp-intrins.h interp/interp-intrins.c interp/mintops.h @@ -297,11 +296,17 @@ set(interp_sources interp/tiering.h interp/tiering.c interp/jiterpreter.c) +set(interp_simd_sources + interp/interp-simd.c) set(interp_stub_sources interp-stubs.c) if(NOT DISABLE_INTERPRETER) -set(mini_interp_sources ${interp_sources}) + if(HOST_WASM) + set(mini_interp_sources ${interp_sources}) + else() + set(mini_interp_sources ${interp_sources} ${interp_simd_sources}) + endif() else() set(mini_interp_sources ${interp_stub_sources}) endif() @@ -504,6 +509,19 @@ if(HOST_BROWSER) install(TARGETS mono-wasm-eh-wasm LIBRARY) endif() +if(HOST_BROWSER OR HOST_WASI) + add_library(mono-wasm-simd STATIC interp/interp-simd.c) + target_link_libraries (mono-wasm-simd PRIVATE monoapi eglib_api) + set_target_properties(mono-wasm-simd PROPERTIES COMPILE_FLAGS "-msimd128") + install(TARGETS mono-wasm-simd LIBRARY) +endif() + +if(HOST_BROWSER OR HOST_WASI OR TARGET_WASM) + add_library(mono-wasm-nosimd STATIC interp/interp-nosimd.c) + target_link_libraries (mono-wasm-nosimd PRIVATE monoapi eglib_api) + install(TARGETS mono-wasm-nosimd LIBRARY) +endif() + find_package(Python3 COMPONENTS Interpreter) add_custom_command( @@ -576,6 +594,9 @@ if(NOT DISABLE_EXECUTABLES) endif() endif() target_link_libraries(mono-sgen PRIVATE monoapi eglib_api monosgen-static) + if (HOST_WASM) + target_link_libraries(mono-sgen PRIVATE mono-wasm-nosimd) + endif() if(HAVE_ICU_SHIM) target_link_libraries(mono-sgen PRIVATE icu_shim_objects) endif() diff --git a/src/mono/mono/mini/aot-compiler.c b/src/mono/mono/mini/aot-compiler.c index 1944a77d2c00b0..c342ef5d007571 100644 --- a/src/mono/mono/mini/aot-compiler.c +++ b/src/mono/mono/mini/aot-compiler.c @@ -3844,10 +3844,12 @@ encode_method_ref (MonoAotCompile *acfg, MonoMethod *method, guint8 *buf, guint8 else if (info->subtype == WRAPPER_SUBTYPE_UNSAFE_ACCESSOR) { encode_method_ref (acfg, info->d.unsafe_accessor.method, p, &p); encode_value (info->d.unsafe_accessor.kind, p, &p); - /* WISH: is there some kind of string heap token we could use here? */ - uint32_t len = (uint32_t) strlen (info->d.unsafe_accessor.member_name); - encode_value (len, p, &p); - encode_string (info->d.unsafe_accessor.member_name, p, &p); + if (info->d.unsafe_accessor.member_name) { + /* WISH: is there some kind of string heap token we could use here? */ + uint32_t len = (uint32_t) strlen (info->d.unsafe_accessor.member_name); + encode_value (len, p, &p); + encode_string (info->d.unsafe_accessor.member_name, p, &p); + } } else if (info->subtype == WRAPPER_SUBTYPE_INTERP_IN) encode_signature (acfg, info->d.interp_in.sig, p, &p); @@ -4324,6 +4326,7 @@ get_method_index (MonoAotCompile *acfg, MonoMethod *method) return index - 1; } +/* Return TRUE if the method can be skipped */ static gboolean collect_dedup_method (MonoAotCompile *acfg, MonoMethod *method) { @@ -4332,14 +4335,16 @@ collect_dedup_method (MonoAotCompile *acfg, MonoMethod *method) if (acfg->dedup_phase == DEDUP_SKIP) return TRUE; // Remember for later - if (acfg->dedup_phase == DEDUP_COLLECT && !g_hash_table_lookup (dedup_methods, method)) + g_assert (acfg->dedup_phase == DEDUP_COLLECT); + if (!g_hash_table_lookup (dedup_methods, method)) g_hash_table_insert (dedup_methods, method, method); + else + // Already processed when compiling another assembly + return TRUE; } return FALSE; } - - static int add_method_full (MonoAotCompile *acfg, MonoMethod *method, gboolean extra, int depth) { @@ -10703,6 +10708,18 @@ execute_system (const char * command) #ifdef ENABLE_LLVM +#ifdef HOST_WIN32 +#define OPT_NAME "opt.exe" +#else +#define OPT_NAME "opt" +#endif + +#ifdef HOST_WIN32 +#define LLC_NAME "llc.exe" +#else +#define LLC_NAME "llc" +#endif + /* * emit_llvm_file: * @@ -10771,11 +10788,11 @@ emit_llvm_file (MonoAotCompile *acfg) } else { #if LLVM_API_VERSION >= 1600 /* The safepoints pass requires new pass manager syntax*/ - opts = g_strdup ("-disable-tail-calls -passes='"); + opts = g_strdup ("-disable-tail-calls -passes=\""); if (!acfg->aot_opts.llvm_only) { opts = g_strdup_printf ("%sdefault,", opts); } - opts = g_strdup_printf ("%splace-safepoints' -spp-all-backedges", opts); + opts = g_strdup_printf ("%splace-safepoints\" -spp-all-backedges", opts); #elif LLVM_API_VERSION >= 1300 /* The safepoints pass requires the old pass manager */ opts = g_strdup ("-disable-tail-calls -place-safepoints -spp-all-backedges -enable-new-pm=0"); @@ -10805,7 +10822,7 @@ emit_llvm_file (MonoAotCompile *acfg) opts = g_strdup_printf ("%s -fp-contract=fast -enable-no-infs-fp-math -enable-no-nans-fp-math -enable-no-signed-zeros-fp-math -enable-no-trapping-fp-math -enable-unsafe-fp-math", opts); } - command = g_strdup_printf ("\"%sopt\" -f %s -o \"%s\" \"%s\"", acfg->aot_opts.llvm_path, opts, optbc, tempbc); + command = g_strdup_printf ("\"%s" OPT_NAME "\" -f %s -o \"%s\" \"%s\"", acfg->aot_opts.llvm_path, opts, optbc, tempbc); aot_printf (acfg, "Executing opt: %s\n", command); if (execute_system (command) != 0) return FALSE; @@ -10880,7 +10897,7 @@ emit_llvm_file (MonoAotCompile *acfg) g_string_append_printf (acfg->llc_args, " -mattr=%s", acfg->aot_opts.llvm_cpu_attr); } - command = g_strdup_printf ("\"%sllc\" %s -o \"%s\" \"%s.opt.bc\"", acfg->aot_opts.llvm_path, acfg->llc_args->str, output_fname, acfg->tmpbasename); + command = g_strdup_printf ("\"%s" LLC_NAME "\" %s -o \"%s\" \"%s.opt.bc\"", acfg->aot_opts.llvm_path, acfg->llc_args->str, output_fname, acfg->tmpbasename); g_free (output_fname); aot_printf (acfg, "Executing llc: %s\n", command); diff --git a/src/mono/mono/mini/decompose.c b/src/mono/mono/mini/decompose.c index b4570cc0c7429d..2be1ff52e416d0 100644 --- a/src/mono/mono/mini/decompose.c +++ b/src/mono/mono/mini/decompose.c @@ -1226,7 +1226,7 @@ mono_decompose_vtype_opts (MonoCompile *cfg) dest_var = get_vreg_to_inst (cfg, ins->dreg); if (!src_var) - src_var = mono_compile_create_var_for_vreg (cfg, m_class_get_byval_arg (ins->klass), OP_LOCAL, ins->dreg); + src_var = mono_compile_create_var_for_vreg (cfg, m_class_get_byval_arg (ins->klass), OP_LOCAL, ins->sreg1); if (!dest_var) dest_var = mono_compile_create_var_for_vreg (cfg, m_class_get_byval_arg (ins->klass), OP_LOCAL, ins->dreg); diff --git a/src/mono/mono/mini/driver.c b/src/mono/mono/mini/driver.c index c5f04a816a1fa8..a49820113d18ad 100644 --- a/src/mono/mono/mini/driver.c +++ b/src/mono/mono/mini/driver.c @@ -1617,6 +1617,7 @@ mini_usage (void) #endif " --handlers Install custom handlers, use --help-handlers for details.\n" " --aot-path=PATH List of additional directories to search for AOT images.\n" + " --path=DIR Add DIR to the list of directories to search for assemblies.\n" ); g_print ("\nOptions:\n"); @@ -2069,6 +2070,7 @@ mono_main (int argc, char* argv[]) char *aot_options = NULL; GPtrArray *agents = NULL; char *extra_bindings_config_file = NULL; + GList *paths = NULL; #ifdef MONO_JIT_INFO_TABLE_TEST int test_jit_info_table = FALSE; #endif @@ -2294,6 +2296,8 @@ mono_main (int argc, char* argv[]) g_free (tmp); split++; } + } else if (strncmp (argv [i], "--path=", 7) == 0) { + paths = g_list_append (paths, argv [i] + 7); } else if (strncmp (argv [i], "--compile-all=", 14) == 0) { action = DO_COMPILE; recompilation_times = atoi (argv [i] + 14); @@ -2503,6 +2507,16 @@ mono_main (int argc, char* argv[]) if (g_hasenv ("MONO_XDEBUG")) enable_debugging = TRUE; + if (paths) { + char **p = g_new0 (char *, g_list_length (paths) + 1); + int pindex = 0; + for (GList *l = paths; l; l = l->next) + p [pindex ++] = (char*)l->data; + g_list_free (paths); + + mono_set_assemblies_path_direct (p); + } + #ifdef MONO_CROSS_COMPILE if (!mono_compile_aot) { fprintf (stderr, "This mono runtime is compiled for cross-compiling. Only the --aot option is supported.\n"); diff --git a/src/mono/mono/mini/interp/interp-nosimd.c b/src/mono/mono/mini/interp/interp-nosimd.c new file mode 100644 index 00000000000000..63bcf2783ec087 --- /dev/null +++ b/src/mono/mono/mini/interp/interp-nosimd.c @@ -0,0 +1,31 @@ + +#include "interp-internals.h" +#include "interp-simd.h" + +#ifdef INTERP_ENABLE_SIMD + +gboolean interp_simd_enabled = FALSE; + +#ifdef HOST_BROWSER + +int interp_simd_p_p_wasm_opcode_table [] = { +}; + +int interp_simd_p_pp_wasm_opcode_table [] = { +}; + +int interp_simd_p_ppp_wasm_opcode_table [] = { +}; + +#endif // HOST_BROWSER + +PP_SIMD_Method interp_simd_p_p_table [] = { +}; + +PPP_SIMD_Method interp_simd_p_pp_table [] = { +}; + +PPPP_SIMD_Method interp_simd_p_ppp_table [] = { +}; + +#endif // INTERP_ENABLE_SIMD diff --git a/src/mono/mono/mini/interp/interp-simd.c b/src/mono/mono/mini/interp/interp-simd.c index 65e60b4c6e7017..f21fdec5aefd3b 100644 --- a/src/mono/mono/mini/interp/interp-simd.c +++ b/src/mono/mono/mini/interp/interp-simd.c @@ -8,6 +8,8 @@ #ifdef INTERP_ENABLE_SIMD +gboolean interp_simd_enabled = TRUE; + typedef gint64 v128_i8 __attribute__ ((vector_size (SIZEOF_V128))); typedef guint64 v128_u8 __attribute__ ((vector_size (SIZEOF_V128))); typedef gint32 v128_i4 __attribute__ ((vector_size (SIZEOF_V128))); @@ -213,57 +215,57 @@ interp_v128_i2_op_left_shift (gpointer res, gpointer v1, gpointer s1) static void interp_v128_i4_op_left_shift (gpointer res, gpointer v1, gpointer s1) { - *(v128_i4*)res = *(v128_i4*)v1 << *(gint32*)s1; + *(v128_i4*)res = *(v128_i4*)v1 << (*(gint32*)s1 & 31); } static void interp_v128_i8_op_left_shift (gpointer res, gpointer v1, gpointer s1) { - *(v128_i8*)res = *(v128_i8*)v1 << *(gint32*)s1; + *(v128_i8*)res = *(v128_i8*)v1 << (*(gint32*)s1 & 63); } // op_RightShift static void interp_v128_i1_op_right_shift (gpointer res, gpointer v1, gpointer s1) { - *(v128_i1*)res = *(v128_i1*)v1 >> *(gint32*)s1; + *(v128_i1*)res = *(v128_i1*)v1 >> (*(gint32*)s1 & 7); } static void interp_v128_i2_op_right_shift (gpointer res, gpointer v1, gpointer s1) { - *(v128_i2*)res = *(v128_i2*)v1 >> *(gint32*)s1; + *(v128_i2*)res = *(v128_i2*)v1 >> (*(gint32*)s1 & 15); } static void interp_v128_i4_op_right_shift (gpointer res, gpointer v1, gpointer s1) { - *(v128_i4*)res = *(v128_i4*)v1 >> *(gint32*)s1; + *(v128_i4*)res = *(v128_i4*)v1 >> (*(gint32*)s1 & 31); } // op_UnsignedRightShift static void interp_v128_i1_op_uright_shift (gpointer res, gpointer v1, gpointer s1) { - *(v128_u1*)res = *(v128_u1*)v1 >> *(gint32*)s1; + *(v128_u1*)res = *(v128_u1*)v1 >> (*(gint32*)s1 & 7); } static void interp_v128_i2_op_uright_shift (gpointer res, gpointer v1, gpointer s1) { - *(v128_u2*)res = *(v128_u2*)v1 >> *(gint32*)s1; + *(v128_u2*)res = *(v128_u2*)v1 >> (*(gint32*)s1 & 15); } static void interp_v128_i4_op_uright_shift (gpointer res, gpointer v1, gpointer s1) { - *(v128_u4*)res = *(v128_u4*)v1 >> *(gint32*)s1; + *(v128_u4*)res = *(v128_u4*)v1 >> (*(gint32*)s1 & 31); } static void interp_v128_i8_op_uright_shift (gpointer res, gpointer v1, gpointer s1) { - *(v128_u8*)res = *(v128_u8*)v1 >> *(gint32*)s1; + *(v128_u8*)res = *(v128_u8*)v1 >> (*(gint32*)s1 & 63); } // op_OnesComplement diff --git a/src/mono/mono/mini/interp/interp-simd.h b/src/mono/mono/mini/interp/interp-simd.h index e3306a251fc9f6..8e0222613e44a2 100644 --- a/src/mono/mono/mini/interp/interp-simd.h +++ b/src/mono/mono/mini/interp/interp-simd.h @@ -3,6 +3,8 @@ #include +extern gboolean interp_simd_enabled; + typedef void (*PP_SIMD_Method) (gpointer, gpointer); typedef void (*PPP_SIMD_Method) (gpointer, gpointer, gpointer); typedef void (*PPPP_SIMD_Method) (gpointer, gpointer, gpointer, gpointer); diff --git a/src/mono/mono/mini/interp/transform-simd.c b/src/mono/mono/mini/interp/transform-simd.c index 255a2aba595634..7df7f92ab6d7c0 100644 --- a/src/mono/mono/mini/interp/transform-simd.c +++ b/src/mono/mono/mini/interp/transform-simd.c @@ -3,6 +3,7 @@ */ #include "config.h" +#include "interp-simd.h" #include #include #include @@ -900,6 +901,9 @@ interp_emit_simd_intrinsics (TransformData *td, MonoMethod *cmethod, MonoMethodS if (image != mono_get_corlib ()) return FALSE; + if (!interp_simd_enabled) + return FALSE; + class_ns = m_class_get_name_space (cmethod->klass); class_name = m_class_get_name (cmethod->klass); diff --git a/src/mono/mono/mini/intrinsics.c b/src/mono/mono/mini/intrinsics.c index b1e5e767231470..ef77b7dc89f2ee 100644 --- a/src/mono/mono/mini/intrinsics.c +++ b/src/mono/mono/mini/intrinsics.c @@ -2079,7 +2079,9 @@ mini_emit_inst_for_method (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSign MonoType *t = method_context->method_inst->type_argv [0]; MonoClass *arg0 = mono_class_from_mono_type_internal (t); if (m_class_is_valuetype (arg0) && !mono_class_has_default_constructor (arg0, FALSE)) { - if (m_class_is_primitive (arg0)) { + if (m_class_is_primitive (arg0) || m_class_is_enumtype (arg0)) { + if (m_class_is_enumtype (arg0)) + t = mono_class_enum_basetype_internal (arg0); int dreg = alloc_dreg (cfg, mini_type_to_stack_type (cfg, t)); mini_emit_init_rvar (cfg, dreg, t); ins = cfg->cbb->last_ins; diff --git a/src/mono/mono/mini/jit-icalls.c b/src/mono/mono/mini/jit-icalls.c index 04a15e94aea165..628d18abe35cfd 100644 --- a/src/mono/mono/mini/jit-icalls.c +++ b/src/mono/mono/mini/jit-icalls.c @@ -1358,7 +1358,7 @@ constrained_gsharedvt_call_setup (gpointer mp, MonoMethod *cmethod, MonoClass *k error_init (error); - if (mono_class_is_interface (klass) || !m_class_is_valuetype (klass)) { + if ((mono_class_is_interface (klass) || !m_class_is_valuetype (klass)) && !m_method_is_static (cmethod)) { MonoObject *this_obj; is_iface = mono_class_is_interface (klass); @@ -1390,7 +1390,12 @@ constrained_gsharedvt_call_setup (gpointer mp, MonoMethod *cmethod, MonoClass *k } } - if (m_class_is_valuetype (klass) && (m->klass == mono_defaults.object_class || m->klass == m_class_get_parent (mono_defaults.enum_class) || m->klass == mono_defaults.enum_class)) { + if (m_method_is_static (cmethod)) { + /* + * Static calls don't have this arg + */ + *this_arg = NULL; + } else if (m_class_is_valuetype (klass) && (m->klass == mono_defaults.object_class || m->klass == m_class_get_parent (mono_defaults.enum_class) || m->klass == mono_defaults.enum_class)) { /* * Calling a non-vtype method with a vtype receiver, has to box. */ diff --git a/src/mono/mono/mini/llvm-intrinsics.h b/src/mono/mono/mini/llvm-intrinsics.h index 1bb09bf0388ae7..be73ec309dfa85 100644 --- a/src/mono/mono/mini/llvm-intrinsics.h +++ b/src/mono/mono/mini/llvm-intrinsics.h @@ -291,6 +291,8 @@ INTRINS_OVR_2_ARG(WASM_NARROW_UNSIGNED_V16, wasm_narrow_unsigned, Wasm, sse_i1_t INTRINS_OVR_2_ARG(WASM_NARROW_UNSIGNED_V8, wasm_narrow_unsigned, Wasm, sse_i2_t, sse_i4_t) INTRINS_OVR_2_ARG(WASM_CONV_R8_TO_I4, fptosi_sat, Generic, v64_i4_t, v128_r8_t) INTRINS_OVR_2_ARG(WASM_CONV_R8_TO_U4, fptoui_sat, Generic, v64_i4_t, v128_r8_t) +INTRINS_OVR_TAG(WASM_FMAX, maximum, Generic, V128 | R4 | R8) +INTRINS_OVR_TAG(WASM_FMIN, minimum, Generic, V128 | R4 | R8) INTRINS_OVR_TAG(WASM_PMAX, wasm_pmax, Wasm, V128 | R4 | R8) INTRINS_OVR_TAG(WASM_PMIN, wasm_pmin, Wasm, V128 | R4 | R8) INTRINS_OVR(WASM_PMAX_V4, fabs, Generic, sse_r4_t) diff --git a/src/mono/mono/mini/method-to-ir.c b/src/mono/mono/mini/method-to-ir.c index b5f75ea1c3133d..dedc6804c3cbd8 100644 --- a/src/mono/mono/mini/method-to-ir.c +++ b/src/mono/mono/mini/method-to-ir.c @@ -6452,8 +6452,18 @@ mono_method_to_ir (MonoCompile *cfg, MonoMethod *method, MonoBasicBlock *start_b generic_context = &generic_container->context; cfg->generic_context = generic_context; - if (!cfg->gshared) - g_assert (!sig->has_type_parameters); + if (!cfg->gshared) { + gboolean check_type_parameter = TRUE; + if (method->wrapper_type == MONO_WRAPPER_OTHER) { + WrapperInfo *info = mono_marshal_get_wrapper_info (method); + g_assert (info); + if (info->subtype == WRAPPER_SUBTYPE_UNSAFE_ACCESSOR) + check_type_parameter = FALSE; + } + + if (check_type_parameter) + g_assert (!sig->has_type_parameters); + } if (sig->generic_param_count && method->wrapper_type == MONO_WRAPPER_NONE) { g_assert (method->is_inflated); @@ -10840,12 +10850,15 @@ mono_method_to_ir (MonoCompile *cfg, MonoMethod *method, MonoBasicBlock *start_b EMIT_NEW_TEMPLOADA (cfg, addr, vtvar->inst_c0); MONO_EMIT_NEW_STORE_MEMBASE (cfg, OP_STORE_MEMBASE_REG, addr->dreg, 0, ins->dreg); EMIT_NEW_TEMPLOAD (cfg, ins, vtvar->inst_c0); - ins->opcode = OP_LDTOKEN_FIELD; - ins->inst_c0 = n; - ins->inst_p1 = handle; + if (handle_class == mono_defaults.fieldhandle_class) { + ins->opcode = OP_LDTOKEN_FIELD; + ins->inst_c0 = n; + ins->inst_p1 = handle; + + cfg->flags |= MONO_CFG_NEEDS_DECOMPOSE; + cfg->cbb->needs_decompose = TRUE; + } - cfg->flags |= MONO_CFG_NEEDS_DECOMPOSE; - cfg->cbb->needs_decompose = TRUE; } } diff --git a/src/mono/mono/mini/mini-amd64.c b/src/mono/mono/mini/mini-amd64.c index 1a2f9fff59d34e..e03b142feef1f3 100644 --- a/src/mono/mono/mini/mini-amd64.c +++ b/src/mono/mono/mini/mini-amd64.c @@ -357,7 +357,7 @@ collect_field_info_nested (MonoClass *klass, GArray *fields_array, int offset, g g_assert(info); for (guint32 i = 0; i < info->num_fields; ++i) { if (MONO_TYPE_ISSTRUCT (info->fields [i].field->type)) { - collect_field_info_nested (mono_class_from_mono_type_internal (info->fields [i].field->type), fields_array, info->fields [i].offset, pinvoke, unicode); + collect_field_info_nested (mono_class_from_mono_type_internal (info->fields [i].field->type), fields_array, (offset + info->fields [i].offset), pinvoke, unicode); } else { guint32 align; StructFieldInfo f; @@ -367,7 +367,7 @@ collect_field_info_nested (MonoClass *klass, GArray *fields_array, int offset, g info->fields [i].mspec, &align, TRUE, unicode); f.offset = offset + info->fields [i].offset; - if (i == info->num_fields - 1 && f.size + f.offset < info->native_size) { + if ((i == info->num_fields - 1) && ((f.size + f.offset) < info->native_size)) { /* This can happen with .pack directives eg. 'fixed' arrays */ if (MONO_TYPE_IS_PRIMITIVE (f.type)) { /* Replicate the last field to fill out the remaining place, since the code in add_valuetype () needs type information */ diff --git a/src/mono/mono/mini/mini-generic-sharing.c b/src/mono/mono/mini/mini-generic-sharing.c index 6ad8dcb0075cfc..c131d51a6bd070 100644 --- a/src/mono/mono/mini/mini-generic-sharing.c +++ b/src/mono/mono/mini/mini-generic-sharing.c @@ -2886,7 +2886,8 @@ info_equal (gpointer data1, gpointer data2, MonoRgctxInfoType info_type) return data1 == data2; case MONO_RGCTX_INFO_VIRT_METHOD: case MONO_RGCTX_INFO_VIRT_METHOD_CODE: - case MONO_RGCTX_INFO_VIRT_METHOD_BOX_TYPE: { + case MONO_RGCTX_INFO_VIRT_METHOD_BOX_TYPE: + case MONO_RGCTX_INFO_GSHAREDVT_CONSTRAINED_CALL_INFO: { MonoJumpInfoVirtMethod *info1 = (MonoJumpInfoVirtMethod *)data1; MonoJumpInfoVirtMethod *info2 = (MonoJumpInfoVirtMethod *)data2; diff --git a/src/mono/mono/mini/mini-llvm.c b/src/mono/mono/mini/mini-llvm.c index a6c23aa59b67a2..657bb23d5cf63d 100644 --- a/src/mono/mono/mini/mini-llvm.c +++ b/src/mono/mono/mini/mini-llvm.c @@ -8183,9 +8183,13 @@ MONO_RESTORE_WARNING result = fcmp_and_select (builder, ins, l, r); } -#elif defined(TARGET_ARM64) +#elif defined(TARGET_ARM64) || defined(TARGET_WASM) LLVMValueRef min_max_args [] = { l, r }; +#ifdef TARGET_WASM + IntrinsicId iid = ins->inst_c0 == OP_FMAX ? INTRINS_WASM_FMAX : INTRINS_WASM_FMIN; +#else IntrinsicId iid = ins->inst_c0 == OP_FMAX ? INTRINS_AARCH64_ADV_SIMD_FMAX : INTRINS_AARCH64_ADV_SIMD_FMIN; +#endif llvm_ovr_tag_t ovr_tag = ovr_tag_from_mono_vector_class (ins->klass); result = call_overloaded_intrins (ctx, iid, ovr_tag, min_max_args, ""); #else diff --git a/src/mono/mono/mini/mini-ppc.c b/src/mono/mono/mini/mini-ppc.c index 3b9bdf2bcb985c..b3514123b7e941 100644 --- a/src/mono/mono/mini/mini-ppc.c +++ b/src/mono/mono/mini/mini-ppc.c @@ -297,7 +297,7 @@ gboolean mono_ppc_is_direct_call_sequence (guint32 *code) { #ifdef TARGET_POWERPC64 - g_assert(*code == 0x4e800021 || *code == 0x4e800020 || *code == 0x4e800420); + g_assert(*code == 0x4e800021 || *code == 0x4e800020 || *code == 0x4e800421 || *code == 0x4e800420); /* the thunk-less direct call sequence: lis/ori/sldi/oris/ori/mtlr/blrl */ if (ppc_opcode (code [-1]) == 31) { /* mtlr */ @@ -2939,7 +2939,7 @@ ppc_patch_full (MonoCompile *cfg, guchar *code, const guchar *target, gboolean i return; } - if (prim == 15 || ins == 0x4e800021 || ins == 0x4e800020 || ins == 0x4e800420) { + if (prim == 15 || ins == 0x4e800021 || ins == 0x4e800020 || ins == 0x4e800421 || ins == 0x4e800420) { #ifdef TARGET_POWERPC64 #if !defined(PPC_USES_FUNCTION_DESCRIPTOR) handle_thunk (cfg, code, target); @@ -2948,7 +2948,7 @@ ppc_patch_full (MonoCompile *cfg, guchar *code, const guchar *target, gboolean i guint32 *branch_ins; /* the trampoline code will try to patch the blrl, blr, bcctr */ - if (ins == 0x4e800021 || ins == 0x4e800020 || ins == 0x4e800420) { + if (ins == 0x4e800021 || ins == 0x4e800020 || ins == 0x4e800421 || ins == 0x4e800420) { branch_ins = seq; if (ppc_is_load_op (seq [-3]) || ppc_opcode (seq [-3]) == 31) /* ld || lwz || mr */ code -= 32; @@ -2996,7 +2996,7 @@ ppc_patch_full (MonoCompile *cfg, guchar *code, const guchar *target, gboolean i #else guint32 *seq; /* the trampoline code will try to patch the blrl, blr, bcctr */ - if (ins == 0x4e800021 || ins == 0x4e800020 || ins == 0x4e800420) { + if (ins == 0x4e800021 || ins == 0x4e800020 || ins == 0x4e800421 || ins == 0x4e800420) { code -= 12; } /* this is the lis/ori/mtlr/blrl sequence */ @@ -3004,7 +3004,7 @@ ppc_patch_full (MonoCompile *cfg, guchar *code, const guchar *target, gboolean i g_assert ((seq [0] >> 26) == 15); g_assert ((seq [1] >> 26) == 24); g_assert ((seq [2] >> 26) == 31); - g_assert (seq [3] == 0x4e800021 || seq [3] == 0x4e800020 || seq [3] == 0x4e800420); + g_assert (seq [3] == 0x4e800021 || seq [3] == 0x4e800020 || seq [3] == 0x4e800421 || seq [3] == 0x4e800420); /* FIXME: make this thread safe */ ppc_lis (code, PPC_CALL_REG, (guint32)(target) >> 16); ppc_ori (code, PPC_CALL_REG, PPC_CALL_REG, (guint32)(target) & 0xffff); @@ -3426,8 +3426,8 @@ mono_arch_output_basic_block (MonoCompile *cfg, MonoBasicBlock *bb) ppc_ldr (code, PPC_CALL_REG, 0, PPC_CALL_REG); cfg->thunk_area += THUNK_SIZE; #endif - ppc_mtlr (code, PPC_CALL_REG); - ppc_blrl (code); + ppc_mtctr (code, PPC_CALL_REG); + ppc_bcctrl (code, PPC_BR_ALWAYS, 0); } else { ppc_bl (code, 0); } @@ -3913,8 +3913,8 @@ mono_arch_output_basic_block (MonoCompile *cfg, MonoBasicBlock *bb) ppc_ldr (code, PPC_CALL_REG, 0, PPC_CALL_REG); cfg->thunk_area += THUNK_SIZE; #endif - ppc_mtlr (code, PPC_CALL_REG); - ppc_blrl (code); + ppc_mtctr (code, PPC_CALL_REG); + ppc_bcctrl (code, PPC_BR_ALWAYS, 0); } else { ppc_bl (code, 0); } @@ -3945,9 +3945,9 @@ mono_arch_output_basic_block (MonoCompile *cfg, MonoBasicBlock *bb) } } #endif - ppc_mtlr (code, ins->sreg1); + ppc_mtctr (code, ins->sreg1); #endif - ppc_blrl (code); + ppc_bcctrl (code, PPC_BR_ALWAYS, 0); /* FIXME: this should be handled somewhere else in the new jit */ code = emit_move_return_value (cfg, ins, code); break; @@ -3965,8 +3965,8 @@ mono_arch_output_basic_block (MonoCompile *cfg, MonoBasicBlock *bb) } else { ppc_ldptr (code, ppc_r12, ins->inst_offset, ins->sreg1); } - ppc_mtlr (code, ppc_r12); - ppc_blrl (code); + ppc_mtctr (code, ppc_r12); + ppc_bcctrl (code, PPC_BR_ALWAYS, 0); /* FIXME: this should be handled somewhere else in the new jit */ code = emit_move_return_value (cfg, ins, code); break; @@ -4022,8 +4022,8 @@ mono_arch_output_basic_block (MonoCompile *cfg, MonoBasicBlock *bb) ppc_ldr (code, PPC_CALL_REG, 0, PPC_CALL_REG); cfg->thunk_area += THUNK_SIZE; #endif - ppc_mtlr (code, PPC_CALL_REG); - ppc_blrl (code); + ppc_mtctr (code, PPC_CALL_REG); + ppc_bcctrl (code, PPC_BR_ALWAYS, 0); } else { ppc_bl (code, 0); } @@ -4040,8 +4040,8 @@ mono_arch_output_basic_block (MonoCompile *cfg, MonoBasicBlock *bb) ppc_ldr (code, PPC_CALL_REG, 0, PPC_CALL_REG); cfg->thunk_area += THUNK_SIZE; #endif - ppc_mtlr (code, PPC_CALL_REG); - ppc_blrl (code); + ppc_mtctr (code, PPC_CALL_REG); + ppc_bcctrl (code, PPC_BR_ALWAYS, 0); } else { ppc_bl (code, 0); } @@ -4725,8 +4725,8 @@ mono_arch_output_basic_block (MonoCompile *cfg, MonoBasicBlock *bb) ppc_ldr (code, PPC_CALL_REG, 0, PPC_CALL_REG); cfg->thunk_area += THUNK_SIZE; #endif - ppc_mtlr (code, PPC_CALL_REG); - ppc_blrl (code); + ppc_mtctr (code, PPC_CALL_REG); + ppc_bcctrl (code, PPC_BR_ALWAYS, 0); } else { ppc_bl (code, 0); } @@ -5348,8 +5348,8 @@ mono_arch_emit_prolog (MonoCompile *cfg) ppc_ldr (code, PPC_CALL_REG, 0, PPC_CALL_REG); cfg->thunk_area += THUNK_SIZE; #endif - ppc_mtlr (code, PPC_CALL_REG); - ppc_blrl (code); + ppc_mtctr (code, PPC_CALL_REG); + ppc_bcctrl (code, PPC_BR_ALWAYS, 0); } else { ppc_bl (code, 0); } diff --git a/src/mono/mono/utils/mono-mmap-wasm.c b/src/mono/mono/utils/mono-mmap-wasm.c index 5c38aac9f36641..b2f417b0860385 100644 --- a/src/mono/mono/utils/mono-mmap-wasm.c +++ b/src/mono/mono/utils/mono-mmap-wasm.c @@ -90,8 +90,8 @@ mono_setmmapjit (int flag) /* Ignored on HOST_WASM */ } -void* -mono_valloc (void *addr, size_t size, int flags, MonoMemAccountType type) +static void* +valloc_impl (void *addr, size_t size, int flags, MonoMemAccountType type) { void *ptr; int mflags = 0; @@ -119,6 +119,19 @@ mono_valloc (void *addr, size_t size, int flags, MonoMemAccountType type) return ptr; } +void* +mono_valloc (void *addr, size_t size, int flags, MonoMemAccountType type) +{ +#if HOST_WASI + // WASI implements mmap using malloc, so the returned address is not page aligned + // and our code depends on it + g_assert (!addr); + return mono_valloc_aligned (size, mono_pagesize (), flags, type); +#else + return valloc_impl (addr, size, flags, type); +#endif +} + static GHashTable *valloc_hash; typedef struct { @@ -130,7 +143,7 @@ void* mono_valloc_aligned (size_t size, size_t alignment, int flags, MonoMemAccountType type) { /* Allocate twice the memory to be able to put the block on an aligned address */ - char *mem = (char *) mono_valloc (NULL, size + alignment, flags, type); + char *mem = (char *) valloc_impl (NULL, size + alignment, flags, type); char *aligned; if (!mem) diff --git a/src/mono/nuget/Microsoft.NET.Sdk.WebAssembly.Pack/build/Microsoft.NET.Sdk.WebAssembly.Browser.targets b/src/mono/nuget/Microsoft.NET.Sdk.WebAssembly.Pack/build/Microsoft.NET.Sdk.WebAssembly.Browser.targets index 64539a49fdaa8a..324f36cad7957f 100644 --- a/src/mono/nuget/Microsoft.NET.Sdk.WebAssembly.Pack/build/Microsoft.NET.Sdk.WebAssembly.Browser.targets +++ b/src/mono/nuget/Microsoft.NET.Sdk.WebAssembly.Pack/build/Microsoft.NET.Sdk.WebAssembly.Browser.targets @@ -19,9 +19,13 @@ Copyright (c) .NET Foundation. All rights reserved. dotnet $([MSBuild]::NormalizeDirectory($(MSBuildThisFileDirectory), '..', 'WasmAppHost')) - <_RuntimeConfigJsonPath>$([MSBuild]::NormalizePath($(OutputPath), '$(AssemblyName).runtimeconfig.json')) + + <_RunWorkingDirectory>$(OutputPath) + <_RunWorkingDirectory Condition="'$(_RunWorkingDirectory)' != '' and !$([System.IO.Path]::IsPathRooted($(_RunWorkingDirectory)))">$([System.IO.Path]::Combine($(MSBuildProjectDirectory), $(_RunWorkingDirectory))) + <_RuntimeConfigJsonPath>$([MSBuild]::NormalizePath($(_RunWorkingDirectory), '$(AssemblyName).runtimeconfig.json')) + exec "$([MSBuild]::NormalizePath($(WasmAppHostDir), 'WasmAppHost.dll'))" --use-staticwebassets --runtime-config "$(_RuntimeConfigJsonPath)" $(WasmHostArguments) - $(OutputPath) + $(_RunWorkingDirectory) @@ -224,7 +228,7 @@ Copyright (c) .NET Foundation. All rights reserved. <_WasmRuntimePackVersion>%(ResolvedRuntimePack.NuGetPackageVersion) - + diff --git a/src/mono/nuget/Microsoft.NET.Workload.Mono.Toolchain.Current.Manifest/WorkloadManifest.json.in b/src/mono/nuget/Microsoft.NET.Workload.Mono.Toolchain.Current.Manifest/WorkloadManifest.json.in index 5af8dcbd94e17e..076e642d2b6209 100644 --- a/src/mono/nuget/Microsoft.NET.Workload.Mono.Toolchain.Current.Manifest/WorkloadManifest.json.in +++ b/src/mono/nuget/Microsoft.NET.Workload.Mono.Toolchain.Current.Manifest/WorkloadManifest.json.in @@ -5,7 +5,7 @@ }, "workloads": { "wasm-tools": { - "description": ".NET WebAssembly build tools", + "description": ".NET WebAssembly build tools for net8.0", "packs": [ "Microsoft.NET.Runtime.WebAssembly.Sdk", "Microsoft.NETCore.App.Runtime.Mono.browser-wasm", @@ -15,7 +15,7 @@ "platforms": [ "win-x64", "win-arm64", "linux-x64", "linux-arm64", "osx-x64", "osx-arm64"] }, "wasm-experimental": { - "description": ".NET WebAssembly experimental tooling", + "description": ".NET WebAssembly experimental tooling for net8.0", "packs": [ "Microsoft.NET.Runtime.WebAssembly.Templates", "Microsoft.NETCore.App.Runtime.Mono.multithread.browser-wasm", @@ -24,7 +24,7 @@ "platforms": [ "win-x64", "win-arm64", "linux-x64", "linux-arm64", "osx-x64", "osx-arm64" ] }, "wasi-experimental": { - "description": ".NET WASI experimental", + "description": ".NET WASI experimental for net8.0", "packs": [ "Microsoft.NET.Runtime.WebAssembly.Wasi.Sdk", "Microsoft.NETCore.App.Runtime.Mono.wasi-wasm", diff --git a/src/mono/nuget/Microsoft.NET.Workload.Mono.Toolchain.Current.Manifest/WorkloadTelemetry.targets b/src/mono/nuget/Microsoft.NET.Workload.Mono.Toolchain.Current.Manifest/WorkloadTelemetry.targets index f9cd6efe170a2c..d94c7de33c3f60 100644 --- a/src/mono/nuget/Microsoft.NET.Workload.Mono.Toolchain.Current.Manifest/WorkloadTelemetry.targets +++ b/src/mono/nuget/Microsoft.NET.Workload.Mono.Toolchain.Current.Manifest/WorkloadTelemetry.targets @@ -19,6 +19,7 @@ + @@ -29,17 +30,18 @@ <_WorkloadUsesBlazorWasm>$(UsingMicrosoftNETSdkBlazorWebAssembly) <_WorkloadUsesWasmSDK>$(UsingMicrosoftNETSdkWebAssembly) <_WorkloadUsesMonoAOT>$(RunAOTCompilation) - <_WorkloadUsesMonoAOT Condition="'$(RunAOTCompilation)' == '' and '$(PublishAot)' != 'true' and ('$(TargetPlatformIdentifier)' == 'maccatalyst' or '$(TargetPlatformIdentifier)' == 'ios' or '$(TargetPlatformIdentifier)' == 'tvos')">$(_RunAotCompiler) + <_WorkloadUsesMonoAOT Condition="'$(_WorkloadUsesMonoAOT)' == '' and '$(PublishAot)' != 'true' and ('$(TargetPlatformIdentifier)' == 'maccatalyst' or '$(TargetPlatformIdentifier)' == 'ios' or '$(TargetPlatformIdentifier)' == 'tvos')">$(_RunAotCompiler) <_WorkloadUsesNativeAOT>$(PublishAot) <_WorkloadUsesInterpreter>$(MonoForceInterpreter) <_WorkloadUsesInterpreter Condition="'$(_WorkloadUsesInterpreter)' == '' and '$(UseInterpreter)' == 'true'">true <_WorkloadUsesInterpreter Condition="'$(_WorkloadUsesInterpreter)' == '' and '$(RunAOTCompilation)' != 'true' and ('$(_WorkloadUsesBlazorWasm)' == 'true' or '$(_WorkloadUsesWasmSDK)' == 'true')">true <_WorkloadUsesLibraryMode Condition="'$(NativeLib)' != '' and ('$(_WorkloadUsesMonoAOT)' == 'true' or '$(_WorkloadUsesNativeAOT)' == 'true')">true + <_WorkloadUsesStripILAfterAOT Condition="'$(WasmStripILAfterAOT)' == 'true' or '$(AndroidStripILAfterAOT)' == 'true'">true - <_WorkloadUsesOther Condition="'$([System.IO.Path]::GetFileName(%(ReferencePath.Identity)).ToLower())' == 'avalonia.dll'">true - <_WorkloadUsesOther Condition="'$([System.IO.Path]::GetFileName(%(ReferencePath.Identity)).ToLower())' == 'uno.dll'">true + <_WorkloadUsesOther Condition="'$([System.IO.Path]::GetFileName("%(ReferencePath.Identity)").ToLowerInvariant())' == 'avalonia.dll'">true + <_WorkloadUsesOther Condition="'$([System.IO.Path]::GetFileName("%(ReferencePath.Identity)").ToLowerInvariant())' == 'uno.dll'">true <_WorkloadUsesMobileSDKOnly Condition="'$(RuntimeIdentifier)' != 'browser-wasm' and '$(UseMaui)' != 'true' and '$(_WorkloadUsesOther)' != 'true'">true diff --git a/src/mono/sample/wasm/browser-advanced/index.html b/src/mono/sample/wasm/browser-advanced/index.html index c8961d7c715408..24d51ea29672d6 100644 --- a/src/mono/sample/wasm/browser-advanced/index.html +++ b/src/mono/sample/wasm/browser-advanced/index.html @@ -13,10 +13,7 @@ - - - - + diff --git a/src/mono/sample/wasm/browser-advanced/main.js b/src/mono/sample/wasm/browser-advanced/main.js index f95fcbf9903be6..b5c414322fefd0 100644 --- a/src/mono/sample/wasm/browser-advanced/main.js +++ b/src/mono/sample/wasm/browser-advanced/main.js @@ -31,6 +31,7 @@ try { // here we show how emscripten could be further configured // It is preferred to use specific 'with***' methods instead in all other cases. .withConfig({ + startupMemoryCache: true, resources: { modulesAfterConfigLoaded: { "advanced-sample.lib.module.js": "" diff --git a/src/mono/sample/wasm/browser-bench/appstart-frame.html b/src/mono/sample/wasm/browser-bench/appstart-frame.html index 5bb75e32aa1b75..481ffa4e27301f 100644 --- a/src/mono/sample/wasm/browser-bench/appstart-frame.html +++ b/src/mono/sample/wasm/browser-bench/appstart-frame.html @@ -12,9 +12,6 @@ - - - diff --git a/src/mono/sample/wasm/browser-bench/frame-main.js b/src/mono/sample/wasm/browser-bench/frame-main.js index c1042928d78d07..88358e0310a10b 100644 --- a/src/mono/sample/wasm/browser-bench/frame-main.js +++ b/src/mono/sample/wasm/browser-bench/frame-main.js @@ -31,6 +31,10 @@ try { } const runtime = await dotnet + .withConfig({ + maxParallelDownloads: 10000, + // diagnosticTracing:true, + }) .withModuleConfig({ printErr: () => undefined, print: () => undefined, @@ -38,7 +42,6 @@ try { if (window.parent != window) { window.parent.resolveAppStartEvent("onConfigLoaded"); } - // config.diagnosticTracing = true; } }) .create(); diff --git a/src/mono/wasi/README.md b/src/mono/wasi/README.md index 66e59384297bb3..4c8759006359da 100644 --- a/src/mono/wasi/README.md +++ b/src/mono/wasi/README.md @@ -17,6 +17,7 @@ or for just native rebuild ./build.sh -bl -os wasi -subset mono.runtime+libs.native+mono.wasiruntime -c Debug ``` + ### 3. Run it Finally, you can build and run the sample: @@ -49,4 +50,4 @@ Download the Mono Debug extension and configure a launch.json like this: } ] } -``` \ No newline at end of file +``` diff --git a/src/mono/wasi/build/WasiApp.Native.targets b/src/mono/wasi/build/WasiApp.Native.targets index 36c66585870a2c..ba9c085924d46a 100644 --- a/src/mono/wasi/build/WasiApp.Native.targets +++ b/src/mono/wasi/build/WasiApp.Native.targets @@ -273,6 +273,10 @@ <_WasmEHLibToExclude Condition="'$(WasmEnableExceptionHandling)' != 'true'">libmono-wasm-eh-wasm.a + <_WasmSIMDLib Condition="'$(WasmEnableSIMD)' == 'true'">libmono-wasm-simd.a + <_WasmSIMDLib Condition="'$(WasmEnableSIMD)' != 'true'">libmono-wasm-nosimd.a + <_WasmSIMDLibToExclude Condition="'$(WasmEnableSIMD)' != 'true'">libmono-wasm-simd.a + <_WasmSIMDLibToExclude Condition="'$(WasmEnableSIMD)' == 'true'">libmono-wasm-nosimd.a @@ -286,7 +290,9 @@ Include="$(MicrosoftNetCoreAppRuntimePackRidNativeDir)*.a" Exclude="@(_MonoRuntimeComponentDontLink->'$(MicrosoftNetCoreAppRuntimePackRidNativeDir)%(Identity)')" /> <_WasmNativeFileForLinking Condition="'$(_WasmEHLib)' != ''" Include="$(MicrosoftNetCoreAppRuntimePackRidNativeDir)$(_WasmEHLib)" /> + <_WasmNativeFileForLinking Condition="'$(_WasmSIMDLib)' != ''" Include="$(MicrosoftNetCoreAppRuntimePackRidNativeDir)$(_WasmSIMDLib)" /> <_WasmNativeFileForLinking Remove="$(MicrosoftNetCoreAppRuntimePackRidNativeDir)$(_WasmEHLibToExclude)" /> + <_WasmNativeFileForLinking Remove="$(MicrosoftNetCoreAppRuntimePackRidNativeDir)$(_WasmSIMDLibToExclude)" /> <_WasmNativeFileForLinking Include="$(WasiSysRoot)\lib\wasm32-wasi\libc++.a" /> <_WasmNativeFileForLinking Include="$(WasiSysRoot)\lib\wasm32-wasi\libc++abi.a" /> diff --git a/src/mono/wasi/build/WasiApp.targets b/src/mono/wasi/build/WasiApp.targets index 1b8fc4e3fa9333..099ba8c3ebea27 100644 --- a/src/mono/wasi/build/WasiApp.targets +++ b/src/mono/wasi/build/WasiApp.targets @@ -4,8 +4,6 @@ @@ -133,7 +133,6 @@ true true - false @@ -147,10 +146,23 @@ true + + false + + + true + false + + + + true + true + + @@ -160,14 +172,6 @@ - - - - false - true - - - @@ -178,7 +182,6 @@ <_MonoAotCrossCompilerPath>@(MonoAotCrossCompiler->WithMetadataValue('RuntimeIdentifier','browser-wasm')) <_EmccDefaultFlagsRsp>$([MSBuild]::NormalizePath($(_WasmRuntimePackSrcDir), 'emcc-default.rsp')) <_EmccDefaultLinkFlagsRsp>$([MSBuild]::NormalizePath($(_WasmRuntimePackSrcDir), 'emcc-link.rsp')) - true $(WasmBuildNative) <_WasmICallTablePath>$(_WasmIntermediateOutputPath)icall-table.h @@ -221,10 +224,11 @@ <_EmccCommonFlags Include="$(_DefaultEmccFlags)" /> <_EmccCommonFlags Include="$(EmccFlags)" /> - <_EmccCommonFlags Include="-g" Condition="'$(WasmNativeDebugSymbols)' == 'true'" /> - <_EmccCommonFlags Include="-v" Condition="'$(EmccVerbose)' != 'false'" /> - <_EmccCommonFlags Include="-s DISABLE_EXCEPTION_CATCHING=0" Condition="'$(WasmEnableExceptionHandling)' == 'false'" /> - <_EmccCommonFlags Include="-fwasm-exceptions" Condition="'$(WasmEnableExceptionHandling)' == 'true'" /> + <_EmccCommonFlags Include="-g" Condition="'$(WasmNativeStrip)' == 'false'" /> + <_EmccCommonFlags Include="-v" Condition="'$(EmccVerbose)' != 'false'" /> + <_EmccCommonFlags Include="-s DISABLE_EXCEPTION_CATCHING=0" Condition="'$(WasmEnableExceptionHandling)' == 'false'" /> + <_EmccCommonFlags Include="-fwasm-exceptions" Condition="'$(WasmEnableExceptionHandling)' == 'true'" /> + <_EmccCommonFlags Include="-s MAXIMUM_MEMORY=$(EmccMaximumHeapSize)" Condition="'$(EmccMaximumHeapSize)' != ''" /> <_EmccIncludePaths Include="$(_WasmIntermediateOutputPath.TrimEnd('\/'))" /> <_EmccIncludePaths Include="$(_WasmRuntimePackIncludeDir)mono-2.0" /> @@ -248,6 +252,7 @@ <_EmccCFlags Include="-emit-llvm" /> <_EmccCFlags Include=""-I%(_EmccIncludePaths.Identity)"" /> + <_EmccCFlags Include="-g" Condition="'$(WasmNativeDebugSymbols)' == 'true'" /> <_EmccLDFlags Include="$(EmccLinkOptimizationFlag)" /> @@ -440,6 +445,10 @@ <_WasmEHLib Condition="'$(WasmEnableExceptionHandling)' != 'true'">libmono-wasm-eh-js.a <_WasmEHLibToExclude Condition="'$(WasmEnableExceptionHandling)' == 'true'">libmono-wasm-eh-js.a <_WasmEHLibToExclude Condition="'$(WasmEnableExceptionHandling)' != 'true'">libmono-wasm-eh-wasm.a + <_WasmSIMDLib Condition="'$(WasmEnableSIMD)' == 'true'">libmono-wasm-simd.a + <_WasmSIMDLib Condition="'$(WasmEnableSIMD)' != 'true'">libmono-wasm-nosimd.a + <_WasmSIMDLibToExclude Condition="'$(WasmEnableSIMD)' != 'true'">libmono-wasm-simd.a + <_WasmSIMDLibToExclude Condition="'$(WasmEnableSIMD)' == 'true'">libmono-wasm-nosimd.a <_EmccExportedLibraryFunction>"[@(EmccExportedLibraryFunction -> '%27%(Identity)%27', ',')]" <_EmccExportedRuntimeMethods>"[@(EmccExportedRuntimeMethod -> '%27%(Identity)%27', ',')]" <_EmccExportedFunctions>@(EmccExportedFunction -> '%(Identity)',',') @@ -460,7 +469,9 @@ Include="$(MicrosoftNetCoreAppRuntimePackRidNativeDir)*.a" Exclude="@(_MonoRuntimeComponentDontLink->'$(MicrosoftNetCoreAppRuntimePackRidNativeDir)%(Identity)')" /> <_WasmNativeFileForLinking Include="$(MicrosoftNetCoreAppRuntimePackRidNativeDir)$(_WasmEHLib)" /> + <_WasmNativeFileForLinking Condition="'$(_WasmSIMDLib)' != ''" Include="$(MicrosoftNetCoreAppRuntimePackRidNativeDir)$(_WasmSIMDLib)" /> <_WasmNativeFileForLinking Remove="$(MicrosoftNetCoreAppRuntimePackRidNativeDir)$(_WasmEHLibToExclude)" /> + <_WasmNativeFileForLinking Remove="$(MicrosoftNetCoreAppRuntimePackRidNativeDir)$(_WasmSIMDLibToExclude)" /> <_WasmExtraJSFile Include="@(Content)" Condition="'%(Content.Extension)' == '.js'" /> diff --git a/src/mono/wasm/build/WasmApp.targets b/src/mono/wasm/build/WasmApp.targets index be1b7214e2c0bd..0a22f3f82d528f 100644 --- a/src/mono/wasm/build/WasmApp.targets +++ b/src/mono/wasm/build/WasmApp.targets @@ -57,6 +57,8 @@ - $(EmccInitialHeapSize) - Initial heap size specified with `emcc`. Default value: 16777216 or size of the DLLs, whichever is larger. Corresponds to `-s INITIAL_MEMORY=...` emcc arg. (previously named EmccTotalMemory, which is still kept as an alias) + - $(EmccMaximumHeapSize) - Maximum heap size specified with `emcc`. Default value: 2147483648 or size of the DLLs, whichever is larger. + Corresponds to `-s MAXIMUM_MEMORY=...` emcc arg. - $(EmccStackSize) - Stack size. Default value: 5MB. Corresponds to `-s STACK_SIZE=...` emcc arg. diff --git a/src/mono/wasm/debugger/BrowserDebugProxy/EvaluateExpression.cs b/src/mono/wasm/debugger/BrowserDebugProxy/EvaluateExpression.cs index c384f74e9c7223..45d700e122ea7f 100644 --- a/src/mono/wasm/debugger/BrowserDebugProxy/EvaluateExpression.cs +++ b/src/mono/wasm/debugger/BrowserDebugProxy/EvaluateExpression.cs @@ -387,12 +387,14 @@ private static async Task> ResolveElementAccess(ExpressionSyntaxR { var values = new List(); JObject index = null; + List nestedIndexers = new(); IEnumerable elementAccesses = replacer.elementAccess; foreach (ElementAccessExpressionSyntax elementAccess in elementAccesses.Reverse()) { - index = await resolver.Resolve(elementAccess, replacer.memberAccessValues, index, replacer.variableDefinitions, token); + index = await resolver.Resolve(elementAccess, replacer.memberAccessValues, nestedIndexers, replacer.variableDefinitions, token); if (index == null) throw new ReturnAsErrorException($"Failed to resolve element access for {elementAccess}", "ReferenceError"); + nestedIndexers.Add(index); } values.Add(index); return values; diff --git a/src/mono/wasm/debugger/BrowserDebugProxy/MemberReferenceResolver.cs b/src/mono/wasm/debugger/BrowserDebugProxy/MemberReferenceResolver.cs index 650583a9dc7bf2..e1b9583ddbe3e1 100644 --- a/src/mono/wasm/debugger/BrowserDebugProxy/MemberReferenceResolver.cs +++ b/src/mono/wasm/debugger/BrowserDebugProxy/MemberReferenceResolver.cs @@ -366,7 +366,12 @@ async Task ResolveAsInstanceMember(ArraySegment parts, JObject } } - public async Task Resolve(ElementAccessExpressionSyntax elementAccess, Dictionary memberAccessValues, JObject indexObject, List variableDefinitions, CancellationToken token) + public async Task Resolve( + ElementAccessExpressionSyntax elementAccess, + Dictionary memberAccessValues, + List nestedIndexObject, + List variableDefinitions, + CancellationToken token) { try { @@ -376,12 +381,13 @@ public async Task Resolve(ElementAccessExpressionSyntax elementAccess, if (rootObject == null) { - // it might be a jagged array where indexObject should be treated as a new rootObject - rootObject = indexObject; - indexObject = null; + // it might be a jagged array where the previously added nestedIndexObject should be treated as a new rootObject + rootObject = nestedIndexObject.LastOrDefault(); + if (rootObject != null) + nestedIndexObject.RemoveAt(nestedIndexObject.Count - 1); } - ElementIndexInfo elementIdxInfo = await GetElementIndexInfo(); + ElementIndexInfo elementIdxInfo = await GetElementIndexInfo(nestedIndexObject); if (elementIdxInfo is null) return null; @@ -394,6 +400,7 @@ public async Task Resolve(ElementAccessExpressionSyntax elementAccess, if (!DotnetObjectId.TryParse(rootObject?["objectId"]?.Value(), out DotnetObjectId objectId)) throw new InvalidOperationException($"Cannot apply indexing with [] to a primitive object of type '{type}'"); + bool isMultidimensional = elementIdxInfo.DimensionsCount != 1; switch (objectId.Scheme) { case "valuetype": //can be an inlined array @@ -407,7 +414,7 @@ public async Task Resolve(ElementAccessExpressionSyntax elementAccess, } case "array": rootObject["value"] = await context.SdbAgent.GetArrayValues(objectId.Value, token); - if (!elementIdxInfo.IsMultidimensional) + if (!isMultidimensional) { int.TryParse(elementIdxInfo.ElementIdxStr, out elementIdx); return (JObject)rootObject["value"][elementIdx]["value"]; @@ -417,10 +424,8 @@ public async Task Resolve(ElementAccessExpressionSyntax elementAccess, return (JObject)(((JArray)rootObject["value"]).FirstOrDefault(x => x["name"].Value() == elementIdxInfo.ElementIdxStr)["value"]); } case "object": - if (elementIdxInfo.IsMultidimensional) - throw new InvalidOperationException($"Cannot apply indexing with [,] to an object of type '{type}'"); // ToDo: try to use the get_Item for string as well - if (type == "string") + if (!isMultidimensional && type == "string") { var eaExpressionFormatted = elementAccessStrExpression.Replace('.', '_'); // instance_str variableDefinitions.Add(new (eaExpressionFormatted, rootObject, ExpressionEvaluator.ConvertJSToCSharpLocalVariableAssignment(eaExpressionFormatted, rootObject))); @@ -428,7 +433,7 @@ public async Task Resolve(ElementAccessExpressionSyntax elementAccess, var variableDef = await ExpressionEvaluator.GetVariableDefinitions(this, variableDefinitions, invokeToStringInObject: false, token); return await ExpressionEvaluator.EvaluateSimpleExpression(this, eaFormatted, elementAccessStr, variableDef, logger, token); } - if (indexObject is null && elementIdxInfo.IndexingExpression is null) + if (elementIdxInfo.Indexers is null || elementIdxInfo.Indexers.Count == 0) throw new InternalErrorException($"Unable to write index parameter to invoke the method in the runtime."); var typeIds = await context.SdbAgent.GetTypeIdsForObject(objectId.Value, true, token); @@ -441,15 +446,13 @@ public async Task Resolve(ElementAccessExpressionSyntax elementAccess, { MethodInfoWithDebugInformation methodInfo = await context.SdbAgent.GetMethodInfo(methodIds[i], token); ParameterInfo[] paramInfo = methodInfo.GetParametersInfo(); - if (paramInfo.Length == 1) + if (paramInfo.Length == elementIdxInfo.DimensionsCount) { try { - if (indexObject != null && !CheckParametersCompatibility(paramInfo[0].TypeCode, indexObject)) + if (!CheckParametersCompatibility(paramInfo, elementIdxInfo.Indexers)) continue; - ArraySegment buffer = indexObject is null ? - await WriteLiteralExpressionAsIndex(objectId, elementIdxInfo.IndexingExpression, elementIdxInfo.ElementIdxStr) : - await WriteJObjectAsIndex(objectId, indexObject, elementIdxInfo.ElementIdxStr, paramInfo[0].TypeCode); + ArraySegment buffer = await WriteIndexObjectAsIndices(objectId, elementIdxInfo.Indexers, paramInfo); JObject getItemRetObj = await context.SdbAgent.InvokeMethod(buffer, methodIds[i], token); return (JObject)getItemRetObj["value"]; } @@ -470,31 +473,32 @@ await WriteLiteralExpressionAsIndex(objectId, elementIdxInfo.IndexingExpression, throw new ReturnAsErrorException($"Unable to evaluate element access '{elementAccess}': {ex.Message}", ex.GetType().Name); } - async Task GetElementIndexInfo() + async Task GetElementIndexInfo(List nestedIndexers) { - // e.g. x[a[0]], x[a[b[1]]] etc. - if (indexObject is not null) - return new ElementIndexInfo(ElementIdxStr: indexObject["value"].ToString() ); - if (elementAccess.ArgumentList is null) return null; - StringBuilder elementIdxStr = new StringBuilder(); - var multiDimensionalArray = false; + int dimCnt = elementAccess.ArgumentList.Arguments.Count; LiteralExpressionSyntax indexingExpression = null; - for (int i = 0; i < elementAccess.ArgumentList.Arguments.Count; i++) + StringBuilder elementIdxStr = new StringBuilder(); + List indexers = new(); + // nesting should be resolved in reverse order + int nestedIndexersCnt = nestedIndexers.Count - 1; + for (int i = 0; i < dimCnt; i++) { + JObject indexObject; var arg = elementAccess.ArgumentList.Arguments[i]; if (i != 0) { elementIdxStr.Append(", "); - multiDimensionalArray = true; } // e.g. x[1] if (arg.Expression is LiteralExpressionSyntax) { indexingExpression = arg.Expression as LiteralExpressionSyntax; - elementIdxStr.Append(indexingExpression.ToString()); + string expression = indexingExpression.ToString(); + elementIdxStr.Append(expression); + indexers.Add(indexingExpression); } // e.g. x[a] or x[a.b] @@ -508,6 +512,18 @@ async Task GetElementIndexInfo() // x[a] indexObject ??= await Resolve(argParm.Identifier.Text, token); elementIdxStr.Append(indexObject["value"].ToString()); + indexers.Add(indexObject); + } + // nested indexing, e.g. x[a[0]], x[a[b[1]]], x[a[0], b[1]] + else if (arg.Expression is ElementAccessExpressionSyntax) + { + if (nestedIndexers == null || nestedIndexersCnt < 0) + throw new InvalidOperationException($"Cannot resolve nested indexing"); + JObject nestedIndexObject = nestedIndexers[nestedIndexersCnt]; + nestedIndexers.RemoveAt(nestedIndexersCnt); + elementIdxStr.Append(nestedIndexObject["value"].ToString()); + indexers.Add(nestedIndexObject); + nestedIndexersCnt--; } // indexing with expressions, e.g. x[a + 1] else @@ -519,36 +535,57 @@ async Task GetElementIndexInfo() if (idxType != "number") throw new InvalidOperationException($"Cannot index with an object of type '{idxType}'"); elementIdxStr.Append(indexObject["value"].ToString()); + indexers.Add(indexObject); } } return new ElementIndexInfo( + DimensionsCount: dimCnt, ElementIdxStr: elementIdxStr.ToString(), - IsMultidimensional: multiDimensionalArray, - IndexingExpression: indexingExpression); + Indexers: indexers); } - async Task> WriteJObjectAsIndex(DotnetObjectId rootObjId, JObject indexObject, string elementIdxStr, ElementType? expectedType) + async Task> WriteIndexObjectAsIndices(DotnetObjectId rootObjId, List indexObjects, ParameterInfo[] paramInfo) { using var writer = new MonoBinaryWriter(); writer.WriteObj(rootObjId, context.SdbAgent); - writer.Write(1); // number of method args - if (!await writer.WriteJsonValue(indexObject, context.SdbAgent, expectedType, token)) - throw new InternalErrorException($"Parsing index of type {indexObject["type"].Value()} to write it into the buffer failed."); + writer.Write(indexObjects.Count); // number of method args + foreach ((ParameterInfo pi, object indexObject) in paramInfo.Zip(indexObjects)) + { + if (indexObject is JObject indexJObject) + { + // indexed by an identifier name syntax + if (!await writer.WriteJsonValue(indexJObject, context.SdbAgent, pi.TypeCode, token)) + throw new InternalErrorException($"Parsing index of type {indexJObject["type"].Value()} to write it into the buffer failed."); + } + else if (indexObject is LiteralExpressionSyntax expression) + { + // indexed by a literal expression syntax + if (!await writer.WriteConst(expression, context.SdbAgent, token)) + throw new InternalErrorException($"Parsing literal expression index = {expression} to write it into the buffer failed."); + } + else + { + throw new InternalErrorException($"Unexpected index type."); + } + } return writer.GetParameterBuffer(); } + } - async Task> WriteLiteralExpressionAsIndex(DotnetObjectId rootObjId, LiteralExpressionSyntax indexingExpression, string elementIdxStr) + private static bool CheckParametersCompatibility(ParameterInfo[] paramInfos, List indexObjects) + { + if (paramInfos.Length != indexObjects.Count) + return false; + foreach ((ParameterInfo paramInfo, object indexObj) in paramInfos.Zip(indexObjects)) { - using var writer = new MonoBinaryWriter(); - writer.WriteObj(rootObjId, context.SdbAgent); - writer.Write(1); // number of method args - if (!await writer.WriteConst(indexingExpression, context.SdbAgent, token)) - throw new InternalErrorException($"Parsing index of type {indexObject["type"].Value()} to write it into the buffer failed."); - return writer.GetParameterBuffer(); + // shouldn't we check LiteralExpressionSyntax for compatibility as well? + if (indexObj is JObject indexJObj && !CheckParameterCompatibility(paramInfo.TypeCode, indexJObj)) + return false; } + return true; } - private static bool CheckParametersCompatibility(ElementType? paramTypeCode, JObject value) + private static bool CheckParameterCompatibility(ElementType? paramTypeCode, JObject value) { if (!paramTypeCode.HasValue) return true; @@ -871,7 +908,8 @@ public JObject TryGetEvaluationResult(string id) private sealed record ElementIndexInfo( string ElementIdxStr, - bool IsMultidimensional = false, - LiteralExpressionSyntax IndexingExpression = null); + // keeps JObjects and LiteralExpressionSyntaxes: + List Indexers, + int DimensionsCount = 1); } } diff --git a/src/mono/wasm/debugger/DebuggerTestSuite/EvaluateOnCallFrame2Tests.cs b/src/mono/wasm/debugger/DebuggerTestSuite/EvaluateOnCallFrame2Tests.cs index 051da33469ce2d..b1a79b28ceeefe 100644 --- a/src/mono/wasm/debugger/DebuggerTestSuite/EvaluateOnCallFrame2Tests.cs +++ b/src/mono/wasm/debugger/DebuggerTestSuite/EvaluateOnCallFrame2Tests.cs @@ -731,5 +731,23 @@ await CheckEvaluateFail(id, ("dt+1", "Cannot evaluate '(dt+1\n)': (2,9): error CS0019: Operator '+' cannot be applied to operands of type 'object' and 'int'") ); }); + + [Fact] + public async Task EvaluateObjectIndexingMultidimensional() => await CheckInspectLocalsAtBreakpointSite( + "DebuggerTests.EvaluateLocalsWithIndexingTests", "EvaluateLocals", 12, "DebuggerTests.EvaluateLocalsWithIndexingTests.EvaluateLocals", + "window.setTimeout(function() { invoke_static_method ('[debugger-test] DebuggerTests.EvaluateLocalsWithIndexingTests:EvaluateLocals'); })", + wait_for_event_fn: async (pause_location) => + { + var id = pause_location["callFrames"][0]["callFrameId"].Value(); + await EvaluateOnCallFrameAndCheck(id, + ("f[j, aDouble]", TNumber("3.34")), //only IdentifierNameSyntaxes + ("f[1, aDouble]", TNumber("3.34")), //IdentifierNameSyntax with LiteralExpressionSyntax + ("f[aChar, \"&\", longString]", TString("9-&-longString")), + ("f[f.numArray[j], aDouble]", TNumber("4.34")), //ElementAccessExpressionSyntax + ("f[f.numArray[j], f.numArray[0]]", TNumber("3")), //multiple ElementAccessExpressionSyntaxes + ("f[f.numArray[f.numList[0]], f.numArray[i]]", TNumber("3")), + ("f[f.numArray[f.numList[0]], f.numArray[f.numArray[i]]]", TNumber("4")) + ); + }); } } diff --git a/src/mono/wasm/debugger/DebuggerTestSuite/EvaluateOnCallFrameTests.cs b/src/mono/wasm/debugger/DebuggerTestSuite/EvaluateOnCallFrameTests.cs index 2ff9bd26a28272..2d0fb87822758a 100644 --- a/src/mono/wasm/debugger/DebuggerTestSuite/EvaluateOnCallFrameTests.cs +++ b/src/mono/wasm/debugger/DebuggerTestSuite/EvaluateOnCallFrameTests.cs @@ -585,7 +585,7 @@ public async Task EvaluateIndexingNegative() => await CheckInspectLocalsAtBreakp Assert.Equal("Unable to evaluate element access 'f.idx0[2]': Cannot apply indexing with [] to a primitive object of type 'number'", res.Error["result"]?["description"]?.Value()); var exceptionDetailsStack = res.Error["exceptionDetails"]?["stackTrace"]?["callFrames"]?[0]; Assert.Equal("DebuggerTests.EvaluateLocalsWithIndexingTests.EvaluateLocals", exceptionDetailsStack?["functionName"]?.Value()); - Assert.Equal(556, exceptionDetailsStack?["lineNumber"]?.Value()); + Assert.Equal(558, exceptionDetailsStack?["lineNumber"]?.Value()); Assert.Equal(12, exceptionDetailsStack?["columnNumber"]?.Value()); (_, res) = await EvaluateOnCallFrame(id, "f[1]", expect_ok: false ); Assert.Equal( "Unable to evaluate element access 'f[1]': Cannot apply indexing with [] to an object of type 'DebuggerTests.EvaluateLocalsWithIndexingTests.TestEvaluate'", res.Error["result"]?["description"]?.Value()); @@ -722,7 +722,7 @@ public async Task EvaluateIndexingByExpressionNegative() => await CheckInspectLo Assert.Equal("Unable to evaluate element access 'f.numList[\"a\" + 1]': Cannot index with an object of type 'string'", res.Error["result"]?["description"]?.Value()); var exceptionDetailsStack = res.Error["exceptionDetails"]?["stackTrace"]?["callFrames"]?[0]; Assert.Equal("DebuggerTests.EvaluateLocalsWithIndexingTests.EvaluateLocals", exceptionDetailsStack?["functionName"]?.Value()); - Assert.Equal(556, exceptionDetailsStack?["lineNumber"]?.Value()); + Assert.Equal(558, exceptionDetailsStack?["lineNumber"]?.Value()); Assert.Equal(12, exceptionDetailsStack?["columnNumber"]?.Value()); }); diff --git a/src/mono/wasm/debugger/DebuggerTestSuite/HotReloadTests.cs b/src/mono/wasm/debugger/DebuggerTestSuite/HotReloadTests.cs index bfeb9ec9b4d5ea..ade5c6adc8bbd5 100644 --- a/src/mono/wasm/debugger/DebuggerTestSuite/HotReloadTests.cs +++ b/src/mono/wasm/debugger/DebuggerTestSuite/HotReloadTests.cs @@ -596,7 +596,8 @@ await SendCommandAndCheck (JObject.FromObject(new { }), "Debugger.resume", scrip await CheckProps (c, new { Field1 = TNumber(123), Field2 = TString("spqr"), - }, "c", num_fields: 2); + Field3 = TString(null), + }, "c", num_fields: 3); }); } diff --git a/src/mono/wasm/debugger/tests/ApplyUpdateReferencedAssembly3/MethodBody2_v2.cs b/src/mono/wasm/debugger/tests/ApplyUpdateReferencedAssembly3/MethodBody2_v2.cs index 379adbe266908e..cbde71233f4b45 100644 --- a/src/mono/wasm/debugger/tests/ApplyUpdateReferencedAssembly3/MethodBody2_v2.cs +++ b/src/mono/wasm/debugger/tests/ApplyUpdateReferencedAssembly3/MethodBody2_v2.cs @@ -22,6 +22,7 @@ public C() } public double Field1; public string Field2; + public string Field3; } } diff --git a/src/mono/wasm/debugger/tests/debugger-test/debugger-evaluate-test.cs b/src/mono/wasm/debugger/tests/debugger-test/debugger-evaluate-test.cs index 480dc30115c430..e46177ecc925b5 100644 --- a/src/mono/wasm/debugger/tests/debugger-test/debugger-evaluate-test.cs +++ b/src/mono/wasm/debugger/tests/debugger-test/debugger-evaluate-test.cs @@ -522,7 +522,6 @@ public class TestEvaluate public int idx0; public int idx1; - // ToDo: add 2d indexing - https://github.com/dotnet/runtime/issues/76062 public string this[char key] => "res_" + key; public string this[bool key] => key.ToString(); public bool this[string key] => key.Length > 3; @@ -530,11 +529,14 @@ public class TestEvaluate public int this[float key] => (int)key; public int this[decimal key] => (int)key; + public double this[int key1, double key2] => key1 + key2; + public string this[char key1, string key2, string key3] => $"{key1}-{key2}-{key3}"; + public void run() { numList = new List { 1, 2 }; textList = new List { "1", "2" }; - numArray = new int[] { 1, 2 }; + numArray = new int[] { 1, 2, 0 }; textArray = new string[] { "1", "2" }; numArrayOfArrays = new int[][] { numArray, numArray }; numListOfLists = new List> { numList, numList }; diff --git a/src/mono/wasm/features.md b/src/mono/wasm/features.md index 46aa26d85038e1..d3d85f2e570a9e 100644 --- a/src/mono/wasm/features.md +++ b/src/mono/wasm/features.md @@ -193,12 +193,13 @@ See also [fetch integrity on MDN](https://developer.mozilla.org/en-US/docs/Web/A ### Pre-fetching In order to start downloading application resources as soon as possible you can add HTML elements to `` of your page similar to: +Adding too many files into prefetch could be counterproductive. +Please benchmark your startup performance on real target devices and with realistic network conditions. ```html - ``` See also [link rel prefetch on MDN](https://developer.mozilla.org/en-US/docs/Web/HTML/Attributes/rel/prefetch) @@ -291,6 +292,24 @@ A WebAssembly application that works well on desktop PCs browser may take minute ### Shell environments - NodeJS & V8 While our primary target is web browsers, we have partial support for Node.JS v14 sufficient to pass most of our automated tests. We also have partial support for the D8 command-line shell, version 11 or higher, sufficient to pass most of our automated tests. Both of these environments may lack support for features that are available in the browser. +#### NodeJS < 20 +Until node version 20, you may need to pass these arguments when running the application `--experimental-wasm-simd --experimental-wasm-eh`. When you run the application using `dotnet run`, you can add these to the runtimeconfig template + +```json +"wasmHostProperties": { + "perHostConfig": [ + { + "name": "node", + ... + "host-args": [ + "--experimental-wasm-simd", // 👈 Enable SIMD support + "--experimental-wasm-eh" // 👈 Enable exception handling support + ] + } + ] +} +``` + ## Choosing the right platform target Every end user has different needs, so the right platform for every application may differ. diff --git a/src/mono/wasm/host/BrowserHost.cs b/src/mono/wasm/host/BrowserHost.cs index 5c16a420e76deb..cbe160eeb0214a 100644 --- a/src/mono/wasm/host/BrowserHost.cs +++ b/src/mono/wasm/host/BrowserHost.cs @@ -74,7 +74,7 @@ private async Task RunAsync(ILoggerFactory loggerFactory, CancellationToken toke debugging: _args.CommonConfig.Debugging); runArgsJson.Save(Path.Combine(_args.CommonConfig.AppPath, "runArgs.json")); - string[] urls = envVars.TryGetValue("ASPNETCORE_URLS", out string? aspnetUrls) + string[] urls = (envVars.TryGetValue("ASPNETCORE_URLS", out string? aspnetUrls) && aspnetUrls.Length > 0) ? aspnetUrls.Split(';', StringSplitOptions.RemoveEmptyEntries) : new string[] { $"http://127.0.0.1:{_args.CommonConfig.HostProperties.WebServerPort}", "https://127.0.0.1:0" }; @@ -167,7 +167,7 @@ private static DevServerOptions CreateDevServerOptions(BrowserArguments args, st devServerOptions = CreateDevServerOptions(urls, staticWebAssetsPath, onConsoleConnected); if (devServerOptions == null) - throw new CommandLineException("Please, provide mainAssembly in hostProperties of runtimeconfig"); + throw new CommandLineException($"Please, provide mainAssembly in hostProperties of runtimeconfig. Alternatively leave the static web assets manifest ('*{staticWebAssetsV2Extension}') in the build output directory '{appPath}' ."); } return devServerOptions; @@ -183,7 +183,7 @@ private static DevServerOptions CreateDevServerOptions(BrowserArguments args, st ); private static string? FindFirstFileWithExtension(string directory, string extension) - => Directory.EnumerateFiles(directory, "*" + extension).First(); + => Directory.EnumerateFiles(directory, "*" + extension).FirstOrDefault(); private async Task RunConsoleMessagesPump(WebSocket socket, WasmTestMessagesProcessor messagesProcessor, CancellationToken token) { diff --git a/src/mono/wasm/host/DevServer/DevServer.cs b/src/mono/wasm/host/DevServer/DevServer.cs index b1369deabcc0cc..9a5a079cee695e 100644 --- a/src/mono/wasm/host/DevServer/DevServer.cs +++ b/src/mono/wasm/host/DevServer/DevServer.cs @@ -46,7 +46,8 @@ internal static class DevServer services.AddSingleton(Options.Create(options)); services.AddSingleton(realUrlsAvailableTcs); services.AddRouting(); - }); + }) + .UseUrls(options.Urls); IWebHost? host = builder.Build(); @@ -70,8 +71,7 @@ private static IConfiguration ConfigureHostConfiguration(DevServerOptions option [WebHostDefaults.EnvironmentKey] = "Development", ["Logging:LogLevel:Microsoft"] = "Warning", ["Logging:LogLevel:Microsoft.Hosting.Lifetime"] = "Information", - [WebHostDefaults.StaticWebAssetsKey] = options.StaticWebAssetsPath, - ["ApplyCopHeaders"] = options.WebServerUseCrossOriginPolicy.ToString() + [WebHostDefaults.StaticWebAssetsKey] = options.StaticWebAssetsPath }; config.AddInMemoryCollection(inMemoryConfiguration); diff --git a/src/mono/wasm/host/DevServer/DevServerStartup.cs b/src/mono/wasm/host/DevServer/DevServerStartup.cs index f438caf4b4b7ae..0fdc45a1754fdf 100644 --- a/src/mono/wasm/host/DevServer/DevServerStartup.cs +++ b/src/mono/wasm/host/DevServer/DevServerStartup.cs @@ -2,13 +2,16 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.IO; +using System.Net.WebSockets; using System.Threading.Tasks; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Routing; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; using Microsoft.WebAssembly.AppHost; namespace Microsoft.WebAssembly.AppHost.DevServer; @@ -27,16 +30,16 @@ public static void ConfigureServices(IServiceCollection services) services.AddRouting(); } - public static void Configure(IApplicationBuilder app, TaskCompletionSource realUrlsAvailableTcs, ILogger logger, IHostApplicationLifetime applicationLifetime, IConfiguration configuration) + public static void Configure(IApplicationBuilder app, IOptions optionsContainer, TaskCompletionSource realUrlsAvailableTcs, ILogger logger, IHostApplicationLifetime applicationLifetime, IConfiguration configuration) { app.UseDeveloperExceptionPage(); EnableConfiguredPathbase(app, configuration); app.UseWebAssemblyDebugging(); - bool applyCopHeaders = configuration.GetValue("ApplyCopHeaders"); + DevServerOptions options = optionsContainer.Value; - if (applyCopHeaders) + if (options.WebServerUseCrossOriginPolicy) { app.Use(async (ctx, next) => { @@ -63,6 +66,29 @@ public static void Configure(IApplicationBuilder app, TaskCompletionSource + { + if (ctx.Request.Path.StartsWithSegments("/console")) + { + if (!ctx.WebSockets.IsWebSocketRequest) + { + ctx.Response.StatusCode = 400; + return; + } + + using WebSocket socket = await ctx.WebSockets.AcceptWebSocketAsync(); + await options.OnConsoleConnected(socket); + } + else + { + await next(ctx); + } + }); + } app.UseEndpoints(endpoints => { @@ -70,7 +96,7 @@ public static void Configure(IApplicationBuilder app, TaskCompletionSource { - if (applyCopHeaders) + if (options.WebServerUseCrossOriginPolicy) { // Browser multi-threaded runtime requires cross-origin policy headers to enable SharedArrayBuffer. ApplyCrossOriginPolicyHeaders(fileContext.Context); diff --git a/src/mono/wasm/host/Program.cs b/src/mono/wasm/host/Program.cs index a5005dce9d0fc4..105eb59929ecab 100644 --- a/src/mono/wasm/host/Program.cs +++ b/src/mono/wasm/host/Program.cs @@ -30,6 +30,9 @@ public static async Task Main(string[] args) RegisterHostHandler(WasmHost.Wasmtime, WasiEngineHost.InvokeAsync); using CancellationTokenSource cts = new(); + + Console.CancelKeyPress += (object? sender, ConsoleCancelEventArgs e) => cts.Cancel(); + ILoggerFactory loggerFactory = LoggerFactory.Create(builder => builder .AddPassThroughConsole() diff --git a/src/mono/wasm/host/RuntimeConfigJson.cs b/src/mono/wasm/host/RuntimeConfigJson.cs index ed698ed8fb3725..3ad30dd88015ae 100644 --- a/src/mono/wasm/host/RuntimeConfigJson.cs +++ b/src/mono/wasm/host/RuntimeConfigJson.cs @@ -24,7 +24,7 @@ internal sealed record WasmHostProperties( int? FirefoxDebuggingPort, int? ChromeProxyPort, int? ChromeDebuggingPort, - int WebServerPort = 9000) + int WebServerPort = 0) { // using an explicit property because the deserializer doesn't like // extension data in the record constructor diff --git a/src/mono/wasm/runtime/CMakeLists.txt b/src/mono/wasm/runtime/CMakeLists.txt index 6b8ef873ec27da..6d939088d74314 100644 --- a/src/mono/wasm/runtime/CMakeLists.txt +++ b/src/mono/wasm/runtime/CMakeLists.txt @@ -25,6 +25,7 @@ target_link_libraries(dotnet.native ${MONO_ARTIFACTS_DIR}/libmonosgen-2.0.a ${MONO_ARTIFACTS_DIR}/libmono-icall-table.a ${MONO_ARTIFACTS_DIR}/libmono-wasm-eh-js.a + ${MONO_ARTIFACTS_DIR}/libmono-wasm-${CONFIGURATION_INTERPSIMDTABLES_LIB}.a ${MONO_ARTIFACTS_DIR}/libmono-profiler-aot.a ${MONO_ARTIFACTS_DIR}/libmono-profiler-browser.a ${NATIVE_BIN_DIR}/wasm-bundled-timezones.a diff --git a/src/mono/wasm/runtime/es6/dotnet.es6.pre.js b/src/mono/wasm/runtime/es6/dotnet.es6.pre.js index 9eb9b1c6b99e7d..490935d5ca0284 100644 --- a/src/mono/wasm/runtime/es6/dotnet.es6.pre.js +++ b/src/mono/wasm/runtime/es6/dotnet.es6.pre.js @@ -1,3 +1,5 @@ if (_nativeModuleLoaded) throw new Error("Native module already loaded"); _nativeModuleLoaded = true; -createDotnetRuntime = Module = createDotnetRuntime(Module); \ No newline at end of file +createDotnetRuntime = Module = createDotnetRuntime(Module); +Module["getWasmIndirectFunctionTable"] = function () { return wasmTable; } +Module["getMemory"] = function () { return wasmMemory; } diff --git a/src/mono/wasm/runtime/exports.ts b/src/mono/wasm/runtime/exports.ts index a29fd2da59d84f..6e50e26e164c0c 100644 --- a/src/mono/wasm/runtime/exports.ts +++ b/src/mono/wasm/runtime/exports.ts @@ -8,7 +8,7 @@ import type { RuntimeAPI } from "./types"; import { Module, linkerDisableLegacyJsInterop, exportedRuntimeAPI, passEmscriptenInternals, runtimeHelpers, setRuntimeGlobals, } from "./globals"; import { GlobalObjects, is_nullish } from "./types/internal"; -import { configureEmscriptenStartup, configureWorkerStartup } from "./startup"; +import { configureEmscriptenStartup, configureRuntimeStartup, configureWorkerStartup } from "./startup"; import { create_weak_ref } from "./weak-ref"; import { export_internal } from "./exports-internal"; @@ -143,5 +143,5 @@ class RuntimeList { // export external API export { - passEmscriptenInternals, initializeExports, initializeReplacements, configureEmscriptenStartup, configureWorkerStartup, setRuntimeGlobals + passEmscriptenInternals, initializeExports, initializeReplacements, configureRuntimeStartup, configureEmscriptenStartup, configureWorkerStartup, setRuntimeGlobals }; \ No newline at end of file diff --git a/src/mono/wasm/runtime/globals.ts b/src/mono/wasm/runtime/globals.ts index 5db69fc91bf619..88be75543ab4e5 100644 --- a/src/mono/wasm/runtime/globals.ts +++ b/src/mono/wasm/runtime/globals.ts @@ -15,8 +15,8 @@ export let Module: DotnetModuleInternal; export let INTERNAL: any; export const ENVIRONMENT_IS_NODE = typeof process == "object" && typeof process.versions == "object" && typeof process.versions.node == "string"; -export const ENVIRONMENT_IS_WEB = typeof window == "object"; export const ENVIRONMENT_IS_WORKER = typeof importScripts == "function"; +export const ENVIRONMENT_IS_WEB = typeof window == "object" || (ENVIRONMENT_IS_WORKER && !ENVIRONMENT_IS_NODE); export const ENVIRONMENT_IS_SHELL = !ENVIRONMENT_IS_WEB && !ENVIRONMENT_IS_NODE && !ENVIRONMENT_IS_WORKER; // these are imported and re-exported from emscripten internals export let ENVIRONMENT_IS_PTHREAD: boolean; @@ -59,7 +59,6 @@ export function setRuntimeGlobals(globalObjects: GlobalObjects) { gitHash, allAssetsInMemory: createPromiseController(), dotnetReady: createPromiseController(), - memorySnapshotSkippedOrDone: createPromiseController(), afterInstantiateWasm: createPromiseController(), beforePreInit: createPromiseController(), afterPreInit: createPromiseController(), @@ -67,6 +66,12 @@ export function setRuntimeGlobals(globalObjects: GlobalObjects) { beforeOnRuntimeInitialized: createPromiseController(), afterOnRuntimeInitialized: createPromiseController(), afterPostRun: createPromiseController(), + mono_wasm_exit: () => { + throw new Error("Mono shutdown"); + }, + abort: (reason: any) => { + throw reason; + } }); Object.assign(globalObjects.module.config!, {}) as any; diff --git a/src/mono/wasm/runtime/jiterpreter-jit-call.ts b/src/mono/wasm/runtime/jiterpreter-jit-call.ts index eba38843de57ef..af918bedb0cf9c 100644 --- a/src/mono/wasm/runtime/jiterpreter-jit-call.ts +++ b/src/mono/wasm/runtime/jiterpreter-jit-call.ts @@ -281,7 +281,7 @@ export function mono_jiterp_do_jit_call_indirect( jit_call_cb: jitCallCb, }, m: { - h: (Module).asm.memory + h: (Module).getMemory() }, }); const impl = instance.exports.do_jit_call_indirect; diff --git a/src/mono/wasm/runtime/jiterpreter-support.ts b/src/mono/wasm/runtime/jiterpreter-support.ts index 5ce5759b2a6a4b..0589c28d5bd06e 100644 --- a/src/mono/wasm/runtime/jiterpreter-support.ts +++ b/src/mono/wasm/runtime/jiterpreter-support.ts @@ -239,9 +239,12 @@ export class WasmBuilder { } getWasmImports(): WebAssembly.Imports { + const memory = (Module).getMemory(); + mono_assert(memory instanceof WebAssembly.Memory, () => `expected heap import to be WebAssembly.Memory but was ${memory}`); + const result: any = { c: this.getConstants(), - m: { h: (Module).asm.memory }, + m: { h: memory }, // f: { f: getWasmFunctionTable() }, }; @@ -1589,7 +1592,7 @@ export function copyIntoScratchBuffer(src: NativePointer, size: number): NativeP export function getWasmFunctionTable() { if (!wasmTable) - wasmTable = (Module)["asm"]["__indirect_function_table"]; + wasmTable = Module.getWasmIndirectFunctionTable(); if (!wasmTable) throw new Error("Module did not export the indirect function table"); return wasmTable; diff --git a/src/mono/wasm/runtime/loader/assets.ts b/src/mono/wasm/runtime/loader/assets.ts index ba18815a08db80..be7c0e4592938b 100644 --- a/src/mono/wasm/runtime/loader/assets.ts +++ b/src/mono/wasm/runtime/loader/assets.ts @@ -78,8 +78,6 @@ const containedInSnapshotByAssetTypes: { "pdb": true, "heap": true, "icu": true, - ...jsModulesAssetTypes, - "dotnetwasm": true, }; // these assets are instantiated differently than the main flow @@ -95,7 +93,7 @@ export function shouldLoadIcuAsset(asset: AssetEntryInternal): boolean { return !(asset.behavior == "icu" && asset.name != loaderHelpers.preferredIcuAsset); } -function convert_single_asset(modulesAssets: AssetEntryInternal[], resource: ResourceList | undefined, behavior: SingleAssetBehaviors): AssetEntryInternal { +function convert_single_asset(assetsCollection: AssetEntryInternal[], resource: ResourceList | undefined, behavior: SingleAssetBehaviors): AssetEntryInternal { const keys = Object.keys(resource || {}); mono_assert(keys.length == 1, `Expect to have one ${behavior} asset in resources`); @@ -110,7 +108,7 @@ function convert_single_asset(modulesAssets: AssetEntryInternal[], resource: Res set_single_asset(asset); // so that we can use it on the worker too - modulesAssets.push(asset); + assetsCollection.push(asset); return asset; } @@ -168,15 +166,12 @@ export async function mono_download_assets(): Promise { countAndStartDownload(asset); } - // continue after the dotnet.runtime.js was loaded - await loaderHelpers.runtimeModuleLoaded.promise; - // continue after we know if memory snapshot is available or not - await runtimeHelpers.memorySnapshotSkippedOrDone.promise; + await loaderHelpers.memorySnapshotSkippedOrDone.promise; // start fetching assets in parallel, only if memory snapshot is not available. for (const asset of containedInSnapshotAssets) { - if (!runtimeHelpers.loadedMemorySnapshot) { + if (!runtimeHelpers.loadedMemorySnapshotSize) { countAndStartDownload(asset); } else { // Otherwise cleanup in case we were given pending download. It would be even better if we could abort the download. @@ -193,6 +188,8 @@ export async function mono_download_assets(): Promise { } loaderHelpers.allDownloadsQueued.promise_control.resolve(); + + // continue after the dotnet.runtime.js was loaded await loaderHelpers.runtimeModuleLoaded.promise; const promises_of_asset_instantiation: Promise[] = []; @@ -211,7 +208,6 @@ export async function mono_download_assets(): Promise { // wait till after onRuntimeInitialized and after memory snapshot is loaded or skipped await runtimeHelpers.beforeOnRuntimeInitialized.promise; - await runtimeHelpers.memorySnapshotSkippedOrDone.promise; runtimeHelpers.instantiate_asset(asset, url, data); } } else { @@ -284,7 +280,7 @@ export function prepareAssets() { mono_assert(resources.jsModuleNative, "resources.jsModuleNative must be defined"); mono_assert(resources.jsModuleRuntime, "resources.jsModuleRuntime must be defined"); mono_assert(!MonoWasmThreads || resources.jsModuleWorker, "resources.jsModuleWorker must be defined"); - convert_single_asset(modulesAssets, resources.wasmNative, "dotnetwasm"); + convert_single_asset(alwaysLoadedAssets, resources.wasmNative, "dotnetwasm"); convert_single_asset(modulesAssets, resources.jsModuleNative, "js-module-native"); convert_single_asset(modulesAssets, resources.jsModuleRuntime, "js-module-runtime"); if (MonoWasmThreads) { diff --git a/src/mono/wasm/runtime/loader/config.ts b/src/mono/wasm/runtime/loader/config.ts index c35553ae593e8c..097f6d1ca2cb15 100644 --- a/src/mono/wasm/runtime/loader/config.ts +++ b/src/mono/wasm/runtime/loader/config.ts @@ -10,6 +10,7 @@ import { importLibraryInitializers, invokeLibraryInitializers } from "./libraryI import { mono_exit } from "./exit"; import { makeURLAbsoluteWithApplicationBase } from "./polyfills"; import { appendUniqueQuery } from "./assets"; +import { mono_assert } from "./globals"; export function deep_merge_config(target: MonoConfigInternal, source: MonoConfigInternal): MonoConfigInternal { // no need to merge the same object @@ -220,15 +221,12 @@ export async function mono_wasm_load_config(module: DotnetModuleInternal): Promi await loaderHelpers.afterConfigLoaded.promise; return; } - configLoaded = true; - if (!configFilePath) { - normalizeConfig(); - loaderHelpers.afterConfigLoaded.promise_control.resolve(loaderHelpers.config); - return; - } - mono_log_debug("mono_wasm_load_config"); try { - await loadBootConfig(module); + configLoaded = true; + if (configFilePath) { + mono_log_debug("mono_wasm_load_config"); + await loadBootConfig(module); + } normalizeConfig(); @@ -249,7 +247,12 @@ export async function mono_wasm_load_config(module: DotnetModuleInternal): Promi normalizeConfig(); + mono_assert(!loaderHelpers.config.startupMemoryCache || !module.instantiateWasm, "startupMemoryCache is not supported with Module.instantiateWasm"); + loaderHelpers.afterConfigLoaded.promise_control.resolve(loaderHelpers.config); + if (!loaderHelpers.config.startupMemoryCache) { + loaderHelpers.memorySnapshotSkippedOrDone.promise_control.resolve(); + } } catch (err) { const errMessage = `Failed to load config file ${configFilePath} ${err} ${(err as Error)?.stack}`; loaderHelpers.config = module.config = Object.assign(loaderHelpers.config, { message: errMessage, error: err, isError: true }); diff --git a/src/mono/wasm/runtime/loader/exit.ts b/src/mono/wasm/runtime/loader/exit.ts index e8cb2e42eb2ef1..6b6767e89672e6 100644 --- a/src/mono/wasm/runtime/loader/exit.ts +++ b/src/mono/wasm/runtime/loader/exit.ts @@ -122,9 +122,9 @@ function abort_promises(reason: any) { loaderHelpers.afterConfigLoaded.promise_control.reject(reason); loaderHelpers.wasmDownloadPromise.promise_control.reject(reason); loaderHelpers.runtimeModuleLoaded.promise_control.reject(reason); + loaderHelpers.memorySnapshotSkippedOrDone.promise_control.reject(reason); if (runtimeHelpers.dotnetReady) { runtimeHelpers.dotnetReady.promise_control.reject(reason); - runtimeHelpers.memorySnapshotSkippedOrDone.promise_control.reject(reason); runtimeHelpers.afterInstantiateWasm.promise_control.reject(reason); runtimeHelpers.beforePreInit.promise_control.reject(reason); runtimeHelpers.afterPreInit.promise_control.reject(reason); diff --git a/src/mono/wasm/runtime/loader/globals.ts b/src/mono/wasm/runtime/loader/globals.ts index a710ea0cb3b92d..88b0d472de3ca8 100644 --- a/src/mono/wasm/runtime/loader/globals.ts +++ b/src/mono/wasm/runtime/loader/globals.ts @@ -16,8 +16,8 @@ import { hasDebuggingEnabled } from "./config"; import { logDownloadStatsToConsole, purgeUnusedCacheEntriesAsync } from "./assetsCache"; export const ENVIRONMENT_IS_NODE = typeof process == "object" && typeof process.versions == "object" && typeof process.versions.node == "string"; -export const ENVIRONMENT_IS_WEB = typeof window == "object"; export const ENVIRONMENT_IS_WORKER = typeof importScripts == "function"; +export const ENVIRONMENT_IS_WEB = typeof window == "object" || (ENVIRONMENT_IS_WORKER && !ENVIRONMENT_IS_NODE); export const ENVIRONMENT_IS_SHELL = !ENVIRONMENT_IS_WEB && !ENVIRONMENT_IS_NODE && !ENVIRONMENT_IS_WORKER; export let runtimeHelpers: RuntimeHelpers = {} as any; @@ -87,6 +87,7 @@ export function setLoaderGlobals( allDownloadsQueued: createPromiseController(), wasmDownloadPromise: createPromiseController(), runtimeModuleLoaded: createPromiseController(), + memorySnapshotSkippedOrDone: createPromiseController(), is_exited, is_runtime_running, diff --git a/src/mono/wasm/runtime/loader/run.ts b/src/mono/wasm/runtime/loader/run.ts index a3cb7977a30f0c..b3437a5a81d61b 100644 --- a/src/mono/wasm/runtime/loader/run.ts +++ b/src/mono/wasm/runtime/loader/run.ts @@ -454,10 +454,11 @@ function importModules() { } async function initializeModules(es6Modules: [RuntimeModuleExportsInternal, NativeModuleExportsInternal]) { - const { initializeExports, initializeReplacements, configureEmscriptenStartup, configureWorkerStartup, setRuntimeGlobals, passEmscriptenInternals } = es6Modules[0]; + const { initializeExports, initializeReplacements, configureRuntimeStartup, configureEmscriptenStartup, configureWorkerStartup, setRuntimeGlobals, passEmscriptenInternals } = es6Modules[0]; const { default: emscriptenFactory } = es6Modules[1]; setRuntimeGlobals(globalObjectsRoot); initializeExports(globalObjectsRoot); + await configureRuntimeStartup(); loaderHelpers.runtimeModuleLoaded.promise_control.resolve(); emscriptenFactory((originalModule: EmscriptenModuleInternal) => { @@ -494,9 +495,8 @@ async function createEmscriptenMain(): Promise { mono_exit(1, err); }); - init_globalization(); - setTimeout(() => { + init_globalization(); mono_download_assets(); // intentionally not awaited }, 0); diff --git a/src/mono/wasm/runtime/marshal-to-cs.ts b/src/mono/wasm/runtime/marshal-to-cs.ts index a0f5b58a5cd332..14d6c7e3dd7479 100644 --- a/src/mono/wasm/runtime/marshal-to-cs.ts +++ b/src/mono/wasm/runtime/marshal-to-cs.ts @@ -24,6 +24,7 @@ import { TypedArray } from "./types/emscripten"; import { addUnsettledPromise, settleUnsettledPromise } from "./pthreads/shared/eventloop"; import { mono_log_warn } from "./logging"; +export const jsinteropDoc = "For more information see https://aka.ms/dotnet-wasm-jsinterop"; export function initialize_marshalers_to_cs(): void { if (js_to_cs_marshalers.size == 0) { @@ -389,7 +390,7 @@ export function marshal_js_object_to_cs(arg: JSMarshalerArgument, value: any): v } else { // if value was ManagedObject, it would be double proxied, but the C# signature requires that - mono_check(value[js_owned_gc_handle_symbol] === undefined, "JSObject proxy of ManagedObject proxy is not supported"); + mono_check(value[js_owned_gc_handle_symbol] === undefined, () => `JSObject proxy of ManagedObject proxy is not supported. ${jsinteropDoc}`); mono_check(typeof value === "function" || typeof value === "object", () => `JSObject proxy of ${typeof value} is not supported`); set_arg_type(arg, MarshalerType.JSObject); @@ -474,7 +475,7 @@ function _marshal_cs_object_to_cs(arg: JSMarshalerArgument, value: any): void { else { assert_not_disposed(value); if (value instanceof ArraySegment) { - throw new Error("NotImplementedException: ArraySegment"); + throw new Error("NotImplementedException: ArraySegment. " + jsinteropDoc); } else if (value instanceof ManagedError) { set_arg_type(arg, MarshalerType.Exception); @@ -484,7 +485,7 @@ function _marshal_cs_object_to_cs(arg: JSMarshalerArgument, value: any): void { set_arg_type(arg, MarshalerType.Object); set_gc_handle(arg, gc_handle); } else { - throw new Error("NotImplementedException " + js_type); + throw new Error("NotImplementedException " + js_type + ". " + jsinteropDoc); } } } diff --git a/src/mono/wasm/runtime/marshal-to-js.ts b/src/mono/wasm/runtime/marshal-to-js.ts index 9b5a931067fda6..1ad71dbc4ae784 100644 --- a/src/mono/wasm/runtime/marshal-to-js.ts +++ b/src/mono/wasm/runtime/marshal-to-js.ts @@ -18,7 +18,7 @@ import { import { monoStringToString } from "./strings"; import { JSHandleNull, GCHandleNull, JSMarshalerArgument, JSMarshalerArguments, JSMarshalerType, MarshalerToCs, MarshalerToJs, BoundMarshalerToJs, MarshalerType } from "./types/internal"; import { TypedArray } from "./types/emscripten"; -import { get_marshaler_to_cs_by_type } from "./marshal-to-cs"; +import { get_marshaler_to_cs_by_type, jsinteropDoc } from "./marshal-to-cs"; import { localHeapViewF64, localHeapViewI32, localHeapViewU8 } from "./memory"; export function initialize_marshalers_to_js(): void { @@ -85,7 +85,7 @@ export function get_marshaler_to_js_by_type(marshaler_type: MarshalerType): Mars return undefined; } const converter = cs_to_js_marshalers.get(marshaler_type); - mono_assert(converter && typeof converter === "function", () => `ERR41: Unknown converter for type ${marshaler_type}`); + mono_assert(converter && typeof converter === "function", () => `ERR41: Unknown converter for type ${marshaler_type}. ${jsinteropDoc}`); return converter; } @@ -224,7 +224,7 @@ export function marshal_task_to_js(arg: JSMarshalerArgument, _?: MarshalerType, // when we arrived here from _marshal_cs_object_to_js res_converter = cs_to_js_marshalers.get(type); } - mono_assert(res_converter, () => `Unknown sub_converter for type ${MarshalerType[type]} `); + mono_assert(res_converter, () => `Unknown sub_converter for type ${MarshalerType[type]}. ${jsinteropDoc}`); // this is already resolved const val = res_converter(arg); @@ -256,7 +256,7 @@ export function marshal_task_to_js(arg: JSMarshalerArgument, _?: MarshalerType, // when we arrived here from _marshal_cs_object_to_js res_converter = cs_to_js_marshalers.get(type); } - mono_assert(res_converter, () => `Unknown sub_converter for type ${MarshalerType[type]}`); + mono_assert(res_converter, () => `Unknown sub_converter for type ${MarshalerType[type]}. ${jsinteropDoc}`); const js_value = res_converter!(argInner); orig_resolve(js_value); @@ -291,7 +291,7 @@ export function mono_wasm_marshal_promise(args: JSMarshalerArguments): void { else if (value_type !== MarshalerType.Task) { // this is already resolved task const sub_converter = cs_to_js_marshalers.get(value_type); - mono_assert(sub_converter, () => `Unknown sub_converter for type ${MarshalerType[value_type]} `); + mono_assert(sub_converter, () => `Unknown sub_converter for type ${MarshalerType[value_type]}. ${jsinteropDoc}`); const data = sub_converter(arg_value); promise_control.resolve(data); } @@ -406,7 +406,7 @@ function _marshal_cs_object_to_js(arg: JSMarshalerArgument): any { // other types const converter = cs_to_js_marshalers.get(marshaler_type); - mono_assert(converter, () => `Unknown converter for type ${MarshalerType[marshaler_type]}`); + mono_assert(converter, () => `Unknown converter for type ${MarshalerType[marshaler_type]}. ${jsinteropDoc}`); return converter(arg); } @@ -461,7 +461,7 @@ function _marshal_array_to_js_impl(arg: JSMarshalerArgument, element_type: Marsh result = sourceView.slice();//copy } else { - throw new Error(`NotImplementedException ${MarshalerType[element_type]} `); + throw new Error(`NotImplementedException ${MarshalerType[element_type]}. ${jsinteropDoc}`); } Module._free(buffer_ptr); return result; @@ -483,7 +483,7 @@ function _marshal_span_to_js(arg: JSMarshalerArgument, element_type?: MarshalerT result = new Span(buffer_ptr, length, MemoryViewType.Double); } else { - throw new Error(`NotImplementedException ${MarshalerType[element_type]} `); + throw new Error(`NotImplementedException ${MarshalerType[element_type]}. ${jsinteropDoc}`); } return result; } @@ -504,7 +504,7 @@ function _marshal_array_segment_to_js(arg: JSMarshalerArgument, element_type?: M result = new ArraySegment(buffer_ptr, length, MemoryViewType.Double); } else { - throw new Error(`NotImplementedException ${MarshalerType[element_type]} `); + throw new Error(`NotImplementedException ${MarshalerType[element_type]}. ${jsinteropDoc}`); } const gc_handle = get_arg_gc_handle(arg); if (BuildConfiguration === "Debug") { diff --git a/src/mono/wasm/runtime/polyfills.ts b/src/mono/wasm/runtime/polyfills.ts index fbf3a400f136c4..2ce605955db512 100644 --- a/src/mono/wasm/runtime/polyfills.ts +++ b/src/mono/wasm/runtime/polyfills.ts @@ -4,7 +4,7 @@ import MonoWasmThreads from "consts:monoWasmThreads"; import type { EmscriptenReplacements } from "./types/internal"; import type { TypedArray } from "./types/emscripten"; -import { ENVIRONMENT_IS_NODE, ENVIRONMENT_IS_WEB, INTERNAL, Module, loaderHelpers, runtimeHelpers } from "./globals"; +import { ENVIRONMENT_IS_NODE, ENVIRONMENT_IS_PTHREAD, ENVIRONMENT_IS_WEB, INTERNAL, Module, loaderHelpers, runtimeHelpers } from "./globals"; import { replaceEmscriptenPThreadLibrary } from "./pthreads/shared/emscripten-replacements"; const dummyPerformance = { @@ -30,7 +30,7 @@ export function initializeReplacements(replacements: EmscriptenReplacements): vo replacements.fetch = loaderHelpers.fetch_like; // misc - replacements.noExitRuntime = ENVIRONMENT_IS_WEB; + replacements.noExitRuntime = ENVIRONMENT_IS_WEB && !ENVIRONMENT_IS_PTHREAD; // threads if (MonoWasmThreads) { diff --git a/src/mono/wasm/runtime/snapshot.ts b/src/mono/wasm/runtime/snapshot.ts index 1d1df904643f51..c23422353a7dd3 100644 --- a/src/mono/wasm/runtime/snapshot.ts +++ b/src/mono/wasm/runtime/snapshot.ts @@ -44,22 +44,35 @@ async function openCache(): Promise { } } -export async function getMemorySnapshotSize(): Promise { +export async function checkMemorySnapshotSize(): Promise { try { + if (!runtimeHelpers.config.startupMemoryCache) { + // we could start downloading DLLs because snapshot is disabled + return; + } + const cacheKey = await getCacheKey(); if (!cacheKey) { - return undefined; + return; } const cache = await openCache(); if (!cache) { - return undefined; + return; } const res = await cache.match(cacheKey); const contentLength = res?.headers.get("content-length"); - return contentLength ? parseInt(contentLength) : undefined; + const memorySize = contentLength ? parseInt(contentLength) : undefined; + + runtimeHelpers.loadedMemorySnapshotSize = memorySize; + runtimeHelpers.storeMemorySnapshotPending = !memorySize; } catch (ex) { mono_log_warn("Failed find memory snapshot in the cache", ex); - return undefined; + } + finally { + if (!runtimeHelpers.loadedMemorySnapshotSize) { + // we could start downloading DLLs because there is no snapshot yet + loaderHelpers.memorySnapshotSkippedOrDone.promise_control.resolve(); + } } } diff --git a/src/mono/wasm/runtime/startup.ts b/src/mono/wasm/runtime/startup.ts index f67a8b7c18ba90..c385844dfd30bb 100644 --- a/src/mono/wasm/runtime/startup.ts +++ b/src/mono/wasm/runtime/startup.ts @@ -5,7 +5,7 @@ import MonoWasmThreads from "consts:monoWasmThreads"; import WasmEnableLegacyJsInterop from "consts:wasmEnableLegacyJsInterop"; import { DotnetModuleInternal, CharPtrNull } from "./types/internal"; -import { linkerDisableLegacyJsInterop, ENVIRONMENT_IS_PTHREAD, exportedRuntimeAPI, INTERNAL, loaderHelpers, Module, runtimeHelpers, createPromiseController, mono_assert, linkerWasmEnableSIMD, linkerWasmEnableEH } from "./globals"; +import { linkerDisableLegacyJsInterop, ENVIRONMENT_IS_PTHREAD, exportedRuntimeAPI, INTERNAL, loaderHelpers, Module, runtimeHelpers, createPromiseController, mono_assert, linkerWasmEnableSIMD, linkerWasmEnableEH, ENVIRONMENT_IS_NODE, ENVIRONMENT_IS_WORKER } from "./globals"; import cwraps, { init_c_exports } from "./cwraps"; import { mono_wasm_raise_debug_event, mono_wasm_runtime_ready } from "./debug"; import { toBase64StringImpl } from "./base64"; @@ -21,7 +21,7 @@ import { instantiate_wasm_asset, wait_for_all_assets } from "./assets"; import { mono_wasm_init_diagnostics } from "./diagnostics"; import { replace_linker_placeholders } from "./exports-binding"; import { endMeasure, MeasuredBlock, startMeasure } from "./profiler"; -import { getMemorySnapshot, storeMemorySnapshot, getMemorySnapshotSize } from "./snapshot"; +import { checkMemorySnapshotSize, getMemorySnapshot, storeMemorySnapshot } from "./snapshot"; import { mono_log_debug, mono_log_error, mono_log_warn, mono_set_thread_id } from "./logging"; // threads @@ -39,6 +39,11 @@ import { assertNoProxies } from "./gc-handles"; // default size if MonoConfig.pthreadPoolSize is undefined const MONO_PTHREAD_POOL_SIZE = 4; +export async function configureRuntimeStartup(): Promise { + await init_polyfills_async(); + await checkMemorySnapshotSize(); +} + // we are making emscripten startup async friendly // emscripten is executing the events without awaiting it and so we need to block progress via PromiseControllers above export function configureEmscriptenStartup(module: DotnetModuleInternal): void { @@ -117,8 +122,6 @@ function instantiateWasm( const mark = startMeasure(); if (userInstantiateWasm) { - // user wasm instantiation doesn't support memory snapshots - runtimeHelpers.memorySnapshotSkippedOrDone.promise_control.resolve(); const exports = userInstantiateWasm(imports, (instance: WebAssembly.Instance, module: WebAssembly.Module | undefined) => { endMeasure(mark, MeasuredBlock.instantiateWasm); runtimeHelpers.afterInstantiateWasm.promise_control.resolve(); @@ -229,6 +232,15 @@ async function onRuntimeInitializedAsync(userOnRuntimeInitialized: () => void) { // wait for previous stage await runtimeHelpers.afterPreRun.promise; mono_log_debug("onRuntimeInitialized"); + + runtimeHelpers.mono_wasm_exit = cwraps.mono_wasm_exit; + runtimeHelpers.abort = (reason: any) => { + if (!loaderHelpers.is_exited()) { + cwraps.mono_wasm_abort(); + } + throw reason; + }; + const mark = startMeasure(); // signal this stage, this will allow pending assets to allocate memory runtimeHelpers.beforeOnRuntimeInitialized.promise_control.resolve(); @@ -261,6 +273,10 @@ async function onRuntimeInitializedAsync(userOnRuntimeInitialized: () => void) { bindings_init(); runtimeHelpers.runtimeReady = true; + if (ENVIRONMENT_IS_NODE && !ENVIRONMENT_IS_WORKER) { + Module.runtimeKeepalivePush(); + } + if (MonoWasmThreads) { runtimeHelpers.javaScriptExports.install_synchronization_context(); runtimeHelpers.jsSynchronizationContextInstalled = true; @@ -350,13 +366,6 @@ function mono_wasm_pre_init_essential(isWorker: boolean): void { } init_c_exports(); - runtimeHelpers.mono_wasm_exit = cwraps.mono_wasm_exit; - runtimeHelpers.abort = (reason: any) => { - if (!loaderHelpers.is_exited()) { - cwraps.mono_wasm_abort(); - } - throw reason; - }; cwraps_internal(INTERNAL); if (WasmEnableLegacyJsInterop && !linkerDisableLegacyJsInterop) { cwraps_mono_api(MONO); @@ -375,15 +384,6 @@ async function mono_wasm_pre_init_essential_async(): Promise { mono_log_debug("mono_wasm_pre_init_essential_async"); Module.addRunDependency("mono_wasm_pre_init_essential_async"); - if (linkerWasmEnableSIMD) { - mono_assert(await loaderHelpers.simd(), "This browser/engine doesn't support WASM SIMD. Please use a modern version. See also https://aka.ms/dotnet-wasm-features"); - } - if (linkerWasmEnableEH) { - mono_assert(await loaderHelpers.exceptions(), "This browser/engine doesn't support WASM exception handling. Please use a modern version. See also https://aka.ms/dotnet-wasm-features"); - } - - await init_polyfills_async(); - if (MonoWasmThreads) { preAllocatePThreadWorkerPool(MONO_PTHREAD_POOL_SIZE, runtimeHelpers.config); } @@ -457,25 +457,18 @@ async function instantiate_wasm_module( ): Promise { // this is called so early that even Module exports like addRunDependency don't exist yet try { - let memorySize: number | undefined = undefined; await loaderHelpers.afterConfigLoaded; mono_log_debug("instantiate_wasm_module"); - if (runtimeHelpers.config.startupMemoryCache) { - memorySize = await getMemorySnapshotSize(); - runtimeHelpers.loadedMemorySnapshot = !!memorySize; - runtimeHelpers.storeMemorySnapshotPending = !runtimeHelpers.loadedMemorySnapshot; - } - if (!runtimeHelpers.loadedMemorySnapshot) { - // we should start downloading DLLs etc as they are not in the snapshot - runtimeHelpers.memorySnapshotSkippedOrDone.promise_control.resolve(); - } - await runtimeHelpers.beforePreInit.promise; Module.addRunDependency("instantiate_wasm_module"); + const wasmFeaturePromise = ensureUsedWasmFeatures(); + replace_linker_placeholders(imports); const assetToLoad = await loaderHelpers.wasmDownloadPromise.promise; + + await wasmFeaturePromise; await instantiate_wasm_asset(assetToLoad, imports, successCallback); assetToLoad.pendingDownloadInternal = null as any; // GC assetToLoad.pendingDownload = null as any; // GC @@ -484,19 +477,19 @@ async function instantiate_wasm_module( mono_log_debug("instantiate_wasm_module done"); - if (runtimeHelpers.loadedMemorySnapshot) { + if (runtimeHelpers.loadedMemorySnapshotSize) { try { const wasmMemory = (Module.asm?.memory || Module.wasmMemory)!; // .grow() takes a delta compared to the previous size - wasmMemory.grow((memorySize! - wasmMemory.buffer.byteLength + 65535) >>> 16); + wasmMemory.grow((runtimeHelpers.loadedMemorySnapshotSize! - wasmMemory.buffer.byteLength + 65535) >>> 16); runtimeHelpers.updateMemoryViews(); } catch (err) { mono_log_warn("failed to resize memory for the snapshot", err); - runtimeHelpers.loadedMemorySnapshot = false; + runtimeHelpers.loadedMemorySnapshotSize = undefined; } // now we know if the loading of memory succeeded or not, we can start loading the rest of the assets - runtimeHelpers.memorySnapshotSkippedOrDone.promise_control.resolve(); + loaderHelpers.memorySnapshotSkippedOrDone.promise_control.resolve(); } runtimeHelpers.afterInstantiateWasm.promise_control.resolve(); } catch (err) { @@ -507,9 +500,18 @@ async function instantiate_wasm_module( Module.removeRunDependency("instantiate_wasm_module"); } +async function ensureUsedWasmFeatures() { + if (linkerWasmEnableSIMD) { + mono_assert(await loaderHelpers.simd(), "This browser/engine doesn't support WASM SIMD. Please use a modern version. See also https://aka.ms/dotnet-wasm-features"); + } + if (linkerWasmEnableEH) { + mono_assert(await loaderHelpers.exceptions(), "This browser/engine doesn't support WASM exception handling. Please use a modern version. See also https://aka.ms/dotnet-wasm-features"); + } +} + async function mono_wasm_before_memory_snapshot() { const mark = startMeasure(); - if (runtimeHelpers.loadedMemorySnapshot) { + if (runtimeHelpers.loadedMemorySnapshotSize) { // get the bytes after we re-sized the memory, so that we don't have too much memory in use at the same time const memoryBytes = await getMemorySnapshot(); const heapU8 = localHeapViewU8(); diff --git a/src/mono/wasm/runtime/types/internal.ts b/src/mono/wasm/runtime/types/internal.ts index fcbeb03eadefdd..a91b21973c8d5e 100644 --- a/src/mono/wasm/runtime/types/internal.ts +++ b/src/mono/wasm/runtime/types/internal.ts @@ -128,6 +128,7 @@ export type LoaderHelpers = { allDownloadsQueued: PromiseAndController, wasmDownloadPromise: PromiseAndController, runtimeModuleLoaded: PromiseAndController, + memorySnapshotSkippedOrDone: PromiseAndController, is_exited: () => boolean, is_runtime_running: () => boolean, @@ -176,7 +177,7 @@ export type RuntimeHelpers = { mono_wasm_runtime_is_ready: boolean; mono_wasm_bindings_is_ready: boolean; - loadedMemorySnapshot: boolean, + loadedMemorySnapshotSize?: number, enablePerfMeasure: boolean; waitForDebugger?: number; ExitStatus: ExitStatusError; @@ -194,7 +195,6 @@ export type RuntimeHelpers = { allAssetsInMemory: PromiseAndController, dotnetReady: PromiseAndController, - memorySnapshotSkippedOrDone: PromiseAndController, afterInstantiateWasm: PromiseAndController, beforePreInit: PromiseAndController, afterPreInit: PromiseAndController, @@ -454,6 +454,8 @@ export declare interface EmscriptenModuleInternal { ready: Promise; asm: { memory?: WebAssembly.Memory }; wasmMemory?: WebAssembly.Memory; + getWasmIndirectFunctionTable: any; + getMemory: WebAssembly.Memory; getWasmTableEntry(index: number): any; removeRunDependency(id: string): void; addRunDependency(id: string): void; @@ -490,6 +492,7 @@ export type setGlobalObjectsType = (globalObjects: GlobalObjects) => void; export type initializeExportsType = (globalObjects: GlobalObjects) => RuntimeAPI; export type initializeReplacementsType = (replacements: EmscriptenReplacements) => void; export type configureEmscriptenStartupType = (module: DotnetModuleInternal) => void; +export type configureRuntimeStartupType = () => Promise; export type configureWorkerStartupType = (module: DotnetModuleInternal) => Promise @@ -497,6 +500,7 @@ export type RuntimeModuleExportsInternal = { setRuntimeGlobals: setGlobalObjectsType, initializeExports: initializeExportsType, initializeReplacements: initializeReplacementsType, + configureRuntimeStartup: configureRuntimeStartupType, configureEmscriptenStartup: configureEmscriptenStartupType, configureWorkerStartup: configureWorkerStartupType, passEmscriptenInternals: passEmscriptenInternalsType, diff --git a/src/mono/wasm/templates/templates/browser/README.md b/src/mono/wasm/templates/templates/browser/README.md deleted file mode 100644 index 7ddf4fd1bce9bd..00000000000000 --- a/src/mono/wasm/templates/templates/browser/README.md +++ /dev/null @@ -1,26 +0,0 @@ -## .NET WebAssembly Browser app - -## Build - -You can build the app from Visual Studio or from the command-line: - -``` -dotnet build -c Debug/Release -``` - -After building the app, the result is in the `bin/$(Configuration)/net7.0/browser-wasm/AppBundle` directory. - -## Run - -You can build the app from Visual Studio or the command-line: - -``` -dotnet run -c Debug/Release -``` - -Or you can start any static file server from the AppBundle directory: - -``` -dotnet tool install dotnet-serve -dotnet serve -d:bin/$(Configuration)/net7.0/browser-wasm/AppBundle -``` \ No newline at end of file diff --git a/src/mono/wasm/templates/templates/browser/browser.0.csproj b/src/mono/wasm/templates/templates/browser/browser.0.csproj index 401bdae24fab86..588c5219582125 100644 --- a/src/mono/wasm/templates/templates/browser/browser.0.csproj +++ b/src/mono/wasm/templates/templates/browser/browser.0.csproj @@ -1,13 +1,6 @@ - + - net7.0 - browser-wasm - Exe + net8.0 true - - - - - diff --git a/src/mono/wasm/templates/templates/browser/runtimeconfig.template.json b/src/mono/wasm/templates/templates/browser/runtimeconfig.template.json index 8f0557352c6ed3..b96a94320ba5ee 100644 --- a/src/mono/wasm/templates/templates/browser/runtimeconfig.template.json +++ b/src/mono/wasm/templates/templates/browser/runtimeconfig.template.json @@ -3,9 +3,8 @@ "perHostConfig": [ { "name": "browser", - "html-path": "index.html", - "Host": "browser" + "host": "browser" } ] } -} +} \ No newline at end of file diff --git a/src/mono/wasm/templates/templates/browser/index.html b/src/mono/wasm/templates/templates/browser/wwwroot/index.html similarity index 100% rename from src/mono/wasm/templates/templates/browser/index.html rename to src/mono/wasm/templates/templates/browser/wwwroot/index.html diff --git a/src/mono/wasm/templates/templates/browser/main.js b/src/mono/wasm/templates/templates/browser/wwwroot/main.js similarity index 100% rename from src/mono/wasm/templates/templates/browser/main.js rename to src/mono/wasm/templates/templates/browser/wwwroot/main.js diff --git a/src/mono/wasm/templates/templates/console/runtimeconfig.template.json b/src/mono/wasm/templates/templates/console/runtimeconfig.template.json index 49721faa0baa4d..a056e67f11cf9a 100644 --- a/src/mono/wasm/templates/templates/console/runtimeconfig.template.json +++ b/src/mono/wasm/templates/templates/console/runtimeconfig.template.json @@ -4,12 +4,8 @@ { "name": "node", "js-path": "main.mjs", - "Host": "nodejs", - "host-args": [ - "--experimental-wasm-simd", - "--experimental-wasm-eh" - ] + "host": "nodejs" } ] } -} +} \ No newline at end of file diff --git a/src/mono/wasm/testassets/WasmBasicTestApp/WasmBasicTestApp.csproj b/src/mono/wasm/testassets/WasmBasicTestApp/WasmBasicTestApp.csproj index 33858b9a6a755f..761ac6354ce861 100644 --- a/src/mono/wasm/testassets/WasmBasicTestApp/WasmBasicTestApp.csproj +++ b/src/mono/wasm/testassets/WasmBasicTestApp/WasmBasicTestApp.csproj @@ -4,6 +4,7 @@ browser-wasm Exe true + true diff --git a/src/mono/wasm/wasm.proj b/src/mono/wasm/wasm.proj index c46c92ab70db70..3b75fa9feeb1f0 100644 --- a/src/mono/wasm/wasm.proj +++ b/src/mono/wasm/wasm.proj @@ -33,7 +33,6 @@ <_EmccDefaultsRspPath>$(NativeBinDir)src\emcc-default.rsp <_EmccCompileRspPath>$(NativeBinDir)src\emcc-compile.rsp <_EmccLinkRspPath>$(NativeBinDir)src\emcc-link.rsp - false $(EMSDK_PATH)\upstream\bin\llvm-ar $(EmSdkLLVMAr).exe @@ -388,7 +387,8 @@ $(CMakeBuildRuntimeConfigureCmd) -DICU_LIB_DIR="$(ICULibDir.TrimEnd('\/').Replace('\','/'))" $(CMakeBuildRuntimeConfigureCmd) -DMONO_ARTIFACTS_DIR="$(MonoArtifactsPath.TrimEnd('\/').Replace('\','/'))" $(CMakeBuildRuntimeConfigureCmd) -DNATIVE_BIN_DIR="$(NativeBinDir.TrimEnd('\/').Replace('\','/'))" - $(CMakeBuildRuntimeConfigureCmd) -DCONFIGURATION_COMPILE_OPTIONS="-msimd128" + $(CMakeBuildRuntimeConfigureCmd) -DCONFIGURATION_COMPILE_OPTIONS="-msimd128" -DCONFIGURATION_INTERPSIMDTABLES_LIB="simd" + $(CMakeBuildRuntimeConfigureCmd) -DCONFIGURATION_INTERPSIMDTABLES_LIB="nosimd" $(CMakeBuildRuntimeConfigureCmd) -DDISABLE_THREADS=0 $(CMakeBuildRuntimeConfigureCmd) -DDISABLE_LEGACY_JS_INTEROP=1 $(CMakeBuildRuntimeConfigureCmd) $(CMakeConfigurationEmsdkPath) diff --git a/src/native/corehost/fxr/command_line.cpp b/src/native/corehost/fxr/command_line.cpp index 7f0a4d5a929594..3c85d3669ec7b3 100644 --- a/src/native/corehost/fxr/command_line.cpp +++ b/src/native/corehost/fxr/command_line.cpp @@ -280,17 +280,18 @@ int command_line::parse_args_for_sdk_command( return parse_args(host_info, 1, argc, argv, false, host_mode_t::muxer, new_argoff, app_candidate, opts); } -void command_line::print_muxer_info(const pal::string_t &dotnet_root, const pal::string_t &global_json_path) +void command_line::print_muxer_info(const pal::string_t &dotnet_root, const pal::string_t &global_json_path, bool skip_sdk_info_output) { pal::string_t commit = _STRINGIFY(REPO_COMMIT_HASH); trace::println(_X("\n") _X("Host:\n") _X(" Version: ") _STRINGIFY(HOST_VERSION) _X("\n") _X(" Architecture: ") _STRINGIFY(CURRENT_ARCH_NAME) _X("\n") - _X(" Commit: %s\n") - _X(" RID: %s"), - commit.substr(0, 10).c_str(), - get_runtime_id().c_str()); + _X(" Commit: %s"), + commit.substr(0, 10).c_str()); + + if (!skip_sdk_info_output) + trace::println(_X(" RID: %s"), get_runtime_id().c_str()); trace::println(_X("\n") _X(".NET SDKs installed:")); diff --git a/src/native/corehost/fxr/command_line.h b/src/native/corehost/fxr/command_line.h index 4880ea60f0e45c..935e66800870e5 100644 --- a/src/native/corehost/fxr/command_line.h +++ b/src/native/corehost/fxr/command_line.h @@ -57,7 +57,9 @@ namespace command_line /*out*/ pal::string_t &app_candidate, /*out*/ opt_map_t &opts); - void print_muxer_info(const pal::string_t &dotnet_root, const pal::string_t &global_json_path); + // skip_sdk_info_output indicates whether or not to skip any information that the SDK would have + // already printed. Related: https://github.com/dotnet/sdk/issues/33697 + void print_muxer_info(const pal::string_t &dotnet_root, const pal::string_t &global_json_path, bool skip_sdk_info_output); void print_muxer_usage(bool is_sdk_present); }; diff --git a/src/native/corehost/fxr/fx_muxer.cpp b/src/native/corehost/fxr/fx_muxer.cpp index 53e454c9803052..cbaf90aa69cba2 100644 --- a/src/native/corehost/fxr/fx_muxer.cpp +++ b/src/native/corehost/fxr/fx_muxer.cpp @@ -1055,7 +1055,7 @@ int fx_muxer_t::handle_cli( } else if (pal::strcasecmp(_X("--info"), argv[1]) == 0) { - command_line::print_muxer_info(host_info.dotnet_root, resolver.global_file_path()); + command_line::print_muxer_info(host_info.dotnet_root, resolver.global_file_path(), false /*skip_sdk_info_output*/); return StatusCode::Success; } @@ -1107,7 +1107,7 @@ int fx_muxer_t::handle_cli( if (pal::strcasecmp(_X("--info"), argv[1]) == 0) { - command_line::print_muxer_info(host_info.dotnet_root, resolver.global_file_path()); + command_line::print_muxer_info(host_info.dotnet_root, resolver.global_file_path(), result == 0 /*skip_sdk_info_output*/); } return result; diff --git a/src/native/corehost/fxr/fx_ver.cpp b/src/native/corehost/fxr/fx_ver.cpp index 254f408effe9b8..7a857cdff473b0 100644 --- a/src/native/corehost/fxr/fx_ver.cpp +++ b/src/native/corehost/fxr/fx_ver.cpp @@ -70,31 +70,18 @@ bool fx_ver_t::operator >=(const fx_ver_t& b) const pal::string_t fx_ver_t::as_str() const { - pal::stringstream_t stream; - stream << m_major << _X(".") << m_minor << _X(".") << m_patch; + pal::string_t version = pal::to_string(m_major); + version += _X('.'); + version += pal::to_string(m_minor); + version += _X('.'); + version += pal::to_string(m_patch); if (!m_pre.empty()) - { - stream << m_pre; - } - if (!m_build.empty()) - { - stream << m_build; - } - return stream.str(); -} + version += m_pre; -pal::string_t fx_ver_t::prerelease_glob() const -{ - pal::stringstream_t stream; - stream << m_major << _X(".") << m_minor << _X(".") << m_patch << _X("-*"); - return stream.str(); -} + if (!m_build.empty()) + version += m_build; -pal::string_t fx_ver_t::patch_glob() const -{ - pal::stringstream_t stream; - stream << m_major << _X(".") << m_minor << _X(".*"); - return stream.str(); + return version; } static pal::string_t getId(const pal::string_t &ids, size_t idStart) diff --git a/src/native/corehost/fxr/fx_ver.h b/src/native/corehost/fxr/fx_ver.h index 5f5348386897b0..29f010876e04ae 100644 --- a/src/native/corehost/fxr/fx_ver.h +++ b/src/native/corehost/fxr/fx_ver.h @@ -26,8 +26,6 @@ struct fx_ver_t bool is_empty() const { return m_major == -1; } pal::string_t as_str() const; - pal::string_t prerelease_glob() const; - pal::string_t patch_glob() const; bool operator ==(const fx_ver_t& b) const; bool operator !=(const fx_ver_t& b) const; diff --git a/src/native/corehost/hostmisc/pal.unix.cpp b/src/native/corehost/hostmisc/pal.unix.cpp index 46ffaf951adfb9..34520aefd7365a 100644 --- a/src/native/corehost/hostmisc/pal.unix.cpp +++ b/src/native/corehost/hostmisc/pal.unix.cpp @@ -16,7 +16,6 @@ #include #include #include -#include #include #include "config.h" #include diff --git a/src/native/corehost/hostmisc/pal.windows.cpp b/src/native/corehost/hostmisc/pal.windows.cpp index 98e4efbed72f07..b11610492d3214 100644 --- a/src/native/corehost/hostmisc/pal.windows.cpp +++ b/src/native/corehost/hostmisc/pal.windows.cpp @@ -7,7 +7,6 @@ #include "longfile.h" #include -#include #include #include diff --git a/src/native/corehost/hostpolicy/version.cpp b/src/native/corehost/hostpolicy/version.cpp index ea643605ac88f2..c08316fe4b5a87 100644 --- a/src/native/corehost/hostpolicy/version.cpp +++ b/src/native/corehost/hostpolicy/version.cpp @@ -51,29 +51,31 @@ bool version_t::operator >=(const version_t& b) const pal::string_t version_t::as_str() const { - pal::stringstream_t stream; - + pal::string_t version; if (m_major >= 0) { - stream << m_major; + version += pal::to_string(m_major); if (m_minor >= 0) { - stream << _X(".") << m_minor; + version += _X('.'); + version += pal::to_string(m_minor); if (m_build >= 0) { - stream << _X(".") << m_build; + version += _X('.'); + version += pal::to_string(m_build); if (m_revision >= 0) { - stream << _X(".") << m_revision; + version += _X('.'); + version += pal::to_string(m_revision); } } } } - return stream.str(); + return version; } /*static*/ int version_t::compare(const version_t&a, const version_t& b) diff --git a/src/native/eventpipe/ep-session.c b/src/native/eventpipe/ep-session.c index 959214d1375963..ee134be65059d7 100644 --- a/src/native/eventpipe/ep-session.c +++ b/src/native/eventpipe/ep-session.c @@ -329,24 +329,16 @@ ep_session_enable_rundown (EventPipeSession *session) const uint64_t keywords = 0x80020139; const EventPipeEventLevel verbose_logging_level = EP_EVENT_LEVEL_VERBOSE; - EventPipeProviderConfiguration rundown_providers [2]; - uint32_t rundown_providers_len = (uint32_t)ARRAY_SIZE (rundown_providers); + EventPipeProviderConfiguration rundown_provider; + ep_provider_config_init (&rundown_provider, ep_config_get_rundown_provider_name_utf8 (), keywords, verbose_logging_level, NULL); // Rundown provider. - ep_provider_config_init (&rundown_providers [0], ep_config_get_public_provider_name_utf8 (), keywords, verbose_logging_level, NULL); // Public provider. - ep_provider_config_init (&rundown_providers [1], ep_config_get_rundown_provider_name_utf8 (), keywords, verbose_logging_level, NULL); // Rundown provider. + EventPipeSessionProvider *session_provider = ep_session_provider_alloc ( + ep_provider_config_get_provider_name (&rundown_provider), + ep_provider_config_get_keywords (&rundown_provider), + ep_provider_config_get_logging_level (&rundown_provider), + ep_provider_config_get_filter_data (&rundown_provider)); - // Update provider list with rundown configuration. - for (uint32_t i = 0; i < rundown_providers_len; ++i) { - const EventPipeProviderConfiguration *config = &rundown_providers [i]; - - EventPipeSessionProvider *session_provider = ep_session_provider_alloc ( - ep_provider_config_get_provider_name (config), - ep_provider_config_get_keywords (config), - ep_provider_config_get_logging_level (config), - ep_provider_config_get_filter_data (config)); - - ep_raise_error_if_nok (ep_session_add_session_provider (session, session_provider)); - } + ep_raise_error_if_nok (ep_session_add_session_provider (session, session_provider)); ep_session_set_rundown_enabled (session, true); result = true; diff --git a/src/native/external/patches/zlib-intel/0001-Make-zlib-intel-compile-clean-against-C4244.patch b/src/native/external/patches/zlib-intel/0001-Make-zlib-intel-compile-clean-against-C4244.patch new file mode 100644 index 00000000000000..1ecb02be92ec05 --- /dev/null +++ b/src/native/external/patches/zlib-intel/0001-Make-zlib-intel-compile-clean-against-C4244.patch @@ -0,0 +1,75 @@ +From edabaf799fd071a328e0adb743a98628df6649f0 Mon Sep 17 00:00:00 2001 +From: Levi Broderick +Date: Mon, 28 Aug 2023 15:26:38 -0700 +Subject: [PATCH] Make zlib-intel compile clean against C4244 clang equivalent + is "implicit-int-conversion" warning + +The change to deflate.c is legal because 'len' has an upper bound of +MAX_STORED, which means it fits cleanly into a 16-bit integer. So +writing out 2x 8-bit values will not result in data loss. + +The change to trees.c is legal because within this loop, 'count' is +intended to have an upper bound of 138, with the target assignment +only executing if 'count' is bounded by 4. Neither the 'count' local +in isolation nor the addition that's part of the target line is +expected to result in integer overflow. But even if it did, that's a +matter for a different warning code and doesn't impact the correctness +of the narrowing cast being considered here. + +The change to slide_sse.c is legal because 'w_size' is limited to +1 << 15 (see deflateInit2_ in deflate.c), so this fits cleanly into +a 16-bit value. +--- + src/native/external/zlib-intel/deflate.c | 8 ++++---- + src/native/external/zlib-intel/slide_sse.c | 2 +- + src/native/external/zlib-intel/trees.c | 2 +- + 3 files changed, 6 insertions(+), 6 deletions(-) + +diff --git a/src/native/external/zlib-intel/deflate.c b/src/native/external/zlib-intel/deflate.c +index bd5e95774a6..108b1a187af 100644 +--- a/src/native/external/zlib-intel/deflate.c ++++ b/src/native/external/zlib-intel/deflate.c +@@ -1553,10 +1553,10 @@ local block_state deflate_stored(s, flush) + _tr_stored_block(s, (char *)0, 0L, last); + + /* Replace the lengths in the dummy stored block with len. */ +- s->pending_buf[s->pending - 4] = len; +- s->pending_buf[s->pending - 3] = len >> 8; +- s->pending_buf[s->pending - 2] = ~len; +- s->pending_buf[s->pending - 1] = ~len >> 8; ++ s->pending_buf[s->pending - 4] = (Bytef)len; ++ s->pending_buf[s->pending - 3] = (Bytef)(len >> 8); ++ s->pending_buf[s->pending - 2] = (Bytef)~len; ++ s->pending_buf[s->pending - 1] = (Bytef)(~len >> 8); + + /* Write the stored block header bytes. */ + flush_pending(s->strm); +diff --git a/src/native/external/zlib-intel/slide_sse.c b/src/native/external/zlib-intel/slide_sse.c +index 342fd562dd1..eb74202c5a0 100644 +--- a/src/native/external/zlib-intel/slide_sse.c ++++ b/src/native/external/zlib-intel/slide_sse.c +@@ -18,7 +18,7 @@ void slide_hash_sse(deflate_state *s) + unsigned n; + Posf *p; + uInt wsize = s->w_size; +- z_const __m128i xmm_wsize = _mm_set1_epi16(s->w_size); ++ z_const __m128i xmm_wsize = _mm_set1_epi16((short)s->w_size); + + n = s->hash_size; + p = &s->head[n] - 8; +diff --git a/src/native/external/zlib-intel/trees.c b/src/native/external/zlib-intel/trees.c +index 35462a1313a..f78b7d8c63e 100644 +--- a/src/native/external/zlib-intel/trees.c ++++ b/src/native/external/zlib-intel/trees.c +@@ -717,7 +717,7 @@ local void scan_tree(s, tree, max_code) + if (++count < max_count && curlen == nextlen) { + continue; + } else if (count < min_count) { +- s->bl_tree[curlen].Freq += count; ++ s->bl_tree[curlen].Freq += (ush)count; + } else if (curlen != 0) { + if (curlen != prevlen) s->bl_tree[curlen].Freq++; + s->bl_tree[REP_3_6].Freq++; +-- +2.42.0.windows.1 + diff --git a/src/native/external/patches/zlib/0001-Make-zlib-compile-clean-against-C4244.patch b/src/native/external/patches/zlib/0001-Make-zlib-compile-clean-against-C4244.patch new file mode 100644 index 00000000000000..c2a26b3202b196 --- /dev/null +++ b/src/native/external/patches/zlib/0001-Make-zlib-compile-clean-against-C4244.patch @@ -0,0 +1,57 @@ +From 86d96652ddd60f61dc7b0c94b601f6d156d34632 Mon Sep 17 00:00:00 2001 +From: Levi Broderick +Date: Mon, 28 Aug 2023 15:26:38 -0700 +Subject: [PATCH] Make zlib compile clean against C4244 clang equivalent is + "implicit-int-conversion" warning + +The change to deflate.c is legal because 'len' has an upper bound of +MAX_STORED, which means it fits cleanly into a 16-bit integer. So +writing out 2x 8-bit values will not result in data loss. + +The change to trees.c is legal because within this loop, 'count' is +intended to have an upper bound of 138, with the target assignment +only executing if 'count' is bounded by 4. Neither the 'count' local +in isolation nor the addition that's part of the target line is +expected to result in integer overflow. But even if it did, that's a +matter for a different warning code and doesn't impact the correctness +of the narrowing cast being considered here. +--- + src/native/external/zlib/deflate.c | 8 ++++---- + src/native/external/zlib/trees.c | 2 +- + 2 files changed, 5 insertions(+), 5 deletions(-) + +diff --git a/src/native/external/zlib/deflate.c b/src/native/external/zlib/deflate.c +index d2e1106ef5d..b7636639754 100644 +--- a/src/native/external/zlib/deflate.c ++++ b/src/native/external/zlib/deflate.c +@@ -1738,10 +1738,10 @@ local block_state deflate_stored(s, flush) + _tr_stored_block(s, (char *)0, 0L, last); + + /* Replace the lengths in the dummy stored block with len. */ +- s->pending_buf[s->pending - 4] = len; +- s->pending_buf[s->pending - 3] = len >> 8; +- s->pending_buf[s->pending - 2] = ~len; +- s->pending_buf[s->pending - 1] = ~len >> 8; ++ s->pending_buf[s->pending - 4] = (Bytef)len; ++ s->pending_buf[s->pending - 3] = (Bytef)(len >> 8); ++ s->pending_buf[s->pending - 2] = (Bytef)~len; ++ s->pending_buf[s->pending - 1] = (Bytef)(~len >> 8); + + /* Write the stored block header bytes. */ + flush_pending(s->strm); +diff --git a/src/native/external/zlib/trees.c b/src/native/external/zlib/trees.c +index 5f305c47221..8a3eec559e5 100644 +--- a/src/native/external/zlib/trees.c ++++ b/src/native/external/zlib/trees.c +@@ -721,7 +721,7 @@ local void scan_tree(s, tree, max_code) + if (++count < max_count && curlen == nextlen) { + continue; + } else if (count < min_count) { +- s->bl_tree[curlen].Freq += count; ++ s->bl_tree[curlen].Freq += (ush)count; + } else if (curlen != 0) { + if (curlen != prevlen) s->bl_tree[curlen].Freq++; + s->bl_tree[REP_3_6].Freq++; +-- +2.42.0.windows.1 + diff --git a/src/native/external/zlib-intel-version.txt b/src/native/external/zlib-intel-version.txt index d406ffbcc459b3..287c9e92512bf1 100644 --- a/src/native/external/zlib-intel-version.txt +++ b/src/native/external/zlib-intel-version.txt @@ -12,3 +12,5 @@ copied into our repo in order to build. Since then, we've just updated only those files in-place, ignoring other files in the intel/zlib repo. If new files are introduced which are necessary for building the product, feel free to bring those down as well. + +We have also applied the custom patches under the patches/zlib-intel folder. diff --git a/src/native/external/zlib-intel.cmake b/src/native/external/zlib-intel.cmake index 1b6fa0cb4765bf..e8f198271d865b 100644 --- a/src/native/external/zlib-intel.cmake +++ b/src/native/external/zlib-intel.cmake @@ -1,9 +1,3 @@ -if(MSVC) - add_compile_options($<$:/wd4244>) # conversion from 'type1' to 'type2', possible loss of data -else(CMAKE_C_COMPILER_ID MATCHES "Clang") - add_compile_options($<$:-Wno-implicit-int-conversion>) -endif() - set(ZLIB_SOURCES_BASE adler32.c compress.c diff --git a/src/native/external/zlib-intel/deflate.c b/src/native/external/zlib-intel/deflate.c index bd5e95774a689a..108b1a187af4d3 100644 --- a/src/native/external/zlib-intel/deflate.c +++ b/src/native/external/zlib-intel/deflate.c @@ -1553,10 +1553,10 @@ local block_state deflate_stored(s, flush) _tr_stored_block(s, (char *)0, 0L, last); /* Replace the lengths in the dummy stored block with len. */ - s->pending_buf[s->pending - 4] = len; - s->pending_buf[s->pending - 3] = len >> 8; - s->pending_buf[s->pending - 2] = ~len; - s->pending_buf[s->pending - 1] = ~len >> 8; + s->pending_buf[s->pending - 4] = (Bytef)len; + s->pending_buf[s->pending - 3] = (Bytef)(len >> 8); + s->pending_buf[s->pending - 2] = (Bytef)~len; + s->pending_buf[s->pending - 1] = (Bytef)(~len >> 8); /* Write the stored block header bytes. */ flush_pending(s->strm); diff --git a/src/native/external/zlib-intel/slide_sse.c b/src/native/external/zlib-intel/slide_sse.c index 342fd562dd1152..eb74202c5a04a8 100644 --- a/src/native/external/zlib-intel/slide_sse.c +++ b/src/native/external/zlib-intel/slide_sse.c @@ -18,7 +18,7 @@ void slide_hash_sse(deflate_state *s) unsigned n; Posf *p; uInt wsize = s->w_size; - z_const __m128i xmm_wsize = _mm_set1_epi16(s->w_size); + z_const __m128i xmm_wsize = _mm_set1_epi16((short)s->w_size); n = s->hash_size; p = &s->head[n] - 8; diff --git a/src/native/external/zlib-intel/trees.c b/src/native/external/zlib-intel/trees.c index 35462a1313aa83..f78b7d8c63eaed 100644 --- a/src/native/external/zlib-intel/trees.c +++ b/src/native/external/zlib-intel/trees.c @@ -717,7 +717,7 @@ local void scan_tree(s, tree, max_code) if (++count < max_count && curlen == nextlen) { continue; } else if (count < min_count) { - s->bl_tree[curlen].Freq += count; + s->bl_tree[curlen].Freq += (ush)count; } else if (curlen != 0) { if (curlen != prevlen) s->bl_tree[curlen].Freq++; s->bl_tree[REP_3_6].Freq++; diff --git a/src/native/external/zlib-version.txt b/src/native/external/zlib-version.txt index 00ce7fdbb2cd32..fcac66cc4645f4 100644 --- a/src/native/external/zlib-version.txt +++ b/src/native/external/zlib-version.txt @@ -13,3 +13,5 @@ We have also cherry-picked into our local copy: This patch only affects memLevel 9 compression. .NET doesn't currently use this memLevel, but we'll take this patch out of an abundance of caution just in case we enable this functionality in a future release. + +We have also applied the custom patches under the patches/zlib folder. diff --git a/src/native/external/zlib.cmake b/src/native/external/zlib.cmake index 862e10118cd72c..730bfc4bd14020 100644 --- a/src/native/external/zlib.cmake +++ b/src/native/external/zlib.cmake @@ -1,9 +1,3 @@ -if(MSVC) - add_compile_options($<$:/wd4244>) # conversion from 'type1' to 'type2', possible loss of data -else(CMAKE_C_COMPILER_ID MATCHES "Clang") - add_compile_options($<$:-Wno-implicit-int-conversion>) -endif() - set(ZLIB_SOURCES_BASE adler32.c compress.c diff --git a/src/native/external/zlib/deflate.c b/src/native/external/zlib/deflate.c index d2e1106ef5d07d..b763663975458c 100644 --- a/src/native/external/zlib/deflate.c +++ b/src/native/external/zlib/deflate.c @@ -1738,10 +1738,10 @@ local block_state deflate_stored(s, flush) _tr_stored_block(s, (char *)0, 0L, last); /* Replace the lengths in the dummy stored block with len. */ - s->pending_buf[s->pending - 4] = len; - s->pending_buf[s->pending - 3] = len >> 8; - s->pending_buf[s->pending - 2] = ~len; - s->pending_buf[s->pending - 1] = ~len >> 8; + s->pending_buf[s->pending - 4] = (Bytef)len; + s->pending_buf[s->pending - 3] = (Bytef)(len >> 8); + s->pending_buf[s->pending - 2] = (Bytef)~len; + s->pending_buf[s->pending - 1] = (Bytef)(~len >> 8); /* Write the stored block header bytes. */ flush_pending(s->strm); diff --git a/src/native/external/zlib/trees.c b/src/native/external/zlib/trees.c index 5f305c47221e90..8a3eec559e55bc 100644 --- a/src/native/external/zlib/trees.c +++ b/src/native/external/zlib/trees.c @@ -721,7 +721,7 @@ local void scan_tree(s, tree, max_code) if (++count < max_count && curlen == nextlen) { continue; } else if (count < min_count) { - s->bl_tree[curlen].Freq += count; + s->bl_tree[curlen].Freq += (ush)count; } else if (curlen != 0) { if (curlen != prevlen) s->bl_tree[curlen].Freq++; s->bl_tree[REP_3_6].Freq++; diff --git a/src/native/libs/System.Globalization.Native/entrypoints.c b/src/native/libs/System.Globalization.Native/entrypoints.c index cffad72a023721..84d2177d558841 100644 --- a/src/native/libs/System.Globalization.Native/entrypoints.c +++ b/src/native/libs/System.Globalization.Native/entrypoints.c @@ -69,6 +69,7 @@ static const Entry s_globalizationNative[] = DllImportEntry(GlobalizationNative_GetLocaleInfoSecondaryGroupingSizeNative) DllImportEntry(GlobalizationNative_GetLocaleInfoStringNative) DllImportEntry(GlobalizationNative_GetLocaleNameNative) + DllImportEntry(GlobalizationNative_GetLocalesNative) DllImportEntry(GlobalizationNative_GetLocaleTimeFormatNative) DllImportEntry(GlobalizationNative_IndexOfNative) DllImportEntry(GlobalizationNative_StartsWithNative) diff --git a/src/native/libs/System.Globalization.Native/pal_locale.h b/src/native/libs/System.Globalization.Native/pal_locale.h index 7fe89f667f2132..4a1fe0768e4fda 100644 --- a/src/native/libs/System.Globalization.Native/pal_locale.h +++ b/src/native/libs/System.Globalization.Native/pal_locale.h @@ -21,4 +21,6 @@ PALEXPORT int32_t GlobalizationNative_GetLocaleTimeFormat(const UChar* localeNam PALEXPORT const char* GlobalizationNative_GetLocaleNameNative(const char* localeName); PALEXPORT const char* GlobalizationNative_GetLocaleTimeFormatNative(const char* localeName, int shortFormat); + +PALEXPORT int32_t GlobalizationNative_GetLocalesNative(UChar* locales, int32_t length); #endif diff --git a/src/native/libs/System.Globalization.Native/pal_locale.m b/src/native/libs/System.Globalization.Native/pal_locale.m index d8ab7da1fbee0c..4789ac89691da2 100644 --- a/src/native/libs/System.Globalization.Native/pal_locale.m +++ b/src/native/libs/System.Globalization.Native/pal_locale.m @@ -97,7 +97,7 @@ static void GetParent(const char* localeID, char* parent, int32_t parentCapacity { @autoreleasepool { - const char* value; + NSString *value; NSString *locName = [NSString stringWithFormat:@"%s", localeName]; NSLocale *currentLocale = [[NSLocale alloc] initWithLocaleIdentifier:locName]; NSNumberFormatter *numberFormatter = [[NSNumberFormatter alloc] init]; @@ -112,35 +112,35 @@ static void GetParent(const char* localeID, char* parent, int32_t parentCapacity case LocaleString_LocalizedDisplayName: /// Display name (language + country usually) in English, eg "German (Germany)" (corresponds to LOCALE_SENGLISHDISPLAYNAME) case LocaleString_EnglishDisplayName: - value = [[gbLocale displayNameForKey:NSLocaleIdentifier value:currentLocale.localeIdentifier] UTF8String]; - break; + value = [gbLocale displayNameForKey:NSLocaleIdentifier value:currentLocale.localeIdentifier]; + break; /// Display name in native locale language, eg "Deutsch (Deutschland) (corresponds to LOCALE_SNATIVEDISPLAYNAME) case LocaleString_NativeDisplayName: - value = [[currentLocale displayNameForKey:NSLocaleIdentifier value:currentLocale.localeIdentifier] UTF8String]; + value = [currentLocale displayNameForKey:NSLocaleIdentifier value:currentLocale.localeIdentifier]; break; /// Language Display Name for a language, eg "German" in UI language (corresponds to LOCALE_SLOCALIZEDLANGUAGENAME) case LocaleString_LocalizedLanguageName: /// English name of language, eg "German" (corresponds to LOCALE_SENGLISHLANGUAGENAME) case LocaleString_EnglishLanguageName: - value = [[gbLocale localizedStringForLanguageCode:currentLocale.languageCode] UTF8String]; + value = [gbLocale localizedStringForLanguageCode:currentLocale.languageCode]; break; /// native name of language, eg "Deutsch" (corresponds to LOCALE_SNATIVELANGUAGENAME) case LocaleString_NativeLanguageName: - value = [[currentLocale localizedStringForLanguageCode:currentLocale.languageCode] UTF8String]; + value = [currentLocale localizedStringForLanguageCode:currentLocale.languageCode]; break; /// English name of country, eg "Germany" (corresponds to LOCALE_SENGLISHCOUNTRYNAME) case LocaleString_EnglishCountryName: - value = [[gbLocale localizedStringForCountryCode:currentLocale.countryCode] UTF8String]; + value = [gbLocale localizedStringForCountryCode:currentLocale.countryCode]; break; /// native name of country, eg "Deutschland" (corresponds to LOCALE_SNATIVECOUNTRYNAME) case LocaleString_NativeCountryName: - value = [[currentLocale localizedStringForCountryCode:currentLocale.countryCode] UTF8String]; + value = [currentLocale localizedStringForCountryCode:currentLocale.countryCode]; break; case LocaleString_ThousandSeparator: - value = [currentLocale.groupingSeparator UTF8String]; + value = currentLocale.groupingSeparator; break; case LocaleString_DecimalSeparator: - value = [currentLocale.decimalSeparator UTF8String]; + value = currentLocale.decimalSeparator; // or value = [[currentLocale objectForKey:NSLocaleDecimalSeparator] UTF8String]; break; case LocaleString_Digits: @@ -150,87 +150,84 @@ static void GetParent(const char* localeID, char* parent, int32_t parentCapacity [nf1 setLocale:currentLocale]; NSNumber *newNum = [nf1 numberFromString:digitsString]; - value = [[newNum stringValue] UTF8String]; + value = [newNum stringValue]; break; } case LocaleString_MonetarySymbol: - value = [currentLocale.currencySymbol UTF8String]; + value = currentLocale.currencySymbol; break; case LocaleString_Iso4217MonetarySymbol: // check if this is correct, check currencyISOCode - value = [currentLocale.currencySymbol UTF8String]; + value = currentLocale.currencyCode; break; case LocaleString_CurrencyEnglishName: - value = [[gbLocale localizedStringForCurrencyCode:currentLocale.currencyCode] UTF8String]; + value = [gbLocale localizedStringForCurrencyCode:currentLocale.currencyCode]; break; case LocaleString_CurrencyNativeName: - value = [[currentLocale localizedStringForCurrencyCode:currentLocale.currencyCode] UTF8String]; + value = [currentLocale localizedStringForCurrencyCode:currentLocale.currencyCode]; break; case LocaleString_MonetaryDecimalSeparator: - value = [numberFormatter.currencyDecimalSeparator UTF8String]; + value = numberFormatter.currencyDecimalSeparator; break; case LocaleString_MonetaryThousandSeparator: - value = [numberFormatter.currencyGroupingSeparator UTF8String]; + value = numberFormatter.currencyGroupingSeparator; break; case LocaleString_AMDesignator: - value = [dateFormatter.AMSymbol UTF8String]; + value = dateFormatter.AMSymbol; break; case LocaleString_PMDesignator: - value = [dateFormatter.PMSymbol UTF8String]; + value = dateFormatter.PMSymbol; break; case LocaleString_PositiveSign: - value = [numberFormatter.plusSign UTF8String]; + value = numberFormatter.plusSign; break; case LocaleString_NegativeSign: - value = [numberFormatter.minusSign UTF8String]; + value = numberFormatter.minusSign; break; case LocaleString_Iso639LanguageTwoLetterName: - value = [[currentLocale objectForKey:NSLocaleLanguageCode] UTF8String]; + value = [currentLocale objectForKey:NSLocaleLanguageCode]; break; case LocaleString_Iso639LanguageThreeLetterName: { NSString *iso639_2 = [currentLocale objectForKey:NSLocaleLanguageCode]; - value = uloc_getISO3LanguageByLangCode([iso639_2 UTF8String]); - break; + return iso639_2 == nil ? strdup("") : strdup(uloc_getISO3LanguageByLangCode([iso639_2 UTF8String])); } case LocaleString_Iso3166CountryName: - value = [[currentLocale objectForKey:NSLocaleCountryCode] UTF8String]; + value = [currentLocale objectForKey:NSLocaleCountryCode]; break; case LocaleString_Iso3166CountryName2: { - const char *countryCode = strdup([[currentLocale objectForKey:NSLocaleCountryCode] UTF8String]); - value = uloc_getISO3CountryByCountryCode(countryCode); - break; + NSString* countryCode = [currentLocale objectForKey:NSLocaleCountryCode]; + return countryCode == nil ? strdup("") : strdup(uloc_getISO3CountryByCountryCode([countryCode UTF8String])); } case LocaleString_NaNSymbol: - value = [numberFormatter.notANumberSymbol UTF8String]; + value = numberFormatter.notANumberSymbol; break; case LocaleString_PositiveInfinitySymbol: - value = [numberFormatter.positiveInfinitySymbol UTF8String]; + value = numberFormatter.positiveInfinitySymbol; break; case LocaleString_NegativeInfinitySymbol: - value = [numberFormatter.negativeInfinitySymbol UTF8String]; + value = numberFormatter.negativeInfinitySymbol; break; case LocaleString_PercentSymbol: - value = [numberFormatter.percentSymbol UTF8String]; + value = numberFormatter.percentSymbol; break; case LocaleString_PerMilleSymbol: - value = [numberFormatter.perMillSymbol UTF8String]; + value = numberFormatter.perMillSymbol; break; case LocaleString_ParentName: { char localeNameTemp[FULLNAME_CAPACITY]; const char* lName = [currentLocale.localeIdentifier UTF8String]; GetParent(lName, localeNameTemp, FULLNAME_CAPACITY); - value = strdup(localeNameTemp); - break; + return strdup(localeNameTemp); } default: - value = ""; + value = nil; break; } - return value ? strdup(value) : ""; + return value == nil ? strdup("") : strdup([value UTF8String]); } } @@ -667,6 +664,54 @@ Returns time format information (in native format, it needs to be converted to . } } +// GlobalizationNative_GetLocalesNative gets all locale names and store it in the value buffer +// in case of success, it returns the count of the characters stored in value buffer +// in case of failure, it returns negative number. +// if the input value buffer is null, it returns the length needed to store the +// locale names list. +// if the value is not null, it fills the value with locale names separated by the length +// of each name. +int32_t GlobalizationNative_GetLocalesNative(UChar* value, int32_t length) +{ + @autoreleasepool + { + NSArray* availableLocaleIdentifiers = [NSLocale availableLocaleIdentifiers]; + int32_t index = 0; + int32_t totalLength = 0; + int32_t availableLength = (int32_t)[availableLocaleIdentifiers count]; + + if (availableLength <= 0) + return -1; // failed + + for (NSInteger i = 0; i < availableLength; i++) + { + NSString *localeIdentifier = availableLocaleIdentifiers[i]; + int32_t localeNameLength = localeIdentifier.length; + totalLength += localeNameLength + 1; // add 1 for the name length + if (value != NULL) + { + if (totalLength > length) + return -3; + + value[index++] = (UChar) localeNameLength; + + for (int j = 0; j < localeNameLength; j++) + { + if ((UChar)[localeIdentifier characterAtIndex:j] == '_') + { + value[index++] = (UChar) '-'; + } + else + { + value[index++] = (UChar) [localeIdentifier characterAtIndex:j]; + } + } + } + } + return totalLength; + } +} + #endif #if defined(TARGET_MACCATALYST) || defined(TARGET_IOS) || defined(TARGET_TVOS) diff --git a/src/native/minipal/cpufeatures.c b/src/native/minipal/cpufeatures.c index f637e58c10f14e..be0aabe081c9fc 100644 --- a/src/native/minipal/cpufeatures.c +++ b/src/native/minipal/cpufeatures.c @@ -8,11 +8,11 @@ #include "cpufeatures.h" #include "cpuid.h" -#if TARGET_WINDOWS +#if HOST_WINDOWS #include -#else // TARGET_WINDOWS +#else // HOST_WINDOWS #include "minipalconfig.h" @@ -38,10 +38,10 @@ #include #endif -#endif // !TARGET_WINDOWS +#endif // !HOST_WINDOWS -#if defined(TARGET_UNIX) -#if defined(TARGET_X86) || defined(TARGET_AMD64) +#if defined(HOST_UNIX) +#if defined(HOST_X86) || defined(HOST_AMD64) static uint32_t xmmYmmStateSupport() { @@ -61,7 +61,7 @@ static uint32_t xmmYmmStateSupport() static uint32_t avx512StateSupport() { -#if defined(TARGET_APPLE) +#if defined(HOST_APPLE) // MacOS has specialized behavior where it reports AVX512 support but doesnt // actually enable AVX512 until the first instruction is executed and does so // on a per thread basis. It does this by catching the faulting instruction and @@ -95,11 +95,11 @@ static bool IsAvx512Enabled() { return true; } -#endif // defined(TARGET_X86) || defined(TARGET_AMD64) -#endif // TARGET_UNIX +#endif // defined(HOST_X86) || defined(HOST_AMD64) +#endif // HOST_UNIX -#if defined(TARGET_WINDOWS) -#if defined(TARGET_X86) || defined(TARGET_AMD64) +#if defined(HOST_WINDOWS) +#if defined(HOST_X86) || defined(HOST_AMD64) static uint32_t xmmYmmStateSupport() { // check OS has enabled both XMM and YMM state support @@ -124,14 +124,14 @@ static bool IsAvx512Enabled() return ((FeatureMask & XSTATE_MASK_AVX512) != 0); } -#endif // defined(TARGET_X86) || defined(TARGET_AMD64) -#endif // TARGET_WINDOWS +#endif // defined(HOST_X86) || defined(HOST_AMD64) +#endif // HOST_WINDOWS int minipal_getcpufeatures(void) { int result = 0; -#if defined(TARGET_X86) || defined(TARGET_AMD64) +#if defined(HOST_X86) || defined(HOST_AMD64) int cpuidInfo[4]; @@ -315,10 +315,10 @@ int minipal_getcpufeatures(void) } } -#endif // TARGET_X86 || TARGET_AMD64 +#endif // HOST_X86 || HOST_AMD64 -#if defined(TARGET_ARM64) -#if defined(TARGET_UNIX) +#if defined(HOST_ARM64) +#if defined(HOST_UNIX) #if HAVE_AUXV_HWCAP_H unsigned long hwCap = getauxval(AT_HWCAP); @@ -386,9 +386,9 @@ int minipal_getcpufeatures(void) result |= ARM64IntrinsicConstants_AdvSimd | ARM64IntrinsicConstants_VectorT128; #endif // HAVE_AUXV_HWCAP_H -#endif // TARGET_UNIX +#endif // HOST_UNIX -#if defined(TARGET_WINDOWS) +#if defined(HOST_WINDOWS) // FP and SIMD support are enabled by default result |= ARM64IntrinsicConstants_AdvSimd | ARM64IntrinsicConstants_VectorT128; @@ -418,9 +418,12 @@ int minipal_getcpufeatures(void) { result |= ARM64IntrinsicConstants_Rcpc; } -#endif // TARGET_WINDOWS -#endif // TARGET_ARM64 + // TODO: IsProcessorFeaturePresent doesn't support LRCPC2 yet. + +#endif // HOST_WINDOWS + +#endif // HOST_ARM64 return result; } diff --git a/src/native/minipal/cpufeatures.h b/src/native/minipal/cpufeatures.h index 312bee84ace2bd..73d151f1e2d80f 100644 --- a/src/native/minipal/cpufeatures.h +++ b/src/native/minipal/cpufeatures.h @@ -8,7 +8,7 @@ // Should match the constants defined in the compiler in HardwareIntrinsicHelpers.cs // -#if defined(TARGET_X86) || defined(TARGET_AMD64) +#if defined(HOST_X86) || defined(HOST_AMD64) enum XArchIntrinsicConstants { XArchIntrinsicConstants_Aes = 0x0001, @@ -41,9 +41,9 @@ enum XArchIntrinsicConstants XArchIntrinsicConstants_VectorT256 = 0x8000000, XArchIntrinsicConstants_VectorT512 = 0x10000000, }; -#endif // TARGET_X86 || TARGET_AMD64 +#endif // HOST_X86 || HOST_AMD64 -#if defined(TARGET_ARM64) +#if defined(HOST_ARM64) enum ARM64IntrinsicConstants { ARM64IntrinsicConstants_AdvSimd = 0x0001, @@ -64,7 +64,7 @@ enum ARM64IntrinsicConstants #define ARM64_ATOMICS_FEATURE_FLAG_BIT 7 static_assert((1 << ARM64_ATOMICS_FEATURE_FLAG_BIT) == ARM64IntrinsicConstants_Atomics, "ARM64_ATOMICS_FEATURE_FLAG_BIT must match with ARM64IntrinsicConstants_Atomics"); -#endif // TARGET_ARM64 +#endif // HOST_ARM64 #ifdef __cplusplus extern "C" diff --git a/src/tasks/Crossgen2Tasks/Microsoft.NET.CrossGen.targets b/src/tasks/Crossgen2Tasks/Microsoft.NET.CrossGen.targets index a2912cc0822849..c539d330e8add8 100644 --- a/src/tasks/Crossgen2Tasks/Microsoft.NET.CrossGen.targets +++ b/src/tasks/Crossgen2Tasks/Microsoft.NET.CrossGen.targets @@ -443,7 +443,7 @@ Copyright (c) .NET Foundation. All rights reserved. + diff --git a/src/tasks/Microsoft.NET.WebAssembly.Webcil/WebcilConverter.cs b/src/tasks/Microsoft.NET.WebAssembly.Webcil/WebcilConverter.cs index a38af7270a2dad..13c34bde4b8ea1 100644 --- a/src/tasks/Microsoft.NET.WebAssembly.Webcil/WebcilConverter.cs +++ b/src/tasks/Microsoft.NET.WebAssembly.Webcil/WebcilConverter.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.Buffers.Binary; using System.IO; using System.Collections.Immutable; using System.Reflection.PortableExecutable; @@ -174,16 +175,23 @@ public unsafe void GatherInfo(PEReader peReader, out WCFileInfo wcInfo, out PEFi SectionStart: firstWCSection); } - private static void WriteHeader(Stream s, WebcilHeader header) + private static void WriteHeader(Stream s, WebcilHeader webcilHeader) { - WriteStructure(s, header); + if (!BitConverter.IsLittleEndian) + { + webcilHeader.version_major = BinaryPrimitives.ReverseEndianness(webcilHeader.version_major); + webcilHeader.version_minor = BinaryPrimitives.ReverseEndianness(webcilHeader.version_minor); + webcilHeader.coff_sections = BinaryPrimitives.ReverseEndianness(webcilHeader.coff_sections); + webcilHeader.pe_cli_header_rva = BinaryPrimitives.ReverseEndianness(webcilHeader.pe_cli_header_rva); + webcilHeader.pe_cli_header_size = BinaryPrimitives.ReverseEndianness(webcilHeader.pe_cli_header_size); + webcilHeader.pe_debug_rva = BinaryPrimitives.ReverseEndianness(webcilHeader.pe_debug_rva); + webcilHeader.pe_debug_size = BinaryPrimitives.ReverseEndianness(webcilHeader.pe_debug_size); + } + WriteStructure(s, webcilHeader); } private static void WriteSectionHeaders(Stream s, ImmutableArray sectionsHeaders) { - // FIXME: fixup endianness - if (!BitConverter.IsLittleEndian) - throw new NotImplementedException(); foreach (var sectionHeader in sectionsHeaders) { WriteSectionHeader(s, sectionHeader); @@ -192,6 +200,16 @@ private static void WriteSectionHeaders(Stream s, ImmutableArray(Stream s, T structure) where T : unmanaged { - // FIXME: fixup endianness - if (!BitConverter.IsLittleEndian) - throw new NotImplementedException(); unsafe { byte* p = (byte*)&structure; @@ -212,9 +227,6 @@ private static void WriteStructure(Stream s, T structure) private static void WriteStructure(Stream s, T structure) where T : unmanaged { - // FIXME: fixup endianness - if (!BitConverter.IsLittleEndian) - throw new NotImplementedException(); int size = Marshal.SizeOf(); byte[] buffer = new byte[size]; IntPtr ptr = IntPtr.Zero; diff --git a/src/tasks/Microsoft.NET.WebAssembly.Webcil/WebcilReader.cs b/src/tasks/Microsoft.NET.WebAssembly.Webcil/WebcilReader.cs index 4f42f827986643..ac4f9d86095a90 100644 --- a/src/tasks/Microsoft.NET.WebAssembly.Webcil/WebcilReader.cs +++ b/src/tasks/Microsoft.NET.WebAssembly.Webcil/WebcilReader.cs @@ -6,7 +6,7 @@ using System.IO; using System.Reflection; using System.Runtime.InteropServices; - +using System.Buffers.Binary; using System.Reflection.Metadata; using System.Reflection.PortableExecutable; @@ -63,14 +63,20 @@ private unsafe bool ReadHeader() { return false; } - if (!BitConverter.IsLittleEndian) - { - throw new NotImplementedException("TODO: implement big endian support"); - } fixed (byte* p = buffer) { header = *(WebcilHeader*)p; } + if (!BitConverter.IsLittleEndian) + { + header.version_major = BinaryPrimitives.ReverseEndianness(header.version_major); + header.version_minor = BinaryPrimitives.ReverseEndianness(header.version_minor); + header.coff_sections = BinaryPrimitives.ReverseEndianness(header.coff_sections); + header.pe_cli_header_rva = BinaryPrimitives.ReverseEndianness(header.pe_cli_header_rva); + header.pe_cli_header_size = BinaryPrimitives.ReverseEndianness(header.pe_cli_header_size); + header.pe_debug_rva = BinaryPrimitives.ReverseEndianness(header.pe_debug_rva); + header.pe_debug_rva = BinaryPrimitives.ReverseEndianness(header.pe_debug_size); + } if (header.id[0] != 'W' || header.id[1] != 'b' || header.id[2] != 'I' || header.id[3] != 'L' || header.version_major != Internal.Constants.WC_VERSION_MAJOR @@ -346,6 +352,7 @@ private long TranslateRVA(uint rva) private unsafe ImmutableArray ReadSections() { + WebcilSectionHeader secheader; var sections = ImmutableArray.CreateBuilder(_header.coff_sections); var buffer = new byte[Marshal.SizeOf()]; _stream.Seek(SectionDirectoryOffset + _webcilInWasmOffset, SeekOrigin.Begin); @@ -357,8 +364,24 @@ private unsafe ImmutableArray ReadSections() } fixed (byte* p = buffer) { - // FIXME endianness - sections.Add(*(WebcilSectionHeader*)p); + secheader = (*(WebcilSectionHeader*)p); + } + if (!BitConverter.IsLittleEndian) + { + sections.Add + ( + new WebcilSectionHeader + ( + virtualSize: BinaryPrimitives.ReverseEndianness(secheader.VirtualSize), + virtualAddress: BinaryPrimitives.ReverseEndianness(secheader.VirtualAddress), + sizeOfRawData: BinaryPrimitives.ReverseEndianness(secheader.SizeOfRawData), + pointerToRawData: BinaryPrimitives.ReverseEndianness(secheader.PointerToRawData) + ) + ); + } + else + { + sections.Add(secheader); } } return sections.MoveToImmutable(); diff --git a/src/tasks/WorkloadBuildTasks/InstallWorkloadFromArtifacts.cs b/src/tasks/WorkloadBuildTasks/InstallWorkloadFromArtifacts.cs index e520057d5b3bdf..21170ea2152843 100644 --- a/src/tasks/WorkloadBuildTasks/InstallWorkloadFromArtifacts.cs +++ b/src/tasks/WorkloadBuildTasks/InstallWorkloadFromArtifacts.cs @@ -49,7 +49,7 @@ public partial class InstallWorkloadFromArtifacts : Task private string _tempDir = string.Empty; private string _nugetCachePath = string.Empty; - [GeneratedRegex(@"^\d+\.\d+\.\d+(-[A-z]*\.*\d*)?")] + [GeneratedRegex(@"^\d+\.\d+\.\d+(-rtm|-[A-z]*\.*\d*)?")] private static partial Regex bandVersionRegex(); public override bool Execute() @@ -215,7 +215,7 @@ private bool InstallPacks(InstallWorkloadRequest req, string nugetConfigContents (int exitCode, string output) = Utils.TryRunProcess( Log, Path.Combine(req.TargetPath, "dotnet"), - $"workload install --skip-manifest-update --configfile \"{nugetConfigPath}\" --temp-dir \"{_tempDir}/workload-install-temp\" {req.WorkloadId}", + $"workload install --skip-manifest-update --skip-sign-check --configfile \"{nugetConfigPath}\" --temp-dir \"{_tempDir}/workload-install-temp\" {req.WorkloadId}", workingDir: _tempDir, envVars: new Dictionary () { ["NUGET_PACKAGES"] = _nugetCachePath @@ -301,8 +301,8 @@ private bool InstallWorkloadManifest(ITaskItem workloadId, string name, string v string packagePreleaseVersion = bandVersionRegex().Match(version).Groups[1].Value; string bandPreleaseVersion = bandVersionRegex().Match(bandVersion).Groups[1].Value; - if (packagePreleaseVersion != bandPreleaseVersion && packagePreleaseVersion != "-dev" && packagePreleaseVersion != "-ci") - bandVersion = bandVersion.Replace (bandPreleaseVersion, packagePreleaseVersion); + if (packagePreleaseVersion != bandPreleaseVersion && packagePreleaseVersion != "-dev" && packagePreleaseVersion != "-ci" && bandPreleaseVersion != "") + bandVersion = bandVersion.Replace(bandPreleaseVersion, packagePreleaseVersion); PackageReference pkgRef = new(Name: $"{name}.Manifest-{bandVersion}", Version: version, diff --git a/src/tests/Common/external/external.csproj b/src/tests/Common/external/external.csproj index d2541b5ae4835b..71aa7a42b5661c 100644 --- a/src/tests/Common/external/external.csproj +++ b/src/tests/Common/external/external.csproj @@ -13,7 +13,7 @@ --> $(TargetingPackPath) $(NetCoreAppToolCurrent) - win7-x86;win7-x64 + win-x86;win-x64 SharedLibrary false false diff --git a/src/tests/Interop/COM/NETClients/IDispatch/Program.cs b/src/tests/Interop/COM/NETClients/IDispatch/Program.cs index 71839d9afffb8f..fc9144575f0d91 100644 --- a/src/tests/Interop/COM/NETClients/IDispatch/Program.cs +++ b/src/tests/Interop/COM/NETClients/IDispatch/Program.cs @@ -139,6 +139,21 @@ static void Validate_Exception() Assert.Equal(GetErrorCodeFromHResult(e.HResult), errorCode); // Failing HRESULT exceptions contain CLR generated messages } + + // Calling methods through IDispatch::Invoke() (i.e., late-bound) doesn't + // propagate the HRESULT when marked with PreserveSig. It is always 0. + { + Console.WriteLine($"Calling {nameof(DispatchTesting.TriggerException)} (PreserveSig) with {nameof(IDispatchTesting_Exception.Int)} {errorCode}..."); + var dispatchTesting2 = (IDispatchTestingPreserveSig1)dispatchTesting; + Assert.Equal(0, dispatchTesting2.TriggerException(IDispatchTesting_Exception.Int, errorCode)); + } + + { + // Validate the HRESULT as a value type construct works for IDispatch. + Console.WriteLine($"Calling {nameof(DispatchTesting.TriggerException)} (PreserveSig, ValueType) with {nameof(IDispatchTesting_Exception.Int)} {errorCode}..."); + var dispatchTesting3 = (IDispatchTestingPreserveSig2)dispatchTesting; + Assert.Equal(0, dispatchTesting3.TriggerException(IDispatchTesting_Exception.Int, errorCode).Value); + } } static void Validate_StructNotSupported() diff --git a/src/tests/Interop/COM/NETServer/DispatchTesting.cs b/src/tests/Interop/COM/NETServer/DispatchTesting.cs index 477e5751f69e73..66461b8c7e47f2 100644 --- a/src/tests/Interop/COM/NETServer/DispatchTesting.cs +++ b/src/tests/Interop/COM/NETServer/DispatchTesting.cs @@ -57,6 +57,7 @@ public void TriggerException(IDispatchTesting_Exception excep, int errorCode) case IDispatchTesting_Exception.Disp: throw new Exception(); case IDispatchTesting_Exception.HResult: + case IDispatchTesting_Exception.Int: throw new System.ComponentModel.Win32Exception(errorCode); } } diff --git a/src/tests/Interop/COM/NativeServer/DispatchTesting.h b/src/tests/Interop/COM/NativeServer/DispatchTesting.h index 927439fe03dc48..fbe7db6c1bad7f 100644 --- a/src/tests/Interop/COM/NativeServer/DispatchTesting.h +++ b/src/tests/Interop/COM/NativeServer/DispatchTesting.h @@ -243,6 +243,8 @@ class DispatchTesting : public UnknownImpl, public IDispatchTesting return DISP_E_EXCEPTION; case IDispatchTesting_Exception_HResult: return HRESULT_FROM_WIN32(errorCode); + case IDispatchTesting_Exception_Int: + return errorCode; default: return S_FALSE; // Return a success case to indicate failure to trigger a failure. } diff --git a/src/tests/Interop/COM/ServerContracts/Server.Contracts.cs b/src/tests/Interop/COM/ServerContracts/Server.Contracts.cs index 758c200acaabae..0bac21e66ee17e 100644 --- a/src/tests/Interop/COM/ServerContracts/Server.Contracts.cs +++ b/src/tests/Interop/COM/ServerContracts/Server.Contracts.cs @@ -209,6 +209,7 @@ public enum IDispatchTesting_Exception { Disp, HResult, + Int, } [StructLayout(LayoutKind.Sequential)] @@ -220,6 +221,12 @@ public struct HFA_4 public float w; } + [StructLayout(LayoutKind.Sequential)] + public struct HRESULT + { + public int Value; + } + [ComVisible(true)] [Guid("a5e04c1c-474e-46d2-bbc0-769d04e12b54")] [InterfaceType(ComInterfaceType.InterfaceIsIDispatch)] @@ -257,6 +264,32 @@ void DoubleNumeric_ReturnByRef ( System.Collections.IEnumerator GetEnumerator(); } + [ComVisible(true)] + [Guid("a5e04c1c-474e-46d2-bbc0-769d04e12b54")] + [InterfaceType(ComInterfaceType.InterfaceIsIDispatch)] + public interface IDispatchTestingPreserveSig1 + { + void Reserved1(); + void Reserved2(); + void Reserved3(); + + [PreserveSig] + int TriggerException(IDispatchTesting_Exception excep, int errorCode); + } + + [ComVisible(true)] + [Guid("a5e04c1c-474e-46d2-bbc0-769d04e12b54")] + [InterfaceType(ComInterfaceType.InterfaceIsIDispatch)] + public interface IDispatchTestingPreserveSig2 + { + void Reserved1(); + void Reserved2(); + void Reserved3(); + + [PreserveSig] + HRESULT TriggerException(IDispatchTesting_Exception excep, int errorCode); + } + [ComVisible(true)] [Guid("83AFF8E4-C46A-45DB-9D91-2ADB5164545E")] [InterfaceType(ComInterfaceType.InterfaceIsIDispatch)] diff --git a/src/tests/Interop/COM/ServerContracts/Server.Contracts.h b/src/tests/Interop/COM/ServerContracts/Server.Contracts.h index 3c9a1fcb06cbe1..1eb0528aae4b78 100644 --- a/src/tests/Interop/COM/ServerContracts/Server.Contracts.h +++ b/src/tests/Interop/COM/ServerContracts/Server.Contracts.h @@ -385,6 +385,7 @@ enum IDispatchTesting_Exception { IDispatchTesting_Exception_Disp, IDispatchTesting_Exception_HResult, + IDispatchTesting_Exception_Int, }; struct __declspec(uuid("a5e04c1c-474e-46d2-bbc0-769d04e12b54")) diff --git a/src/tests/Interop/StructMarshalling/PInvoke/GameControllerButtonBind.cs b/src/tests/Interop/StructMarshalling/PInvoke/GameControllerButtonBind.cs new file mode 100644 index 00000000000000..a65994f2307a3f --- /dev/null +++ b/src/tests/Interop/StructMarshalling/PInvoke/GameControllerButtonBind.cs @@ -0,0 +1,63 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Runtime.InteropServices; +using System.Runtime.CompilerServices; +using System.Text; + +public unsafe partial struct GameControllerButtonBind +{ + public GameControllerButtonBind + ( + GameControllerBindType? bindType = null, + GameControllerButtonBindValue? value = null + ) : this() + { + if (bindType is not null) + { + BindType = bindType.Value; + } + + if (value is not null) + { + Value = value.Value; + } + } + + public GameControllerBindType BindType; + + public GameControllerButtonBindValue Value; +} + +public enum GameControllerBindType : int +{ + ControllerBindtypeNone = 0x0, + ControllerBindtypeButton = 0x1, + ControllerBindtypeAxis = 0x2, + ControllerBindtypeHat = 0x3, + None = 0x0, + Button = 0x1, + Axis = 0x2, + Hat = 0x3, +} + +[StructLayout(LayoutKind.Explicit)] +public unsafe partial struct GameControllerButtonBindValue +{ + [FieldOffset(0)] + public int Button; + + [FieldOffset(0)] + public int Axis; + + [FieldOffset(0)] + public GameControllerButtonBindValueHat Hat; +} + +public unsafe partial struct GameControllerButtonBindValueHat +{ + public int Hat; + + public int HatMask; +} diff --git a/src/tests/Interop/StructMarshalling/PInvoke/MarshalStructAsParamDLL.cpp b/src/tests/Interop/StructMarshalling/PInvoke/MarshalStructAsParamDLL.cpp index 9278650d20dce8..e19eec7feb0737 100644 --- a/src/tests/Interop/StructMarshalling/PInvoke/MarshalStructAsParamDLL.cpp +++ b/src/tests/Interop/StructMarshalling/PInvoke/MarshalStructAsParamDLL.cpp @@ -1297,3 +1297,8 @@ extern "C" DLL_EXPORT Int32CLongStruct STDMETHODCALLTYPE AddCLongs(Int32CLongStr { return { lhs.i + rhs.i, lhs.l + rhs.l }; } + +extern "C" DLL_EXPORT SDL_GameControllerBindType STDMETHODCALLTYPE getBindType(SDL_GameControllerButtonBind button) +{ + return button.bindType; +} diff --git a/src/tests/Interop/StructMarshalling/PInvoke/MarshalStructAsParamDLL.h b/src/tests/Interop/StructMarshalling/PInvoke/MarshalStructAsParamDLL.h index c85b1e6f62dc05..17844b8c365263 100644 --- a/src/tests/Interop/StructMarshalling/PInvoke/MarshalStructAsParamDLL.h +++ b/src/tests/Interop/StructMarshalling/PInvoke/MarshalStructAsParamDLL.h @@ -974,3 +974,26 @@ struct Int32CLongStruct int32_t i; long l; }; + +typedef enum +{ + SDL_CONTROLLER_BINDTYPE_NONE = 0, + SDL_CONTROLLER_BINDTYPE_BUTTON, + SDL_CONTROLLER_BINDTYPE_AXIS, + SDL_CONTROLLER_BINDTYPE_HAT +} SDL_GameControllerBindType; + +typedef struct SDL_GameControllerButtonBind +{ + SDL_GameControllerBindType bindType; + union + { + int button; + int axis; + struct { + int hat; + int hat_mask; + } hat; + } value; + +} SDL_GameControllerButtonBind; diff --git a/src/tests/Interop/StructMarshalling/PInvoke/NestedStruct.cs b/src/tests/Interop/StructMarshalling/PInvoke/NestedStruct.cs new file mode 100644 index 00000000000000..870e75bff62525 --- /dev/null +++ b/src/tests/Interop/StructMarshalling/PInvoke/NestedStruct.cs @@ -0,0 +1,27 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Runtime.InteropServices; +using Xunit; + +public class Managed +{ + [DllImport("MarshalStructAsParam")] + static extern GameControllerBindType getBindType (GameControllerButtonBind button); + + public static int Main() + { + GameControllerButtonBind button = new GameControllerButtonBind(GameControllerBindType.ControllerBindtypeAxis, null); + if (getBindType(button) == GameControllerBindType.ControllerBindtypeAxis) + { + Console.WriteLine("\nTEST PASSED!"); + return 100; + } + else + { + Console.WriteLine("\nTEST FAILED!"); + return 1; + } + } +} diff --git a/src/tests/Interop/StructMarshalling/PInvoke/NestedStruct.csproj b/src/tests/Interop/StructMarshalling/PInvoke/NestedStruct.csproj new file mode 100644 index 00000000000000..0b0d4aa0836435 --- /dev/null +++ b/src/tests/Interop/StructMarshalling/PInvoke/NestedStruct.csproj @@ -0,0 +1,14 @@ + + + exe + true + + + + + + + + + + diff --git a/src/tests/JIT/HardwareIntrinsics/General/Shared/VectorImmBinaryOperatorTest.template b/src/tests/JIT/HardwareIntrinsics/General/Shared/VectorImmBinaryOperatorTest.template index ff62ffcfc91335..d9b2108472675b 100644 --- a/src/tests/JIT/HardwareIntrinsics/General/Shared/VectorImmBinaryOperatorTest.template +++ b/src/tests/JIT/HardwareIntrinsics/General/Shared/VectorImmBinaryOperatorTest.template @@ -19,7 +19,6 @@ namespace JIT.HardwareIntrinsics.General public static partial class Program { [Fact] - [ActiveIssue("https://github.com/dotnet/runtime/issues/89938", TestRuntimes.Mono)] public static void {Method}{RetBaseType}{Imm}() { var test = new VectorImmBinaryOpTest__{Method}{RetBaseType}{Imm}(); diff --git a/src/tests/JIT/Regression/JitBlue/Runtime_81725/Runtime_81725.csproj b/src/tests/JIT/Regression/JitBlue/Runtime_81725/Runtime_81725.csproj index b9493ef3d767cd..8d9ddba76b933b 100644 --- a/src/tests/JIT/Regression/JitBlue/Runtime_81725/Runtime_81725.csproj +++ b/src/tests/JIT/Regression/JitBlue/Runtime_81725/Runtime_81725.csproj @@ -7,8 +7,6 @@ - - diff --git a/src/tests/JIT/Regression/JitBlue/Runtime_83387/Runtime_83387.cs b/src/tests/JIT/Regression/JitBlue/Runtime_83387/Runtime_83387.cs new file mode 100644 index 00000000000000..cb70acd1677573 --- /dev/null +++ b/src/tests/JIT/Regression/JitBlue/Runtime_83387/Runtime_83387.cs @@ -0,0 +1,19 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Runtime.CompilerServices; +using System.Runtime.Intrinsics; +using Xunit; + +public class Runtime_83387 +{ + [MethodImpl(MethodImplOptions.NoOptimization)] + [Fact] + public static int TestEntryPoint() + { + (ushort A, ushort R) c = (1, 65535); + Vector128 v1 = Vector128.Create((uint)100); + v1 = v1 * c.A; + return (int)v1.ToScalar(); + } +} diff --git a/src/tests/JIT/Regression/JitBlue/Runtime_83387/Runtime_83387.csproj b/src/tests/JIT/Regression/JitBlue/Runtime_83387/Runtime_83387.csproj new file mode 100644 index 00000000000000..15edd99711a1a4 --- /dev/null +++ b/src/tests/JIT/Regression/JitBlue/Runtime_83387/Runtime_83387.csproj @@ -0,0 +1,8 @@ + + + True + + + + + \ No newline at end of file diff --git a/src/tests/JIT/Regression/JitBlue/Runtime_85088/Runtime_85088.csproj b/src/tests/JIT/Regression/JitBlue/Runtime_85088/Runtime_85088.csproj index 85f04c1ebc8f71..d0c820dcec6ea8 100644 --- a/src/tests/JIT/Regression/JitBlue/Runtime_85088/Runtime_85088.csproj +++ b/src/tests/JIT/Regression/JitBlue/Runtime_85088/Runtime_85088.csproj @@ -1,5 +1,7 @@ + + true True diff --git a/src/tests/JIT/Regression/JitBlue/Runtime_85765/Runtime_85765.cs b/src/tests/JIT/Regression/JitBlue/Runtime_85765/Runtime_85765.cs new file mode 100644 index 00000000000000..89c22ae90dc821 --- /dev/null +++ b/src/tests/JIT/Regression/JitBlue/Runtime_85765/Runtime_85765.cs @@ -0,0 +1,63 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Runtime.CompilerServices; +using Xunit; + +public class Runtime_85765 +{ + public struct S0 + { + public S0(bool f1): this() + { + } + } + + public struct S1 + { + public byte F0; + public bool F1; + public bool F2; + } + + [Fact] + public static void Test() + { + S1 vr2 = M4(); + vr2.F2 |= vr2.F1; + Assert.False(Consume(vr2.F2)); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public static S1 M4() + { + S1 var1 = default(S1); + var vr0 = new S0(false); + return var1; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private static bool Consume(bool value) + { + return value; + } + + // ------ + + [Fact] + public unsafe static void Test2() + { + byte* bytes = stackalloc byte[1024]; + bytes[0x1A] = 1; + bytes[0x1B] = 2; + int sum = Foo(bytes); + Assert.True(sum == 515); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public unsafe static int Foo(byte* b) + { + return Unsafe.ReadUnaligned(ref b[0x1A]) + Unsafe.ReadUnaligned(ref b[0x1B]); + } +} diff --git a/src/tests/JIT/Regression/JitBlue/Runtime_85765/Runtime_85765.csproj b/src/tests/JIT/Regression/JitBlue/Runtime_85765/Runtime_85765.csproj new file mode 100644 index 00000000000000..a4cc9d0594f93e --- /dev/null +++ b/src/tests/JIT/Regression/JitBlue/Runtime_85765/Runtime_85765.csproj @@ -0,0 +1,9 @@ + + + True + True + + + + + diff --git a/src/tests/JIT/Regression/JitBlue/Runtime_91056/Runtime_91056.cs b/src/tests/JIT/Regression/JitBlue/Runtime_91056/Runtime_91056.cs new file mode 100644 index 00000000000000..a0c78804a7ee52 --- /dev/null +++ b/src/tests/JIT/Regression/JitBlue/Runtime_91056/Runtime_91056.cs @@ -0,0 +1,33 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Xunit; +using System.Runtime.CompilerServices; + +public class Runtime_91056 +{ + [Fact] + public static void TestEntryPoint() + { + S s = default; + if (False()) + { + s.A = 1234; + } + + Foo(0, 0, s, s); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private static bool False() => false; + + [MethodImpl(MethodImplOptions.NoInlining)] + private static void Foo(int a, int b, S s1, S s2) + { + } + + public struct S + { + public int A; + } +} \ No newline at end of file diff --git a/src/tests/JIT/Regression/JitBlue/Runtime_91056/Runtime_91056.csproj b/src/tests/JIT/Regression/JitBlue/Runtime_91056/Runtime_91056.csproj new file mode 100644 index 00000000000000..444d119c04fe97 --- /dev/null +++ b/src/tests/JIT/Regression/JitBlue/Runtime_91056/Runtime_91056.csproj @@ -0,0 +1,12 @@ + + + + true + True + + + + + + + \ No newline at end of file diff --git a/src/tests/JIT/Regression/JitBlue/Runtime_91062/Runtime_91062.cs b/src/tests/JIT/Regression/JitBlue/Runtime_91062/Runtime_91062.cs new file mode 100644 index 00000000000000..562d2029ff85c5 --- /dev/null +++ b/src/tests/JIT/Regression/JitBlue/Runtime_91062/Runtime_91062.cs @@ -0,0 +1,22 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Runtime.CompilerServices; +using System.Runtime.Intrinsics; +using System.Numerics; +using Xunit; + +public class Runtime_91062 +{ + [Fact] + public static void TestEntryPoint() + { + Foo(default, default); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private static Vector2 Foo(Vector128 v1, Vector128 v2) + { + return Vector2.Lerp(default, default, Vector128.Dot(v1, v2)); + } +} diff --git a/src/tests/JIT/Regression/JitBlue/Runtime_91062/Runtime_91062.csproj b/src/tests/JIT/Regression/JitBlue/Runtime_91062/Runtime_91062.csproj new file mode 100644 index 00000000000000..de6d5e08882e86 --- /dev/null +++ b/src/tests/JIT/Regression/JitBlue/Runtime_91062/Runtime_91062.csproj @@ -0,0 +1,8 @@ + + + True + + + + + diff --git a/src/tests/JIT/Regression/JitBlue/Runtime_91170/Runtime_91170.cs b/src/tests/JIT/Regression/JitBlue/Runtime_91170/Runtime_91170.cs new file mode 100644 index 00000000000000..7f3e9293eba974 --- /dev/null +++ b/src/tests/JIT/Regression/JitBlue/Runtime_91170/Runtime_91170.cs @@ -0,0 +1,63 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +// Found by Antigen + +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Runtime.Intrinsics; +using System.Runtime.Intrinsics.X86; +using System.Numerics; +using Xunit; + +public class TestClass +{ + public struct S1 + { + } + + static short s_short_8 = 5; + static int s_int_9 = -2; + long long_59 = 4; + uint uint_64 = 1; + Vector256 v256_int_90 = Vector256.Create(2, -5, 4, 4, 5, 0, -1, 5); + S1 s1_99 = new S1(); + + private uint Method4(out short p_short_161, S1 p_s1_162, bool p_bool_163, ref int p_int_164) + { + unchecked + { + p_short_161 = 15|4; + if ((long_59 *= 15>>4)!= (long_59 |= 15^4)) + { + } + else + { + Vector128.CreateScalarUnsafe(Vector256.Sum(v256_int_90)); + } + return 15|4; + } + } + + private void Method0() + { + unchecked + { + uint_64 = Method4(out s_short_8, s1_99, 15<4, ref s_int_9); + return; + } + } + + [Fact] + public static void TestEntryPoint() + { + new TestClass().Method0(); + } +} +/* + +Assert failure(PID 34336 [0x00008620], Thread: 38576 [0x96b0]): Assertion failed '!childNode->isContainableHWIntrinsic()' in 'TestClass:Method4(byref,TestClass+S1,bool,byref):uint:this' during 'Lowering nodeinfo' (IL size 63; hash 0xa4e6dede; Tier0) + File: D:\git\runtime\src\coreclr\jit\lowerxarch.cpp Line: 8201 + Image: D:\git\runtime\artifacts\tests\coreclr\windows.x64.Checked\tests\Core_Root\CoreRun.exe +*/ diff --git a/src/tests/JIT/Regression/JitBlue/Runtime_91170/Runtime_91170.csproj b/src/tests/JIT/Regression/JitBlue/Runtime_91170/Runtime_91170.csproj new file mode 100644 index 00000000000000..de6d5e08882e86 --- /dev/null +++ b/src/tests/JIT/Regression/JitBlue/Runtime_91170/Runtime_91170.csproj @@ -0,0 +1,8 @@ + + + True + + + + + diff --git a/src/tests/JIT/Regression/JitBlue/Runtime_91174/Runtime_91174.cs b/src/tests/JIT/Regression/JitBlue/Runtime_91174/Runtime_91174.cs new file mode 100644 index 00000000000000..3ae45f998542b4 --- /dev/null +++ b/src/tests/JIT/Regression/JitBlue/Runtime_91174/Runtime_91174.cs @@ -0,0 +1,30 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Runtime.CompilerServices; +using System.Runtime.Intrinsics; +using Xunit; + +public class Runtime_91174 +{ + [MethodImpl(MethodImplOptions.NoInlining)] + public static int Foo(ref Vector256 v1, ref Vector256 v2) + { + if (Vector256.ToScalar(v1) < Vector256.ToScalar(v2)) + { + Console.WriteLine("FAIL"); + return 101; + } + + return 100; + } + + [Fact] + public static int TestEntryPoint() + { + Vector256 v1 = Vector256.Create(20); + Vector256 v2 = Vector256.Create(10); + return Foo(ref v1, ref v2); + } +} diff --git a/src/tests/JIT/Regression/JitBlue/Runtime_91174/Runtime_91174.csproj b/src/tests/JIT/Regression/JitBlue/Runtime_91174/Runtime_91174.csproj new file mode 100644 index 00000000000000..de6d5e08882e86 --- /dev/null +++ b/src/tests/JIT/Regression/JitBlue/Runtime_91174/Runtime_91174.csproj @@ -0,0 +1,8 @@ + + + True + + + + + diff --git a/src/tests/JIT/Regression/JitBlue/Runtime_91214/Runtime_91214.cs b/src/tests/JIT/Regression/JitBlue/Runtime_91214/Runtime_91214.cs new file mode 100644 index 00000000000000..d30adf6e60cf11 --- /dev/null +++ b/src/tests/JIT/Regression/JitBlue/Runtime_91214/Runtime_91214.cs @@ -0,0 +1,40 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +// Found by Antigen + +using System.Runtime.CompilerServices; +using System.Runtime.Intrinsics; +using System.Numerics; +using Xunit; + +public class Runtime_91214 +{ + [Fact] + public static void TestEntryPoint() + { + Method0(); + } + + struct S + { + public Vector3 v3; + public bool b; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + static S Method2() + { + return default; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + static void Method0() + { + S s = Method2(); + Log(null, s.v3); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + static void Log(object a, object b) { } +} \ No newline at end of file diff --git a/src/tests/JIT/Regression/JitBlue/Runtime_91214/Runtime_91214.csproj b/src/tests/JIT/Regression/JitBlue/Runtime_91214/Runtime_91214.csproj new file mode 100644 index 00000000000000..de6d5e08882e86 --- /dev/null +++ b/src/tests/JIT/Regression/JitBlue/Runtime_91214/Runtime_91214.csproj @@ -0,0 +1,8 @@ + + + True + + + + + diff --git a/src/tests/JIT/Regression/JitBlue/Runtime_91252/Runtime_91252.cs b/src/tests/JIT/Regression/JitBlue/Runtime_91252/Runtime_91252.cs new file mode 100644 index 00000000000000..d4035f3de978fc --- /dev/null +++ b/src/tests/JIT/Regression/JitBlue/Runtime_91252/Runtime_91252.cs @@ -0,0 +1,37 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// +// This test verifies if we correctly value number the operation of +// x ^ x to zero. +// +// Found by Antigen + +using System.Runtime.CompilerServices; +using System.Runtime.Intrinsics; +using Xunit; + +public class Issue_91252 +{ + static Vector64 s_v64_int_22 = Vector64.Create(-5); + Vector64 v64_int_72 = Vector64.Create(-1); + + [MethodImpl(MethodImplOptions.NoInlining)] + public int Repro() + { + s_v64_int_22 = v64_int_72; + return Check(v64_int_72 ^ v64_int_72); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public int Check(Vector64 a) + { + return (a == Vector64.Zero) ? 100 : 101; + } + + [Fact] + public static int EntryPoint() + { + var obj = new Issue_91252(); + return obj.Repro(); + } +} diff --git a/src/tests/JIT/Regression/JitBlue/Runtime_91252/Runtime_91252.csproj b/src/tests/JIT/Regression/JitBlue/Runtime_91252/Runtime_91252.csproj new file mode 100644 index 00000000000000..de6d5e08882e86 --- /dev/null +++ b/src/tests/JIT/Regression/JitBlue/Runtime_91252/Runtime_91252.csproj @@ -0,0 +1,8 @@ + + + True + + + + + diff --git a/src/tests/JIT/Regression/JitBlue/Runtime_91443/Runtime_91443.cs b/src/tests/JIT/Regression/JitBlue/Runtime_91443/Runtime_91443.cs new file mode 100644 index 00000000000000..d3844b77271dd7 --- /dev/null +++ b/src/tests/JIT/Regression/JitBlue/Runtime_91443/Runtime_91443.cs @@ -0,0 +1,23 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Numerics; +using System.Runtime.CompilerServices; +using Xunit; + +public class Runtime_91443 +{ + [Fact] + public static void TestEntryPoint() + { + new Runtime_91443().Method0(); + } + + static Vector3 s; + + [MethodImpl(MethodImplOptions.NoInlining)] + private void Method0() + { + Vector3.Cross(s, s); + } +} diff --git a/src/tests/JIT/Regression/JitBlue/Runtime_91443/Runtime_91443.csproj b/src/tests/JIT/Regression/JitBlue/Runtime_91443/Runtime_91443.csproj new file mode 100644 index 00000000000000..de6d5e08882e86 --- /dev/null +++ b/src/tests/JIT/Regression/JitBlue/Runtime_91443/Runtime_91443.csproj @@ -0,0 +1,8 @@ + + + True + + + + + diff --git a/src/tests/JIT/Regression/JitBlue/Runtime_91576/Runtime_91576.cs b/src/tests/JIT/Regression/JitBlue/Runtime_91576/Runtime_91576.cs new file mode 100644 index 00000000000000..23ae4b41c76a28 --- /dev/null +++ b/src/tests/JIT/Regression/JitBlue/Runtime_91576/Runtime_91576.cs @@ -0,0 +1,51 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +// Generated by Fuzzlyn v1.6 on 2023-09-03 15:59:01 +// Run on X64 Windows +// Seed: 11520325105937570553 +// Reduced from 294.5 KiB to 0.7 KiB in 00:04:32 +// Debug: Outputs False +// Release: Outputs True +using System; +using System.Runtime.CompilerServices; +using Xunit; + +public class Runtime_91576 +{ + [Fact] + public static int TestEntryPoint() + { + Assert.Throws(() => + { + Run(new int[1]); + Run(null); + }); + + return s_result; + } + + static int s_result; + + [MethodImpl(MethodImplOptions.NoInlining)] + private static void Run(int[] l) + { + bool b = false; + try + { + int result = l[0]; + b = true; + } + finally + { + Check(ref b); + } + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private static void Check(ref bool b) + { + s_result = b ? 101 : 100; + } +} + diff --git a/src/tests/JIT/Regression/JitBlue/Runtime_91576/Runtime_91576.csproj b/src/tests/JIT/Regression/JitBlue/Runtime_91576/Runtime_91576.csproj new file mode 100644 index 00000000000000..de6d5e08882e86 --- /dev/null +++ b/src/tests/JIT/Regression/JitBlue/Runtime_91576/Runtime_91576.csproj @@ -0,0 +1,8 @@ + + + True + + + + + diff --git a/src/tests/JIT/Regression/JitBlue/Runtime_91862/Runtime_91862.cs b/src/tests/JIT/Regression/JitBlue/Runtime_91862/Runtime_91862.cs new file mode 100644 index 00000000000000..541e2658f2d6a3 --- /dev/null +++ b/src/tests/JIT/Regression/JitBlue/Runtime_91862/Runtime_91862.cs @@ -0,0 +1,46 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Runtime.CompilerServices; +using System.Runtime.Intrinsics; +using Xunit; + +public class Runtime_91862 +{ + [Fact] + public static int TestEntryPoint() + { + return Foo(default); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private static int Foo(Vector128 v) + { + int result = 101; + // This tree results in a BOUNDS_CHECK for Bar(...) & 3 + float x = Vector128.GetElement(v, Bar(ref result) & 3); + + if (result != 100) + { + Console.WriteLine("FAIL"); + } + + // After inlining x is DCE'd, which will extract side effects of its assignment above. + // That results IR amenable to forward sub, and we end up with a BOUNDS_CHECK + // with a complex index expression that we can still prove is within bounds. + Baz(x); + return result; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private static int Bar(ref int result) + { + result = 100; + return 0; + } + + private static void Baz(float x) + { + } +} diff --git a/src/tests/JIT/Regression/JitBlue/Runtime_91862/Runtime_91862.csproj b/src/tests/JIT/Regression/JitBlue/Runtime_91862/Runtime_91862.csproj new file mode 100644 index 00000000000000..de6d5e08882e86 --- /dev/null +++ b/src/tests/JIT/Regression/JitBlue/Runtime_91862/Runtime_91862.csproj @@ -0,0 +1,8 @@ + + + True + + + + + diff --git a/src/tests/JIT/Regression/JitBlue/Runtime_92218/Runtime_92218.cs b/src/tests/JIT/Regression/JitBlue/Runtime_92218/Runtime_92218.cs new file mode 100644 index 00000000000000..9b4696e31fc16c --- /dev/null +++ b/src/tests/JIT/Regression/JitBlue/Runtime_92218/Runtime_92218.cs @@ -0,0 +1,40 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Runtime.CompilerServices; +using System.Threading; +using Xunit; + +public struct MutableStruct +{ + private long _internalValue; + + public long InternalValue + { + get => Volatile.Read(ref _internalValue); + private set => Volatile.Write(ref _internalValue, value); + } + + public void Add(long value) => AddInternal(value); + private void AddInternal(long value) => InternalValue += value; + public MutableStruct(long value) => InternalValue = value; +} + +public static class Runtime_92218 +{ + [Fact] + [MethodImpl(MethodImplOptions.AggressiveOptimization)] + public static void Problem() + { + var test = new MutableStruct(420); + var from = new MutableStruct(42); + + var wrapper = -new TimeSpan(3); + + while (test.InternalValue >= from.InternalValue) + { + test.Add(wrapper.Ticks); + } + } +} \ No newline at end of file diff --git a/src/tests/JIT/Regression/JitBlue/Runtime_92218/Runtime_92218.csproj b/src/tests/JIT/Regression/JitBlue/Runtime_92218/Runtime_92218.csproj new file mode 100644 index 00000000000000..15edd99711a1a4 --- /dev/null +++ b/src/tests/JIT/Regression/JitBlue/Runtime_92218/Runtime_92218.csproj @@ -0,0 +1,8 @@ + + + True + + + + + \ No newline at end of file diff --git a/src/tests/JIT/Regression/JitBlue/Runtime_92349/Runtime_92349.cs b/src/tests/JIT/Regression/JitBlue/Runtime_92349/Runtime_92349.cs new file mode 100644 index 00000000000000..5de0a28895b268 --- /dev/null +++ b/src/tests/JIT/Regression/JitBlue/Runtime_92349/Runtime_92349.cs @@ -0,0 +1,29 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Runtime.Intrinsics.X86; +using System.Runtime.Intrinsics; +using System.Runtime.CompilerServices; +using System.Threading; +using Xunit; + +public static class Runtime_92349 +{ + [MethodImpl(MethodImplOptions.AggressiveOptimization)] + unsafe static void Test(byte* pValue) + { + *pValue = (byte)Sse2.ConvertToInt32(Vector128.Create(-10, 0, 0, 0)); + } + + [Fact] + public unsafe static void EntryPoint() + { + if (Sse2.IsSupported) + { + ulong value = 0; + Test((byte*)Unsafe.AsPointer(ref value)); + Assert.True(value == 246); + } + } +} diff --git a/src/tests/JIT/Regression/JitBlue/Runtime_92349/Runtime_92349.csproj b/src/tests/JIT/Regression/JitBlue/Runtime_92349/Runtime_92349.csproj new file mode 100644 index 00000000000000..6bb210527e0797 --- /dev/null +++ b/src/tests/JIT/Regression/JitBlue/Runtime_92349/Runtime_92349.csproj @@ -0,0 +1,9 @@ + + + True + True + + + + + \ No newline at end of file diff --git a/src/tests/JIT/Regression/JitBlue/Runtime_92357/Runtime_92357.cs b/src/tests/JIT/Regression/JitBlue/Runtime_92357/Runtime_92357.cs new file mode 100644 index 00000000000000..4704441bacce6c --- /dev/null +++ b/src/tests/JIT/Regression/JitBlue/Runtime_92357/Runtime_92357.cs @@ -0,0 +1,47 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Runtime.CompilerServices; +using System.Runtime.Intrinsics; +using System.Runtime.Intrinsics.X86; +using Xunit; + +public static class Runtime_92357 +{ + [Fact] + [MethodImpl(MethodImplOptions.AggressiveOptimization)] + public static void Problem() + { + if (!Avx2.IsSupported) + { + return; + } + + int y1 = 5; + + Vector256 actual1 = Test1(Vector256.Create((short)1), ref y1); + Vector256 expected1 = Vector256.Create(10, 0, 10, 0, 10, 0, 10, 0, 10, 0, 10, 0, 10, 0, 10, 0); + + Assert.Equal(expected1, actual1); + + long y2 = 5; + + Vector256 actual2 = Test2(Vector256.Create((int)1), ref y2); + Vector256 expected2 = Vector256.Create(10, 0, 10, 0, 10, 0, 10, 0); + + Assert.Equal(expected2, actual2); + } + + [MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.AggressiveOptimization)] + public static Vector256 Test1(Vector256 x, ref int y) + { + return Avx2.MultiplyLow(x + x, Vector256.Create(y).AsInt16()); + } + + [MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.AggressiveOptimization)] + public static Vector256 Test2(Vector256 x, ref long y) + { + return Avx2.MultiplyLow(x + x, Vector256.Create(y).AsInt32()); + } +} diff --git a/src/tests/JIT/Regression/JitBlue/Runtime_92357/Runtime_92357.csproj b/src/tests/JIT/Regression/JitBlue/Runtime_92357/Runtime_92357.csproj new file mode 100644 index 00000000000000..15edd99711a1a4 --- /dev/null +++ b/src/tests/JIT/Regression/JitBlue/Runtime_92357/Runtime_92357.csproj @@ -0,0 +1,8 @@ + + + True + + + + + \ No newline at end of file diff --git a/src/tests/JIT/Regression/JitBlue/Runtime_92590/Runtime_92590.cs b/src/tests/JIT/Regression/JitBlue/Runtime_92590/Runtime_92590.cs new file mode 100644 index 00000000000000..99a5ef2ee5d18d --- /dev/null +++ b/src/tests/JIT/Regression/JitBlue/Runtime_92590/Runtime_92590.cs @@ -0,0 +1,63 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Runtime.CompilerServices; +using System.Runtime.Intrinsics; +using Xunit; + +public class Runtime_92590 +{ + [Fact] + public static void TestEntryPoint() + { + Span bytes = stackalloc byte[4]; + bytes.Fill(0xff); + TestByteByte(ref bytes[0], 0, Vector256.Create((byte)1)); + + Assert.True(bytes.SequenceEqual(stackalloc byte[] { 0x2, 0xff, 0xff, 0xff })); + + bytes.Fill(0xff); + TestByteInt(ref bytes[0], 0, Vector256.Create(1)); + + Assert.True(bytes.SequenceEqual(stackalloc byte[] { 0x2, 0xff, 0xff, 0xff })); + + int i = int.MaxValue; + TestIntByte(ref i, 0, Vector256.Create((byte)1)); + + Assert.Equal(2, i); + + i = int.MaxValue; + TestIntInt(ref i, 0, Vector256.Create(1)); + + Assert.Equal(2, i); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private static void TestByteByte(ref byte b, int x, Vector256 vin) + { + Vector256 v = vin + vin; + Unsafe.Add(ref b, x) = v[3]; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private static void TestByteInt(ref byte b, int x, Vector256 vin) + { + Vector256 v = vin + vin; + Unsafe.Add(ref b, x) = (byte)v[3]; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private static void TestIntByte(ref int b, int x, Vector256 vin) + { + Vector256 v = vin + vin; + Unsafe.Add(ref b, x) = v[3]; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private static void TestIntInt(ref int b, int x, Vector256 vin) + { + Vector256 v = vin + vin; + Unsafe.Add(ref b, x) = v[3]; + } +} diff --git a/src/tests/JIT/Regression/JitBlue/Runtime_92590/Runtime_92590.csproj b/src/tests/JIT/Regression/JitBlue/Runtime_92590/Runtime_92590.csproj new file mode 100644 index 00000000000000..15edd99711a1a4 --- /dev/null +++ b/src/tests/JIT/Regression/JitBlue/Runtime_92590/Runtime_92590.csproj @@ -0,0 +1,8 @@ + + + True + + + + + \ No newline at end of file diff --git a/src/tests/JIT/Regression/JitBlue/Runtime_93342/Runtime_93342.cs b/src/tests/JIT/Regression/JitBlue/Runtime_93342/Runtime_93342.cs new file mode 100644 index 00000000000000..6a68d7f5650797 --- /dev/null +++ b/src/tests/JIT/Regression/JitBlue/Runtime_93342/Runtime_93342.cs @@ -0,0 +1,104 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Runtime.CompilerServices; +using System.Runtime.Intrinsics; +using Xunit; + +public class Runtime_93342 +{ + private int foo; + private int bar; + private int baz; + + [Fact] + public static void TestEntryPoint() + { + new Runtime_93342().Run(); + } + + [MethodImpl(MethodImplOptions.AggressiveOptimization)] + private void Run() + { + if (foo == 1) + { + bar += 11; + baz += 11; + } + if (foo == 2) + bar += 12; + if (foo == 3) + bar += 13; + if (foo == 4) + bar += 14; + if (foo == 5) + bar += 15; + if (foo == 6) + bar += 16; + if (foo == 7) + bar += 17; + if (foo == 8) + bar += 18; + if (foo == 9) + bar += 19; + if (foo == 10) + bar += 20; + if (foo == 11) + bar += 21; + if (foo == 12) + bar += 22; + if (foo == 13) + bar += 23; + if (foo == 14) + bar += 24; + if (foo == 15) + bar += 25; + if (foo == 16) + bar += 26; + if (foo == 17) + bar += 27; + if (foo == 18) + bar += 28; + if (foo == 19) + bar += 29; + if (foo == 20) + bar += 30; + if (foo == 21) + bar += 31; + if (foo == 22) + bar += 32; + if (foo == 23) + bar += 33; + if (foo == 24) + bar += 34; + if (foo == 25) + bar += 35; + if (foo == 26) + bar += 36; + if (foo == 27) + bar += 37; + if (foo == 28) + bar += 38; + if (foo == 29) + bar += 39; + if (foo == 30) + bar += 40; + if (foo == 31) + bar += 41; + if (foo == 32) + bar += 42; + if (foo == 33) + bar += 43; + if (foo == 34) + bar += 44; + if (foo == 35) + bar += 45; + if (foo == 36) + bar += 46; + if (foo == 37) + bar += 47; + + bar = baz; + } +} diff --git a/src/tests/JIT/Regression/JitBlue/Runtime_93342/Runtime_93342.csproj b/src/tests/JIT/Regression/JitBlue/Runtime_93342/Runtime_93342.csproj new file mode 100644 index 00000000000000..15edd99711a1a4 --- /dev/null +++ b/src/tests/JIT/Regression/JitBlue/Runtime_93342/Runtime_93342.csproj @@ -0,0 +1,8 @@ + + + True + + + + + \ No newline at end of file diff --git a/src/tests/JIT/Regression/JitBlue/Runtime_93650/Runtime_93650.cs b/src/tests/JIT/Regression/JitBlue/Runtime_93650/Runtime_93650.cs new file mode 100644 index 00000000000000..b424dd6c2f3c98 --- /dev/null +++ b/src/tests/JIT/Regression/JitBlue/Runtime_93650/Runtime_93650.cs @@ -0,0 +1,42 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Runtime.CompilerServices; +using System.Text; +using Xunit; + +public struct Holder +{ + internal StringBuilder.AppendInterpolatedStringHandler _h; + public Holder() => _h = new(0, 0, new()); + + internal StringBuilder GetBuilder() => Unsafe.As(ref _h); +} + +public static class Runtime_93650 +{ + static int N = 1; + + [Fact] + public static int Problem() + { + var sb = new Holder(); + for (int i = 0; i < N; i++) + { + var s = Bind(ref sb); + if (s.Length != 0) + { + Console.WriteLine("FAILED: StringBuilder.ToString() returned: " + s); + return -1; + } + } + + return 100; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public static string Bind(ref Holder parameters) => GetString(parameters.GetBuilder()); + + public static string GetString(StringBuilder sb) => sb.ToString(); +} diff --git a/src/tests/JIT/Regression/JitBlue/Runtime_93650/Runtime_93650.csproj b/src/tests/JIT/Regression/JitBlue/Runtime_93650/Runtime_93650.csproj new file mode 100644 index 00000000000000..15edd99711a1a4 --- /dev/null +++ b/src/tests/JIT/Regression/JitBlue/Runtime_93650/Runtime_93650.csproj @@ -0,0 +1,8 @@ + + + True + + + + + \ No newline at end of file diff --git a/src/tests/JIT/opt/Devirtualization/Comparer_get_Default.csproj b/src/tests/JIT/opt/Devirtualization/Comparer_get_Default.csproj index c6636e39772f44..ef9e4a7a6abc7c 100644 --- a/src/tests/JIT/opt/Devirtualization/Comparer_get_Default.csproj +++ b/src/tests/JIT/opt/Devirtualization/Comparer_get_Default.csproj @@ -1,6 +1,9 @@ True + + + true diff --git a/src/tests/Loader/CustomAttributes/DynamicObjects.cs b/src/tests/Loader/CustomAttributes/DynamicObjects.cs new file mode 100644 index 00000000000000..5b232d83eba3bb --- /dev/null +++ b/src/tests/Loader/CustomAttributes/DynamicObjects.cs @@ -0,0 +1,107 @@ +using System; +using System.Resources; +using System.Reflection; +using System.Reflection.Emit; +using System.ComponentModel.DataAnnotations; +using System.Linq; + +using Xunit; + +#nullable disable + +namespace DynamicObjects { + public class M { + public const string ObjectRequiredMessage = "some string"; + public static int Main() { + var instance = createObject(); + var attrs = instance.GetType().GetProperty("prop1").GetCustomAttributes(); + + Assert.True(attrs.Count() == 2); + Assert.Equal(attrs.ElementAt(0).ToString(), "System.ComponentModel.DataAnnotations.DisplayAttribute"); + Assert.Equal(attrs.ElementAt(1).ToString(), "System.ComponentModel.DataAnnotations.RequiredAttribute"); + Assert.Equal(typeof(RequiredAttribute), attrs.ElementAt(1).GetType()); + Assert.Equal(ObjectRequiredMessage, ((RequiredAttribute)attrs.ElementAt(1)).FormatErrorMessage("abc")); + + Console.WriteLine("Success"); + return 100; + } + + public static object createObject () { + var an = new AssemblyName { Name = "TempAssembly" ,Version = new Version(1, 0, 0, 0) }; + var assemblyBuilder = AssemblyBuilder.DefineDynamicAssembly(an, AssemblyBuilderAccess.Run); + var moduleBuilder = assemblyBuilder.DefineDynamicModule("TempWorkflowAssembly.dll"); + var tb = moduleBuilder.DefineType("namespace.myclass" + , TypeAttributes.Public | + TypeAttributes.Class | + TypeAttributes.AnsiClass | + TypeAttributes.BeforeFieldInit + , typeof(object)); + + FieldBuilder fb = tb.DefineField("_prop1", + typeof(string), + FieldAttributes.Private); + + var pb = tb.DefineProperty("prop1", PropertyAttributes.HasDefault, typeof(string), null); + MethodAttributes getSetAttr = + MethodAttributes.Public | MethodAttributes.SpecialName | + MethodAttributes.HideBySig; + + // Define the "get" accessor method for prop1. + MethodBuilder custNameGetPropMthdBldr = + tb.DefineMethod("get_prop1", + getSetAttr, + typeof(string), + Type.EmptyTypes); + + ILGenerator custNameGetIL = custNameGetPropMthdBldr.GetILGenerator(); + + custNameGetIL.Emit(OpCodes.Ldarg_0); + custNameGetIL.Emit(OpCodes.Ldfld, fb); + custNameGetIL.Emit(OpCodes.Ret); + + // Define the "set" accessor method for prop1. + MethodBuilder custNameSetPropMthdBldr = + tb.DefineMethod("set_prop1", + getSetAttr, + null, + new Type[] { typeof(string) }); + + ILGenerator custNameSetIL = custNameSetPropMthdBldr.GetILGenerator(); + + custNameSetIL.Emit(OpCodes.Ldarg_0); + custNameSetIL.Emit(OpCodes.Ldarg_1); + custNameSetIL.Emit(OpCodes.Stfld, fb); + custNameSetIL.Emit(OpCodes.Ret); + + // Last, we must map the two methods created above to our PropertyBuilder to + // their corresponding behaviors, "get" and "set" respectively. + pb.SetGetMethod(custNameGetPropMthdBldr); + pb.SetSetMethod(custNameSetPropMthdBldr); + + + ///create display attribute + var dat = typeof(DisplayAttribute); + CustomAttributeBuilder CAB = new CustomAttributeBuilder(dat.GetConstructor(new Type[0]), + new object[0], + new PropertyInfo[1] { dat.GetProperty(nameof(DisplayAttribute.Name))}, + new object[] { "property 1"}); + pb.SetCustomAttribute(CAB); + + // //create required attribute + var rat = typeof(RequiredAttribute); + CustomAttributeBuilder CABR = new CustomAttributeBuilder(rat.GetConstructor(new Type[0]), + new object[0], + new PropertyInfo[2] { rat.GetProperty(nameof(RequiredAttribute.ErrorMessageResourceType)),rat.GetProperty(nameof(RequiredAttribute.ErrorMessageResourceName))}, + new object[] {typeof(ValidationErrors), "ObjectRequired" }); + pb.SetCustomAttribute(CABR); + + var objectType = tb.CreateType(); + return Activator.CreateInstance(objectType); + } + } + + public class ValidationErrors { + public static string ObjectRequired => M.ObjectRequiredMessage; + } + +} diff --git a/src/tests/Loader/CustomAttributes/DynamicObjects.csproj b/src/tests/Loader/CustomAttributes/DynamicObjects.csproj new file mode 100644 index 00000000000000..3e62bfd677e86a --- /dev/null +++ b/src/tests/Loader/CustomAttributes/DynamicObjects.csproj @@ -0,0 +1,9 @@ + + + Exe + + + + + + diff --git a/src/tests/Loader/classloader/InlineArray/InlineArrayInvalid.csproj b/src/tests/Loader/classloader/InlineArray/InlineArrayInvalid.csproj index c9c4d3b28342f2..2b17f8cac36a5b 100644 --- a/src/tests/Loader/classloader/InlineArray/InlineArrayInvalid.csproj +++ b/src/tests/Loader/classloader/InlineArray/InlineArrayInvalid.csproj @@ -7,7 +7,7 @@ - + diff --git a/src/tests/Loader/classloader/InlineArray/InlineArrayValid.csproj b/src/tests/Loader/classloader/InlineArray/InlineArrayValid.csproj index 125a6c3f5432e3..2ed1ccb70a0bbc 100644 --- a/src/tests/Loader/classloader/InlineArray/InlineArrayValid.csproj +++ b/src/tests/Loader/classloader/InlineArray/InlineArrayValid.csproj @@ -7,7 +7,7 @@ - + diff --git a/src/tests/Loader/classloader/InlineArray/InvalidCSharp.il b/src/tests/Loader/classloader/InlineArray/InvalidCSharpInlineArray.il similarity index 98% rename from src/tests/Loader/classloader/InlineArray/InvalidCSharp.il rename to src/tests/Loader/classloader/InlineArray/InvalidCSharpInlineArray.il index ff8d65c09be001..af16b27db42d62 100644 --- a/src/tests/Loader/classloader/InlineArray/InvalidCSharp.il +++ b/src/tests/Loader/classloader/InlineArray/InvalidCSharpInlineArray.il @@ -3,7 +3,7 @@ .assembly extern System.Runtime { .publickeytoken = (B0 3F 5F 7F 11 D5 0A 3A ) } -.assembly InvalidCSharp { } +.assembly InvalidCSharpInlineArray { } .class public explicit ansi sealed beforefieldinit Explicit extends [System.Runtime]System.ValueType diff --git a/src/tests/Loader/classloader/InlineArray/InvalidCSharp.ilproj b/src/tests/Loader/classloader/InlineArray/InvalidCSharpInlineArray.ilproj similarity index 73% rename from src/tests/Loader/classloader/InlineArray/InvalidCSharp.ilproj rename to src/tests/Loader/classloader/InlineArray/InvalidCSharpInlineArray.ilproj index d577c8f9c7a1a1..2a82d64591e9d5 100644 --- a/src/tests/Loader/classloader/InlineArray/InvalidCSharp.ilproj +++ b/src/tests/Loader/classloader/InlineArray/InvalidCSharpInlineArray.ilproj @@ -3,6 +3,6 @@ Library - + diff --git a/src/tests/Loader/classloader/StaticVirtualMethods/NegativeTestCases/MethodBodyOnUnrelatedType.ilproj b/src/tests/Loader/classloader/StaticVirtualMethods/NegativeTestCases/MethodBodyOnUnrelatedType.ilproj index 397521a4d40047..294b3c3a66827e 100644 --- a/src/tests/Loader/classloader/StaticVirtualMethods/NegativeTestCases/MethodBodyOnUnrelatedType.ilproj +++ b/src/tests/Loader/classloader/StaticVirtualMethods/NegativeTestCases/MethodBodyOnUnrelatedType.ilproj @@ -4,6 +4,9 @@ false + + + true Full diff --git a/src/tests/Loader/classloader/regressions/GitHub_93597/GitHub_93597.cs b/src/tests/Loader/classloader/regressions/GitHub_93597/GitHub_93597.cs new file mode 100644 index 00000000000000..122ec91663b37b --- /dev/null +++ b/src/tests/Loader/classloader/regressions/GitHub_93597/GitHub_93597.cs @@ -0,0 +1,43 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; + +public class ReproGH93597 { + public static int Main() { + var expected = new int[] {5,4,3,2,1}; + + const int LowerBound = 5; + + var expectedNzlba = NonZeroLowerBoundArray(expected, LowerBound); + + return Helper(expectedNzlba); + return 100; + } + [MethodImpl(MethodImplOptions.NoInlining)] + private static int Helper(Array a) { + IEnumerable ie = null; + try { + ie = (IEnumerable)a; + } catch (InvalidCastException) { + Console.WriteLine ("caught ICE, good"); + return 100; + } + ie.GetEnumerator(); // mono crashes here + return 101; + } + + + private static Array NonZeroLowerBoundArray(Array szArrayContents, int lowerBound) + { + Array array = Array.CreateInstance(szArrayContents.GetType().GetElementType(), new int[] { szArrayContents.Length }, new int[] { lowerBound }); + for (int i = 0; i < szArrayContents.Length; i++) + { + array.SetValue(szArrayContents.GetValue(i), i + lowerBound); + } + return array; + } + +} + diff --git a/src/tests/Loader/classloader/regressions/GitHub_93597/GitHub_93597.csproj b/src/tests/Loader/classloader/regressions/GitHub_93597/GitHub_93597.csproj new file mode 100644 index 00000000000000..a6b761d37bc58b --- /dev/null +++ b/src/tests/Loader/classloader/regressions/GitHub_93597/GitHub_93597.csproj @@ -0,0 +1,5 @@ + + + + + diff --git a/src/tests/build.proj b/src/tests/build.proj index 090f11df3acd5a..e49ddc0b51b2f8 100644 --- a/src/tests/build.proj +++ b/src/tests/build.proj @@ -595,6 +595,7 @@ $(GroupBuildCmd) "/p:CrossBuild=true" $(GroupBuildCmd) "/p:DefaultBuildAllTarget=BuildNativeAot" $(GroupBuildCmd) "/p:IlcMultiModule=true" + $(GroupBuildCmd) "/p:IlcUseServerGc=false" $(GroupBuildCmd) "/p:BuildNativeAotFrameworkObjects=true" diff --git a/src/tests/issues.targets b/src/tests/issues.targets index 09e50e73c7dd39..5c9868e334b088 100644 --- a/src/tests/issues.targets +++ b/src/tests/issues.targets @@ -63,14 +63,17 @@ https://github.com/dotnet/runtime/issues/57786 + + https://github.com/dotnet/runtime/issues/90580 + CoreCLR does not implement the mono embedding API https://github.com/dotnet/runtime/issues/78899 - - https://github.com/dotnet/runtime/issues/89585 + + https://github.com/dotnet/runtime/issues/88586 @@ -691,6 +694,9 @@ + + https://github.com/dotnet/runtime/issues/90848 + https://github.com/dotnet/runtime/issues/89157 @@ -1156,6 +1162,7 @@ + @@ -1174,6 +1181,10 @@ + + + Dynamic code generation is not supported on this platform + @@ -165,7 +165,7 @@ - - - - <_MajorVersion>$([System.Version]::Parse('$(AssemblyVersion)').Major) - <_MinorVersion>$([System.Version]::Parse('$(AssemblyVersion)').Minor) - <_PatchVersion>$([System.Version]::Parse('$(AssemblyVersion)').Build) - <_BuildNumber>$([System.Version]::Parse('$(AssemblyVersion)').Revision) - - - 5 @@ -256,9 +247,9 @@