diff --git a/src/Dapper.AOT.Analyzers/InGeneration/DapperHelpers.cs b/src/Dapper.AOT.Analyzers/InGeneration/DapperHelpers.cs index 217d9da8..dbc9e5c0 100644 --- a/src/Dapper.AOT.Analyzers/InGeneration/DapperHelpers.cs +++ b/src/Dapper.AOT.Analyzers/InGeneration/DapperHelpers.cs @@ -6,7 +6,7 @@ #if !DAPPERAOT_INTERNAL file #endif - static class DbStringHelpers + static partial class DbStringHelpers { public static void ConfigureDbStringDbParameter( global::System.Data.Common.DbParameter dbParameter, diff --git a/test/Dapper.AOT.Test.Integration/Dapper.AOT.Test.Integration.csproj b/test/Dapper.AOT.Test.Integration/Dapper.AOT.Test.Integration.csproj index 755afdfa..d1013c01 100644 --- a/test/Dapper.AOT.Test.Integration/Dapper.AOT.Test.Integration.csproj +++ b/test/Dapper.AOT.Test.Integration/Dapper.AOT.Test.Integration.csproj @@ -7,16 +7,17 @@ - - - - - - - - - + + + + + + + + + + @@ -49,6 +50,7 @@ + diff --git a/test/Dapper.AOT.Test.Integration/DbStringTests.cs b/test/Dapper.AOT.Test.Integration/DbStringTests.cs index 0e185aa0..829782e5 100644 --- a/test/Dapper.AOT.Test.Integration/DbStringTests.cs +++ b/test/Dapper.AOT.Test.Integration/DbStringTests.cs @@ -1,30 +1,32 @@ -using System.Linq; +using System.Data; +using System.Threading.Tasks; using Dapper.AOT.Test.Integration.Setup; using Xunit; +using Xunit.Abstractions; namespace Dapper.AOT.Test.Integration; [Collection(SharedPostgresqlClient.Collection)] -public class DbStringTests +public class DbStringTests : InterceptedCodeExecutionTestsBase { - private PostgresqlFixture _fixture; - - public DbStringTests(PostgresqlFixture fixture) + public DbStringTests(PostgresqlFixture fixture, ITestOutputHelper log) : base(fixture, log) { - _fixture = fixture; - fixture.NpgsqlConnection.Execute(""" + Fixture.NpgsqlConnection.Execute(""" CREATE TABLE IF NOT EXISTS dbStringTable( id integer PRIMARY KEY, name varchar(40) NOT NULL CHECK (name <> '') ); TRUNCATE dbStringTable; - """ - ); + """); } [Fact] - public void ExecuteMulti() + [DapperAot] + public async Task Test() { - + var sourceCode = PrepareSourceCodeFromFile("DbString"); + var executionResults = BuildAndExecuteInterceptedUserCode(sourceCode, methodName: "ExecuteAsync"); + + // TODO DO THE CHECK HERE } } \ No newline at end of file diff --git a/test/Dapper.AOT.Test.Integration/InterceptionExecutables/DbString.cs b/test/Dapper.AOT.Test.Integration/InterceptionExecutables/DbString.cs new file mode 100644 index 00000000..aae57a2f --- /dev/null +++ b/test/Dapper.AOT.Test.Integration/InterceptionExecutables/DbString.cs @@ -0,0 +1,20 @@ +using System.Data; +using System.Data.SqlClient; +using System.Linq; + +namespace InterceptionExecutables +{ + using System; + using System.IO; + using Dapper; + using System.Threading.Tasks; + + public static class Program + { + public static async Task ExecuteAsync(IDbConnection dbConnection) + { + var res = await dbConnection.QueryAsync("SELECT count(*) FROM dbStringTable"); + return res.First(); + } + } +} \ No newline at end of file diff --git a/test/Dapper.AOT.Test.Integration/InterceptionExecutables/_Template.cs b/test/Dapper.AOT.Test.Integration/InterceptionExecutables/_Template.cs new file mode 100644 index 00000000..9df40238 --- /dev/null +++ b/test/Dapper.AOT.Test.Integration/InterceptionExecutables/_Template.cs @@ -0,0 +1,18 @@ +namespace InterceptionExecutables +{ + using System; + using System.IO; + using Dapper; + using System.Threading.Tasks; + + // this is just a sample for easy test-writing + public static class Program + { + public static async Task ExecuteAsync(IDbConnection dbConnection) + { + + } + } + + +} \ No newline at end of file diff --git a/test/Dapper.AOT.Test.Integration/Setup/InterceptedCodeExecutionTestsBase.cs b/test/Dapper.AOT.Test.Integration/Setup/InterceptedCodeExecutionTestsBase.cs new file mode 100644 index 00000000..1b5408a3 --- /dev/null +++ b/test/Dapper.AOT.Test.Integration/Setup/InterceptedCodeExecutionTestsBase.cs @@ -0,0 +1,140 @@ +using System; +using System.Collections.Generic; +using System.Globalization; +using System.IO; +using System.Linq; +using System.Reflection; +using System.Runtime.InteropServices; +using System.Text; +using System.Threading.Tasks; +using Dapper.CodeAnalysis; +using Dapper.TestCommon; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.Emit; +using Microsoft.CodeAnalysis.Text; +using Xunit; +using Xunit.Abstractions; +using Xunit.Sdk; + +namespace Dapper.AOT.Test.Integration.Setup; + +public abstract class InterceptedCodeExecutionTestsBase : GeneratorTestBase +{ + protected readonly PostgresqlFixture Fixture; + + protected InterceptedCodeExecutionTestsBase(PostgresqlFixture fixture, ITestOutputHelper? log) : base(log) + { + Fixture = fixture; + } + + protected static string PrepareSourceCodeFromFile(string inputFileName, string extension = ".cs") + { + var fullPath = Path.Combine("InterceptionExecutables", inputFileName + extension); + if (!File.Exists(fullPath)) + { + throw new FileNotFoundException(fullPath); + } + + using var sr = new StreamReader(fullPath); + return sr.ReadToEnd(); + } + + protected T BuildAndExecuteInterceptedUserCode( + string userSourceCode, + string className = "Program", + string methodName = "ExecuteAsync") + { + var inputCompilation = RoslynTestHelpers.CreateCompilation("Assembly", syntaxTrees: [ + BuildInterceptorSupportedSyntaxTree(filename: "Program.cs", userSourceCode) + ]); + + var diagnosticsOutputStringBuilder = new StringBuilder(); + var (compilation, generatorDriverRunResult, diagnostics, errorCount) = Execute(inputCompilation, diagnosticsOutputStringBuilder, initializer: g => + { + g.Log += message => Log(message); + }); + + var results = Assert.Single(generatorDriverRunResult.Results); + Assert.NotNull(compilation); + Assert.True(errorCount == 0, "User code should not report errors"); + + var assembly = Compile(compilation!); + var type = assembly.GetTypes().Single(t => t.FullName == $"InterceptionExecutables.{className}"); + var mainMethod = type.GetMethod(methodName, BindingFlags.Public | BindingFlags.Static); + Assert.NotNull(mainMethod); + + var result = mainMethod!.Invoke(obj: null, [ Fixture.NpgsqlConnection ]); + Assert.NotNull(result); + if (result is not Task taskResult) + { + throw new XunitException($"expected execution result is '{typeof(Task)}' but got {result!.GetType()}"); + } + + return taskResult.GetAwaiter().GetResult(); + } + + SyntaxTree BuildInterceptorSupportedSyntaxTree(string filename, string text) + { + var options = new CSharpParseOptions(LanguageVersion.Preview) + .WithFeatures(new [] + { + new KeyValuePair("InterceptorsPreviewNamespaces", "$(InterceptorsPreviewNamespaces);ProgramNamespace;Dapper.AOT"), + new KeyValuePair("Features", "InterceptorsPreview"), + new KeyValuePair("LangVersion", "preview"), + }); + + var stringText = SourceText.From(text, Encoding.UTF8); + return SyntaxFactory.ParseSyntaxTree(stringText, options, filename); + } + + static Assembly Compile(Compilation compilation) + { + using var peStream = new MemoryStream(); + using var pdbstream = new MemoryStream(); + + var dbg = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? DebugInformationFormat.Pdb : DebugInformationFormat.PortablePdb; + var emitResult = compilation.Emit(peStream, pdbstream, null, null, null, new EmitOptions(false, dbg)); + if (!emitResult.Success) + { + TryThrowErrors(emitResult.Diagnostics); + } + + peStream.Position = pdbstream.Position = 0; + return Assembly.Load(peStream.ToArray(), pdbstream.ToArray()); + } + + static void TryThrowErrors(IEnumerable items) + { + var errors = new List(); + foreach (var item in items) + { + if (item.Severity == DiagnosticSeverity.Error) + { + errors.Add(item.GetMessage(CultureInfo.InvariantCulture)); + } + } + + if (errors.Count > 0) + { + throw new CompilationException(errors); + } + } + + class CompilationException : Exception + { + public IEnumerable Errors { get; private set; } + + public CompilationException(IEnumerable errors) + : base(string.Join(Environment.NewLine, errors)) + { + this.Errors = errors; + } + + public CompilationException(params string[] errors) + : base(string.Join(Environment.NewLine, errors)) + { + this.Errors = errors; + } + } +} \ No newline at end of file diff --git a/test/Dapper.AOT.Test/GeneratorTestBase.cs b/test/Dapper.AOT.Test/GeneratorTestBase.cs index bf5d86f1..acec3af9 100644 --- a/test/Dapper.AOT.Test/GeneratorTestBase.cs +++ b/test/Dapper.AOT.Test/GeneratorTestBase.cs @@ -31,13 +31,27 @@ protected GeneratorTestBase(ITestOutputHelper? log) protected static string? GetOriginCodeLocation([CallerFilePath] string? path = null) => path; // input from https://github.com/dotnet/roslyn/blob/main/docs/features/source-generators.cookbook.md#unit-testing-of-generators - + protected (Compilation? Compilation, GeneratorDriverRunResult Result, ImmutableArray Diagnostics, int ErrorCount) Execute(string source, StringBuilder? diagnosticsTo = null, [CallerMemberName] string? name = null, string? fileName = null, Action? initializer = null ) where T : class, IIncrementalGenerator, new() + { + // Create the 'input' compilation that the generator will act on + if (string.IsNullOrWhiteSpace(name)) name = "compilation"; + if (string.IsNullOrWhiteSpace(fileName)) fileName = "input.cs"; + var inputCompilation = RoslynTestHelpers.CreateCompilation(source, name!, fileName!); + + return Execute(inputCompilation, diagnosticsTo, initializer); + } + + protected (Compilation? Compilation, GeneratorDriverRunResult Result, ImmutableArray Diagnostics, int ErrorCount) Execute( + Compilation inputCompilation, + StringBuilder? diagnosticsTo = null, + Action? initializer = null + ) where T : class, IIncrementalGenerator, new() { void OutputDiagnostic(Diagnostic d) { @@ -54,18 +68,13 @@ void Output(string message, bool force = false) diagnosticsTo?.AppendLine(message.Replace('\\', '/')); // need to normalize paths } } - // Create the 'input' compilation that the generator will act on - if (string.IsNullOrWhiteSpace(name)) name = "compilation"; - if (string.IsNullOrWhiteSpace(fileName)) fileName = "input.cs"; - Compilation inputCompilation = RoslynTestHelpers.CreateCompilation(source, name!, fileName!); - // directly create an instance of the generator // (Note: in the compiler this is loaded from an assembly, and created via reflection at runtime) T generator = new(); initializer?.Invoke(generator); ShowDiagnostics("Input code", inputCompilation, diagnosticsTo, "CS8795", "CS1701", "CS1702"); - + // Create the driver that will control the generation, passing in our generator GeneratorDriver driver = CSharpGeneratorDriver.Create(new[] { generator.AsSourceGenerator() }, parseOptions: RoslynTestHelpers.ParseOptionsLatestLangVer); @@ -73,7 +82,7 @@ void Output(string message, bool force = false) // (Note: the generator driver itself is immutable, and all calls return an updated version of the driver that you should use for subsequent calls) driver = driver.RunGeneratorsAndUpdateCompilation(inputCompilation, out var outputCompilation, out var diagnostics); var runResult = driver.GetRunResult(); - + foreach (var result in runResult.Results) { if (result.Exception is not null) throw result.Exception; diff --git a/test/Dapper.AOT.Test/TestCommon/RoslynTestHelpers.cs b/test/Dapper.AOT.Test/TestCommon/RoslynTestHelpers.cs index 02172345..abb8d53a 100644 --- a/test/Dapper.AOT.Test/TestCommon/RoslynTestHelpers.cs +++ b/test/Dapper.AOT.Test/TestCommon/RoslynTestHelpers.cs @@ -16,7 +16,7 @@ namespace Dapper.TestCommon; -internal static class RoslynTestHelpers +public static class RoslynTestHelpers { internal static readonly CSharpParseOptions ParseOptionsLatestLangVer = CSharpParseOptions.Default .WithLanguageVersion(LanguageVersion.Latest) @@ -45,8 +45,41 @@ internal static class RoslynTestHelpers }) .WithFeatures(new[] { DapperInterceptorGenerator.FeatureKeys.InterceptorsPreviewNamespacePair }); - public static Compilation CreateCompilation(string source, string name, string fileName) - => CSharpCompilation.Create(name, + public static Compilation CreateCompilation(string assemblyName, SyntaxTree[] syntaxTrees, OutputKind outputKind = OutputKind.DynamicallyLinkedLibrary) + => CSharpCompilation.Create(assemblyName, + syntaxTrees: syntaxTrees, + references: new[] { + MetadataReference.CreateFromFile(typeof(Binder).Assembly.Location), +#if !NET48 + MetadataReference.CreateFromFile(Assembly.Load("System.Runtime").Location), + MetadataReference.CreateFromFile(Assembly.Load("System.Data").Location), + MetadataReference.CreateFromFile(Assembly.Load("netstandard").Location), + MetadataReference.CreateFromFile(Assembly.Load("System.Collections").Location), + MetadataReference.CreateFromFile(typeof(System.ComponentModel.DataAnnotations.Schema.ColumnAttribute).Assembly.Location), +#endif + MetadataReference.CreateFromFile(typeof(Console).Assembly.Location), + MetadataReference.CreateFromFile(typeof(DbConnection).Assembly.Location), + MetadataReference.CreateFromFile(typeof(System.Data.SqlClient.SqlConnection).Assembly.Location), + MetadataReference.CreateFromFile(typeof(Microsoft.Data.SqlClient.SqlConnection).Assembly.Location), + MetadataReference.CreateFromFile(typeof(OracleConnection).Assembly.Location), + MetadataReference.CreateFromFile(typeof(ValueTask).Assembly.Location), + MetadataReference.CreateFromFile(typeof(Component).Assembly.Location), + MetadataReference.CreateFromFile(typeof(DapperAotExtensions).Assembly.Location), + MetadataReference.CreateFromFile(typeof(SqlMapper).Assembly.Location), + MetadataReference.CreateFromFile(typeof(ImmutableList).Assembly.Location), + MetadataReference.CreateFromFile(typeof(ImmutableArray).Assembly.Location), + MetadataReference.CreateFromFile(typeof(Enumerable).Assembly.Location), + MetadataReference.CreateFromFile(typeof(IAsyncEnumerable).Assembly.Location), + MetadataReference.CreateFromFile(typeof(Span).Assembly.Location), + MetadataReference.CreateFromFile(typeof(IgnoreDataMemberAttribute).Assembly.Location), + MetadataReference.CreateFromFile(typeof(SqlMapper).Assembly.Location), + MetadataReference.CreateFromFile(typeof(DynamicAttribute).Assembly.Location), + MetadataReference.CreateFromFile(typeof(IValidatableObject).Assembly.Location), + }, + options: new CSharpCompilationOptions(outputKind, allowUnsafe: true)); + + public static Compilation CreateCompilation(string source, string assemblyName, string fileName) + => CSharpCompilation.Create(assemblyName, syntaxTrees: new[] { CSharpSyntaxTree.ParseText(source, ParseOptionsLatestLangVer).WithFilePath(fileName) }, references: new[] { MetadataReference.CreateFromFile(typeof(Binder).Assembly.Location),