diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 0000000..445a5a3 --- /dev/null +++ b/.travis.yml @@ -0,0 +1,12 @@ +language: csharp +mono: + - latest +solution: Rezoom.sln +install: + - nuget restore Rezoom.sln + - nuget install NUnit.Runners -Version 3.2.0 -OutputDirectory testrunner +script: + - xbuild /p:Configuration=Release Rezoom.sln + - mono ./testrunner/NUnit.ConsoleRunner.3.2.0/tools/nunit3-console.exe ./Rezoom.Test/bin/Release/Rezoom.Test.dll + - mono ./testrunner/NUnit.ConsoleRunner.3.2.0/tools/nunit3-console.exe ./Rezoom.SQL.Test/bin/Release/Rezoom.SQL.Test.dll + diff --git a/README.md b/README.md index 95cfce5..f99c897 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,5 @@ +[![Build Status](https://travis-ci.org/rspeele/Rezoom.svg?branch=master)](https://travis-ci.org/rspeele/Rezoom) + # Work in progress This is a library implementing a resumption monad for .NET, which @@ -11,26 +13,25 @@ This is similar in effect to Haskell's Haxl. ## What's done so far -* Core library Data.Resumption +* Core library Rezoom - * The fundamental types `Plan` and `IDataEnumerable`. + * The fundamental types `Plan<'a>` and `Errand<'a>`. * Operations like monadic bind, apply, map, etc. on those types. - * `ExecutionContext` for running `Plan`s as `Task`s with - all the caching/deduplication goodness. + * Exception handling for plans. + + * The `plan` computation expression builder that allows writing plans like plain F# code. * Mostly XML-commented, but no "big picture" documentation yet. -* F# library Data.Resumption.Workflows + * Compatibility wrappers to help implement errands from C#. - * F# idiomatic wrappers around the extension methods in Data.Resumption. + * Execution layer to run a `Plan` as a `System.Threading.Task` with batching/caching/deduplication. - * Computation expression builders `datatask` and `dataseq`. These - are by far the best way to write `Plan`s and - `IDataEnumerable`s. + * Note: errand API is still unstable. -* Test library Data.Resumption.Test +* Test library Rezoom.Test * Unit tests against a hypothetical data source that just echos strings. @@ -40,27 +41,47 @@ This is similar in effect to Haskell's Haxl. * Very few tests so far. Contributions in this area would of course be welcomed, even in this early stage of the project. -* Example integration: Data.Resumption.IPGeo +* Example integration: Rezoom.IPGeo * C# library demonstrating how to integrate an existing API - (ip-api.com) with DataRes. + (ip-api.com) with Rezoom. -* Example integration test: Data.Resumption.IPGeo.Test +* Example integration test: Rezoom.IPGeo.Test * Demonstrates how the example integration can be used. -* Example integration: Data.Resumption.EF +* Micro-ORM StaticQL.Mapping + + * Not dependent or inherently integrated with the rest of Rezoom, but designed to work well with it. + + * Automatically materializes the results of a SQL query (`IDataReader`) as CLR objects e.g. `User list`. + + * Works with F# record types and other immutable (constructor-initialized) types. + + * Generates IL for fast object construction. - * C# library demonstrating how to integrate with Entity Framework. + * Uses column naming convention to materialize nested structures, e.g. a list of Groups each with a nested list of Users. - * Uses EntityFramework.Extended to support batching queries. +## What's in progress -## What's still to come +* SQL type provider StaticQL.Provider -* SQL integration with a micro-ORM + * Understands StaticQL, a "lowest common denominator" dialect of SQL (basically SQLite's syntax with a few extensions) - * Should be comparable to Dapper/OrmLite/pals but work with DataRes - for batching. + * Can output it to various backends - SQLite, T-SQL, Postgres, etc. + Goal is not to let you change backends transparently but just to + let you use the same _syntax_ to write queries for any of them. + For example, the built-in functions available to you will depend + on which backend you select. + + * Reads migration scripts from a "model" folder to build a virtual + model of your database. Validates queries against this model. + + * Infers the types of parameters used in your query. + + * Generates statically typed query objects with materialization handled by StaticQL.Mapping. + +## What's on the horizon * Documentation @@ -74,11 +95,6 @@ This is similar in effect to Haskell's Haxl. * Something like a TODO list app that runs in Azure. - * Ideally, would have two or three versions of the backend - implementation: - - * A naive version (no batching) - - * A version with manually coded batching + * Can demo with two execution implementations to show the difference + in round-trips achieved by automatic batching/caching. - * A DataRes version diff --git a/Rezoom.ADO.Test.Internals/Properties/AssemblyInfo.cs b/Rezoom.ADO.Test.Internals/Properties/AssemblyInfo.cs deleted file mode 100644 index 853e56e..0000000 --- a/Rezoom.ADO.Test.Internals/Properties/AssemblyInfo.cs +++ /dev/null @@ -1,36 +0,0 @@ -using System.Reflection; -using System.Runtime.CompilerServices; -using System.Runtime.InteropServices; - -// General Information about an assembly is controlled through the following -// set of attributes. Change these attribute values to modify the information -// associated with an assembly. -[assembly: AssemblyTitle("Rezoom.ADO.Test.Internals")] -[assembly: AssemblyDescription("")] -[assembly: AssemblyConfiguration("")] -[assembly: AssemblyCompany("")] -[assembly: AssemblyProduct("Rezoom.ADO.Test.Internals")] -[assembly: AssemblyCopyright("Copyright © 2016")] -[assembly: AssemblyTrademark("")] -[assembly: AssemblyCulture("")] - -// Setting ComVisible to false makes the types in this assembly not visible -// to COM components. If you need to access a type in this assembly from -// COM, set the ComVisible attribute to true on that type. -[assembly: ComVisible(false)] - -// The following GUID is for the ID of the typelib if this project is exposed to COM -[assembly: Guid("3ea0244a-e97c-47b6-96c3-c83315674caa")] - -// Version information for an assembly consists of the following four values: -// -// Major Version -// Minor Version -// Build Number -// Revision -// -// You can specify all the values or you can default the Build and Revision Numbers -// by using the '*' as shown below: -// [assembly: AssemblyVersion("1.0.*")] -[assembly: AssemblyVersion("1.0.0.0")] -[assembly: AssemblyFileVersion("1.0.0.0")] diff --git a/Rezoom.ADO.Test.Internals/Rezoom.ADO.Test.Internals.csproj b/Rezoom.ADO.Test.Internals/Rezoom.ADO.Test.Internals.csproj deleted file mode 100644 index b76e8df..0000000 --- a/Rezoom.ADO.Test.Internals/Rezoom.ADO.Test.Internals.csproj +++ /dev/null @@ -1,89 +0,0 @@ - - - - Debug - AnyCPU - {3EA0244A-E97C-47B6-96C3-C83315674CAA} - Library - Properties - Rezoom.ADO.Test.Internals - Rezoom.ADO.Test.Internals - v4.6 - 512 - {3AC096D0-A1C2-E12C-1390-A8335801FDAB};{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC} - 10.0 - $(MSBuildExtensionsPath32)\Microsoft\VisualStudio\v$(VisualStudioVersion) - $(ProgramFiles)\Common Files\microsoft shared\VSTT\$(VisualStudioVersion)\UITestExtensionPackages - False - UnitTest - - - true - full - false - bin\Debug\ - DEBUG;TRACE - prompt - 4 - - - pdbonly - true - bin\Release\ - TRACE - prompt - 4 - - - - - - - - - - - - - - - - - - - - - - - {13bb08a8-8135-4630-beab-1f35d660b52b} - Rezoom.ADO - - - - - - - False - - - False - - - False - - - False - - - - - - - - \ No newline at end of file diff --git a/Rezoom.ADO.Test.Internals/TestMaterialization.cs b/Rezoom.ADO.Test.Internals/TestMaterialization.cs deleted file mode 100644 index f010b7f..0000000 --- a/Rezoom.ADO.Test.Internals/TestMaterialization.cs +++ /dev/null @@ -1,229 +0,0 @@ -using System.Collections.Generic; -using Rezoom.ADO.Materialization; -using Microsoft.VisualStudio.TestTools.UnitTesting; - -namespace Rezoom.ADO.Test.Internals -{ - [TestClass] - public class TestMaterialization - { - public class ConstructorPoint - { - public ConstructorPoint(int x, int y) - { - X = x; - Y = y; - } - - public int X { get; } - public int Y { get; } - } - - [TestMethod] - public void TestSimpleConstructor() - { - var template = RowReaderTemplate.Template; - var reader = template.CreateReader(); - reader.ProcessColumnMap(ColumnMap.Parse(new[] { "X", "Y" })); - reader.ProcessRow(new object[] { 3, 5 }); - var point = reader.ToEntity(); - Assert.AreEqual(3, point.X); - Assert.AreEqual(5, point.Y); - } - - public class Point { public int X { get; set; } public int Y { get; set; } } - [TestMethod] - public void TestSimplePropertyAssignment() - { - var template = RowReaderTemplate.Template; - var reader = template.CreateReader(); - reader.ProcessColumnMap(ColumnMap.Parse(new[] { "X", "Y" })); - reader.ProcessRow(new object[] { 3, 5 }); - var point = reader.ToEntity(); - Assert.AreEqual(3, point.X); - Assert.AreEqual(5, point.Y); - } - - public class User - { - public int Id { get; set; } - public string Name { get; set; } - public Group[] Groups { get; set; } - - public class Group - { - public int Id { get; set; } - public string Name { get; set; } - } - } - - [TestMethod] - public void TestArrayNavProperty() - { - var template = RowReaderTemplate.Template; - var reader = template.CreateReader(); - var columnMap = ColumnMap.Parse(new[] { "Id", "Name", "Groups$Id", "Name" }); - reader.ProcessColumnMap(columnMap); - reader.ProcessRow(new object[] { 1, "bob", 2, "developers" }); - reader.ProcessRow(new object[] { 1, "bob", 3, "testers" }); - var user = reader.ToEntity(); - Assert.AreEqual(1, user.Id); - Assert.AreEqual("bob", user.Name); - Assert.AreEqual(2, user.Groups.Length); - Assert.AreEqual(2, user.Groups[0].Id); - Assert.AreEqual("developers", user.Groups[0].Name); - Assert.AreEqual(3, user.Groups[1].Id); - Assert.AreEqual("testers", user.Groups[1].Name); - } - - [TestMethod] - public void TestManyArrayNavProperty() - { - var template = RowReaderTemplate.Template; - var reader = template.CreateReader(); - var columnMap = ColumnMap.Parse(new[] { "Id", "Name", "Groups$Id", "Name" }); - reader.ProcessColumnMap(columnMap); - reader.ProcessRow(new object[] { 1, "bob", 2, "developers" }); - reader.ProcessRow(new object[] { 1, "bob", 3, "testers" }); - reader.ProcessRow(new object[] { 2, "jim", 2, "developers" }); - reader.ProcessRow(new object[] { 2, "jim", 4, "slackers" }); - var users = reader.ToEntity(); - Assert.AreEqual(2, users.Length); - - Assert.AreEqual(1, users[0].Id); - Assert.AreEqual("bob", users[0].Name); - Assert.AreEqual(2, users[0].Groups.Length); - Assert.AreEqual(2, users[0].Groups[0].Id); - Assert.AreEqual("developers", users[0].Groups[0].Name); - Assert.AreEqual(3, users[0].Groups[1].Id); - Assert.AreEqual("testers", users[0].Groups[1].Name); - - Assert.AreEqual(2, users[1].Id); - Assert.AreEqual("jim", users[1].Name); - Assert.AreEqual(2, users[1].Groups.Length); - Assert.AreEqual(2, users[1].Groups[0].Id); - Assert.AreEqual("developers", users[1].Groups[0].Name); - Assert.AreEqual(4, users[1].Groups[1].Id); - Assert.AreEqual("slackers", users[1].Groups[1].Name); - } - - public class NestUser - { - public int Id { get; set; } - public string Name { get; set; } - public Group[] Groups { get; set; } - - public class Group - { - public int Id { get; set; } - public string Name { get; set; } - public Tag[] Tags { get; set; } - } - public class Tag - { - public int Id { get; set; } - public string Name { get; set; } - } - } - - [TestMethod] - public void TestNestedArrayNavProperty() - { - var template = RowReaderTemplate.Template; - var reader = template.CreateReader(); - var columnMap = ColumnMap.Parse(new[] { "Id", "Name", "Groups$Id", "Name", "Groups$Tags$Id", "Name" }); - reader.ProcessColumnMap(columnMap); - reader.ProcessRow(new object[] { 1, "bob", 2, "developers", 4, "t1" }); - reader.ProcessRow(new object[] { 1, "bob", 2, "developers", 5, "t2" }); - reader.ProcessRow(new object[] { 1, "bob", 2, "developers", 6, "t3" }); - reader.ProcessRow(new object[] { 1, "bob", 3, "testers", 4, "t1" }); - reader.ProcessRow(new object[] { 1, "bob", 3, "testers", 5, "t2" }); - reader.ProcessRow(new object[] { 1, "bob", 3, "testers", 7, "t4" }); - var user = reader.ToEntity(); - Assert.AreEqual(1, user.Id); - Assert.AreEqual("bob", user.Name); - Assert.AreEqual(2, user.Groups.Length); - Assert.AreEqual(2, user.Groups[0].Id); - Assert.AreEqual("developers", user.Groups[0].Name); - Assert.AreEqual(3, user.Groups[0].Tags.Length); - Assert.AreEqual("t2", user.Groups[0].Tags[1].Name); - Assert.AreEqual(3, user.Groups[1].Id); - Assert.AreEqual("testers", user.Groups[1].Name); - Assert.AreEqual(3, user.Groups[1].Tags.Length); - Assert.AreEqual("t4", user.Groups[1].Tags[2].Name); - } - - public class Folder - { - public int Id { get; set; } - public string Name { get; set; } - public List Children { get; set; } - } - - [TestMethod] - public void TestRecursiveArrayNavProperties() - { - var template = RowReaderTemplate.Template; - var reader = template.CreateReader(); - var columnMap = ColumnMap.Parse(new[] { "Id", "Name", "Children$Id", "Name", "Children$Children$Id", "Name" }); - reader.ProcessColumnMap(columnMap); - reader.ProcessRow(new object[] { 1, "f1", 2, "f1.1", 3, "f1.1.1" }); - reader.ProcessRow(new object[] { 1, "f1", 4, "f1.2", 5, "f1.2.1" }); - var folder = reader.ToEntity(); - Assert.AreEqual(1, folder.Id); - Assert.AreEqual("f1", folder.Name); - - Assert.AreEqual(2, folder.Children[0].Id); - Assert.AreEqual("f1.1", folder.Children[0].Name); - Assert.AreEqual(3, folder.Children[0].Children[0].Id); - Assert.AreEqual("f1.1.1", folder.Children[0].Children[0].Name); - - Assert.AreEqual(4, folder.Children[1].Id); - Assert.AreEqual("f1.2", folder.Children[1].Name); - Assert.AreEqual(5, folder.Children[1].Children[0].Id); - Assert.AreEqual("f1.2.1", folder.Children[1].Children[0].Name); - } - - public class Zoo - { - public int Id { get; set; } - public string Name { get; set; } - } - public class Animal - { - public Zoo Zoo { get; set; } - public int Id { get; set; } - public string Name { get; set; } - } - - [TestMethod] - public void TestSingleNavProperty() - { - var template = RowReaderTemplate.Template; - var reader = template.CreateReader(); - var columnMap = ColumnMap.Parse(new[] { "Id", "Name", "Zoo$Id", "Name" }); - reader.ProcessColumnMap(columnMap); - reader.ProcessRow(new object[] { 1, "zebra", 2, "NC zoo" }); - var animal = reader.ToEntity(); - Assert.AreEqual(1, animal.Id); - Assert.AreEqual("zebra", animal.Name); - Assert.IsNotNull(animal.Zoo); - Assert.AreEqual(2, animal.Zoo.Id); - Assert.AreEqual("NC zoo", animal.Zoo.Name); - } - - [TestMethod] - public void TestSingleNullNavProperty() - { - var template = RowReaderTemplate.Template; - var reader = template.CreateReader(); - var columnMap = ColumnMap.Parse(new[] { "Id", "Name", "Zoo$Id", "Name" }); - reader.ProcessColumnMap(columnMap); - reader.ProcessRow(new object[] { 1, "zebra", null, null }); - var animal = reader.ToEntity(); - Assert.AreEqual(1, animal.Id); - Assert.AreEqual("zebra", animal.Name); - Assert.IsNull(animal.Zoo); - } - } -} diff --git a/Rezoom.ADO.Test/Environment.fs b/Rezoom.ADO.Test/Environment.fs deleted file mode 100644 index bfeea80..0000000 --- a/Rezoom.ADO.Test/Environment.fs +++ /dev/null @@ -1,129 +0,0 @@ -[] -module Rezoom.ADO.Test.Environment -open Rezoom -open Rezoom.ADO -open Rezoom.Execution -open System -open System.Collections -open System.Collections.Generic -open System.Data.SQLite -open FSharp.Reflection - -let private initializeSchema (conn : SQLiteConnection) = - let cmd = conn.CreateCommand() - cmd.CommandText <- @" - create table Users - ( Id int primary key - , Name text - ) - ; - create table Groups - ( Id int primary key - , Name text - ) - ; - create table UserGroupMaps - ( UserId int - , GroupId int - , primary key(UserId, GroupId) - , foreign key(UserId) references Users(Id) - , foreign key(GroupId) references Groups(Id) - ) - ; - " - cmd.Connection <- conn - cmd.ExecuteNonQuery() - -let private initializeData (conn : SQLiteConnection) = - let cmd = conn.CreateCommand() - cmd.CommandText <- @" - insert into Users(Id, Name) - values - ( 1, ""Jim"" ) - , ( 2, ""Mary"" ) - , ( 3, ""Ellen"" ) - , ( 4, ""Rick"" ) - ; - insert into Groups(Id, Name) - values - ( 1, ""Admins"" ) - , ( 2, ""Content Creators"" ) - , ( 3, ""Content Reviewers"" ) - , ( 4, ""Content Organizers"" ) - ; - insert into UserGroupMaps(UserId, GroupId) - values - ( 1, 1 ) - , ( 1, 2 ) - , ( 2, 3 ) - , ( 2, 4 ) - , ( 3, 2 ) - , ( 4, 3 ) - ; - " - cmd.Connection <- conn - cmd.ExecuteNonQuery() - -let private initializeDb (conn : SQLiteConnection) = - ignore <| initializeSchema conn - ignore <| initializeData conn - -type TestDbServiceFactory() = - inherit DbServiceFactory() - override __.CreateConnection() = - let filename = (Guid.NewGuid().ToString("n").Substring(0, 4) + ".db") - SQLiteConnection.CreateFile(filename) - let conn = new SQLiteConnection("Data Source=" + filename) - conn.Open() - initializeDb conn - upcast conn - -let query query args = - let args = Array.ofList args - let command = - { new FormattableString() with - member __.Format = query - member __.GetArguments() = args - member __.ArgumentCount = args.Length - member __.GetArgument(index) = args.[index] - member __.ToString(provider) = String.Format(provider, query, args) - } |> Command.Query - plan { - let! rs = CommandErrand(command).ToPlan() - return rs.[0] - } - -type 'a ExpectedResult = - | Exception of (exn -> bool) - | Value of 'a - -type 'a TestTask = - { - Task : 'a Plan - ExpectedResult : 'a ExpectedResult - } - -let test (task : 'a TestTask) = - use context = - new ExecutionContext(new TestDbServiceFactory(), new DebugExecutionLog()) - let answer = - try - context.Execute(task.Task).Result |> Some - with - | ex -> - match task.ExpectedResult with - | Exception predicate -> - if predicate ex then None - else reraise() - | _ -> reraise() - match answer with - | None -> () - | Some v -> - match task.ExpectedResult with - | Exception predicate -> - failwithf "Got value %A when exception was expected" v - | Value expect -> - if expect <> v then - failwithf "Got %A; expected %A" v expect - else () - \ No newline at end of file diff --git a/Rezoom.ADO.Test/TestQueries.fs b/Rezoom.ADO.Test/TestQueries.fs deleted file mode 100644 index 24bedae..0000000 --- a/Rezoom.ADO.Test/TestQueries.fs +++ /dev/null @@ -1,32 +0,0 @@ -namespace Rezoom.ADO.Test -open Rezoom -open Microsoft.VisualStudio.TestTools.UnitTesting - -[] -type TestQueries() = - [] - member __.TestQuerySingleUser() = - { - Task = - plan { - let! user = query "select Id, Name from Users where Id = {0}" [1] - return string <| user.Rows.[0].[1] - } - ExpectedResult = Value "Jim" - } |> test - - [] - member __.TestConcurrentQueries() = - { - Task = - plan { - let! users = query "select Id from Users" [] - let names = new ResizeArray() - for user in batch users.Rows do - let id = user.[0] - let! name = query "select Name from Users where Id = {0}" [id] - names.Add(name.Rows.[0].[0] |> string) - return names |> Set.ofSeq - } - ExpectedResult = Value (["Ellen"; "Jim"; "Mary"; "Rick"] |> Set.ofSeq) - } |> test \ No newline at end of file diff --git a/Rezoom.ADO.Test/packages.config b/Rezoom.ADO.Test/packages.config deleted file mode 100644 index 1c00e18..0000000 --- a/Rezoom.ADO.Test/packages.config +++ /dev/null @@ -1,4 +0,0 @@ - - - - \ No newline at end of file diff --git a/Rezoom.ADO/Command.cs b/Rezoom.ADO/Command.cs deleted file mode 100644 index bccd608..0000000 --- a/Rezoom.ADO/Command.cs +++ /dev/null @@ -1,22 +0,0 @@ -using System; - -namespace Rezoom.ADO -{ - public class Command - { - public Command(FormattableString text, bool mutation, bool idempotent) - { - Text = text; - Mutation = mutation; - Idempotent = idempotent; - } - public bool Mutation { get; } - public bool Idempotent { get; } - public FormattableString Text { get; } - - public static Command Query(FormattableString text, bool idempotent = true) - => new Command(text, mutation: false, idempotent: idempotent); - public static Command Mutate(FormattableString text, bool idempotent = false) - => new Command(text, mutation: true, idempotent: idempotent); - } -} \ No newline at end of file diff --git a/Rezoom.ADO/CommandBatch.cs b/Rezoom.ADO/CommandBatch.cs deleted file mode 100644 index fa5c626..0000000 --- a/Rezoom.ADO/CommandBatch.cs +++ /dev/null @@ -1,128 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Data.Common; -using System.IO; -using System.Linq; -using System.Text; -using System.Threading.Tasks; - -namespace Rezoom.ADO -{ - /// - /// Builds up a batch of SQL strings to run in a single IDbCommand. - /// - internal class CommandBatch : IDisposable - { - /// - /// We put this after each command to terminate it. - /// The extra characters are intended to guard against accidental issues like - /// unclosed string literals or block comments spilling into the next command. - /// - private const string CommandTerminator = ";--'*/;"; - - private readonly IDbTypeRecognizer _typeRecognizer; - private readonly DbCommand _command; - private readonly List _commands = new List(); - - private Task>> _executing; - - public CommandBatch(DbConnection connection, IDbTypeRecognizer typeRecognizer) - { - _typeRecognizer = typeRecognizer; - _command = connection.CreateCommand(); - _command.Connection = connection; - } - - public Func>> Prepare(Command command) - { - if (_executing != null) - throw new InvalidOperationException("Command is already executing"); - var parameterValues = command.Text.GetArguments(); - var parameterNames = new object[parameterValues.Length]; - for (var i = 0; i < parameterValues.Length; i++) - { - var dbParamName = $"@DRBATCHPARAM_{_command.Parameters.Count}"; - var dbParam = _command.CreateParameter(); - dbParam.ParameterName = dbParamName; - dbParam.Value = parameterValues[i]; - dbParam.DbType = _typeRecognizer.GetDbType(parameterValues[i]); - _command.Parameters.Add(dbParam); - parameterNames[i] = dbParamName; - } - var sqlReferencingParams = string.Format(command.Text.Format, parameterNames); - - var commandIndex = _commands.Count; - _commands.Add(sqlReferencingParams); - return () => GetResultSet(commandIndex); - } - - private async Task>> GetAllResultSets() - { - var separators = new List(); - var gluedText = new StringBuilder(); - foreach (var command in _commands) - { - if (gluedText.Length > 0) - { - var sep = $"DRSEP_{Guid.NewGuid():N}"; - gluedText.AppendLine(CommandTerminator); - gluedText.AppendLine($"SELECT NULL as {sep};"); - separators.Add(sep); - } - gluedText.AppendLine(command); - } - _command.CommandText = gluedText.ToString(); - using (var reader = await _command.ExecuteReaderAsync().ConfigureAwait(false)) - { - var sepi = 0; - var allResults = new List>(); - var currentCommandResults = new List(); - do - { - var fieldNames = Enumerable.Range(0, reader.FieldCount) - .Select(reader.GetName) - .ToArray(); - // If we hit our separator, that's the end of a command's result sets. - if (separators.Count > sepi && fieldNames.Length == 1 && fieldNames[0] == separators[sepi]) - { - allResults.Add(currentCommandResults); - currentCommandResults = new List(); - sepi++; - } - else - { - var rows = new List>(); - while (await reader.ReadAsync().ConfigureAwait(false)) - { - var row = new object[reader.FieldCount]; - reader.GetValues(row); - rows.Add(row); - } - currentCommandResults.Add(new CommandResponse(fieldNames, rows)); - } - } while (await reader.NextResultAsync().ConfigureAwait(false)); - if (sepi < allResults.Count) - { - throw new InvalidDataException($"Unexpected result sets missing separator"); - } - // The last result set doesn't have a trailing separator, so we need to add it on here. - allResults.Add(currentCommandResults); - return allResults; - } - } - - private async Task> GetResultSet(int index) - { - if (_executing != null) - { - var allResults = await _executing.ConfigureAwait(false); - return allResults[index]; - } - _executing = GetAllResultSets(); - var results = await _executing.ConfigureAwait(false); - return results[index]; - } - - public void Dispose() => _command.Dispose(); - } -} \ No newline at end of file diff --git a/Rezoom.ADO/CommandRequest.cs b/Rezoom.ADO/CommandRequest.cs deleted file mode 100644 index 75d8d4b..0000000 --- a/Rezoom.ADO/CommandRequest.cs +++ /dev/null @@ -1,25 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Threading.Tasks; - -namespace Rezoom.ADO -{ - public class CommandErrand : CS.AsynchronousErrand> - { - private readonly Command _command; - - public CommandErrand(Command command) - { - _command = command; - } - - public override object Identity => FormattableString.Invariant(_command.Text); - public override object DataSource => typeof(CommandBatch); - public override bool Mutation => _command.Mutation; - public override bool Idempotent => _command.Idempotent; - public override object SequenceGroup => typeof(CommandBatch); - - public override Func>> Prepare(ServiceContext context) - => context.GetService().Prepare(_command); - } -} diff --git a/Rezoom.ADO/CommandResponse.cs b/Rezoom.ADO/CommandResponse.cs deleted file mode 100644 index 0cf20c2..0000000 --- a/Rezoom.ADO/CommandResponse.cs +++ /dev/null @@ -1,20 +0,0 @@ -using System.Collections.Generic; - -namespace Rezoom.ADO -{ - public class CommandResponse - { - public CommandResponse - (IReadOnlyList columnNames, IReadOnlyList> rows) - { - ColumnNames = columnNames; - Rows = rows; - } - - public IReadOnlyList ColumnNames { get; } - public IReadOnlyList> Rows { get; } - - public static readonly CommandResponse Empty - = new CommandResponse(new string[0], new IReadOnlyList[0]); - } -} \ No newline at end of file diff --git a/Rezoom.ADO/DbServiceFactory.cs b/Rezoom.ADO/DbServiceFactory.cs deleted file mode 100644 index 440f4bd..0000000 --- a/Rezoom.ADO/DbServiceFactory.cs +++ /dev/null @@ -1,33 +0,0 @@ -using System.Data.Common; -using Rezoom; - -namespace Rezoom.ADO -{ - public abstract class DbServiceFactory : ServiceFactory - { - protected abstract DbConnection CreateConnection(); - protected virtual IDbTypeRecognizer CreateDbTypeRecognizer() => new DbTypeRecognizer(); - - public override LivingService CreateService(ServiceContext context) - { - if (typeof(T) == typeof(DbConnection)) - { - var conn = CreateConnection(); - return new LivingService(ServiceLifetime.ExecutionLocal, (T)(object)conn); - } - if (typeof(T) == typeof(IDbTypeRecognizer)) - { - var recognizer = CreateDbTypeRecognizer(); - return new LivingService(ServiceLifetime.ExecutionLocal, (T)recognizer); - } - if (typeof(T) == typeof(CommandBatch)) - { - var dbConnection = context.GetService(); - var dbTypeRecognizer = context.GetService(); - var cmdContext = new CommandBatch(dbConnection, dbTypeRecognizer); - return new LivingService(ServiceLifetime.StepLocal, (T)(object)cmdContext); - } - return null; - } - } -} diff --git a/Rezoom.ADO/DbTypeRecognizer.cs b/Rezoom.ADO/DbTypeRecognizer.cs deleted file mode 100644 index bf37e93..0000000 --- a/Rezoom.ADO/DbTypeRecognizer.cs +++ /dev/null @@ -1,42 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Data; - -namespace Rezoom.ADO -{ - public class DbTypeRecognizer : IDbTypeRecognizer - { - public virtual DbType StringType => DbType.AnsiString; - public virtual DbType DateTimeType => DbType.DateTime2; - private static readonly Dictionary Primitives = new Dictionary - { - { typeof(bool), DbType.Boolean }, - - { typeof(long), DbType.Int64 }, - { typeof(int), DbType.Int32 }, - { typeof(short), DbType.Int16 }, - { typeof(sbyte), DbType.SByte }, - - { typeof(ulong), DbType.UInt64 }, - { typeof(uint), DbType.UInt32 }, - { typeof(ushort), DbType.UInt16 }, - { typeof(byte), DbType.Byte }, - - { typeof(double), DbType.Double }, - { typeof(float), DbType.Single }, - { typeof(decimal), DbType.Decimal }, - - { typeof(Guid), DbType.Guid }, - }; - public DbType GetDbType(object value) - { - if (value == null) return DbType.Object; - if (value is string) return StringType; - if (value is DateTime) return DateTimeType; - var type = value.GetType(); - DbType result; - if (Primitives.TryGetValue(type, out result)) return result; - throw new NotSupportedException($"The type {type} is not supported for a database parameter"); - } - } -} \ No newline at end of file diff --git a/Rezoom.ADO/IDbTypeRecognizer.cs b/Rezoom.ADO/IDbTypeRecognizer.cs deleted file mode 100644 index 0bd7e1e..0000000 --- a/Rezoom.ADO/IDbTypeRecognizer.cs +++ /dev/null @@ -1,9 +0,0 @@ -using System.Data; - -namespace Rezoom.ADO -{ - public interface IDbTypeRecognizer - { - DbType GetDbType(object value); - } -} \ No newline at end of file diff --git a/Rezoom.ADO/Materialization/ColumnMap.cs b/Rezoom.ADO/Materialization/ColumnMap.cs deleted file mode 100644 index 130e0a6..0000000 --- a/Rezoom.ADO/Materialization/ColumnMap.cs +++ /dev/null @@ -1,63 +0,0 @@ -using System; -using System.Collections.Generic; - -namespace Rezoom.ADO.Materialization -{ - public class ColumnMap - { - private readonly Dictionary _columnIndices = - new Dictionary(StringComparer.OrdinalIgnoreCase); - private readonly Dictionary _subMaps = - new Dictionary(StringComparer.OrdinalIgnoreCase); - - private ColumnMap GetOrCreateSubMap(string name) - { - ColumnMap val; - if (_subMaps.TryGetValue(name, out val)) return val; - val = new ColumnMap(); - _subMaps[name] = val; - return val; - } - - private void SetColumnIndex(string name, int index) => _columnIndices[name] = index; - - private void Load(IReadOnlyList columnNames) - { - var root = this; - var current = this; - for (var i = 0; i < columnNames.Count; i++) - { - var path = columnNames[i].Split('.', '$'); - if (path.Length == 1) current.SetColumnIndex(path[0], i); - else - { - current = root; - for (var j = 0; j < path.Length - 1; j++) - { - current = current.GetOrCreateSubMap(path[j]); - } - current.SetColumnIndex(path[path.Length - 1], i); - } - } - } - - public int ColumnIndex(string propertyName) - { - int idx; - return _columnIndices.TryGetValue(propertyName, out idx) ? idx : -1; - } - - public ColumnMap SubMap(string propertyName) - { - ColumnMap sub; - return _subMaps.TryGetValue(propertyName, out sub) ? sub : null; - } - - internal static ColumnMap Parse(IReadOnlyList columnNames) - { - var map = new ColumnMap(); - map.Load(columnNames); - return map; - } - } -} \ No newline at end of file diff --git a/Rezoom.ADO/Materialization/GenBuilderProperties/ManyNavGenBuilderProperty.cs b/Rezoom.ADO/Materialization/GenBuilderProperties/ManyNavGenBuilderProperty.cs deleted file mode 100644 index 435635b..0000000 --- a/Rezoom.ADO/Materialization/GenBuilderProperties/ManyNavGenBuilderProperty.cs +++ /dev/null @@ -1,124 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Reflection; -using System.Reflection.Emit; - -namespace Rezoom.ADO.Materialization.GenBuilderProperties -{ - internal class ManyNavGenBuilderProperty : NavGenBuilderProperty - { - private readonly Type _dictionaryType; - private readonly Type _collectionType; - - private FieldBuilder _dict; - - public ManyNavGenBuilderProperty(string fieldName, Type entityType, Type collectionType) - : base(fieldName, entityType) - { - _collectionType = collectionType; - _dictionaryType = typeof(Dictionary<,>).MakeGenericType(KeyColumn.Type, EntityReaderType); - } - - public override void InstallFields(TypeBuilder type, ILGenerator constructor) - { - base.InstallFields(type, constructor); - _dict = type.DefineField("_dr_dict_" + FieldName, _dictionaryType, FieldAttributes.Private); - var cons = _dictionaryType.GetConstructor(Type.EmptyTypes); - if (cons == null) throw new Exception("Unexpected lack of default constructor on dictionary type"); - - constructor.Emit(OpCodes.Dup); // dup "this" - constructor.Emit(OpCodes.Newobj, cons); // new dictionary - constructor.Emit(OpCodes.Stfld, _dict); // assign field - } - - public override void InstallProcessingLogic(GenProcessRowContext cxt) - { - var il = cxt.IL; - var skip = il.DefineLabel(); - var subProcess = il.DefineLabel(); - var keyRaw = il.DeclareLocal(typeof(object)); - var key = il.DeclareLocal(KeyColumn.Type); - var entReader = il.DeclareLocal(EntityReaderType); - il.Emit(OpCodes.Dup); // this, this - il.Emit(OpCodes.Ldfld, SubColumnMap); // this, cmap - il.Emit(OpCodes.Brfalse, skip); // skip if we have no column map (for recursive case) - - // get the key value from the row - il.Emit(OpCodes.Ldloc, cxt.Row); // this, row - il.Emit(OpCodes.Ldloc, cxt.This); // this, row, this - il.Emit(OpCodes.Ldfld, KeyColumnIndex); // this, row, index - il.Emit(OpCodes.Ldelem_Ref); // this, rval - // store it in a local - il.Emit(OpCodes.Dup); // this, rval, rval - il.Emit(OpCodes.Stloc, keyRaw); // this, rval - // if our id is null, bail - il.Emit(OpCodes.Brfalse, skip); - { - // stack clean (this at top) - il.Emit(OpCodes.Dup); // this, this - il.Emit(OpCodes.Ldfld, _dict); // this, dict - il.Emit(OpCodes.Ldloc, keyRaw); // this, dict, rval - il.Emit(OpCodes.Call, PrimitiveConverter.ToType(KeyColumn.Type)); // this, dict, key - il.Emit(OpCodes.Dup); // this, dict, key, key - il.Emit(OpCodes.Stloc, key); // this, dict, key - il.Emit(OpCodes.Ldloca, entReader); // this, dict, key, &reader - il.Emit(OpCodes.Call, _dictionaryType.GetMethod(nameof(Dictionary.TryGetValue))); - // this, gotv - // if we've got one, skip to sub-processing the row - il.Emit(OpCodes.Brtrue, subProcess); - { - // otherwise, make one - // stack clean (this at top) - il.Emit(OpCodes.Ldsfld, EntityReaderStaticTemplateType.GetField - (nameof(RowReaderTemplate.Template))); - // this, template - il.Emit(OpCodes.Callvirt, EntityReaderTemplateType.GetMethod - (nameof(IRowReaderTemplate.CreateReader))); - // this, newreader - il.Emit(OpCodes.Dup); - // this, newreader, newreader - il.Emit(OpCodes.Stloc, entReader); - // this, newreader - - // process column map - il.Emit(OpCodes.Ldloc, cxt.This); // this, newreader, this - il.Emit(OpCodes.Ldfld, SubColumnMap); // this, newreader, columnmap - il.Emit(OpCodes.Callvirt, EntityReaderType.GetMethod - (nameof(IRowReader.ProcessColumnMap))); - - // save in dictionary - // stack clean (this at top) - il.Emit(OpCodes.Dup); // this, this - il.Emit(OpCodes.Ldfld, _dict); // this, dict - il.Emit(OpCodes.Ldloc, key); // this, dict, key - il.Emit(OpCodes.Ldloc, entReader); // this, dict, key, reader - il.Emit(OpCodes.Call, _dictionaryType.GetMethod - (nameof(Dictionary.Add))); - // stack clean (this at top) - } - il.MarkLabel(subProcess); - // have the entity reader process the row - il.Emit(OpCodes.Ldloc, entReader); // this, reader - il.Emit(OpCodes.Ldloc, cxt.Row); // this, reader, row - il.Emit(OpCodes.Callvirt, EntityReaderType.GetMethod(nameof(IRowReader.ProcessRow))); - // this - } - il.MarkLabel(skip); - } - - public override void InstallPushValue(GenInstanceMethodContext cxt) - { - var il = cxt.IL; - - il.Emit(OpCodes.Ldloc, cxt.This); // this - il.Emit(OpCodes.Ldfld, _dict); // dict - var manyNavConverter = typeof(ManyNavConverter<,>) - .MakeGenericType(KeyColumn.Type, EntityType); - var conversion = - (MethodInfo) - manyNavConverter.GetMethod(nameof(ManyNavConverter.ToType)) - .Invoke(null, new object[] { _collectionType }); - il.Emit(OpCodes.Call, conversion); - } - } -} diff --git a/Rezoom.ADO/Materialization/GenBuilderProperties/NavGenBuilderProperty.cs b/Rezoom.ADO/Materialization/GenBuilderProperties/NavGenBuilderProperty.cs deleted file mode 100644 index 56a90dd..0000000 --- a/Rezoom.ADO/Materialization/GenBuilderProperties/NavGenBuilderProperty.cs +++ /dev/null @@ -1,89 +0,0 @@ -using System; -using System.Reflection; -using System.Reflection.Emit; -using Rezoom.ADO.Materialization.TypeInfo; - -namespace Rezoom.ADO.Materialization.GenBuilderProperties -{ - internal abstract class NavGenBuilderProperty : IGenBuilderProperty - { - protected readonly Type EntityReaderStaticTemplateType; - protected readonly Type EntityReaderTemplateType; - protected readonly Type EntityReaderType; - protected readonly Type EntityType; - protected readonly string FieldName; - protected readonly TypeColumn KeyColumn; - - protected NavGenBuilderProperty(string fieldName, Type entityType) - { - FieldName = fieldName; - EntityType = entityType; - EntityReaderType = typeof(IRowReader<>).MakeGenericType(entityType); - EntityReaderStaticTemplateType = typeof(RowReaderTemplate<>).MakeGenericType(entityType); - EntityReaderTemplateType = typeof(IRowReaderTemplate<>).MakeGenericType(entityType); - var key = typeof(TypeProfile<>).MakeGenericType(entityType) - .GetField(nameof(TypeProfile.Profile)) - .GetValue(null) as TypeProfile; - if (key == null) throw new NullReferenceException("Unexpected null type profile"); - KeyColumn = key.KeyColumn; - if (KeyColumn == null) throw new InvalidOperationException($"Type {entityType} has no key column"); - } - - protected FieldBuilder SubColumnMap; - protected FieldBuilder KeyColumnIndex; - - public bool Singular => false; - - public virtual void InstallFields(TypeBuilder type, ILGenerator constructor) - { - SubColumnMap = type.DefineField("_dr_cmap_" + FieldName, typeof(ColumnMap), FieldAttributes.Private); - KeyColumnIndex = type.DefineField("_dr_kcol_" + FieldName, typeof(int), FieldAttributes.Private); - } - - public virtual void InstallProcessingLogic(GenProcessColumnMapContext cxt) - { - var il = cxt.IL; - var skip = il.DefineLabel(); - var done = il.DefineLabel(); - il.Emit(OpCodes.Dup); // this, this - il.Emit(OpCodes.Dup); // this, this, this - // Get submap for this nav property - il.Emit(OpCodes.Dup); // this, this, this, this - il.Emit(OpCodes.Ldloc, cxt.ColumnMap); // this, this, this, this, colmap - il.Emit(OpCodes.Dup); - il.Emit(OpCodes.Brfalse_S, skip); - { - if (FieldName != null) - { - il.Emit(OpCodes.Ldstr, FieldName); // this, this, this, this colmap, fieldname - il.Emit(OpCodes.Call, typeof(ColumnMap).GetMethod(nameof(ColumnMap.SubMap))); - } - // this, this, this, this, submap - il.Emit(OpCodes.Dup); - il.Emit(OpCodes.Brfalse_S, skip); - // Set column map field to submap - il.Emit(OpCodes.Stfld, SubColumnMap); // this, this, this - il.Emit(OpCodes.Ldfld, SubColumnMap); // this, this, submap - // Get key column index from submap - il.Emit(OpCodes.Ldstr, KeyColumn.Name); // this, this, submap, keyfield - il.Emit(OpCodes.Call, typeof(ColumnMap).GetMethod(nameof(ColumnMap.ColumnIndex))); - // this, this, keyindex - il.Emit(OpCodes.Stfld, KeyColumnIndex); - il.Emit(OpCodes.Br_S, done); - } - il.MarkLabel(skip); - { - // this, this, this, this, submap - il.Emit(OpCodes.Pop); - il.Emit(OpCodes.Pop); - il.Emit(OpCodes.Pop); - il.Emit(OpCodes.Pop); - } - il.MarkLabel(done); - } - - public abstract void InstallProcessingLogic(GenProcessRowContext cxt); - - public abstract void InstallPushValue(GenInstanceMethodContext cxt); - } -} \ No newline at end of file diff --git a/Rezoom.ADO/Materialization/GenBuilderProperties/PrimitiveGenBuilderProperty.cs b/Rezoom.ADO/Materialization/GenBuilderProperties/PrimitiveGenBuilderProperty.cs deleted file mode 100644 index 3db93aa..0000000 --- a/Rezoom.ADO/Materialization/GenBuilderProperties/PrimitiveGenBuilderProperty.cs +++ /dev/null @@ -1,103 +0,0 @@ -using System; -using System.Reflection; -using System.Reflection.Emit; - -namespace Rezoom.ADO.Materialization.GenBuilderProperties -{ - /// - /// Implements IGenBuilderProperty for a primitive field -- for example, one of type int, string, or Guid. - /// - internal class PrimitiveGenBuilderProperty : IGenBuilderProperty - { - private readonly string _fieldName; - private readonly Type _fieldType; - private readonly bool _nonNull; - - /// - /// The field that stores the value for this property. - /// - private FieldBuilder _value; - /// - /// The boolean field that stores whether or not we've loaded the value for this property yet. - /// - private FieldBuilder _seen; - /// - /// The field that stores the column index for this property. - /// - private FieldBuilder _columnIndex; - - public PrimitiveGenBuilderProperty(string fieldName, Type fieldType) - { - _fieldName = fieldName; - _fieldType = fieldType; - _nonNull = fieldType.IsValueType && - (!fieldType.IsGenericType || fieldType.GetGenericTypeDefinition() != typeof(Nullable<>)); - } - - public bool Singular => true; - - public void InstallFields(TypeBuilder type, ILGenerator constructor) - { - _value = type.DefineField("_dr_" + _fieldName, _fieldType, FieldAttributes.Private); - _seen = type.DefineField("_dr_seen_" + _fieldName, typeof(bool), FieldAttributes.Private); - _columnIndex = type.DefineField("_dr_col_" + _fieldName, typeof(int), FieldAttributes.Private); - } - - public void InstallProcessingLogic(GenProcessColumnMapContext cxt) - { - var il = cxt.IL; - il.Emit(OpCodes.Dup); // dup this - il.Emit(OpCodes.Ldloc, cxt.ColumnMap); - il.Emit(OpCodes.Ldstr, _fieldName); - il.Emit(OpCodes.Callvirt, typeof(ColumnMap).GetMethod(nameof(ColumnMap.ColumnIndex))); - il.Emit(OpCodes.Stfld, _columnIndex); - } - - public void InstallProcessingLogic(GenProcessRowContext cxt) - { - var il = cxt.IL; - var skipOnNull = _nonNull ? cxt.SkipSingularProperties : il.DefineLabel(); - - // First check if we can skip singular properties - il.Emit(OpCodes.Dup); - il.Emit(OpCodes.Ldfld, _seen); - il.Emit(OpCodes.Brtrue, cxt.SkipSingularProperties); - { - // If not, attempt to load the value. - // Load the row array - il.Emit(OpCodes.Ldloc, cxt.Row); - // Get the column index - il.Emit(OpCodes.Ldloc, cxt.This); - il.Emit(OpCodes.Ldfld, _columnIndex); - // Load the value from the array - il.Emit(OpCodes.Ldelem_Ref); - var obj = il.DeclareLocal(typeof(object)); - il.Emit(OpCodes.Dup); - il.Emit(OpCodes.Stloc, obj); - // If the value is null, we can skip this property - // ... in fact, if the value is null but we're non-nullable, we can skip all singulars - il.Emit(OpCodes.Brfalse, skipOnNull); - { - // Convert and save the object - il.Emit(OpCodes.Dup); // dup "this" instance - il.Emit(OpCodes.Ldloc, obj); - il.Emit(OpCodes.Call, PrimitiveConverter.ToType(_fieldType)); - il.Emit(OpCodes.Stfld, _value); - // Set seen to true - il.Emit(OpCodes.Dup); // dup "this" instance - il.Emit(OpCodes.Ldc_I4_1); // 1 (true) - il.Emit(OpCodes.Stfld, _seen); - } - } - if (_nonNull) return; - il.MarkLabel(skipOnNull); - } - - public void InstallPushValue(GenInstanceMethodContext cxt) - { - var il = cxt.IL; - il.Emit(OpCodes.Ldloc, cxt.This); - il.Emit(OpCodes.Ldfld, _value); - } - } -} diff --git a/Rezoom.ADO/Materialization/GenBuilderProperties/SingleNavGenBuilderProperty.cs b/Rezoom.ADO/Materialization/GenBuilderProperties/SingleNavGenBuilderProperty.cs deleted file mode 100644 index d737b2a..0000000 --- a/Rezoom.ADO/Materialization/GenBuilderProperties/SingleNavGenBuilderProperty.cs +++ /dev/null @@ -1,100 +0,0 @@ -using System; -using System.Reflection; -using System.Reflection.Emit; - -namespace Rezoom.ADO.Materialization.GenBuilderProperties -{ - internal class SingleNavGenBuilderProperty : NavGenBuilderProperty - { - private FieldBuilder _reader; - - public SingleNavGenBuilderProperty(string fieldName, Type entityType) - : base(fieldName, entityType) - { - } - - public override void InstallFields(TypeBuilder type, ILGenerator constructor) - { - base.InstallFields(type, constructor); - _reader = type.DefineField("_dr_single_" + FieldName, EntityReaderType, FieldAttributes.Private); - } - - public override void InstallProcessingLogic(GenProcessRowContext cxt) - { - var il = cxt.IL; - var skip = il.DefineLabel(); - var read = il.DefineLabel(); - var localReader = il.DeclareLocal(EntityReaderType); - - il.Emit(OpCodes.Dup); // this, this - il.Emit(OpCodes.Ldfld, _reader); // this, reader - il.Emit(OpCodes.Dup); // this, reader, reader - il.Emit(OpCodes.Stloc, localReader); // this, reader - il.Emit(OpCodes.Brtrue_S, read); - { - // ok, if we don't have a reader, let's check our id column and see if we can get one going - il.Emit(OpCodes.Ldloc, cxt.Row); // this, row - il.Emit(OpCodes.Ldloc, cxt.This); // this, row, this - il.Emit(OpCodes.Ldfld, KeyColumnIndex); // this, row, kidx - il.Emit(OpCodes.Ldelem_Ref); // this, kval - il.Emit(OpCodes.Brfalse, skip); // if we have a null value, skip - // ok, we definitely have an id, but no reader yet -- let's make one - il.Emit(OpCodes.Ldsfld, EntityReaderStaticTemplateType.GetField - (nameof(RowReaderTemplate.Template))); - // this, template - il.Emit(OpCodes.Callvirt, EntityReaderTemplateType.GetMethod - (nameof(IRowReaderTemplate.CreateReader))); - // this, newreader - il.Emit(OpCodes.Dup); // this, newreader, newreader - il.Emit(OpCodes.Stloc, localReader); // this, newreader - il.Emit(OpCodes.Dup); // this, newreader, newreader - il.Emit(OpCodes.Ldloc, cxt.This); // this, newreader, newreader, this - il.Emit(OpCodes.Ldfld, SubColumnMap); // this, newreader, newreader, submap - il.Emit(OpCodes.Callvirt, EntityReaderType.GetMethod(nameof(IRowReader.ProcessColumnMap))); - // this, newreader - il.Emit(OpCodes.Stfld, _reader); - il.Emit(OpCodes.Ldloc, cxt.This); - // this - } - il.MarkLabel(read); - il.Emit(OpCodes.Ldloc, localReader); // this, reader - il.Emit(OpCodes.Ldloc, cxt.Row); // this, reader, row - il.Emit(OpCodes.Callvirt, EntityReaderType.GetMethod(nameof(IRowReader.ProcessRow))); // this - il.MarkLabel(skip); - } - - public override void InstallPushValue(GenInstanceMethodContext cxt) - { - var il = cxt.IL; - var val = il.DefineLabel(); - var done = il.DefineLabel(); - il.Emit(OpCodes.Ldloc, cxt.This); - il.Emit(OpCodes.Ldfld, _reader); // reader - il.Emit(OpCodes.Dup); // reader, reader - il.Emit(OpCodes.Brtrue_S, val); - - { - il.Emit(OpCodes.Pop); - if (EntityType.IsValueType) - { - var defaultValue = il.DeclareLocal(EntityType); - il.Emit(OpCodes.Ldloca, defaultValue); - il.Emit(OpCodes.Initobj, EntityType); - il.Emit(OpCodes.Ldloc, defaultValue); - } - else - { - il.Emit(OpCodes.Ldnull); - } - il.Emit(OpCodes.Br_S, done); - } - - il.MarkLabel(val); // reader - - il.Emit(OpCodes.Callvirt, EntityReaderType.GetMethod - (nameof(IRowReader.ToEntity))); // entity - - il.MarkLabel(done); - } - } -} \ No newline at end of file diff --git a/Rezoom.ADO/Materialization/GenBuilderProperty.cs b/Rezoom.ADO/Materialization/GenBuilderProperty.cs deleted file mode 100644 index c38d88d..0000000 --- a/Rezoom.ADO/Materialization/GenBuilderProperty.cs +++ /dev/null @@ -1,22 +0,0 @@ -using System; -using Rezoom.ADO.Materialization.GenBuilderProperties; - -namespace Rezoom.ADO.Materialization -{ - internal static class GenBuilderProperty - { - public static IGenBuilderProperty GetProperty(string name, Type propertyType) - { - if (PrimitiveConverter.IsPrimitive(propertyType)) - { - return new PrimitiveGenBuilderProperty(name, propertyType); - } - var elementType = ManyNavConverter.IsMany(propertyType); - if (elementType != null) - { - return new ManyNavGenBuilderProperty(name, elementType, propertyType); - } - return new SingleNavGenBuilderProperty(name, propertyType); - } - } -} diff --git a/Rezoom.ADO/Materialization/GenBuilders/ConstructorGenBuilder.cs b/Rezoom.ADO/Materialization/GenBuilders/ConstructorGenBuilder.cs deleted file mode 100644 index 908ac0b..0000000 --- a/Rezoom.ADO/Materialization/GenBuilders/ConstructorGenBuilder.cs +++ /dev/null @@ -1,28 +0,0 @@ -using System.Collections.Generic; -using System.Linq; -using System.Reflection; -using System.Reflection.Emit; - -namespace Rezoom.ADO.Materialization.GenBuilders -{ - /// - /// Implements IGenBuilder for a type we'll initialize by calling its constructor with parameters. - /// - internal class ConstructorGenBuilder : IGenBuilder - { - private readonly ConstructorInfo _constructor; - - public ConstructorGenBuilder(ConstructorInfo constructor) - { - _constructor = constructor; - Properties = - _constructor.GetParameters() - .Select(p => GenBuilderProperty.GetProperty(p.Name, p.ParameterType)) - .ToList(); - } - - public IReadOnlyList Properties { get; } - - public void InstallConstructor(ILGenerator il) => il.Emit(OpCodes.Newobj, _constructor); - } -} \ No newline at end of file diff --git a/Rezoom.ADO/Materialization/GenBuilders/ManyGenBuilder.cs b/Rezoom.ADO/Materialization/GenBuilders/ManyGenBuilder.cs deleted file mode 100644 index 6b21d95..0000000 --- a/Rezoom.ADO/Materialization/GenBuilders/ManyGenBuilder.cs +++ /dev/null @@ -1,26 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Reflection.Emit; -using Rezoom.ADO.Materialization.GenBuilderProperties; - -namespace Rezoom.ADO.Materialization.GenBuilders -{ - internal class ManyGenBuilder : IGenBuilder - { - public ManyGenBuilder(Type collectionType) - { - var element = ManyNavConverter.IsMany(collectionType); - if (element == null) throw new NotSupportedException - ($"{collectionType} is not a supported collection type"); - Properties = new IGenBuilderProperty[] - { - new ManyNavGenBuilderProperty(fieldName: null, entityType: element, collectionType: collectionType) - }; - } - public IReadOnlyList Properties { get; } - public void InstallConstructor(ILGenerator il) - { - // the single property value *is* the collection, so no need to mess with the stack - } - } -} diff --git a/Rezoom.ADO/Materialization/GenBuilders/PropertyAssignmentGenBuilder.cs b/Rezoom.ADO/Materialization/GenBuilders/PropertyAssignmentGenBuilder.cs deleted file mode 100644 index e8448aa..0000000 --- a/Rezoom.ADO/Materialization/GenBuilders/PropertyAssignmentGenBuilder.cs +++ /dev/null @@ -1,65 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using System.Reflection; -using System.Reflection.Emit; - -namespace Rezoom.ADO.Materialization.GenBuilders -{ - /// - /// Implements IGenBuilder for a type we'll initialize by calling its default constructor, then - /// assigning its writable properties (even those with private setters). - /// - internal class PropertyAssignmentGenBuilder : IGenBuilder - { - private readonly Type _targetType; - private readonly ConstructorInfo _parameterlessConstructor; - private readonly List _props; - public PropertyAssignmentGenBuilder(Type type) - { - _targetType = type; - _parameterlessConstructor = type.GetConstructor(Type.EmptyTypes); - _props = type.GetProperties() - .Where(p => p.CanWrite) - .ToList(); - Properties = _props - .Select(p => GenBuilderProperty.GetProperty(p.Name, p.PropertyType)) - .ToList(); - } - - public IReadOnlyList Properties { get; } - public void InstallConstructor(ILGenerator il) - { - // initialize the object - il.Emit(OpCodes.Newobj, _parameterlessConstructor); - // now we have: - // prop1, prop2, prop3, object - // calling the setters is going to require a bit of stack-shuffling - // since for each one, we need: - // object, prop - var oloc = il.DeclareLocal(_targetType); - il.Emit(OpCodes.Stloc, oloc); // pop object to local - for (var i = _props.Count - 1; i >= 0; i--) - { - var prop = _props[i]; - var ploc = il.DeclareLocal(prop.PropertyType); - il.Emit(OpCodes.Stloc, ploc); // pop property to local - il.Emit(OpCodes.Ldloc, oloc); // push object - il.Emit(OpCodes.Ldloc, ploc); // push property - - var setter = prop.GetSetMethod(nonPublic: true); - // pop both by calling setter - if (setter.IsVirtual) - { - il.Emit(OpCodes.Constrained, _targetType); - il.Emit(OpCodes.Callvirt, setter); - } - else - { - il.Emit(OpCodes.Call, setter); - } - } - il.Emit(OpCodes.Ldloc, oloc); // push object - } - } -} \ No newline at end of file diff --git a/Rezoom.ADO/Materialization/GenInstanceMethodContext.cs b/Rezoom.ADO/Materialization/GenInstanceMethodContext.cs deleted file mode 100644 index cd5c9f3..0000000 --- a/Rezoom.ADO/Materialization/GenInstanceMethodContext.cs +++ /dev/null @@ -1,38 +0,0 @@ -using System.Reflection.Emit; - -namespace Rezoom.ADO.Materialization -{ - internal class GenInstanceMethodContext - { - public GenInstanceMethodContext(ILGenerator il, LocalBuilder @this) - { - IL = il; - This = @this; - } - public LocalBuilder This { get; } - public ILGenerator IL { get; } - } - - internal class GenProcessColumnMapContext : GenInstanceMethodContext - { - public GenProcessColumnMapContext(ILGenerator il, LocalBuilder @this) : base(il, @this) - { - ColumnMap = il.DeclareLocal(typeof(ColumnMap)); - } - public LocalBuilder ColumnMap { get; } - } - - internal class GenProcessRowContext : GenInstanceMethodContext - { - public GenProcessRowContext(ILGenerator il, LocalBuilder @this) : base(il, @this) - { - SkipSingularProperties = il.DefineLabel(); - Row = il.DeclareLocal(typeof(object[])); - } - public LocalBuilder Row { get; } - /// - /// Label to skip to after all "singular" properties. - /// - public Label SkipSingularProperties { get; } - } -} diff --git a/Rezoom.ADO/Materialization/IGenBuilder.cs b/Rezoom.ADO/Materialization/IGenBuilder.cs deleted file mode 100644 index c8df26e..0000000 --- a/Rezoom.ADO/Materialization/IGenBuilder.cs +++ /dev/null @@ -1,24 +0,0 @@ -using System.Collections.Generic; -using System.Reflection.Emit; - -namespace Rezoom.ADO.Materialization -{ - internal interface IGenBuilder - { - /// - /// Get the properties of this type in constructor order. - /// - IReadOnlyList Properties { get; } - /// - /// Assuming all the values of the properties are on the stack in constructor order, - /// add the logic to call the constructor (and perform any post-constructor assignments). - /// - /// Should effectively pop all the property values and push the constructed object. - /// - /// - /// Used for the generated IBuilder's Materialize() method. - /// - /// - void InstallConstructor(ILGenerator il); - } -} \ No newline at end of file diff --git a/Rezoom.ADO/Materialization/IGenBuilderProperty.cs b/Rezoom.ADO/Materialization/IGenBuilderProperty.cs deleted file mode 100644 index dd59a92..0000000 --- a/Rezoom.ADO/Materialization/IGenBuilderProperty.cs +++ /dev/null @@ -1,49 +0,0 @@ -using System.Reflection.Emit; - -namespace Rezoom.ADO.Materialization -{ - /// - /// Describes a property for code generation of an IBuilder. - /// - internal interface IGenBuilderProperty - { - /// - /// If true, this property has only one value for an instance of the target type, - /// and that value will appear on the first row that has a value for *any* singular property. - /// - /// - /// This is assumed to be true for all properties that aren't lists, arrays, or other collection types. - /// - bool Singular { get; } - - /// - /// Add the fields this properties needs to keep track of to , which is - /// going to be an IBuilder. Assume "this" is on the top of the stack and keep it there. - /// - /// - /// - void InstallFields(TypeBuilder type, ILGenerator constructor); - - /// - /// Add logic to process this proeprty to the IBuiler's ProcessColumnMap() method. - /// Assume "this" reference is on the stack (and leave it there). - /// - /// - void InstallProcessingLogic(GenProcessColumnMapContext cxt); - - /// - /// Add the logic to process this property to the IBuilder's ProcessRow() method. - /// Assume that a "this" reference to the IBuilder is currently on top of the stack. - /// Leave it there when done. - /// - /// - void InstallProcessingLogic(GenProcessRowContext cxt); - - /// - /// Add logic to push the value of this property onto the stack within the IBuilder's Materialize() - /// method. - /// - /// - void InstallPushValue(GenInstanceMethodContext cxt); - } -} diff --git a/Rezoom.ADO/Materialization/IRowReader.cs b/Rezoom.ADO/Materialization/IRowReader.cs deleted file mode 100644 index 864650a..0000000 --- a/Rezoom.ADO/Materialization/IRowReader.cs +++ /dev/null @@ -1,19 +0,0 @@ -namespace Rezoom.ADO.Materialization -{ - public interface IRowReaderTemplate - { - IRowReader CreateReader(); - } - public interface IRowReader - { - void ProcessColumnMap(ColumnMap map); - void ProcessRow(object[] row); - T ToEntity(); - } - - public static class RowReaderTemplate - { - public static readonly IRowReaderTemplate Template = (IRowReaderTemplate) - RowReaderTemplateGenerator.GenerateReaderTemplate(typeof(T)); - } -} \ No newline at end of file diff --git a/Rezoom.ADO/Materialization/ManyNavConverter.cs b/Rezoom.ADO/Materialization/ManyNavConverter.cs deleted file mode 100644 index 892b9d2..0000000 --- a/Rezoom.ADO/Materialization/ManyNavConverter.cs +++ /dev/null @@ -1,129 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Collections.ObjectModel; -using System.Linq; -using System.Reflection; -using Microsoft.FSharp.Collections; - -namespace Rezoom.ADO.Materialization -{ - public static class ManyNavConverter - { - private static readonly HashSet SupportedGenerics = - new HashSet(new[] - { - typeof(IEnumerable<>), - typeof(IDictionary<,>), - typeof(IReadOnlyDictionary<,>), - typeof(ICollection<>), - typeof(IReadOnlyCollection<>), - typeof(IList<>), - typeof(IReadOnlyList<>), - typeof(List<>), - typeof(Dictionary<,>), - typeof(ReadOnlyCollection<>), - typeof(FSharpList<>), - typeof(FSharpMap<,>), - typeof(FSharpSet<>), - }); - public static Type IsMany(Type collectionType) - { - if (collectionType.IsArray) return collectionType.GetElementType(); - if (!collectionType.IsConstructedGenericType) return null; - var def = collectionType.GetGenericTypeDefinition(); - // TODO: this is a hack, but happens to work for all the supported types - return SupportedGenerics.Contains(def) ? collectionType.GetGenericArguments().Last() : null; - } - } - public static class ManyNavConverter - { - public static Dictionary ToDictionary(Dictionary> collection) - { - var dictionary = new Dictionary(collection.Count); - foreach (var kv in collection) - { - dictionary[kv.Key] = kv.Value.ToEntity(); - } - return dictionary; - } - public static IDictionary ToIDictionary(Dictionary> collection) - => ToDictionary(collection); - public static IReadOnlyDictionary ToIReadOnlyDictionary - (Dictionary> collection) => ToDictionary(collection); - - public static TEntity[] ToArray(Dictionary> collection) - { - var i = 0; - var arr = new TEntity[collection.Count]; - foreach (var element in collection.Values) - { - arr[i++] = element.ToEntity(); - } - return arr; - } - - public static IEnumerable ToIEnumerable - (Dictionary> collection) => ToArray(collection); - - public static ICollection ToICollection - (Dictionary> collection) => ToArray(collection); - - public static IReadOnlyCollection ToIReadOnlyCollection - (Dictionary> collection) => ToArray(collection); - - public static IList ToIList - (Dictionary> collection) => ToArray(collection); - - public static IReadOnlyList ToIReadOnlyList - (Dictionary> collection) => ToArray(collection); - - public static ReadOnlyCollection ToReadOnlyCollection - (Dictionary> collection) => new ReadOnlyCollection(ToArray(collection)); - - public static List ToList(Dictionary> collection) - { - var list = new List(collection.Count); - foreach (var element in collection.Values) - { - list.Add(element.ToEntity()); - } - return list; - } - - public static FSharpList ToFSharpList(Dictionary> collection) - { - var list = FSharpList.Empty; - foreach (var element in collection.Values) - { - list = FSharpList.Cons(element.ToEntity(), list); - } - return list; - } - - public static FSharpMap ToFSharpMap(Dictionary> collection) - => MapModule.OfSeq(collection.Select(kv => Tuple.Create(kv.Key, kv.Value.ToEntity()))); - - public static FSharpSet ToFSharpSet(Dictionary> collection) - => SetModule.OfSeq(collection.Values.Select(v => v.ToEntity())); - - private static readonly Dictionary Converters = - typeof(ManyNavConverter) - .GetMethods(BindingFlags.Public | BindingFlags.Static) - .Where(m => - { - var pars = m.GetParameters(); - return pars.Length == 1 - && pars[0].ParameterType == typeof(Dictionary>) - && m.ReturnType != typeof(void); - }).ToDictionary(m => m.ReturnType); - - public static bool IsManyNav(Type targetType) => Converters.ContainsKey(targetType); - - public static MethodInfo ToType(Type targetType) - { - MethodInfo converter; - if (Converters.TryGetValue(targetType, out converter)) return converter; - throw new NotSupportedException($"Can't convert to {targetType}"); - } - } -} \ No newline at end of file diff --git a/Rezoom.ADO/Materialization/PrimitiveConverter.cs b/Rezoom.ADO/Materialization/PrimitiveConverter.cs deleted file mode 100644 index 6ef1af9..0000000 --- a/Rezoom.ADO/Materialization/PrimitiveConverter.cs +++ /dev/null @@ -1,741 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Globalization; -using System.Linq; -using System.Reflection; - -namespace Rezoom.ADO.Materialization -{ - public static class PrimitiveConverter - { - public static string ToString(object obj) => obj?.ToString(); - - private enum JumpTag - { - Int32, - Int64, - Int16, - Byte, - Double, - Float, - Decimal, - String, - UInt32, - UInt64, - UInt16, - SByte, - Char, - Invalid, - } - - #region Signed Integers - - #region FastInt8 - private static JumpTag _t8; - public static sbyte ToInt8(object obj) - { - var thru0 = false; - switch (_t8) - { - case JumpTag.Int32: - thru0 = true; - if (obj is int) { _t8 = JumpTag.Int32; return (sbyte)(int)obj; } - goto case JumpTag.Int64; - case JumpTag.Int64: - if (obj is long) { _t8 = JumpTag.Int64; return (sbyte)(long)obj; } - goto case JumpTag.Int16; - case JumpTag.Int16: - if (obj is short) { _t8 = JumpTag.Int16; return (sbyte)(short)obj; } - goto case JumpTag.Byte; - case JumpTag.Byte: - if (obj is byte) { _t8 = JumpTag.Byte; return (sbyte)(byte)obj; } - goto case JumpTag.Double; - case JumpTag.Double: - if (obj is double) { _t8 = JumpTag.Double; return (sbyte)(double)obj; } - goto case JumpTag.Float; - case JumpTag.Float: - if (obj is float) { _t8 = JumpTag.Float; return (sbyte)(float)obj; } - goto case JumpTag.Decimal; - case JumpTag.Decimal: - if (obj is decimal) { _t8 = JumpTag.Decimal; return (sbyte)(decimal)obj; } - goto case JumpTag.String; - case JumpTag.String: - var str = obj as string; - if (str != null) { _t8 = JumpTag.String; return sbyte.Parse(str, CultureInfo.InvariantCulture); } - goto case JumpTag.UInt32; - case JumpTag.UInt32: - if (obj is uint) { _t8 = JumpTag.UInt32; return (sbyte)(uint)obj; } - goto case JumpTag.UInt64; - case JumpTag.UInt64: - if (obj is ulong) { _t8 = JumpTag.UInt64; return (sbyte)(ulong)obj; } - goto case JumpTag.UInt16; - case JumpTag.UInt16: - if (obj is ushort) { _t8 = JumpTag.UInt16; return (sbyte)(ushort)obj; } - goto case JumpTag.SByte; - case JumpTag.SByte: - if (obj is sbyte) { _t8 = JumpTag.SByte; return (sbyte)obj; } - goto case JumpTag.Char; - case JumpTag.Char: - if (obj is char) { _t8 = JumpTag.Char; return (sbyte)(char)obj; } - goto case JumpTag.Invalid; - case JumpTag.Invalid: - if (thru0) goto default; - goto case 0; - default: - throw new ArgumentOutOfRangeException(nameof(obj), $"Can't convert {obj} to SByte"); - } - } - #endregion - - public static sbyte? ToNullableInt8(object obj) => obj == null ? (sbyte?)null : ToInt8(obj); - - #region FastInt16 - private static JumpTag _t16; - public static short ToInt16(object obj) - { - var thru0 = false; - switch (_t16) - { - case JumpTag.Int32: - thru0 = true; - if (obj is int) { _t16 = JumpTag.Int32; return (short)(int)obj; } - goto case JumpTag.Int64; - case JumpTag.Int64: - if (obj is long) { _t16 = JumpTag.Int64; return (short)(long)obj; } - goto case JumpTag.Int16; - case JumpTag.Int16: - if (obj is short) { _t16 = JumpTag.Int16; return (short)obj; } - goto case JumpTag.Byte; - case JumpTag.Byte: - if (obj is byte) { _t16 = JumpTag.Byte; return (byte)obj; } - goto case JumpTag.Double; - case JumpTag.Double: - if (obj is double) { _t16 = JumpTag.Double; return (short)(double)obj; } - goto case JumpTag.Float; - case JumpTag.Float: - if (obj is float) { _t16 = JumpTag.Float; return (short)(float)obj; } - goto case JumpTag.Decimal; - case JumpTag.Decimal: - if (obj is decimal) { _t16 = JumpTag.Decimal; return (short)(decimal)obj; } - goto case JumpTag.String; - case JumpTag.String: - var str = obj as string; - if (str != null) { _t16 = JumpTag.String; return short.Parse(str, CultureInfo.InvariantCulture); } - goto case JumpTag.UInt32; - case JumpTag.UInt32: - if (obj is uint) { _t16 = JumpTag.UInt32; return (short)(uint)obj; } - goto case JumpTag.UInt64; - case JumpTag.UInt64: - if (obj is ulong) { _t16 = JumpTag.UInt64; return (short)(ulong)obj; } - goto case JumpTag.UInt16; - case JumpTag.UInt16: - if (obj is ushort) { _t16 = JumpTag.UInt16; return (short)(ushort)obj; } - goto case JumpTag.SByte; - case JumpTag.SByte: - if (obj is sbyte) { _t16 = JumpTag.SByte; return (sbyte)obj; } - goto case JumpTag.Char; - case JumpTag.Char: - if (obj is char) { _t16 = JumpTag.Char; return (short)(char)obj; } - goto case JumpTag.Invalid; - case JumpTag.Invalid: - if (thru0) goto default; - goto case 0; - default: - throw new ArgumentOutOfRangeException(nameof(obj), $"Can't convert {obj} to Int16"); - } - } - #endregion - - public static short? ToNullableInt16(object obj) => obj == null ? (short?)null : ToInt16(obj); - - #region FastInt32 - private static JumpTag _t32; - public static int ToInt32(object obj) - { - var thru0 = false; - switch (_t32) - { - case JumpTag.Int32: - thru0 = true; - if (obj is int) { _t32 = JumpTag.Int32; return (int)obj; } - goto case JumpTag.Int64; - case JumpTag.Int64: - if (obj is long) { _t32 = JumpTag.Int64; return (int)(long)obj; } - goto case JumpTag.Int16; - case JumpTag.Int16: - if (obj is short) { _t32 = JumpTag.Int16; return (short)obj; } - goto case JumpTag.Byte; - case JumpTag.Byte: - if (obj is byte) { _t32 = JumpTag.Byte; return (byte)obj; } - goto case JumpTag.Double; - case JumpTag.Double: - if (obj is double) { _t32 = JumpTag.Double; return (int)(double)obj; } - goto case JumpTag.Float; - case JumpTag.Float: - if (obj is float) { _t32 = JumpTag.Float; return (int)(float)obj; } - goto case JumpTag.Decimal; - case JumpTag.Decimal: - if (obj is decimal) { _t32 = JumpTag.Decimal; return (int)(decimal)obj; } - goto case JumpTag.String; - case JumpTag.String: - var str = obj as string; - if (str != null) { _t32 = JumpTag.String; return int.Parse(str, CultureInfo.InvariantCulture); } - goto case JumpTag.UInt32; - case JumpTag.UInt32: - if (obj is uint) { _t32 = JumpTag.UInt32; return (int)(uint)obj; } - goto case JumpTag.UInt64; - case JumpTag.UInt64: - if (obj is ulong) { _t32 = JumpTag.UInt64; return (int)(ulong)obj; } - goto case JumpTag.UInt16; - case JumpTag.UInt16: - if (obj is ushort) { _t32 = JumpTag.UInt16; return (ushort)obj; } - goto case JumpTag.SByte; - case JumpTag.SByte: - if (obj is sbyte) { _t32 = JumpTag.SByte; return (sbyte)obj; } - goto case JumpTag.Char; - case JumpTag.Char: - if (obj is char) { _t32 = JumpTag.Char; return (char)obj; } - goto case JumpTag.Invalid; - case JumpTag.Invalid: - if (thru0) goto default; - goto case 0; - default: - throw new ArgumentOutOfRangeException(nameof(obj), $"Can't convert {obj} to Int32"); - } - } - #endregion FastInt32 - - public static int? ToNullableInt32(object obj) => obj == null ? (int?)null : ToInt32(obj); - - #region FastInt64 - private static JumpTag _t64; - public static long ToInt64(object obj) - { - var thru0 = false; - switch (_t64) - { - case JumpTag.Int32: - thru0 = true; - if (obj is int) { _t64 = JumpTag.Int32; return (int)obj; } - goto case JumpTag.Int64; - case JumpTag.Int64: - if (obj is long) { _t64 = JumpTag.Int64; return (long)obj; } - goto case JumpTag.Int16; - case JumpTag.Int16: - if (obj is short) { _t64 = JumpTag.Int16; return (short)obj; } - goto case JumpTag.Byte; - case JumpTag.Byte: - if (obj is byte) { _t64 = JumpTag.Byte; return (byte)obj; } - goto case JumpTag.Double; - case JumpTag.Double: - if (obj is double) { _t64 = JumpTag.Double; return (long)(double)obj; } - goto case JumpTag.Float; - case JumpTag.Float: - if (obj is float) { _t64 = JumpTag.Float; return (long)(float)obj; } - goto case JumpTag.Decimal; - case JumpTag.Decimal: - if (obj is decimal) { _t64 = JumpTag.Decimal; return (long)(decimal)obj; } - goto case JumpTag.String; - case JumpTag.String: - var str = obj as string; - if (str != null) { _t64 = JumpTag.String; return long.Parse(str, CultureInfo.InvariantCulture); } - goto case JumpTag.UInt32; - case JumpTag.UInt32: - if (obj is uint) { _t64 = JumpTag.UInt32; return (uint)obj; } - goto case JumpTag.UInt64; - case JumpTag.UInt64: - if (obj is ulong) { _t64 = JumpTag.UInt64; return (long)(ulong)obj; } - goto case JumpTag.UInt16; - case JumpTag.UInt16: - if (obj is ushort) { _t64 = JumpTag.UInt16; return (ushort)obj; } - goto case JumpTag.SByte; - case JumpTag.SByte: - if (obj is sbyte) { _t64 = JumpTag.SByte; return (sbyte)obj; } - goto case JumpTag.Char; - case JumpTag.Char: - if (obj is char) { _t64 = JumpTag.Char; return (char)obj; } - goto case JumpTag.Invalid; - case JumpTag.Invalid: - if (thru0) goto default; - goto case 0; - default: - throw new ArgumentOutOfRangeException(nameof(obj), $"Can't convert {obj} to Int64"); - } - } - #endregion - - public static long? ToNullableInt64(object obj) => obj == null ? (long?)null : ToInt64(obj); - - #endregion - - #region Unsigned Integers - - #region FastUInt8 - private static JumpTag _tu8; - public static byte ToUInt8(object obj) - { - var thru0 = false; - switch (_tu8) - { - case JumpTag.Int32: - thru0 = true; - if (obj is int) { _tu8 = JumpTag.Int32; return (byte)(int)obj; } - goto case JumpTag.Int64; - case JumpTag.Int64: - if (obj is long) { _tu8 = JumpTag.Int64; return (byte)(long)obj; } - goto case JumpTag.Int16; - case JumpTag.Int16: - if (obj is short) { _tu8 = JumpTag.Int16; return (byte)(short)obj; } - goto case JumpTag.Byte; - case JumpTag.Byte: - if (obj is byte) { _tu8 = JumpTag.Byte; return (byte)obj; } - goto case JumpTag.Double; - case JumpTag.Double: - if (obj is double) { _tu8 = JumpTag.Double; return (byte)(double)obj; } - goto case JumpTag.Float; - case JumpTag.Float: - if (obj is float) { _tu8 = JumpTag.Float; return (byte)(float)obj; } - goto case JumpTag.Decimal; - case JumpTag.Decimal: - if (obj is decimal) { _tu8 = JumpTag.Decimal; return (byte)(decimal)obj; } - goto case JumpTag.String; - case JumpTag.String: - var str = obj as string; - if (str != null) { _tu8 = JumpTag.String; return byte.Parse(str, CultureInfo.InvariantCulture); } - goto case JumpTag.UInt32; - case JumpTag.UInt32: - if (obj is uint) { _tu8 = JumpTag.UInt32; return (byte)(uint)obj; } - goto case JumpTag.UInt64; - case JumpTag.UInt64: - if (obj is ulong) { _tu8 = JumpTag.UInt64; return (byte)(ulong)obj; } - goto case JumpTag.UInt16; - case JumpTag.UInt16: - if (obj is ushort) { _tu8 = JumpTag.UInt16; return (byte)(ushort)obj; } - goto case JumpTag.SByte; - case JumpTag.SByte: - if (obj is sbyte) { _tu8 = JumpTag.SByte; return (byte)(sbyte)obj; } - goto case JumpTag.Char; - case JumpTag.Char: - if (obj is char) { _tu8 = JumpTag.Char; return (byte)(char)obj; } - goto case JumpTag.Invalid; - case JumpTag.Invalid: - if (thru0) goto default; - goto case 0; - default: - throw new ArgumentOutOfRangeException(nameof(obj), $"Can't convert {obj} to Byte"); - } - } - #endregion - - public static byte? ToNullableUInt8(object obj) => obj == null ? (byte?)null : ToUInt8(obj); - - #region FastUInt16 - private static JumpTag _tu16; - public static ushort ToUInt16(object obj) - { - var thru0 = false; - switch (_tu16) - { - case JumpTag.Int32: - thru0 = true; - if (obj is int) { _tu16 = JumpTag.Int32; return (ushort)(int)obj; } - goto case JumpTag.Int64; - case JumpTag.Int64: - if (obj is long) { _tu16 = JumpTag.Int64; return (ushort)(long)obj; } - goto case JumpTag.Int16; - case JumpTag.Int16: - if (obj is short) { _tu16 = JumpTag.Int16; return (ushort)(short)obj; } - goto case JumpTag.Byte; - case JumpTag.Byte: - if (obj is byte) { _tu16 = JumpTag.Byte; return (byte)obj; } - goto case JumpTag.Double; - case JumpTag.Double: - if (obj is double) { _tu16 = JumpTag.Double; return (ushort)(double)obj; } - goto case JumpTag.Float; - case JumpTag.Float: - if (obj is float) { _tu16 = JumpTag.Float; return (ushort)(float)obj; } - goto case JumpTag.Decimal; - case JumpTag.Decimal: - if (obj is decimal) { _tu16 = JumpTag.Decimal; return (ushort)(decimal)obj; } - goto case JumpTag.String; - case JumpTag.String: - var str = obj as string; - if (str != null) { _tu16 = JumpTag.String; return ushort.Parse(str, CultureInfo.InvariantCulture); } - goto case JumpTag.UInt32; - case JumpTag.UInt32: - if (obj is uint) { _tu16 = JumpTag.UInt32; return (ushort)(uint)obj; } - goto case JumpTag.UInt64; - case JumpTag.UInt64: - if (obj is ulong) { _tu16 = JumpTag.UInt64; return (ushort)(ulong)obj; } - goto case JumpTag.UInt16; - case JumpTag.UInt16: - if (obj is ushort) { _tu16 = JumpTag.UInt16; return (ushort)obj; } - goto case JumpTag.SByte; - case JumpTag.SByte: - if (obj is sbyte) { _tu16 = JumpTag.SByte; return (ushort)(sbyte)obj; } - goto case JumpTag.Char; - case JumpTag.Char: - if (obj is char) { _tu16 = JumpTag.Char; return (char)obj; } - goto case JumpTag.Invalid; - case JumpTag.Invalid: - if (thru0) goto default; - goto case 0; - default: - throw new ArgumentOutOfRangeException(nameof(obj), $"Can't convert {obj} to UInt16"); - } - } - #endregion - - public static ushort? ToNullableUInt16(object obj) => obj == null ? (ushort?)null : ToUInt16(obj); - - #region FastUInt32 - private static JumpTag _tu32; - public static uint ToUInt32(object obj) - { - var thru0 = false; - switch (_tu32) - { - case JumpTag.Int32: - thru0 = true; - if (obj is int) { _tu32 = JumpTag.Int32; return (uint)(int)obj; } - goto case JumpTag.Int64; - case JumpTag.Int64: - if (obj is long) { _tu32 = JumpTag.Int64; return (uint)(long)obj; } - goto case JumpTag.Int16; - case JumpTag.Int16: - if (obj is short) { _tu32 = JumpTag.Int16; return (uint)(short)obj; } - goto case JumpTag.Byte; - case JumpTag.Byte: - if (obj is byte) { _tu32 = JumpTag.Byte; return (byte)obj; } - goto case JumpTag.Double; - case JumpTag.Double: - if (obj is double) { _tu32 = JumpTag.Double; return (uint)(double)obj; } - goto case JumpTag.Float; - case JumpTag.Float: - if (obj is float) { _tu32 = JumpTag.Float; return (uint)(float)obj; } - goto case JumpTag.Decimal; - case JumpTag.Decimal: - if (obj is decimal) { _tu32 = JumpTag.Decimal; return (uint)(decimal)obj; } - goto case JumpTag.String; - case JumpTag.String: - var str = obj as string; - if (str != null) { _tu32 = JumpTag.String; return uint.Parse(str, CultureInfo.InvariantCulture); } - goto case JumpTag.UInt32; - case JumpTag.UInt32: - if (obj is uint) { _tu32 = JumpTag.UInt32; return (uint)obj; } - goto case JumpTag.UInt64; - case JumpTag.UInt64: - if (obj is ulong) { _tu32 = JumpTag.UInt64; return (uint)(ulong)obj; } - goto case JumpTag.UInt16; - case JumpTag.UInt16: - if (obj is ushort) { _tu32 = JumpTag.UInt16; return (ushort)obj; } - goto case JumpTag.SByte; - case JumpTag.SByte: - if (obj is sbyte) { _tu32 = JumpTag.SByte; return (uint)(sbyte)obj; } - goto case JumpTag.Char; - case JumpTag.Char: - if (obj is char) { _tu32 = JumpTag.Char; return (char)obj; } - goto case JumpTag.Invalid; - case JumpTag.Invalid: - if (thru0) goto default; - goto case 0; - default: - throw new ArgumentOutOfRangeException(nameof(obj), $"Can't convert {obj} to UInt32"); - } - } - #endregion - - public static uint? ToNullableUInt32(object obj) => obj == null ? (uint?)null : ToUInt32(obj); - - #region FastUInt64 - private static JumpTag _tu64; - public static ulong ToUInt64(object obj) - { - var thru0 = false; - switch (_tu64) - { - case JumpTag.Int32: - thru0 = true; - if (obj is int) { _tu64 = JumpTag.Int32; return (ulong)(int)obj; } - goto case JumpTag.Int64; - case JumpTag.Int64: - if (obj is long) { _tu64 = JumpTag.Int64; return (ulong)(long)obj; } - goto case JumpTag.Int16; - case JumpTag.Int16: - if (obj is short) { _tu64 = JumpTag.Int16; return (ulong)(short)obj; } - goto case JumpTag.Byte; - case JumpTag.Byte: - if (obj is byte) { _tu64 = JumpTag.Byte; return (byte)obj; } - goto case JumpTag.Double; - case JumpTag.Double: - if (obj is double) { _tu64 = JumpTag.Double; return (ulong)(double)obj; } - goto case JumpTag.Float; - case JumpTag.Float: - if (obj is float) { _tu64 = JumpTag.Float; return (ulong)(float)obj; } - goto case JumpTag.Decimal; - case JumpTag.Decimal: - if (obj is decimal) { _tu64 = JumpTag.Decimal; return (ulong)(decimal)obj; } - goto case JumpTag.String; - case JumpTag.String: - var str = obj as string; - if (str != null) { _tu64 = JumpTag.String; return ulong.Parse(str, CultureInfo.InvariantCulture); } - goto case JumpTag.UInt32; - case JumpTag.UInt32: - if (obj is uint) { _tu64 = JumpTag.UInt32; return (uint)obj; } - goto case JumpTag.UInt64; - case JumpTag.UInt64: - if (obj is ulong) { _tu64 = JumpTag.UInt64; return (ulong)obj; } - goto case JumpTag.UInt16; - case JumpTag.UInt16: - if (obj is ushort) { _tu64 = JumpTag.UInt16; return (ushort)obj; } - goto case JumpTag.SByte; - case JumpTag.SByte: - if (obj is sbyte) { _tu64 = JumpTag.SByte; return (ulong)(sbyte)obj; } - goto case JumpTag.Char; - case JumpTag.Char: - if (obj is char) { _tu64 = JumpTag.Char; return (char)obj; } - goto case JumpTag.Invalid; - case JumpTag.Invalid: - if (thru0) goto default; - goto case 0; - default: - throw new ArgumentOutOfRangeException(nameof(obj), $"Can't convert {obj} to UInt64"); - } - } - #endregion - - public static ulong? ToNullableUInt64(object obj) => obj == null ? (ulong?)null : ToUInt64(obj); - - #endregion - - #region Floating Point - - #region FastSingle - private static JumpTag _tf32; - public static float ToSingle(object obj) - { - var thru0 = false; - switch (_tf32) - { - case JumpTag.Int32: - thru0 = true; - if (obj is int) { _tf32 = JumpTag.Int32; return (int)obj; } - goto case JumpTag.Int64; - case JumpTag.Int64: - if (obj is long) { _tf32 = JumpTag.Int64; return (long)obj; } - goto case JumpTag.Int16; - case JumpTag.Int16: - if (obj is short) { _tf32 = JumpTag.Int16; return (short)obj; } - goto case JumpTag.Byte; - case JumpTag.Byte: - if (obj is byte) { _tf32 = JumpTag.Byte; return (byte)obj; } - goto case JumpTag.Double; - case JumpTag.Double: - if (obj is double) { _tf32 = JumpTag.Double; return (float)(double)obj; } - goto case JumpTag.Float; - case JumpTag.Float: - if (obj is float) { _tf32 = JumpTag.Float; return (float)obj; } - goto case JumpTag.Decimal; - case JumpTag.Decimal: - if (obj is decimal) { _tf32 = JumpTag.Decimal; return (float)(decimal)obj; } - goto case JumpTag.String; - case JumpTag.String: - var str = obj as string; - if (str != null) { _tf32 = JumpTag.String; return float.Parse(str, CultureInfo.InvariantCulture); } - goto case JumpTag.UInt32; - case JumpTag.UInt32: - if (obj is uint) { _tf32 = JumpTag.UInt32; return (uint)obj; } - goto case JumpTag.UInt64; - case JumpTag.UInt64: - if (obj is ulong) { _tf32 = JumpTag.UInt64; return (ulong)obj; } - goto case JumpTag.UInt16; - case JumpTag.UInt16: - if (obj is ushort) { _tf32 = JumpTag.UInt16; return (ushort)obj; } - goto case JumpTag.SByte; - case JumpTag.SByte: - if (obj is sbyte) { _tf32 = JumpTag.SByte; return (sbyte)obj; } - goto case JumpTag.Char; - case JumpTag.Char: - if (obj is char) { _tf32 = JumpTag.Char; return (char)obj; } - goto case JumpTag.Invalid; - case JumpTag.Invalid: - if (thru0) goto default; - goto case 0; - default: - throw new ArgumentOutOfRangeException(nameof(obj), $"Can't convert {obj} to Single"); - } - } - #endregion - - public static float? ToNullableSingle(object obj) => obj == null ? (float?)null : ToSingle(obj); - - #region FastDouble - private static JumpTag _tf64; - public static double ToDouble(object obj) - { - var thru0 = false; - switch (_tf64) - { - case JumpTag.Int32: - thru0 = true; - if (obj is int) { _tf64 = JumpTag.Int32; return (int)obj; } - goto case JumpTag.Int64; - case JumpTag.Int64: - if (obj is long) { _tf64 = JumpTag.Int64; return (long)obj; } - goto case JumpTag.Int16; - case JumpTag.Int16: - if (obj is short) { _tf64 = JumpTag.Int16; return (short)obj; } - goto case JumpTag.Byte; - case JumpTag.Byte: - if (obj is byte) { _tf64 = JumpTag.Byte; return (byte)obj; } - goto case JumpTag.Double; - case JumpTag.Double: - if (obj is double) { _tf64 = JumpTag.Double; return (double)obj; } - goto case JumpTag.Float; - case JumpTag.Float: - if (obj is float) { _tf64 = JumpTag.Float; return (float)obj; } - goto case JumpTag.Decimal; - case JumpTag.Decimal: - if (obj is decimal) { _tf64 = JumpTag.Decimal; return (double)(decimal)obj; } - goto case JumpTag.String; - case JumpTag.String: - var str = obj as string; - if (str != null) { _tf64 = JumpTag.String; return double.Parse(str, CultureInfo.InvariantCulture); } - goto case JumpTag.UInt32; - case JumpTag.UInt32: - if (obj is uint) { _tf64 = JumpTag.UInt32; return (uint)obj; } - goto case JumpTag.UInt64; - case JumpTag.UInt64: - if (obj is ulong) { _tf64 = JumpTag.UInt64; return (ulong)obj; } - goto case JumpTag.UInt16; - case JumpTag.UInt16: - if (obj is ushort) { _tf64 = JumpTag.UInt16; return (ushort)obj; } - goto case JumpTag.SByte; - case JumpTag.SByte: - if (obj is sbyte) { _tf64 = JumpTag.SByte; return (sbyte)obj; } - goto case JumpTag.Char; - case JumpTag.Char: - if (obj is char) { _tf64 = JumpTag.Char; return (char)obj; } - goto case JumpTag.Invalid; - case JumpTag.Invalid: - if (thru0) goto default; - goto case 0; - default: - throw new ArgumentOutOfRangeException(nameof(obj), $"Can't convert {obj} to Double"); - } - } - #endregion - - public static double? ToNullableDouble(object obj) => obj == null ? (double?)null : ToDouble(obj); - - #region FastDecimal - private static JumpTag _tfdec; - public static decimal ToDecimal(object obj) - { - var thru0 = false; - switch (_tfdec) - { - case JumpTag.Int32: - thru0 = true; - if (obj is int) { _tfdec = JumpTag.Int32; return (int)obj; } - goto case JumpTag.Int64; - case JumpTag.Int64: - if (obj is long) { _tfdec = JumpTag.Int64; return (long)obj; } - goto case JumpTag.Int16; - case JumpTag.Int16: - if (obj is short) { _tfdec = JumpTag.Int16; return (short)obj; } - goto case JumpTag.Byte; - case JumpTag.Byte: - if (obj is byte) { _tfdec = JumpTag.Byte; return (byte)obj; } - goto case JumpTag.Double; - case JumpTag.Double: - if (obj is double) { _tfdec = JumpTag.Double; return (decimal)(double)obj; } - goto case JumpTag.Float; - case JumpTag.Float: - if (obj is float) { _tfdec = JumpTag.Float; return (decimal)(float)obj; } - goto case JumpTag.Decimal; - case JumpTag.Decimal: - if (obj is decimal) { _tfdec = JumpTag.Decimal; return (decimal)obj; } - goto case JumpTag.String; - case JumpTag.String: - var str = obj as string; - if (str != null) { _tfdec = JumpTag.String; return decimal.Parse(str, CultureInfo.InvariantCulture); } - goto case JumpTag.UInt32; - case JumpTag.UInt32: - if (obj is uint) { _tfdec = JumpTag.UInt32; return (uint)obj; } - goto case JumpTag.UInt64; - case JumpTag.UInt64: - if (obj is ulong) { _tfdec = JumpTag.UInt64; return (ulong)obj; } - goto case JumpTag.UInt16; - case JumpTag.UInt16: - if (obj is ushort) { _tfdec = JumpTag.UInt16; return (ushort)obj; } - goto case JumpTag.SByte; - case JumpTag.SByte: - if (obj is sbyte) { _tfdec = JumpTag.SByte; return (sbyte)obj; } - goto case JumpTag.Char; - case JumpTag.Char: - if (obj is char) { _tfdec = JumpTag.Char; return (char)obj; } - goto case JumpTag.Invalid; - case JumpTag.Invalid: - if (thru0) goto default; - goto case 0; - default: - throw new ArgumentOutOfRangeException(nameof(obj), $"Can't convert {obj} to Decimal"); - } - } - #endregion - - public static decimal? ToNullableDecimal(object obj) => obj == null ? (decimal?)null : ToDecimal(obj); - - #endregion - - public static Guid ToGuid(object obj) - { - if (obj is Guid) return (Guid)obj; - return Guid.Parse(obj.ToString()); - } - - public static Guid? ToNullableGuid(object obj) => obj == null ? (Guid?)null : ToGuid(obj); - - public static DateTime ToDateTime(object obj) => Convert.ToDateTime(obj); - public static DateTime? ToNullableDateTime(object obj) => obj == null ? (DateTime?)null : ToDateTime(obj); - - public static DateTimeOffset ToDateTimeOffset(object obj) - { - if (obj is DateTimeOffset) return (DateTimeOffset)obj; - if (obj is DateTime) return (DateTime)obj; // this conversion is evil, but maybe better than nothing - return DateTimeOffset.Parse(obj.ToString()); - } - - public static DateTimeOffset? ToNullableDateTimeOffset(object obj) => - obj == null ? (DateTimeOffset?)null : ToDateTimeOffset(obj); - - public static TimeSpan ToTimeSpan(object obj) - { - if (obj is TimeSpan) return (TimeSpan)obj; - return TimeSpan.Parse(obj.ToString()); - } - - public static TimeSpan? ToNullableTimeSpan(object obj) => obj == null ? (TimeSpan?)null : ToTimeSpan(obj); - - private static readonly Dictionary Converters = - typeof(PrimitiveConverter) - .GetMethods(BindingFlags.Public | BindingFlags.Static) - .Where(m => - { - var pars = m.GetParameters(); - return pars.Length == 1 - && pars[0].ParameterType == typeof(object) - && m.ReturnType != typeof(void); - }).ToDictionary(m => m.ReturnType); - - public static bool IsPrimitive(Type targetType) => Converters.ContainsKey(targetType); - - public static MethodInfo ToType(Type targetType) - { - MethodInfo converter; - if (Converters.TryGetValue(targetType, out converter)) return converter; - throw new NotSupportedException($"Can't convert to {targetType}"); - } - } -} \ No newline at end of file diff --git a/Rezoom.ADO/Materialization/RowReaderTemplateGenerator.cs b/Rezoom.ADO/Materialization/RowReaderTemplateGenerator.cs deleted file mode 100644 index fe7c1bc..0000000 --- a/Rezoom.ADO/Materialization/RowReaderTemplateGenerator.cs +++ /dev/null @@ -1,140 +0,0 @@ -using System; -using System.Linq; -using System.Reflection; -using System.Reflection.Emit; -using Rezoom.ADO.Materialization.GenBuilders; -using Rezoom.ADO.Materialization.TypeInfo; - -namespace Rezoom.ADO.Materialization -{ - internal static class RowReaderTemplateGenerator - { - private static void ImplementRowReader(TypeBuilder builder, Type targetType) - { - var profile = TypeProfile.OfType(targetType); - var gen = profile.IsCollection - ? new ManyGenBuilder(targetType) - : profile.Columns.Any(c => c.Setter == null) - ? new ConstructorGenBuilder(profile.PrimaryConstructor) as IGenBuilder - : new PropertyAssignmentGenBuilder(targetType); - builder.AddInterfaceImplementation(typeof(IRowReader<>).MakeGenericType(targetType)); - - GenInstanceMethodContext toEntityContext; - { - var toEntity = builder.DefineMethod - (nameof(IRowReader.ToEntity), MethodAttributes.Public | MethodAttributes.Virtual); - toEntity.SetParameters(); - toEntity.SetReturnType(targetType); - var il = toEntity.GetILGenerator(); - var thisLocal = il.DeclareLocal(builder); - il.Emit(OpCodes.Ldarg_0); // load this - il.Emit(OpCodes.Stloc, thisLocal); - toEntityContext = new GenInstanceMethodContext(il, thisLocal); - } - - GenProcessColumnMapContext columnContext; - { - var processColumnMap = builder.DefineMethod - (nameof(IRowReader.ProcessColumnMap), MethodAttributes.Public | MethodAttributes.Virtual); - processColumnMap.SetParameters(typeof(ColumnMap)); - var il = processColumnMap.GetILGenerator(); - var thisLocal = il.DeclareLocal(builder); - il.Emit(OpCodes.Ldarg_0); // load this - il.Emit(OpCodes.Dup); - il.Emit(OpCodes.Stloc, thisLocal); - columnContext = new GenProcessColumnMapContext(il, thisLocal); - il.Emit(OpCodes.Ldarg_1); // load column map - il.Emit(OpCodes.Stloc, columnContext.ColumnMap); - } - - GenProcessRowContext rowContext; - { - var processRow = builder.DefineMethod - (nameof(IRowReader.ProcessRow), MethodAttributes.Public | MethodAttributes.Virtual); - processRow.SetParameters(typeof(object[])); - var il = processRow.GetILGenerator(); - var thisLocal = il.DeclareLocal(builder); - il.Emit(OpCodes.Ldarg_0); // load this - il.Emit(OpCodes.Dup); - il.Emit(OpCodes.Stloc, thisLocal); - rowContext = new GenProcessRowContext(il, thisLocal); - il.Emit(OpCodes.Ldarg_1); // load row - il.Emit(OpCodes.Stloc, rowContext.Row); - } - - var consIL = builder.DefineConstructor - (MethodAttributes.Public, CallingConventions.HasThis, Type.EmptyTypes).GetILGenerator(); - consIL.Emit(OpCodes.Ldarg_0); // load this - - foreach (var prop in gen.Properties) - { - prop.InstallFields(builder, consIL); - prop.InstallProcessingLogic(columnContext); - prop.InstallPushValue(toEntityContext); - } - { - var marked = false; - foreach (var prop in gen.Properties.OrderByDescending(p => p.Singular)) - { - if (!marked && !prop.Singular) - { - rowContext.IL.MarkLabel(rowContext.SkipSingularProperties); - marked = true; - } - prop.InstallProcessingLogic(rowContext); - } - if (!marked) rowContext.IL.MarkLabel(rowContext.SkipSingularProperties); - } - - gen.InstallConstructor(toEntityContext.IL); - consIL.Emit(OpCodes.Pop); // pop this - consIL.Emit(OpCodes.Ret); - columnContext.IL.Emit(OpCodes.Pop); // pop this - columnContext.IL.Emit(OpCodes.Ret); - rowContext.IL.Emit(OpCodes.Pop); // pop this - rowContext.IL.Emit(OpCodes.Ret); - toEntityContext.IL.Emit(OpCodes.Ret); // return constructed object - } - - private static void ImplementRowReaderTemplate(TypeBuilder builder, Type targetType, Type readerType) - { - var cons = readerType.GetConstructor(Type.EmptyTypes); - if (cons == null) throw new Exception("No default constructor for reader"); - builder.DefineDefaultConstructor(MethodAttributes.Public); - builder.AddInterfaceImplementation(typeof(IRowReaderTemplate<>).MakeGenericType(targetType)); - var creator = builder.DefineMethod - (nameof(IRowReaderTemplate.CreateReader), MethodAttributes.Public | MethodAttributes.Virtual); - creator.SetParameters(); - creator.SetReturnType(typeof(IRowReader<>).MakeGenericType(targetType)); - var il = creator.GetILGenerator(); - il.Emit(OpCodes.Newobj, cons); - il.Emit(OpCodes.Ret); - } - - public static object GenerateReaderTemplate(Type targetType) - { - // create a dynamic assembly to house our dynamic type - var assembly = new AssemblyName($"Readers.{targetType.Name}{Guid.NewGuid():N}"); - var appDomain = System.Threading.Thread.GetDomain(); - var assemblyBuilder = appDomain.DefineDynamicAssembly(assembly, AssemblyBuilderAccess.Run); - var moduleBuilder = assemblyBuilder.DefineDynamicModule(assembly.Name); - - // create the dynamic IRowReader type - var reader = moduleBuilder.DefineType - ( $"{targetType.Name}Reader" - , TypeAttributes.Public | TypeAttributes.AutoClass | TypeAttributes.AnsiClass - , typeof(object) - ); - ImplementRowReader(reader, targetType); - - var template = moduleBuilder.DefineType - ($"{targetType.Name}ReaderTemplate" - , TypeAttributes.Public | TypeAttributes.AutoClass | TypeAttributes.AnsiClass - , typeof(object) - ); - ImplementRowReaderTemplate(template, targetType, reader.CreateType()); - var templateType = template.CreateType(); - return Activator.CreateInstance(templateType); - } - } -} \ No newline at end of file diff --git a/Rezoom.ADO/Materialization/TypeInfo/TypeColumn.cs b/Rezoom.ADO/Materialization/TypeInfo/TypeColumn.cs deleted file mode 100644 index ac5703d..0000000 --- a/Rezoom.ADO/Materialization/TypeInfo/TypeColumn.cs +++ /dev/null @@ -1,19 +0,0 @@ -using System; -using System.Reflection; - -namespace Rezoom.ADO.Materialization.TypeInfo -{ - public class TypeColumn - { - public TypeColumn(string name, Type type, MethodInfo setter) - { - Name = name; - Type = type; - Setter = setter; - } - - public string Name { get; } - public Type Type { get; } - public MethodInfo Setter { get; } - } -} \ No newline at end of file diff --git a/Rezoom.ADO/Materialization/TypeInfo/TypeProfile.cs b/Rezoom.ADO/Materialization/TypeInfo/TypeProfile.cs deleted file mode 100644 index 0a8335c..0000000 --- a/Rezoom.ADO/Materialization/TypeInfo/TypeProfile.cs +++ /dev/null @@ -1,87 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using System.Reflection; - -namespace Rezoom.ADO.Materialization.TypeInfo -{ - public static class TypeProfile - { - public static readonly TypeProfile Profile = TypeProfile.OfType(typeof(T)); - } - public class TypeProfile - { - private readonly Dictionary _columnsByNameCI; - private TypeProfile - ( Type type - , Type collectionElementType - , ConstructorInfo primaryConstructor - , IReadOnlyList typeColumns - ) - { - Type = type; - CollectionElementType = collectionElementType; - PrimaryConstructor = primaryConstructor; - Columns = typeColumns; - _columnsByNameCI = typeColumns.ToDictionary(t => t.Name, StringComparer.OrdinalIgnoreCase); - } - - public bool IsCollection => CollectionElementType != null; - public Type Type { get; } - public Type CollectionElementType { get; } - public ConstructorInfo PrimaryConstructor { get; } - public IReadOnlyList Columns { get; } - - public TypeColumn KeyColumn - { - get - { - TypeColumn found; - if (_columnsByNameCI.TryGetValue("id", out found)) return found; - if (_columnsByNameCI.TryGetValue(Type.Name + "id", out found)) return found; - if (_columnsByNameCI.TryGetValue(Type.Name + "_id", out found)) return found; - return null; - } - } - - public static TypeProfile OfType(Type type) - { - var elementType = ManyNavConverter.IsMany(type); - if (elementType != null) - { - return new TypeProfile(type, elementType, null, new TypeColumn[0]); - } - var settableProperties = type.GetProperties() - .Where(p => p.CanWrite) - .Select(p => new - { - Property = p, - Setter = p.GetSetMethod(nonPublic: false) - }) - .Where(p => p.Setter != null) - .ToList(); - var constructors = type.GetConstructors().Select(c => new - { - Constructor = c, - Parameters = c.GetParameters(), - }).ToList(); - var longestConstructor = constructors.OrderByDescending(c => c.Parameters.Length).FirstOrDefault(); - if (longestConstructor == null) throw new ArgumentException("Type has no public constructors", nameof(type)); - var defaultConstructor = constructors.FirstOrDefault(c => c.Parameters.Length == 0); - if (defaultConstructor == null || longestConstructor.Parameters.Length >= settableProperties.Count) - { - var columns = longestConstructor.Parameters - .Select(p => new TypeColumn(p.Name, p.ParameterType, null)) - .ToList(); - return new TypeProfile(type, null, longestConstructor.Constructor, columns); - } - else - { - var columns = settableProperties - .Select(p => new TypeColumn(p.Property.Name, p.Property.PropertyType, p.Setter)) - .ToList(); - return new TypeProfile(type, null, defaultConstructor.Constructor, columns); - } - } - } -} \ No newline at end of file diff --git a/Rezoom.ADO/Properties/AssemblyInfo.cs b/Rezoom.ADO/Properties/AssemblyInfo.cs deleted file mode 100644 index 75952e4..0000000 --- a/Rezoom.ADO/Properties/AssemblyInfo.cs +++ /dev/null @@ -1,38 +0,0 @@ -using System.Reflection; -using System.Runtime.CompilerServices; -using System.Runtime.InteropServices; - -// General Information about an assembly is controlled through the following -// set of attributes. Change these attribute values to modify the information -// associated with an assembly. -[assembly: AssemblyTitle("Rezoom.ADO")] -[assembly: AssemblyDescription("")] -[assembly: AssemblyConfiguration("")] -[assembly: AssemblyCompany("")] -[assembly: AssemblyProduct("Rezoom.ADO")] -[assembly: AssemblyCopyright("Copyright © 2016")] -[assembly: AssemblyTrademark("")] -[assembly: AssemblyCulture("")] - -// Setting ComVisible to false makes the types in this assembly not visible -// to COM components. If you need to access a type in this assembly from -// COM, set the ComVisible attribute to true on that type. -[assembly: ComVisible(false)] - -// The following GUID is for the ID of the typelib if this project is exposed to COM -[assembly: Guid("13bb08a8-8135-4630-beab-1f35d660b52b")] - -// Version information for an assembly consists of the following four values: -// -// Major Version -// Minor Version -// Build Number -// Revision -// -// You can specify all the values or you can default the Build and Revision Numbers -// by using the '*' as shown below: -// [assembly: AssemblyVersion("1.0.*")] -[assembly: AssemblyVersion("1.0.0.0")] -[assembly: AssemblyFileVersion("1.0.0.0")] - -[assembly: InternalsVisibleTo("Rezoom.ADO.Test.Internals")] diff --git a/Rezoom.ADO/Rezoom.ADO.csproj b/Rezoom.ADO/Rezoom.ADO.csproj deleted file mode 100644 index 830c89a..0000000 --- a/Rezoom.ADO/Rezoom.ADO.csproj +++ /dev/null @@ -1,89 +0,0 @@ - - - - - Debug - AnyCPU - {13BB08A8-8135-4630-BEAB-1F35D660B52B} - Library - Properties - Rezoom.ADO - Rezoom.ADO - v4.6 - 512 - - - true - full - false - bin\Debug\ - DEBUG;TRACE - prompt - 4 - - - pdbonly - true - bin\Release\ - TRACE - prompt - 4 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - {d98acbeb-a039-4340-a7c5-6ed2b677268b} - Rezoom - - - {9db721d3-da97-4be3-b60b-9b7a682e803e} - Rezoom.Execution - - - - - \ No newline at end of file diff --git a/Rezoom.Documentation.sln b/Rezoom.Documentation.sln new file mode 100644 index 0000000..ce97090 --- /dev/null +++ b/Rezoom.Documentation.sln @@ -0,0 +1,22 @@ + +Microsoft Visual Studio Solution File, Format Version 12.00 +# Visual Studio 14 +VisualStudioVersion = 14.0.25420.1 +MinimumVisualStudioVersion = 10.0.40219.1 +Project("{F2A71F9B-5D33-465A-A702-920D77279786}") = "Rezoom.Documentation", "Rezoom.Documentation\Rezoom.Documentation.fsproj", "{62669A70-A41C-44BC-B860-EDA3550018FA}" +EndProject +Global + GlobalSection(SolutionConfigurationPlatforms) = preSolution + Debug|Any CPU = Debug|Any CPU + Release|Any CPU = Release|Any CPU + EndGlobalSection + GlobalSection(ProjectConfigurationPlatforms) = postSolution + {62669A70-A41C-44BC-B860-EDA3550018FA}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {62669A70-A41C-44BC-B860-EDA3550018FA}.Debug|Any CPU.Build.0 = Debug|Any CPU + {62669A70-A41C-44BC-B860-EDA3550018FA}.Release|Any CPU.ActiveCfg = Release|Any CPU + {62669A70-A41C-44BC-B860-EDA3550018FA}.Release|Any CPU.Build.0 = Release|Any CPU + EndGlobalSection + GlobalSection(SolutionProperties) = preSolution + HideSolutionNode = FALSE + EndGlobalSection +EndGlobal diff --git a/Rezoom.Documentation/Intro.fsx b/Rezoom.Documentation/Intro.fsx new file mode 100644 index 0000000..b784d3d --- /dev/null +++ b/Rezoom.Documentation/Intro.fsx @@ -0,0 +1,24 @@ +(*** hide ***) + +#r "../Rezoom.SQL.Provider/bin/Debug/Rezoom.dll" +#r "../Rezoom.SQL.Provider/bin/Debug/LicenseToCIL.dll" +#r "../Rezoom.SQL.Provider/bin/Debug/Rezoom.SQL.Compiler.dll" +#r "../Rezoom.SQL.Provider/bin/Debug/Rezoom.SQL.Mapping.dll" +#r "../Rezoom.SQL.Provider/bin/Debug/Rezoom.SQL.Provider.dll" +#nowarn "193" +open System +open Rezoom +open Rezoom.SQL +open Rezoom.SQL.Provider + +(** + +# Rezoom: a resumption monad for .NET data access. + +Rezoom is a library intended to help you deal with data that lives somewhere else. + +Between database servers, web APIs, and specialized protocols like SNMP, data has an annoying habit of +hanging out on other machines and waiting for us to request it. Unfortunately, while programmers can +afford to be inefficient with memory and CPU, we still have to count network round trips on our fingers. + +*) diff --git a/Rezoom.Documentation/Rezoom.Documentation.fsproj b/Rezoom.Documentation/Rezoom.Documentation.fsproj new file mode 100644 index 0000000..28e4937 --- /dev/null +++ b/Rezoom.Documentation/Rezoom.Documentation.fsproj @@ -0,0 +1,115 @@ + + + + + Debug + AnyCPU + 2.0 + 62669a70-a41c-44bc-b860-eda3550018fa + Library + Rezoom.Documentation + Rezoom.Documentation + v4.6 + 4.4.0.0 + true + Rezoom.Documentation + + + true + full + false + false + bin\Debug\ + DEBUG;TRACE + 3 + bin\Debug\Rezoom.Documentation.XML + + + pdbonly + true + true + bin\Release\ + TRACE + 3 + bin\Release\Rezoom.Documentation.XML + + + 11 + + + + + $(MSBuildExtensionsPath32)\..\Microsoft SDKs\F#\3.0\Framework\v4.0\Microsoft.FSharp.Targets + + + + + $(MSBuildExtensionsPath32)\Microsoft\VisualStudio\v$(VisualStudioVersion)\FSharp\Microsoft.FSharp.Targets + + + + + + + + + + + + + + ..\packages\FSharp.Formatting.2.14.4\lib\net40\CSharpFormat.dll + True + + + ..\packages\FSharp.Formatting.2.14.4\lib\net40\FSharp.CodeFormat.dll + True + + + ..\packages\FSharp.Compiler.Service.2.0.0.6\lib\net45\FSharp.Compiler.Service.dll + True + + + ..\packages\FSharp.Formatting.2.14.4\lib\net40\FSharp.Formatting.Common.dll + True + + + ..\packages\FSharp.Formatting.2.14.4\lib\net40\FSharp.Literate.dll + True + + + ..\packages\FSharp.Formatting.2.14.4\lib\net40\FSharp.Markdown.dll + True + + + ..\packages\FSharp.Formatting.2.14.4\lib\net40\FSharp.MetadataFormat.dll + True + + + ..\packages\FSharpVSPowerTools.Core.2.3.0\lib\net45\FSharpVSPowerTools.Core.dll + True + + + + True + + + ..\packages\FSharp.Formatting.2.14.4\lib\net40\RazorEngine.dll + True + + + + + + ..\packages\FSharp.Formatting.2.14.4\lib\net40\System.Web.Razor.dll + True + + + + \ No newline at end of file diff --git a/Rezoom.Documentation/build.fsx b/Rezoom.Documentation/build.fsx new file mode 100644 index 0000000..a651209 --- /dev/null +++ b/Rezoom.Documentation/build.fsx @@ -0,0 +1,21 @@ +#load "../packages/FSharp.Formatting.2.14.4/FSharp.Formatting.fsx" +open FSharp.Literate +open System.IO + +let source = __SOURCE_DIRECTORY__ +let template = Path.Combine(source, "template.html") + +let files = + [ "Intro.fsx" + ] + +let replacements = + [ "project-name", "Rezoom" + "github-link", "https://github.com/rspeele/Rezoom" + ] + +for file in files do + Literate.ProcessScriptFile + ( Path.Combine(source, file), template, lineNumbers = false + , replacements = replacements + ) \ No newline at end of file diff --git a/Rezoom.Documentation/packages.config b/Rezoom.Documentation/packages.config new file mode 100644 index 0000000..f4c7d3f --- /dev/null +++ b/Rezoom.Documentation/packages.config @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/Rezoom.Documentation/style.css b/Rezoom.Documentation/style.css new file mode 100644 index 0000000..4732b8b --- /dev/null +++ b/Rezoom.Documentation/style.css @@ -0,0 +1,224 @@ +@import url(https://fonts.googleapis.com/css?family=Droid+Sans|Droid+Sans+Mono|Open+Sans:400,600,700); + +/*-------------------------------------------------------------------------- + Formatting for F# code snippets +/*--------------------------------------------------------------------------*/ + +/* strings --- and stlyes for other string related formats */ +span.s { color: #D69D85; } +/* printf formatters */ +span.pf { color: #E0C57F; } +/* escaped chars */ +span.e { color: #E0E268; } + +/* identifiers --- and styles for more specific identifier types */ +span.i { color: #d1d1d1; } +/* type or module */ +span.t { color: #4EC9B0; } +/* function */ +span.f { color: #e1e1e1; } +/* DU case or active pattern */ +span.p { color: #4ec9b0; } + +/* keywords */ +span.k { color: #569CD6; } +/* comment */ +span.c { color: #57A64A; } +/* operators */ +span.o { color: #569CD6;font-weight: bold; } +/* numbers */ +span.n { color: #B5CEA8; } +/* line number */ +span.l { color: #80b0b0; } +/* mutable var or ref cell */ +span.v { color: #d1d1d1; font-weight: bold; } +/* inactive code */ +span.inactive { color: #808080; } +/* preprocessor */ +span.prep { color: #9B9B9B; } +/* fsi output */ +span.fsi { color: #808080; } + +/* omitted */ +span.omitted { + background: #3c4e52; + border-radius: 5px; + color: #808080; + padding: 0px 0px 1px 0px; +} +/* tool tip */ +div.tip { + background: #475b5f; + border-radius: 4px; + font: 11pt consolas,'Droid Sans Mono',monospace; + padding: 6px 8px 6px 8px; + display: none; + color: #d1d1d1; +} +table.pre pre { + padding: 0px; + margin: 0px; + border: none; +} +table.pre, pre.fssnip, pre { + line-height: 13pt; + border: 1px solid #d8d8d8; + border-collapse: separate; + white-space: pre; + font: 9pt consolas,'Droid Sans Mono',monospace; + width: 90%; + margin: 10px 20px 20px 20px; + background-color: #212d30; + padding: 10px; + border-radius: 5px; + color: #d1d1d1; +} +pre.fssnip code { + font: 9pt consolas,'Droid Sans Mono',monospace; +} +table.pre pre { + padding: 0px; + margin: 0px; + border-radius: 0px; + width: 100%; +} +table.pre td { + padding: 0px; + white-space: normal; + margin: 0px; +} +table.pre td.lines { + width: 30px; +} + +/*-------------------------------------------------------------------------- + Formatting for page & standard document content +/*--------------------------------------------------------------------------*/ + +body { + font-family: 'Open Sans', serif; + padding-top: 0px; + padding-bottom: 40px; +} + +pre { + word-wrap: inherit; +} + +/* Format the heading - nicer spacing etc. */ +.masthead { + overflow: hidden; +} +.masthead .muted a { + text-decoration: none; + color: #999999; +} +.masthead ul, .masthead li { + margin-bottom: 0px; +} +.masthead .nav li { + margin-top: 15px; + font-size: 110%; +} +.masthead h3 { + margin-bottom: 5px; + font-size: 170%; +} +hr { + margin: 0px 0px 20px 0px; +} + +/* Make table headings and td.title bold */ +td.title, thead { + font-weight: bold; +} + +/* Format the right-side menu */ +#menu { + margin-top: 50px; + font-size: 11pt; + padding-left: 20px; +} + +#menu .nav-header { + font-size: 12pt; + color: #606060; + margin-top: 20px; +} + +#menu li { + line-height: 25px; +} + +/* Change font sizes for headings etc. */ +#main h1 { font-size: 26pt; margin: 10px 0px 15px 0px; font-weight: 400; } +#main h2 { font-size: 20pt; margin: 20px 0px 0px 0px; font-weight: 400; } +#main h3 { font-size: 14pt; margin: 15px 0px 0px 0px; font-weight: 600; } +#main p { font-size: 11pt; margin: 5px 0px 15px 0px; } +#main ul { font-size: 11pt; margin-top: 10px; } +#main li { font-size: 11pt; margin: 5px 0px 5px 0px; } +#main strong { font-weight: 700; } + +/*-------------------------------------------------------------------------- + Formatting for API reference +/*--------------------------------------------------------------------------*/ + +.type-list .type-name, .module-list .module-name { + width: 25%; + font-weight: bold; +} +.member-list .member-name { + width: 35%; +} +#main .xmldoc h2 { + font-size: 14pt; + margin: 10px 0px 0px 0px; +} +#main .xmldoc h3 { + font-size: 12pt; + margin: 10px 0px 0px 0px; +} +.github-link { + float: right; + text-decoration: none; +} +.github-link img { + border-style: none; + margin-left: 10px; +} +.github-link .hover { display: none; } +.github-link:hover .hover { display: block; } +.github-link .normal { display: block; } +.github-link:hover .normal { display: none; } + +/*-------------------------------------------------------------------------- + Links +/*--------------------------------------------------------------------------*/ + +h1 a, h1 a:hover, h1 a:focus, +h2 a, h2 a:hover, h2 a:focus, +h3 a, h3 a:hover, h3 a:focus, +h4 a, h4 a:hover, h4 a:focus, +h5 a, h5 a:hover, h5 a:focus, +h6 a, h6 a:hover, h6 a:focus { color: inherit; text-decoration: inherit; outline: none; } + +/*-------------------------------------------------------------------------- + Additional formatting for the homepage +/*--------------------------------------------------------------------------*/ + +#nuget { + margin-top: 20px; + font-size: 11pt; + padding: 20px; +} + +#nuget pre { + font-size: 11pt; + -moz-border-radius: 0px; + -webkit-border-radius: 0px; + border-radius: 0px; + background: #404040; + border-style: none; + color: #e0e0e0; + margin-top: 15px; +} \ No newline at end of file diff --git a/Rezoom.Documentation/template.html b/Rezoom.Documentation/template.html new file mode 100644 index 0000000..f70931a --- /dev/null +++ b/Rezoom.Documentation/template.html @@ -0,0 +1,64 @@ + + + + + + {page-title} + + + + + + + + + + + + + + +
+
+ +

{project-name}

+
+
+
+
+ {document} + {tooltips} +
+
+ + +
+
+
+ Fork me on GitHub + + \ No newline at end of file diff --git a/Rezoom.EF/App.config b/Rezoom.EF/App.config deleted file mode 100644 index 7e1d79c..0000000 --- a/Rezoom.EF/App.config +++ /dev/null @@ -1,17 +0,0 @@ - - - - -
- - - - - - - - - - - - \ No newline at end of file diff --git a/Rezoom.EF/Cloner.cs b/Rezoom.EF/Cloner.cs deleted file mode 100644 index 29e1801..0000000 --- a/Rezoom.EF/Cloner.cs +++ /dev/null @@ -1,35 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using System.Reflection; - -namespace Rezoom.EF -{ - internal class Cloner where TEntity : new() - { - private static PropertyInfo[] ClonableProperties() - { - var props = typeof(TEntity).GetProperties(); - var names = new HashSet(props.Select(p => p.Name), StringComparer.OrdinalIgnoreCase); - return props.Where(p - => p.CanRead - && p.CanWrite - && !names.Contains(p.Name + "Id")) // don't clone navigation properties - .ToArray(); - } - - private readonly PropertyInfo[] _cloneProperties = ClonableProperties(); - - public TEntity Clone(TEntity entity) - { - var newEntity = new TEntity(); - foreach (var prop in _cloneProperties) - { - prop.SetValue(newEntity, prop.GetValue(entity)); - } - return newEntity; - } - - public static readonly Cloner Instance = new Cloner(); - } -} \ No newline at end of file diff --git a/Rezoom.EF/ContextRequest.cs b/Rezoom.EF/ContextRequest.cs deleted file mode 100644 index c269a1c..0000000 --- a/Rezoom.EF/ContextRequest.cs +++ /dev/null @@ -1,21 +0,0 @@ -using System; -using System.Data.Entity; -using System.Threading.Tasks; - -namespace Rezoom.EF -{ - public abstract class ContextErrand : CS.AsynchronousErrand - where TContext : DbContext - { - public override object DataSource => typeof(TContext); - public override object SequenceGroup => typeof(TContext); - - protected abstract Func> Prepare(TContext db); - - public sealed override Func> Prepare(ServiceContext context) - { - var db = context.GetService(); - return Prepare(db); - } - } -} \ No newline at end of file diff --git a/Rezoom.EF/DeleteRequest.cs b/Rezoom.EF/DeleteRequest.cs deleted file mode 100644 index 5ea56d7..0000000 --- a/Rezoom.EF/DeleteRequest.cs +++ /dev/null @@ -1,35 +0,0 @@ -using System; -using System.Data.Entity; -using System.Threading.Tasks; - -namespace Rezoom.EF -{ - public class DeleteErrand : ContextErrand - where TContext : DbContext - where TEntity : class - { - private readonly TEntity _entity; - private readonly Func> _set; - - public DeleteErrand(Func> set, TEntity entity) - { - _set = set; - _entity = entity; - } - - public override bool Mutation => true; - public override bool Idempotent => true; - - protected override Func> Prepare(TContext db) - { - var set = _set(db); - set.Attach(_entity); - set.Remove(_entity); - return async () => - { - await db.SaveChangesAsync(); - return null; - }; - } - } -} \ No newline at end of file diff --git a/Rezoom.EF/InsertRequest.cs b/Rezoom.EF/InsertRequest.cs deleted file mode 100644 index e701e68..0000000 --- a/Rezoom.EF/InsertRequest.cs +++ /dev/null @@ -1,33 +0,0 @@ -using System; -using System.Data.Entity; -using System.Threading.Tasks; - -namespace Rezoom.EF -{ - public class InsertErrand : ContextErrand - where TContext : DbContext - where TEntity : class - { - private readonly Func> _set; - private readonly TEntity _toInsert; - - public InsertErrand(Func> set, TEntity toInsert) - { - _set = set; - _toInsert = toInsert; - } - - public override bool Mutation => true; - public override bool Idempotent => false; - - protected override Func> Prepare(TContext db) - { - _set(db).Add(_toInsert); - return async () => - { - await db.SaveChangesAsync(); - return _toInsert; - }; - } - } -} \ No newline at end of file diff --git a/Rezoom.EF/Properties/AssemblyInfo.cs b/Rezoom.EF/Properties/AssemblyInfo.cs deleted file mode 100644 index 18f60d5..0000000 --- a/Rezoom.EF/Properties/AssemblyInfo.cs +++ /dev/null @@ -1,36 +0,0 @@ -using System.Reflection; -using System.Runtime.CompilerServices; -using System.Runtime.InteropServices; - -// General Information about an assembly is controlled through the following -// set of attributes. Change these attribute values to modify the information -// associated with an assembly. -[assembly: AssemblyTitle("Rezoom.EF")] -[assembly: AssemblyDescription("")] -[assembly: AssemblyConfiguration("")] -[assembly: AssemblyCompany("")] -[assembly: AssemblyProduct("Rezoom.EF")] -[assembly: AssemblyCopyright("Copyright © 2016")] -[assembly: AssemblyTrademark("")] -[assembly: AssemblyCulture("")] - -// Setting ComVisible to false makes the types in this assembly not visible -// to COM components. If you need to access a type in this assembly from -// COM, set the ComVisible attribute to true on that type. -[assembly: ComVisible(false)] - -// The following GUID is for the ID of the typelib if this project is exposed to COM -[assembly: Guid("51023e89-6081-4bbf-8945-8f17e6c4d65c")] - -// Version information for an assembly consists of the following four values: -// -// Major Version -// Minor Version -// Build Number -// Revision -// -// You can specify all the values or you can default the Build and Revision Numbers -// by using the '*' as shown below: -// [assembly: AssemblyVersion("1.0.*")] -[assembly: AssemblyVersion("1.0.0.0")] -[assembly: AssemblyFileVersion("1.0.0.0")] diff --git a/Rezoom.EF/QueryRequest.cs b/Rezoom.EF/QueryRequest.cs deleted file mode 100644 index 25ad520..0000000 --- a/Rezoom.EF/QueryRequest.cs +++ /dev/null @@ -1,30 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Data.Entity; -using System.Linq; -using System.Threading.Tasks; -using EntityFramework.Extensions; - -namespace Rezoom.EF -{ - public class QueryErrand : ContextErrand> - where TContext : DbContext - where T : class - { - private readonly Func> _query; - - public QueryErrand(Func> query) - { - _query = query; - } - - public override bool Mutation => false; - public override bool Idempotent => true; - - protected override Func>> Prepare(TContext db) - { - var future = _query(db).AsNoTracking().Future(); - return () => Task.FromResult(future.ToList()); // unfortunately, futures don't support ToListAsync - } - } -} diff --git a/Rezoom.EF/Rezoom.EF.csproj b/Rezoom.EF/Rezoom.EF.csproj deleted file mode 100644 index 142fa6f..0000000 --- a/Rezoom.EF/Rezoom.EF.csproj +++ /dev/null @@ -1,87 +0,0 @@ - - - - - Debug - AnyCPU - {51023E89-6081-4BBF-8945-8F17E6C4D65C} - Library - Properties - Rezoom.EF - Rezoom.EF - v4.6 - 512 - - - true - full - false - bin\Debug\ - DEBUG;TRACE - prompt - 4 - - - pdbonly - true - bin\Release\ - TRACE - prompt - 4 - - - - ..\packages\EntityFramework.6.1.3\lib\net45\EntityFramework.dll - True - - - ..\packages\EntityFramework.Extended.6.1.0.168\lib\net45\EntityFramework.Extended.dll - True - - - ..\packages\EntityFramework.6.1.3\lib\net45\EntityFramework.SqlServer.dll - True - - - - - - - - - - - - - - - - - - - - - - - {d98acbeb-a039-4340-a7c5-6ed2b677268b} - Rezoom - - - {9db721d3-da97-4be3-b60b-9b7a682e803e} - Rezoom.Execution - - - - - - - - - - diff --git a/Rezoom.EF/UpdateRequest.cs b/Rezoom.EF/UpdateRequest.cs deleted file mode 100644 index 7ea402d..0000000 --- a/Rezoom.EF/UpdateRequest.cs +++ /dev/null @@ -1,39 +0,0 @@ -using System; -using System.Data.Entity; -using System.Threading.Tasks; - -namespace Rezoom.EF -{ - public class UpdateErrand : ContextErrand - where TContext : DbContext - where TEntity : class, new() - { - private readonly TEntity _entity; - private readonly Func> _set; - private readonly Action _change; - - public UpdateErrand(Func> set, TEntity entity, Action change) - { - _set = set; - _entity = entity; - _change = change; - } - - public override bool Mutation => true; - public override bool Idempotent => false; - - protected override Func> Prepare(TContext db) - { - var copy = Cloner.Instance.Clone(_entity); - var set = _set(db); - set.Attach(copy); - _change(copy); - return async () => - { - await db.SaveChangesAsync(); - return null; - }; - - } - } -} \ No newline at end of file diff --git a/Rezoom.EF/packages.config b/Rezoom.EF/packages.config deleted file mode 100644 index 16e1a5e..0000000 --- a/Rezoom.EF/packages.config +++ /dev/null @@ -1,5 +0,0 @@ - - - - - \ No newline at end of file diff --git a/Rezoom.Execution/DebugExecutionLog.cs b/Rezoom.Execution/DebugExecutionLog.cs deleted file mode 100644 index 94a2a65..0000000 --- a/Rezoom.Execution/DebugExecutionLog.cs +++ /dev/null @@ -1,24 +0,0 @@ -using System; -using System.Diagnostics; - -namespace Rezoom.Execution -{ - /// - /// An that writes events to the debug output window. - /// - public class DebugExecutionLog : IExecutionLog - { - public void OnStepStart() => Debug.WriteLine("OnStepStart()"); - - public void OnStepFinish() => Debug.WriteLine("OnStepFinish()"); - - public void OnPrepare(Errand request) - => Debug.WriteLine($"OnPrepare({request.Identity})"); - - public void OnPrepareFailure(Exception exception) - => Debug.WriteLine($"OnException({exception.Message})"); - - public void OnComplete(Errand request, DataResponse response) - => Debug.WriteLine($"OnComplete({request.Identity},{response})"); - } -} diff --git a/Rezoom.Execution/ExecutionContext.cs b/Rezoom.Execution/ExecutionContext.cs deleted file mode 100644 index 9d82cb9..0000000 --- a/Rezoom.Execution/ExecutionContext.cs +++ /dev/null @@ -1,78 +0,0 @@ -using System; -using System.Threading.Tasks; -using Microsoft.FSharp.Core; - -namespace Rezoom.Execution -{ - /// - /// Handles execution of an by stepping through it and running its pending - /// s with caching and deduplication. - /// - public class ExecutionContext : IDisposable - { - private readonly IExecutionLog _log; - private readonly DefaultServiceContext _serviceContext; - private readonly ResponseCache _responseCache = new ResponseCache(); - - /// - /// Create an execution context by giving it an to provide - /// services required by the s that it'll be responsible for executing. - /// - /// - /// - public ExecutionContext(ServiceFactory serviceFactory, IExecutionLog log = null) - { - _serviceContext = new DefaultServiceContext(serviceFactory); - _log = log; - } - - private async Task> ExecutePending - ( Batch pending - , FSharpFunc, Plan> resume - ) - { - _log?.OnStepStart(); - _serviceContext.BeginStep(); - Batch responses; - try - { - var stepContext = new StepContext(_serviceContext, _log, _responseCache); - var retrievals = pending.MapCS - (request => stepContext.AddRequest(request)); - await stepContext.Execute().ConfigureAwait(false); - responses = retrievals.MapCS(retrieve => retrieve()); - } - finally - { - _serviceContext.EndStep(); - _log?.OnStepFinish(); - } - return resume.Invoke(responses); - } - - /// - /// Asynchronously run the given to completion. - /// - /// - /// - /// - public async Task Execute(Plan task) - { - while (true) - { - if (task.IsStep) - { - var step = (Plan.Step)task; - task = await ExecutePending(step.Item.Item1, step.Item.Item2).ConfigureAwait(false); - } - else - { - return ((Plan.Result)task).Item; - } - } - } - - public TService GetService() => _serviceContext.GetService(); - public void Dispose() => _serviceContext.Dispose(); - } -} diff --git a/Rezoom.Execution/IExecutionLog.cs b/Rezoom.Execution/IExecutionLog.cs deleted file mode 100644 index 6f67299..0000000 --- a/Rezoom.Execution/IExecutionLog.cs +++ /dev/null @@ -1,40 +0,0 @@ -using System; - -namespace Rezoom.Execution -{ - /// - /// Receives all the data requests and their responses - /// that occur during execution of a data task. - /// - /// - /// Execution log implementations are responsible for catching their own exceptions. - /// They should never attempt to mutate the responses passed through them. - /// - public interface IExecutionLog - { - /// - /// Called when an execution step begins. - /// - void OnStepStart(); - /// - /// Called when an execution step finishes. - /// - void OnStepFinish(); - /// - /// Called when a data request is prepared for execution. - /// - /// - void OnPrepare(Errand request); - /// - /// Called when a data request's prepare method throws an exception. - /// - /// - void OnPrepareFailure(Exception exception); - /// - /// Called when a data request has finished executing. - /// - /// - /// - void OnComplete(Errand request, DataResponse response); - } -} diff --git a/Rezoom.Execution/Properties/AssemblyInfo.cs b/Rezoom.Execution/Properties/AssemblyInfo.cs deleted file mode 100644 index 94a361a..0000000 --- a/Rezoom.Execution/Properties/AssemblyInfo.cs +++ /dev/null @@ -1,36 +0,0 @@ -using System.Reflection; -using System.Runtime.CompilerServices; -using System.Runtime.InteropServices; - -// General Information about an assembly is controlled through the following -// set of attributes. Change these attribute values to modify the information -// associated with an assembly. -[assembly: AssemblyTitle("Rezoom.Execution")] -[assembly: AssemblyDescription("")] -[assembly: AssemblyConfiguration("")] -[assembly: AssemblyCompany("")] -[assembly: AssemblyProduct("Rezoom.Execution")] -[assembly: AssemblyCopyright("Copyright © Robert Peele 2016")] -[assembly: AssemblyTrademark("")] -[assembly: AssemblyCulture("")] - -// Setting ComVisible to false makes the types in this assembly not visible -// to COM components. If you need to access a type in this assembly from -// COM, set the ComVisible attribute to true on that type. -[assembly: ComVisible(false)] - -// The following GUID is for the ID of the typelib if this project is exposed to COM -[assembly: Guid("9db721d3-da97-4be3-b60b-9b7a682e803e")] - -// Version information for an assembly consists of the following four values: -// -// Major Version -// Minor Version -// Build Number -// Revision -// -// You can specify all the values or you can default the Build and Revision Numbers -// by using the '*' as shown below: -// [assembly: AssemblyVersion("1.0.*")] -[assembly: AssemblyVersion("1.0.0.0")] -[assembly: AssemblyFileVersion("1.0.0.0")] diff --git a/Rezoom.Execution/Rezoom.Execution.csproj b/Rezoom.Execution/Rezoom.Execution.csproj deleted file mode 100644 index d5b9730..0000000 --- a/Rezoom.Execution/Rezoom.Execution.csproj +++ /dev/null @@ -1,64 +0,0 @@ - - - - - Debug - AnyCPU - {9DB721D3-DA97-4BE3-B60B-9B7A682E803E} - Library - Properties - Rezoom.Execution - Rezoom.Execution - v4.6 - 512 - - - true - full - false - bin\Debug\ - DEBUG;TRACE - prompt - 4 - - - pdbonly - true - bin\Release\ - TRACE - prompt - 4 - - - - - - - - - - - - - - - - - - - - - - {d98acbeb-a039-4340-a7c5-6ed2b677268b} - Rezoom - - - - - \ No newline at end of file diff --git a/Rezoom.Execution/StepContext.cs b/Rezoom.Execution/StepContext.cs deleted file mode 100644 index 3ed8517..0000000 --- a/Rezoom.Execution/StepContext.cs +++ /dev/null @@ -1,162 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using System.Threading.Tasks; -using Microsoft.FSharp.Core; - -namespace Rezoom.Execution -{ - internal class StepContext - { - private readonly ServiceContext _serviceContext; - private readonly IExecutionLog _executionLog; - private readonly ResponseCache _cache; - private readonly List> _unsequenced = new List>(); - private readonly Dictionary>> _sequenceGroups - = new Dictionary>>(); - private readonly Dictionary> _deduped - = new Dictionary>(); - - public StepContext(ServiceContext serviceContext, IExecutionLog executionLog, ResponseCache cache) - { - _serviceContext = serviceContext; - _executionLog = executionLog; - _cache = cache; - } - - private class PendingResult - { - private DataResponse _result; - public DataResponse Get() => _result; - public async Task Run(Errand request, IExecutionLog log, FSharpFunc> prepared) - { - try - { - _result = DataResponse.NewRetrievalSuccess(await prepared.Invoke(null).ConfigureAwait(false)); - } - catch (Exception ex) - { - _result = DataResponse.NewRetrievalException(ex); - } - log?.OnComplete(request, _result); - } - } - - private Func AddRequestToRun(Errand request) - { - if (request.Mutation) - { - _cache.Invalidate(request.DataSource); - } - var eventual = new PendingResult(); - FSharpFunc> prepared; - try - { - prepared = request.InternalPrepare(_serviceContext); - _executionLog?.OnPrepare(request); - } - catch (Exception ex) - { - _executionLog?.OnPrepareFailure(ex); - return () => DataResponse.NewRetrievalException(ex); - } - Func run = () => eventual.Run(request, _executionLog, prepared); - if (!request.Parallelizable) - { - _unsequenced.Add(run); - } - else - { - var sequenceGroupId = request.SequenceGroup; - if (sequenceGroupId == null) - { - _unsequenced.Add(run); - } - else - { - List> sequenceGroup; - if (!_sequenceGroups.TryGetValue(sequenceGroupId, out sequenceGroup)) - { - sequenceGroup = new List>(); - _sequenceGroups[sequenceGroupId] = sequenceGroup; - } - sequenceGroup.Add(run); - } - } - return eventual.Get; - } - - public Func AddRequest(Errand request) - { - var identity = request.Identity; - // If this request is not cachable, we have to run it. - if (!request.Idempotent || identity == null) return AddRequestToRun(request); - // Otherwise... - var dataSource = request.DataSource; - // Check for a cached result. - object value = null; - if (_cache.TryGetValue(dataSource, identity, ref value)) - { - return () => DataResponse.NewRetrievalSuccess(value); - } - // Check for de-duplication of this request within this step. - Func existing; - if (_deduped.TryGetValue(identity, out existing)) return existing; - // Otherwise, we really need to run this request. - var toRun = AddRequestToRun(request); - _deduped[identity] = toRun; - return () => - { - var result = toRun(); - if (result.IsRetrievalSuccess) - { - _cache.Store(dataSource, identity, ((DataResponse.RetrievalSuccess)result).Item); - } - return result; - }; - } - - private static async Task ExecuteSequentialGroup(Task pending, IEnumerator> rest) - { - await pending.ConfigureAwait(false); - while (rest.MoveNext()) - { - await rest.Current().ConfigureAwait(false); - } - } - - private static Task ExecuteSequentialGroup(IEnumerable> tasks) - { - using (var enumerator = tasks.GetEnumerator()) - { - while (enumerator.MoveNext()) - { - var task = enumerator.Current(); - if (task.IsCompleted) continue; - return ExecuteSequentialGroup(task, enumerator); - } - } - return Task.CompletedTask; - } - - public Task Execute() - { - var tasks = new Task[_sequenceGroups.Values.Count + _unsequenced.Count]; - var i = 0; - var allDone = true; - foreach (var sgroup in _sequenceGroups.Values) - { - var task = ExecuteSequentialGroup(sgroup); - tasks[i++] = task; - allDone &= task.IsCompleted; - } - foreach (var unseq in _unsequenced) - { - var task = unseq(); - tasks[i++] = task; - allDone &= task.IsCompleted; - } - return allDone ? Task.CompletedTask : Task.WhenAll(tasks); - } - } -} \ No newline at end of file diff --git a/Rezoom.IPGeo.Test/Environment.fs b/Rezoom.IPGeo.Test/Environment.fs index 5dfc34f..96b784b 100644 --- a/Rezoom.IPGeo.Test/Environment.fs +++ b/Rezoom.IPGeo.Test/Environment.fs @@ -11,31 +11,27 @@ type 'a ExpectedResult = | Value of 'a type 'a TestTask = - { - Task : 'a Plan + { Task : 'a Plan Batches : string list list ExpectedResult : 'a ExpectedResult } type TestExecutionLog() = + inherit ExecutionLog() let batches = new ResizeArray() member __.Batches = batches |> Seq.map List.ofSeq |> List.ofSeq - interface IExecutionLog with - member this.OnComplete(request, response) = () - member this.OnPrepareFailure(exn) = () - member this.OnPrepare(request) = - batches.[batches.Count - 1].Add(string request.Identity) - member this.OnStepFinish() = () - member this.OnStepStart() = batches.Add(new ResizeArray<_>()) + override this.OnPreparedErrand(errand) = + batches.[batches.Count - 1].Add(string errand.CacheInfo.Identity) + override this.OnBeginStep() = batches.Add(new ResizeArray<_>()) let test (task : 'a TestTask) = let log = new TestExecutionLog() - use context = - new ExecutionContext(new ZeroServiceFactory(), log) + let config = { ExecutionConfig.Default with Log = log } let answer = try - context.Execute(task.Task).Result |> Some + let task = execute config task.Task + Some task.Result with | ex -> match task.ExpectedResult with diff --git a/Rezoom.IPGeo.Test/Rezoom.IPGeo.Test.fsproj b/Rezoom.IPGeo.Test/Rezoom.IPGeo.Test.fsproj index a7370ca..4e0e817 100644 --- a/Rezoom.IPGeo.Test/Rezoom.IPGeo.Test.fsproj +++ b/Rezoom.IPGeo.Test/Rezoom.IPGeo.Test.fsproj @@ -1,4 +1,4 @@ - + @@ -33,22 +33,50 @@ 3 bin\Release\Rezoom.IPGeo.Test.XML + + 11 + + + + + $(MSBuildExtensionsPath32)\..\Microsoft SDKs\F#\3.0\Framework\v4.0\Microsoft.FSharp.Targets + + + + + $(MSBuildExtensionsPath32)\Microsoft\VisualStudio\v$(VisualStudioVersion)\FSharp\Microsoft.FSharp.Targets + + + + - + + + + + + + + + ..\packages\FSharp.Core.4.0.0.1\lib\net40\FSharp.Core.dll + True + + + ..\packages\FsUnit.2.3.2\lib\net45\FsUnit.NUnit.dll + True + - + + ..\packages\FsUnit.2.3.2\lib\net45\NHamcrest.dll + True + + + ..\packages\NUnit.3.5.0\lib\net45\nunit.framework.dll True - - - - - - - Rezoom.IPGeo {ceb9e01b-71c6-468b-8c3e-a1617f036370} @@ -59,28 +87,7 @@ {d98acbeb-a039-4340-a7c5-6ed2b677268b} True - - Rezoom.Execution - {9db721d3-da97-4be3-b60b-9b7a682e803e} - True - - - 11 - - - - - $(MSBuildExtensionsPath32)\..\Microsoft SDKs\F#\3.0\Framework\v4.0\Microsoft.FSharp.Targets - - - - - $(MSBuildExtensionsPath32)\Microsoft\VisualStudio\v$(VisualStudioVersion)\FSharp\Microsoft.FSharp.Targets - - - - - + \ No newline at end of file diff --git a/Rezoom.IPGeo.Test/TestGeo.fs b/Rezoom.IPGeo.Test/TestGeo.fs index dcbd93a..7834e35 100644 --- a/Rezoom.IPGeo.Test/TestGeo.fs +++ b/Rezoom.IPGeo.Test/TestGeo.fs @@ -1,65 +1,59 @@ -namespace Rezoom.IPGeo.Test +module Rezoom.IPGeo.Test.TestGeo open Rezoom open Rezoom.IPGeo -open Microsoft.VisualStudio.TestTools.UnitTesting +open NUnit.Framework +open FsUnit -[] -type TestGeo() = - static let googleDNS = "8.8.8.8" - static let googleDNS2 = "8.8.4.4" - static let openDNS = "208.67.220.220" - static let fb = "2a03:2880:2110:df07:face:b00c::1" +let googleDNS = "8.8.8.8" +let googleDNS2 = "8.8.4.4" +let openDNS = "208.67.220.220" +let fb = "2a03:2880:2110:df07:face:b00c::1" - static let googleDNSISP = "Google" - static let googleDNS2ISP = "Level 3 Communications" - static let openDNSISP = "OpenDNS, LLC" - static let fbISP = "Facebook" +let googleDNSISP = "Google" +let googleDNS2ISP = "Level 3 Communications" +let openDNSISP = "OpenDNS, LLC" +let fbISP = "Facebook" - [] - member __.TestBatches() = - { - Task = - plan { - let! g1, g2 = Geo.Locate(googleDNS), Geo.Locate(googleDNS2) - let! o = Geo.Locate(openDNS) - return g1.Isp, g2.Isp, o.Isp - } - Batches = - [ - [googleDNS; googleDNS2] - [openDNS] - ] - ExpectedResult = Value (googleDNSISP, googleDNS2ISP, openDNSISP) - } |> test +[] +let ``batches`` () = + { Task = + plan { + let! g1, g2 = Geo.Locate(googleDNS), Geo.Locate(googleDNS2) + let! o = Geo.Locate(openDNS) + return g1.Isp, g2.Isp, o.Isp + } + Batches = + [ [googleDNS; googleDNS2] + [openDNS] + ] + ExpectedResult = Value (googleDNSISP, googleDNS2ISP, openDNSISP) + } |> test - [] - member __.TestCaching() = - { - Task = - plan { - let! g1 = Geo.Locate(googleDNS) - let! g2, o = Geo.Locate(googleDNS), Geo.Locate(openDNS) - return g1.Isp, g2.Isp, o.Isp - } - Batches = - [ - [googleDNS] - [openDNS] // note that we don't request google DNS again - ] - ExpectedResult = Value (googleDNSISP, googleDNSISP, openDNSISP) - } |> test +[] +let ``caching`` () = + { Task = + plan { + let! g1 = Geo.Locate(googleDNS) + let! g2, o = Geo.Locate(googleDNS), Geo.Locate(openDNS) + return g1.Isp, g2.Isp, o.Isp + } + Batches = + [ [googleDNS] + [] // empty batch because we defer openDNS after getting google from the cache + [openDNS] + ] + ExpectedResult = Value (googleDNSISP, googleDNSISP, openDNSISP) + } |> test - [] - member __.TestDedup() = - { - Task = - plan { - let! g1, fb1, fb2 = Geo.Locate(googleDNS), Geo.Locate(fb), Geo.Locate(fb) - return g1.Isp, fb1.Isp, fb2.Isp - } - Batches = - [ - [googleDNS; fb] // note that we don't request fb twice - ] - ExpectedResult = Value (googleDNSISP, fbISP, fbISP) - } |> test \ No newline at end of file +[] +let ``dedup`` () = + { Task = + plan { + let! g1, fb1, fb2 = Geo.Locate(googleDNS), Geo.Locate(fb), Geo.Locate(fb) + return g1.Isp, fb1.Isp, fb2.Isp + } + Batches = + [ [googleDNS; fb] // note that we don't request fb twice + ] + ExpectedResult = Value (googleDNSISP, fbISP, fbISP) + } |> test \ No newline at end of file diff --git a/Rezoom.IPGeo.Test/app.config b/Rezoom.IPGeo.Test/app.config new file mode 100644 index 0000000..c130c89 --- /dev/null +++ b/Rezoom.IPGeo.Test/app.config @@ -0,0 +1,11 @@ + + + + + + + + + + + \ No newline at end of file diff --git a/Rezoom.IPGeo.Test/packages.config b/Rezoom.IPGeo.Test/packages.config new file mode 100644 index 0000000..02f2331 --- /dev/null +++ b/Rezoom.IPGeo.Test/packages.config @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/Rezoom.IPGeo/Geo.cs b/Rezoom.IPGeo/Geo.cs index fe23ab6..7e22832 100644 --- a/Rezoom.IPGeo/Geo.cs +++ b/Rezoom.IPGeo/Geo.cs @@ -1,4 +1,8 @@ -namespace Rezoom.IPGeo +using Microsoft.FSharp.Core; +using Rezoom.IPGeo.Internals; +using Rezoom.CS; + +namespace Rezoom.IPGeo { public static class Geo { @@ -7,6 +11,6 @@ public static class Geo /// /// /// - public static Plan Locate(string ip) => new GeoErrand(ip).ToPlan(); + public static FSharpFunc> Locate(string ip) => new GeoErrand(ip).ToPlan(); } } diff --git a/Rezoom.IPGeo/Internals/GeoBatch.cs b/Rezoom.IPGeo/Internals/GeoBatch.cs index 924d1db..a83b956 100644 --- a/Rezoom.IPGeo/Internals/GeoBatch.cs +++ b/Rezoom.IPGeo/Internals/GeoBatch.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Threading; using System.Threading.Tasks; namespace Rezoom.IPGeo @@ -14,7 +15,7 @@ internal class GeoBatch private Task> _runningTask; - public Func> Prepare(GeoQuery query) + public Func> Prepare(GeoQuery query) { if (_runningTask != null) { @@ -22,7 +23,7 @@ public Func> Prepare(GeoQuery query) } var index = _queries.Count; _queries.Add(query); - return () => GetResult(index); + return _ => GetResult(index); } private static async Task> GetResults(List requests) diff --git a/Rezoom.IPGeo/Internals/GeoRequest.cs b/Rezoom.IPGeo/Internals/GeoRequest.cs index 622f832..1c10ba7 100644 --- a/Rezoom.IPGeo/Internals/GeoRequest.cs +++ b/Rezoom.IPGeo/Internals/GeoRequest.cs @@ -1,8 +1,21 @@ using System; +using System.Threading; using System.Threading.Tasks; +using Rezoom; -namespace Rezoom.IPGeo +namespace Rezoom.IPGeo.Internals { + internal class GeoCacheInfo : CacheInfo + { + private readonly string _ip; + public GeoCacheInfo(string ip) + { + _ip = ip; + } + public override object Category => typeof(GeoCacheInfo).Assembly; + public override object Identity => _ip; + public override bool Cacheable => true; + } /// /// Implements for looking up for an IP address. /// @@ -15,21 +28,11 @@ public GeoErrand(string ip) _ip = ip; } - // Requests for the same ip can be deduped/cached. - public override object Identity => _ip; - // The cache will only be cleared when there is a mutation with the same datasource. - public override object DataSource => typeof(GeoBatch); - // Requests with the same sequence group will be prepared and executed sequentially, so GeoBatch doesn't need - // to be thread-safe. - public override object SequenceGroup => typeof(GeoBatch); - // Looking up an IP 3 times is the same as looking it up once and returning the result 3 times. - public override bool Idempotent => true; - // Looking up an IP doesn't change anything, so it shouldn't invalidate any caches. - public override bool Mutation => false; + public override CacheInfo CacheInfo => new GeoCacheInfo(_ip); - public override Func> Prepare(ServiceContext context) + public override Func> Prepare(ServiceContext context) { - var batch = context.GetService>().Service; + var batch = context.GetService, GeoBatch>(); return batch.Prepare(new GeoQuery { Query = _ip }); } } diff --git a/Rezoom.IPGeo/Rezoom.IPGeo.csproj b/Rezoom.IPGeo/Rezoom.IPGeo.csproj index 2b94d88..927bcae 100644 --- a/Rezoom.IPGeo/Rezoom.IPGeo.csproj +++ b/Rezoom.IPGeo/Rezoom.IPGeo.csproj @@ -1,4 +1,4 @@ - + @@ -30,6 +30,7 @@ 4 + ..\packages\Newtonsoft.Json.9.0.1\lib\net45\Newtonsoft.Json.dll True @@ -60,10 +61,6 @@ {d98acbeb-a039-4340-a7c5-6ed2b677268b} Rezoom - - {9db721d3-da97-4be3-b60b-9b7a682e803e} - Rezoom.Execution - - + \ No newline at end of file diff --git a/Rezoom.SQL.Compiler/AST.fs b/Rezoom.SQL.Compiler/AST.fs new file mode 100644 index 0000000..1bb2229 --- /dev/null +++ b/Rezoom.SQL.Compiler/AST.fs @@ -0,0 +1,773 @@ +namespace Rezoom.SQL.Compiler +open System +open System.Collections.Generic +open System.Globalization + +type NumericLiteral = + | IntegerLiteral of uint64 + | FloatLiteral of float + +type SignedNumericLiteral = + { Sign : int // -1, 0, 1 + Value : NumericLiteral + } + +type Literal = + | NullLiteral + | BooleanLiteral of bool + | StringLiteral of string + | BlobLiteral of byte array + | NumericLiteral of NumericLiteral + | DateTimeLiteral of DateTime + | DateTimeOffsetLiteral of DateTimeOffset + +type SavepointName = Name + +type Alias = Name option + +type IntegerSize = + | Integer8 + | Integer16 + | Integer32 + | Integer64 + +type FloatSize = + | Float32 + | Float64 + +type TypeName = + | StringTypeName of maxLength : int option + | BinaryTypeName of maxLength : int option + | IntegerTypeName of IntegerSize + | FloatTypeName of FloatSize + | DecimalTypeName + | BooleanTypeName + | DateTimeTypeName + | DateTimeOffsetTypeName + +[] +[] +type ObjectName<'t> = + { Source : SourceInfo + SchemaName : Name option + ObjectName : Name + Info : 't + } + override this.ToString() = + string <| + match this.SchemaName with + | None -> this.ObjectName + | Some schema -> schema + "." + this.ObjectName + member this.Equals(other) = + this.SchemaName = other.SchemaName + && this.ObjectName = other.ObjectName + override this.Equals(other) = + match other with + | :? ObjectName<'t> as other -> this.Equals(other) + | _ -> false + override this.GetHashCode() = this.SchemaName +@+ this.ObjectName + interface IEquatable> with + member this.Equals(other) = this.Equals(other) + +type ColumnName<'t> = + { Table : ObjectName<'t> option + ColumnName : Name + } + override this.ToString() = + string <| + match this.Table with + | None -> this.ColumnName + | Some tbl -> string tbl + "." + this.ColumnName + +type BindParameter = + | NamedParameter of Name // prefix character : or $ or @ is ignored + +type BinaryOperator = + | Concatenate + | Multiply + | Divide + | Modulo + | Add + | Subtract + | BitShiftLeft + | BitShiftRight + | BitAnd + | BitOr + | LessThan + | LessThanOrEqual + | GreaterThan + | GreaterThanOrEqual + | Equal + | NotEqual + | Is + | IsNot + | And + | Or + /// True if this operator expects boolean inputs and has a boolean output. + member this.IsLogicalOperator = + match this with + | And + | Or -> true + | _ -> false + +type UnaryOperator = + | Negative + | Not + | BitNot + | NotNull + | IsNull + /// True if this operator expects boolean inputs and has a boolean output. + member this.IsLogicalOperator = + match this with + | Not -> true + | _ -> false + +type SimilarityOperator = + | Like + | Glob + | Match + | Regexp + +type Raise = + | RaiseIgnore + | RaiseRollback of string + | RaiseAbort of string + | RaiseFail of string + +type ExprType<'t, 'e> = + | LiteralExpr of Literal + | BindParameterExpr of BindParameter + | ColumnNameExpr of ColumnName<'t> + | CastExpr of CastExpr<'t, 'e> + | CollateExpr of CollationExpr<'t, 'e> + | FunctionInvocationExpr of FunctionInvocationExpr<'t, 'e> + | SimilarityExpr of SimilarityExpr<'t, 'e> + | BinaryExpr of BinaryExpr<'t, 'e> + | UnaryExpr of UnaryExpr<'t, 'e> + | BetweenExpr of BetweenExpr<'t, 'e> + | InExpr of InExpr<'t, 'e> + | ExistsExpr of SelectStmt<'t, 'e> + | CaseExpr of CaseExpr<'t, 'e> + | ScalarSubqueryExpr of SelectStmt<'t, 'e> + | RaiseExpr of Raise + +and + [] + [] + Expr<'t, 'e> = + { Value : ExprType<'t, 'e> + Info : 'e + Source : SourceInfo + } + member this.Equals(other) = this.Value = other.Value + override this.Equals(other) = + match other with + | :? Expr<'t, 'e> as other -> this.Equals(other) + | _ -> false + override this.GetHashCode() = this.Value.GetHashCode() + interface IEquatable> with + member this.Equals(other) = this.Equals(other) + +and InExpr<'t, 'e> = + { Invert : bool + Input : Expr<'t, 'e> + Set : InSet<'t, 'e> WithSource + } + +and CollationExpr<'t, 'e> = + { Input : Expr<'t, 'e> + Collation : Name + } + +and BinaryExpr<'t, 'e> = + { Left : Expr<'t, 'e> + Operator : BinaryOperator + Right : Expr<'t, 'e> + } + +and UnaryExpr<'t, 'e> = + { Operator : UnaryOperator + Operand : Expr<'t, 'e> + } + +and SimilarityExpr<'t, 'e> = + { Invert : bool + Operator : SimilarityOperator + Input : Expr<'t, 'e> + Pattern : Expr<'t, 'e> + Escape : Expr<'t, 'e> option + } + +and BetweenExpr<'t, 'e> = + { Invert : bool + Input : Expr<'t, 'e> + Low : Expr<'t, 'e> + High : Expr<'t, 'e> + } + +and CastExpr<'t, 'e> = + { Expression : Expr<'t, 'e> + AsType : TypeName + } + +and TableInvocation<'t, 'e> = + { Table : ObjectName<'t> + Arguments : Expr<'t, 'e> array option // we use an option to distinguish between schema.table and schema.table() + } + +and FunctionInvocationExpr<'t, 'e> = + { FunctionName : Name + Arguments : FunctionArguments<'t, 'e> + } + +and CaseExpr<'t, 'e> = + { Input : Expr<'t, 'e> option + Cases : (Expr<'t, 'e> * Expr<'t, 'e>) array + Else : Expr<'t, 'e> option WithSource + } + +and Distinct = | Distinct + +and DistinctColumns = + | DistinctColumns + | AllColumns + +and FunctionArguments<'t, 'e> = + | ArgumentWildcard + | ArgumentList of (Distinct option * Expr<'t, 'e> array) + +and InSet<'t, 'e> = + | InExpressions of Expr<'t, 'e> array + | InSelect of SelectStmt<'t, 'e> + | InTable of TableInvocation<'t, 'e> + | InParameter of BindParameter + +and + [] + [] + SelectStmtCore<'t, 'e> = + { With : WithClause<'t, 'e> option + Compound : CompoundExpr<'t, 'e> + OrderBy : OrderingTerm<'t, 'e> array option + Limit : Limit<'t, 'e> option + Info : 't + } + member this.Equals(other) = + this.With = other.With + && this.Compound = other.Compound + && this.OrderBy = other.OrderBy + && this.Limit = other.Limit + override this.Equals(other) = + match other with + | :? SelectStmtCore<'t, 'e> as other -> this.Equals(other) + | _ -> false + override this.GetHashCode() = + this.With + +@+ this.Compound + +@+ this.OrderBy + +@+ this.Limit + interface IEquatable> with + member this.Equals(other) = this.Equals(other) + + +and SelectStmt<'t, 'e> = SelectStmtCore<'t, 'e> WithSource + +and WithClause<'t, 'e> = + { Recursive : bool + Tables : CommonTableExpression<'t, 'e> array + } + +and + [] + [] + CommonTableExpression<'t, 'e> = + { Name : Name + ColumnNames : Name WithSource array WithSource option + AsSelect : SelectStmt<'t, 'e> + Info : 't + } + member this.Equals(other) = + this.Name = other.Name + && this.ColumnNames = other.ColumnNames + && this.AsSelect = other.AsSelect + override this.Equals(other) = + match other with + | :? CommonTableExpression<'t, 'e> as other -> this.Equals(other) + | _ -> false + override this.GetHashCode() = + this.Name + +@+ this.ColumnNames + +@+ this.AsSelect + interface IEquatable> with + member this.Equals(other) = this.Equals(other) + +and OrderDirection = + | Ascending + | Descending + +and OrderingTerm<'t, 'e> = + { By : Expr<'t, 'e> + Direction : OrderDirection + } + +and Limit<'t, 'e> = + { Limit : Expr<'t, 'e> + Offset : Expr<'t, 'e> option + } + +and CompoundExprCore<'t, 'e> = + | CompoundTerm of CompoundTerm<'t, 'e> + | Union of CompoundExpr<'t, 'e> * CompoundTerm<'t, 'e> + | UnionAll of CompoundExpr<'t, 'e> * CompoundTerm<'t, 'e> + | Intersect of CompoundExpr<'t, 'e> * CompoundTerm<'t, 'e> + | Except of CompoundExpr<'t, 'e> * CompoundTerm<'t, 'e> + member this.Info = + match this with + | CompoundTerm term -> term.Info + | Union (ex, _) + | UnionAll (ex, _) + | Intersect (ex, _) + | Except (ex, _) -> ex.Value.Info + +and CompoundExpr<'t, 'e> = CompoundExprCore<'t, 'e> WithSource + +and CompoundTermCore<'t, 'e> = + | Values of Expr<'t, 'e> array WithSource array + | Select of SelectCore<'t, 'e> + +and + [] + [] + CompoundTerm<'t, 'e> = + { Value : CompoundTermCore<'t, 'e> + Source : SourceInfo + Info : 't + } + member this.Equals(other) = other.Value = this.Value + override this.Equals(other) = + match other with + | :? CompoundTerm<'t, 'e> as other -> this.Equals(other) + | _ -> false + override this.GetHashCode() = this.Value.GetHashCode() + interface IEquatable> with + member this.Equals(other) = this.Equals(other) + +and + [] + [] + SelectCore<'t, 'e> = + { Columns : ResultColumns<'t, 'e> + From : TableExpr<'t, 'e> option + Where : Expr<'t, 'e> option + GroupBy : GroupBy<'t, 'e> option + Info : 't + } + member this.Equals(other) = + this.Columns = other.Columns + && this.From = other.From + && this.Where = other.Where + && this.GroupBy = other.GroupBy + override this.Equals(other) = + match other with + | :? SelectCore<'t, 'e> as other -> this.Equals(other) + | _ -> false + override this.GetHashCode() = + this.Columns + +@+ this.From + +@+ this.Where + +@+ this.GroupBy + interface IEquatable> with + member this.Equals(other) = this.Equals(other) + + +and GroupBy<'t, 'e> = + { By : Expr<'t, 'e> array + Having : Expr<'t, 'e> option + } + +and ResultColumns<'t, 'e> = + { Distinct : DistinctColumns option + Columns : ResultColumn<'t, 'e> array + } + +and ResultColumnNavCardinality = + | NavOne + | NavOptional + | NavMany + member this.Separator = + match this with + | NavOne -> "$" + | NavOptional -> "?$" + | NavMany -> "*$" + +and ResultColumnNav<'t, 'e> = + { Cardinality : ResultColumnNavCardinality + Name : Name + Columns : ResultColumn<'t, 'e> array + } + +and ResultColumnCase<'t, 'e> = + | ColumnsWildcard + | TableColumnsWildcard of Name + | Column of Expr<'t, 'e> * Alias + | ColumnNav of ResultColumnNav<'t, 'e> + member this.AssumeColumn() = + match this with + | Column (expr, alias) -> expr, alias + | _ -> failwith "BUG: wildcard was assumed to be a single column (should've been expanded by now)" + +and ResultColumn<'t, 'e> = + { Case : ResultColumnCase<'t, 'e> + Source : SourceInfo + } + +and IndexHint = + | IndexedBy of Name + | NotIndexed + +and QualifiedTableName<'t> = + { TableName : ObjectName<'t> + IndexHint : IndexHint option + } + +and TableOrSubqueryType<'t, 'e> = + | Table of TableInvocation<'t, 'e> * IndexHint option // note: an index hint is invalid if the table has args + | Subquery of SelectStmt<'t, 'e> + +and + [] + [] + TableOrSubquery<'t, 'e> = + { Table : TableOrSubqueryType<'t, 'e> + Alias : Name option + Info : 't + } + member this.Equals(other) = + this.Table = other.Table + && this.Alias = other.Alias + override this.Equals(other) = + match other with + | :? TableOrSubquery<'t, 'e> as other -> this.Equals(other) + | _ -> false + override this.GetHashCode() = this.Table +@+ this.Alias + interface IEquatable> with + member this.Equals(other) = this.Equals(other) + +and JoinType = + | Inner + | LeftOuter + | Cross + | Natural of JoinType + member this.IsOuter = this = LeftOuter + +and JoinConstraint<'t, 'e> = + | JoinOn of Expr<'t, 'e> + | JoinUnconstrained + +and Join<'t, 'e> = + { JoinType : JoinType + LeftTable : TableExpr<'t, 'e> + RightTable : TableExpr<'t, 'e> + Constraint : JoinConstraint<'t, 'e> + } + +and TableExprCore<'t, 'e> = + | TableOrSubquery of TableOrSubquery<'t, 'e> + | Join of Join<'t, 'e> + +and TableExpr<'t, 'e> = TableExprCore<'t, 'e> WithSource + +type ForeignKeyEvent = + | OnDelete + | OnUpdate + +type ForeignKeyEventHandler = + | SetNull + | SetDefault + | Cascade + | Restrict + | NoAction + +type ForeignKeyRule = + | MatchRule of Name + | EventRule of (ForeignKeyEvent * ForeignKeyEventHandler) + +type ForeignKeyDeferClause = + { Deferrable : bool + InitiallyDeferred : bool option + } + +type ForeignKeyClause<'t> = + { ReferencesTable : ObjectName<'t> + ReferencesColumns : Name WithSource array + Rules : ForeignKeyRule array + Defer : ForeignKeyDeferClause option + } + +type PrimaryKeyClause = + { Order : OrderDirection + AutoIncrement : bool + } + +type ColumnConstraintType<'t, 'e> = + | NullableConstraint + | PrimaryKeyConstraint of PrimaryKeyClause + | UniqueConstraint + | DefaultConstraint of Expr<'t, 'e> + | CollateConstraint of Name + | ForeignKeyConstraint of ForeignKeyClause<'t> + member this.DefaultName(columnName : Name) = + match this with + | NullableConstraint -> columnName + "__NULL" + | PrimaryKeyConstraint _ -> columnName + "__PK" + | UniqueConstraint -> columnName + "__UNIQUE" + | DefaultConstraint _ -> columnName + "__DEFAULT" + | CollateConstraint _ -> columnName + "__COLLATION" + | ForeignKeyConstraint fk -> + columnName + + "__FK__" + + fk.ReferencesTable.ObjectName.Value + + "__" + + String.concat "_" [ for c in fk.ReferencesColumns -> c.Value.Value ] + +type ColumnConstraint<'t, 'e> = + { Name : Name + ColumnConstraintType : ColumnConstraintType<'t, 'e> + } + +type ColumnDef<'t, 'e> = + { Name : Name + Type : TypeName + Constraints : ColumnConstraint<'t, 'e> array + } + member this.Nullable = + this.Constraints + |> Array.exists (function | { ColumnConstraintType = NullableConstraint } -> true | _ -> false) + +type AlterTableAlteration<'t, 'e> = + | RenameTo of Name + | AddColumn of ColumnDef<'t, 'e> + +type AlterTableStmt<'t, 'e> = + { Table : ObjectName<'t> + Alteration : AlterTableAlteration<'t, 'e> + } + +type TableIndexConstraintType = + | PrimaryKey + | Unique + +type TableIndexConstraintClause<'t, 'e> = + { Type : TableIndexConstraintType + IndexedColumns : (Name * OrderDirection) array + } + +type TableConstraintType<'t, 'e> = + | TableIndexConstraint of TableIndexConstraintClause<'t, 'e> + | TableForeignKeyConstraint of Name WithSource array * ForeignKeyClause<'t> + | TableCheckConstraint of Expr<'t, 'e> + member this.DefaultName() = + match this with + | TableIndexConstraint con -> + String.concat "_" [ for name, _ in con.IndexedColumns -> name.Value ] + + "__" + + (match con.Type with + | PrimaryKey -> "PK" + | Unique -> "UNIQUE") + | TableForeignKeyConstraint (names, fk) -> + String.concat "_" [ for name in names -> name.Value.Value ] + + "__FK__" + + fk.ReferencesTable.ObjectName.Value + + "__" + + String.concat "_" [ for c in fk.ReferencesColumns -> c.Value.Value ] + | TableCheckConstraint _ -> "CHECK" + +type TableConstraint<'t, 'e> = + { Name : Name + TableConstraintType : TableConstraintType<'t, 'e> + } + +type CreateTableDefinition<'t, 'e> = + { Columns : ColumnDef<'t, 'e> array + Constraints : TableConstraint<'t, 'e> array + WithoutRowId : bool + } + member this.AllConstraints() = + seq { + for column in this.Columns do + for constr in column.Constraints -> + constr.Name, Set.singleton column.Name + for constr in this.Constraints -> + constr.Name, + match constr.TableConstraintType with + | TableIndexConstraint constr -> constr.IndexedColumns |> Seq.map fst |> Set.ofSeq + | TableForeignKeyConstraint (names, _) -> names |> Seq.map (fun v -> v.Value) |> Set.ofSeq + | TableCheckConstraint _ -> Set.empty + } + +type CreateTableAs<'t, 'e> = + | CreateAsDefinition of CreateTableDefinition<'t, 'e> + | CreateAsSelect of SelectStmt<'t, 'e> + +type CreateTableStmt<'t, 'e> = + { Temporary : bool + Name : ObjectName<'t> + As : CreateTableAs<'t, 'e> + } + +type CreateIndexStmt<'t, 'e> = + { Unique : bool + IndexName : ObjectName<'t> + TableName : ObjectName<'t> + IndexedColumns : (Name * OrderDirection) array + Where : Expr<'t, 'e> option + } + +type DeleteStmt<'t, 'e> = + { With : WithClause<'t, 'e> option + DeleteFrom : QualifiedTableName<'t> + Where : Expr<'t, 'e> option + OrderBy : OrderingTerm<'t, 'e> array option + Limit : Limit<'t, 'e> option + } + +type UpdateOr = + | UpdateOrRollback + | UpdateOrAbort + | UpdateOrReplace + | UpdateOrFail + | UpdateOrIgnore + +type UpdateStmt<'t, 'e> = + { With : WithClause<'t, 'e> option + UpdateTable : QualifiedTableName<'t> + Or : UpdateOr option + Set : (Name WithSource * Expr<'t, 'e>) array + Where : Expr<'t, 'e> option + OrderBy : OrderingTerm<'t, 'e> array option + Limit : Limit<'t, 'e> option + } + +type InsertOr = + | InsertOrRollback + | InsertOrAbort + | InsertOrReplace + | InsertOrFail + | InsertOrIgnore + +type InsertStmt<'t, 'e> = + { With : WithClause<'t, 'e> option + Or : InsertOr option + InsertInto : ObjectName<'t> + Columns : Name WithSource array option + Data : SelectStmt<'t, 'e> option // either select/values, or "default values" if none + } + +type CreateViewStmt<'t, 'e> = + { Temporary : bool + ViewName : ObjectName<'t> + ColumnNames : Name WithSource array option + AsSelect : SelectStmt<'t, 'e> + } + +type DropObjectType = + | DropIndex + | DropTable + | DropView + +type DropObjectStmt<'t> = + { Drop : DropObjectType + ObjectName : ObjectName<'t> + } + +type VendorStmtFragment<'t, 'e> = + | VendorEmbeddedExpr of Expr<'t, 'e> + | VendorRaw of string + +type VendorStmt<'t, 'e> = + { VendorName : Name WithSource + Fragments : VendorStmtFragment<'t, 'e> array + ImaginaryStmts : Stmt<'t, 'e> array option + } + +and Stmt<'t, 'e> = + | AlterTableStmt of AlterTableStmt<'t, 'e> + | CreateIndexStmt of CreateIndexStmt<'t, 'e> + | CreateTableStmt of CreateTableStmt<'t, 'e> + | CreateViewStmt of CreateViewStmt<'t, 'e> + | DeleteStmt of DeleteStmt<'t, 'e> + | DropObjectStmt of DropObjectStmt<'t> + | InsertStmt of InsertStmt<'t, 'e> + | SelectStmt of SelectStmt<'t, 'e> + | UpdateStmt of UpdateStmt<'t, 'e> + | BeginStmt + | CommitStmt + | RollbackStmt + +type TotalStmt<'t, 'e> = + | CoreStmt of Stmt<'t, 'e> + | VendorStmt of VendorStmt<'t, 'e> + member this.CoreStmts() = + match this with + | CoreStmt stmt -> Seq.singleton stmt + | VendorStmt { ImaginaryStmts = None } -> Seq.empty + | VendorStmt { ImaginaryStmts = Some stmts } -> stmts :> _ seq + member this.SelectStmts() = + this.CoreStmts() + |> Seq.choose (function | SelectStmt s -> Some s | _ -> None) + +type ExprType = ExprType +type Expr = Expr +type InExpr = InExpr +type CollationExpr = CollationExpr +type BetweenExpr = BetweenExpr +type SimilarityExpr = SimilarityExpr +type BinaryExpr = BinaryExpr +type UnaryExpr = UnaryExpr +type ObjectName = ObjectName +type ColumnName = ColumnName +type InSet = InSet +type CaseExpr = CaseExpr +type CastExpr = CastExpr +type FunctionArguments = FunctionArguments +type FunctionInvocationExpr = FunctionInvocationExpr + +type WithClause = WithClause +type CommonTableExpression = CommonTableExpression +type CompoundExprCore = CompoundExprCore +type CompoundExpr = CompoundExpr +type CompoundTermCore = CompoundTermCore +type CompoundTerm = CompoundTerm +type CreateTableDefinition = CreateTableDefinition +type CreateTableStmt = CreateTableStmt +type SelectCore = SelectCore +type Join = Join +type JoinConstraint = JoinConstraint +type GroupBy = GroupBy +type Limit = Limit +type OrderingTerm = OrderingTerm +type ResultColumnCase = ResultColumnCase +type ResultColumn = ResultColumn +type ResultColumns = ResultColumns +type TableOrSubquery = TableOrSubquery +type TableExprCore = TableExprCore +type TableExpr = TableExpr +type TableInvocation = TableInvocation +type SelectStmt = SelectStmt +type ColumnConstraint = ColumnConstraint +type ColumnDef = ColumnDef +type AlterTableStmt = AlterTableStmt +type AlterTableAlteration = AlterTableAlteration +type CreateIndexStmt = CreateIndexStmt +type TableIndexConstraintClause = TableIndexConstraintClause +type TableConstraint = TableConstraint +type CreateViewStmt = CreateViewStmt +type QualifiedTableName = QualifiedTableName +type DeleteStmt = DeleteStmt +type DropObjectStmt = DropObjectStmt +type UpdateStmt = UpdateStmt +type InsertStmt = InsertStmt +type VendorStmt = VendorStmt +type Stmt = Stmt +type TotalStmt = TotalStmt +type TotalStmts = TotalStmt IReadOnlyList diff --git a/Rezoom.SQL.Compiler/ASTMapping.fs b/Rezoom.SQL.Compiler/ASTMapping.fs new file mode 100644 index 0000000..310de76 --- /dev/null +++ b/Rezoom.SQL.Compiler/ASTMapping.fs @@ -0,0 +1,341 @@ +namespace Rezoom.SQL.Compiler +open System +open System.Collections.Generic + +type ASTMapping<'t1, 'e1, 't2, 'e2>(mapT : 't1 -> 't2, mapE : 'e1 -> 'e2) = + member this.Binary(binary : BinaryExpr<'t1, 'e1>) = + { Operator = binary.Operator + Left = this.Expr(binary.Left) + Right = this.Expr(binary.Right) + } + member this.Unary(unary : UnaryExpr<'t1, 'e1>) = + { Operator = unary.Operator + Operand = this.Expr(unary.Operand) + } + member this.ObjectName(objectName : ObjectName<'t1>) = + { SchemaName = objectName.SchemaName + ObjectName = objectName.ObjectName + Source = objectName.Source + Info = mapT objectName.Info + } + member this.ColumnName(columnName : ColumnName<'t1>) = + { Table = Option.map this.ObjectName columnName.Table + ColumnName = columnName.ColumnName + } + member this.Cast(cast : CastExpr<'t1, 'e1>) = + { Expression = this.Expr(cast.Expression) + AsType = cast.AsType + } + member this.Collation(collation : CollationExpr<'t1, 'e1>) = + { Input = this.Expr(collation.Input) + Collation = collation.Collation + } + member this.FunctionInvocation(func : FunctionInvocationExpr<'t1, 'e1>) = + { FunctionName = func.FunctionName + Arguments = + match func.Arguments with + | ArgumentWildcard -> ArgumentWildcard + | ArgumentList (distinct, exprs) -> + ArgumentList (distinct, exprs |> rmap this.Expr) + } + member this.Similarity(sim : SimilarityExpr<'t1, 'e1>) = + { Invert = sim.Invert + Operator = sim.Operator + Input = this.Expr(sim.Input) + Pattern = this.Expr(sim.Pattern) + Escape = Option.map this.Expr sim.Escape + } + member this.Between(between : BetweenExpr<'t1, 'e1>) = + { Invert = between.Invert + Input = this.Expr(between.Input) + Low = this.Expr(between.Low) + High = this.Expr(between.High) + } + member this.In(inex : InExpr<'t1, 'e1>) = + { Invert = inex.Invert + Input = this.Expr(inex.Input) + Set = + { Source = inex.Set.Source + Value = + match inex.Set.Value with + | InExpressions exprs -> exprs |> rmap this.Expr |> InExpressions + | InSelect select -> InSelect <| this.Select(select) + | InTable table -> InTable <| this.TableInvocation(table) + | InParameter par -> InParameter par + } + } + member this.Case(case : CaseExpr<'t1, 'e1>) = + { Input = Option.map this.Expr case.Input + Cases = + [| + for whenExpr, thenExpr in case.Cases -> + this.Expr(whenExpr), this.Expr(thenExpr) + |] + Else = + { Source = case.Else.Source + Value = Option.map this.Expr case.Else.Value + } + } + member this.ExprType(expr : ExprType<'t1, 'e1>) : ExprType<'t2, 'e2> = + match expr with + | LiteralExpr lit -> LiteralExpr lit + | BindParameterExpr par -> BindParameterExpr par + | ColumnNameExpr name -> ColumnNameExpr <| this.ColumnName(name) + | CastExpr cast -> CastExpr <| this.Cast(cast) + | CollateExpr collation -> CollateExpr <| this.Collation(collation) + | FunctionInvocationExpr func -> FunctionInvocationExpr <| this.FunctionInvocation(func) + | SimilarityExpr sim -> SimilarityExpr <| this.Similarity(sim) + | BinaryExpr bin -> BinaryExpr <| this.Binary(bin) + | UnaryExpr un -> UnaryExpr <| this.Unary(un) + | BetweenExpr between -> BetweenExpr <| this.Between(between) + | InExpr inex -> InExpr <| this.In(inex) + | ExistsExpr select -> ExistsExpr <| this.Select(select) + | CaseExpr case -> CaseExpr <| this.Case(case) + | ScalarSubqueryExpr select -> ScalarSubqueryExpr <| this.Select(select) + | RaiseExpr raise -> RaiseExpr raise + member this.Expr(expr : Expr<'t1, 'e1>) = + { Value = this.ExprType(expr.Value) + Source = expr.Source + Info = mapE expr.Info + } + member this.TableInvocation(table : TableInvocation<'t1, 'e1>) = + { Table = this.ObjectName(table.Table) + Arguments = table.Arguments |> Option.map (rmap this.Expr) + } + member this.CTE(cte : CommonTableExpression<'t1, 'e1>) = + { Name = cte.Name + ColumnNames = cte.ColumnNames + AsSelect = this.Select(cte.AsSelect) + Info = mapT cte.Info + } + member this.WithClause(withClause : WithClause<'t1, 'e1>) = + { Recursive = withClause.Recursive + Tables = rmap this.CTE withClause.Tables + } + member this.OrderingTerm(orderingTerm : OrderingTerm<'t1, 'e1>) = + { By = this.Expr(orderingTerm.By) + Direction = orderingTerm.Direction + } + member this.Limit(limit : Limit<'t1, 'e1>) = + { Limit = this.Expr(limit.Limit) + Offset = Option.map this.Expr limit.Offset + } + member this.ResultColumn(resultColumn : ResultColumn<'t1, 'e1>) = + let case = + match resultColumn.Case with + | ColumnsWildcard -> ColumnsWildcard + | TableColumnsWildcard tbl -> TableColumnsWildcard tbl + | Column (expr, alias) -> Column (this.Expr(expr), alias) + | ColumnNav nav -> + { Cardinality = nav.Cardinality + Name = nav.Name + Columns = nav.Columns |> Array.map this.ResultColumn + } |> ColumnNav + { Case = case; Source = resultColumn.Source } + member this.ResultColumns(resultColumns : ResultColumns<'t1, 'e1>) = + { Distinct = resultColumns.Distinct + Columns = resultColumns.Columns |> rmap this.ResultColumn + } + member this.TableOrSubquery(table : TableOrSubquery<'t1, 'e1>) = + let tbl = + match table.Table with + | Table (tinvoc, index) -> + Table (this.TableInvocation(tinvoc), index) + | Subquery select -> + Subquery (this.Select(select)) + { Table = tbl + Alias = table.Alias + Info = mapT table.Info + } + member this.JoinConstraint(constr : JoinConstraint<'t1, 'e1>) = + match constr with + | JoinOn expr -> JoinOn <| this.Expr(expr) + | JoinUnconstrained -> JoinUnconstrained + member this.Join(join : Join<'t1, 'e1>) = + { JoinType = join.JoinType + LeftTable = this.TableExpr(join.LeftTable) + RightTable = this.TableExpr(join.RightTable) + Constraint = this.JoinConstraint(join.Constraint) + } + member this.TableExpr(table : TableExpr<'t1, 'e1>) = + { Source = table.Source + Value = + match table.Value with + | TableOrSubquery sub -> TableOrSubquery <| this.TableOrSubquery(sub) + | Join join -> Join <| this.Join(join) + } + member this.GroupBy(groupBy : GroupBy<'t1, 'e1>) = + { By = groupBy.By |> rmap this.Expr + Having = groupBy.Having |> Option.map this.Expr + } + member this.SelectCore(select : SelectCore<'t1, 'e1>) = + { Columns = this.ResultColumns(select.Columns) + From = Option.map this.TableExpr select.From + Where = Option.map this.Expr select.Where + GroupBy = Option.map this.GroupBy select.GroupBy + Info = mapT select.Info + } + member this.CompoundTerm(term : CompoundTerm<'t1, 'e1>) : CompoundTerm<'t2, 'e2> = + { Source = term.Source + Value = + match term.Value with + | Values vals -> + Values (vals |> rmap (fun w -> { Value = rmap this.Expr w.Value; Source = w.Source })) + | Select select -> + Select <| this.SelectCore(select) + Info = mapT term.Info + } + member this.Compound(compound : CompoundExpr<'t1, 'e1>) = + { CompoundExpr.Source = compound.Source + Value = + match compound.Value with + | CompoundTerm term -> CompoundTerm <| this.CompoundTerm(term) + | Union (expr, term) -> Union (this.Compound(expr), this.CompoundTerm(term)) + | UnionAll (expr, term) -> UnionAll (this.Compound(expr), this.CompoundTerm(term)) + | Intersect (expr, term) -> Intersect (this.Compound(expr), this.CompoundTerm(term)) + | Except (expr, term) -> Except (this.Compound(expr), this.CompoundTerm(term)) + } + member this.Select(select : SelectStmt<'t1, 'e1>) : SelectStmt<'t2, 'e2> = + { Source = select.Source + Value = + let select = select.Value + { With = Option.map this.WithClause select.With + Compound = this.Compound(select.Compound) + OrderBy = Option.map (rmap this.OrderingTerm) select.OrderBy + Limit = Option.map this.Limit select.Limit + Info = mapT select.Info + } + } + member this.ForeignKey(foreignKey) = + { ReferencesTable = this.ObjectName(foreignKey.ReferencesTable) + ReferencesColumns = foreignKey.ReferencesColumns + Rules = foreignKey.Rules + Defer = foreignKey.Defer + } + member this.ColumnConstraint(constr : ColumnConstraint<'t1, 'e1>) = + { Name = constr.Name + ColumnConstraintType = + match constr.ColumnConstraintType with + | NullableConstraint -> NullableConstraint + | PrimaryKeyConstraint clause -> PrimaryKeyConstraint clause + | UniqueConstraint -> UniqueConstraint + | DefaultConstraint def -> DefaultConstraint <| this.Expr(def) + | CollateConstraint name -> CollateConstraint name + | ForeignKeyConstraint foreignKey -> ForeignKeyConstraint <| this.ForeignKey(foreignKey) + } + member this.ColumnDef(cdef : ColumnDef<'t1, 'e1>) = + { Name = cdef.Name + Type = cdef.Type + Constraints = rmap this.ColumnConstraint cdef.Constraints + } + member this.Alteration(alteration : AlterTableAlteration<'t1, 'e1>) = + match alteration with + | RenameTo name -> RenameTo name + | AddColumn cdef -> AddColumn <| this.ColumnDef(cdef) + member this.CreateIndex(createIndex : CreateIndexStmt<'t1, 'e1>) = + { Unique = createIndex.Unique + IndexName = this.ObjectName(createIndex.IndexName) + TableName = this.ObjectName(createIndex.TableName) + IndexedColumns = createIndex.IndexedColumns + Where = createIndex.Where |> Option.map this.Expr + } + member this.TableIndexConstraint(constr : TableIndexConstraintClause<'t1, 'e1>) = + { Type = constr.Type + IndexedColumns = constr.IndexedColumns + } + member this.TableConstraint(constr : TableConstraint<'t1, 'e1>) = + { Name = constr.Name + TableConstraintType = + match constr.TableConstraintType with + | TableIndexConstraint clause -> + TableIndexConstraint <| this.TableIndexConstraint(clause) + | TableForeignKeyConstraint (names, foreignKey) -> + TableForeignKeyConstraint (names, this.ForeignKey(foreignKey)) + | TableCheckConstraint expr -> TableCheckConstraint <| this.Expr(expr) + } + member this.CreateTableDefinition(createTable : CreateTableDefinition<'t1, 'e1>) = + { Columns = createTable.Columns |> rmap this.ColumnDef + Constraints = createTable.Constraints |> rmap this.TableConstraint + WithoutRowId = createTable.WithoutRowId + } + member this.CreateTable(createTable : CreateTableStmt<'t1, 'e1>) = + { Temporary = createTable.Temporary + Name = this.ObjectName(createTable.Name) + As = + match createTable.As with + | CreateAsSelect select -> CreateAsSelect <| this.Select(select) + | CreateAsDefinition def -> CreateAsDefinition <| this.CreateTableDefinition(def) + } + member this.CreateView(createView : CreateViewStmt<'t1, 'e1>) = + { Temporary = createView.Temporary + ViewName = this.ObjectName(createView.ViewName) + ColumnNames = createView.ColumnNames + AsSelect = this.Select(createView.AsSelect) + } + member this.QualifiedTableName(qualified : QualifiedTableName<'t1>) = + { TableName = this.ObjectName(qualified.TableName) + IndexHint = qualified.IndexHint + } + member this.Delete(delete : DeleteStmt<'t1, 'e1>) = + { With = Option.map this.WithClause delete.With + DeleteFrom = this.QualifiedTableName(delete.DeleteFrom) + Where = Option.map this.Expr delete.Where + OrderBy = Option.map (rmap this.OrderingTerm) delete.OrderBy + Limit = Option.map this.Limit delete.Limit + } + member this.DropObject(drop : DropObjectStmt<'t1>) = + { Drop = drop.Drop + ObjectName = this.ObjectName(drop.ObjectName) + } + member this.Insert(insert : InsertStmt<'t1, 'e1>) = + { With = Option.map this.WithClause insert.With + Or = insert.Or + InsertInto = this.ObjectName(insert.InsertInto) + Columns = insert.Columns + Data = Option.map this.Select insert.Data + } + member this.Update(update : UpdateStmt<'t1, 'e1>) = + { With = Option.map this.WithClause update.With + UpdateTable = this.QualifiedTableName(update.UpdateTable) + Or = update.Or + Set = update.Set |> rmap (fun (name, expr) -> name, this.Expr(expr)) + Where = Option.map this.Expr update.Where + OrderBy = Option.map (rmap this.OrderingTerm) update.OrderBy + Limit = Option.map this.Limit update.Limit + } + + member this.Stmt(stmt : Stmt<'t1, 'e1>) = + match stmt with + | AlterTableStmt alter -> + AlterTableStmt <| + { Table = this.ObjectName(alter.Table) + Alteration = this.Alteration(alter.Alteration) + } + | CreateIndexStmt index -> CreateIndexStmt <| this.CreateIndex(index) + | CreateTableStmt createTable -> CreateTableStmt <| this.CreateTable(createTable) + | CreateViewStmt createView -> CreateViewStmt <| this.CreateView(createView) + | DeleteStmt delete -> DeleteStmt <| this.Delete(delete) + | DropObjectStmt drop -> DropObjectStmt <| this.DropObject(drop) + | InsertStmt insert -> InsertStmt <| this.Insert(insert) + | SelectStmt select -> SelectStmt <| this.Select(select) + | UpdateStmt update -> UpdateStmt <| this.Update(update) + | BeginStmt -> BeginStmt + | CommitStmt -> CommitStmt + | RollbackStmt -> RollbackStmt + + member this.Vendor(vendor : VendorStmt<'t1, 'e1>) = + let frag = function + | VendorEmbeddedExpr e -> VendorEmbeddedExpr (this.Expr(e)) + | VendorRaw str -> VendorRaw str + { VendorName = vendor.VendorName + Fragments = vendor.Fragments |> rmap frag + ImaginaryStmts = vendor.ImaginaryStmts |> Option.map (rmap this.Stmt) + } + + member this.TotalStmt(stmt : TotalStmt<'t1, 'e1>) = + match stmt with + | CoreStmt core -> this.Stmt(core) |> CoreStmt + | VendorStmt vendor -> VendorStmt <| this.Vendor(vendor) + +type ASTMapping = + static member Stripper() = ASTMapping<_, _, unit, unit>((fun _ -> ()), fun _ -> ()) \ No newline at end of file diff --git a/Rezoom.SQL.Compiler/AggregateChecker.fs b/Rezoom.SQL.Compiler/AggregateChecker.fs new file mode 100644 index 0000000..f9ba0b3 --- /dev/null +++ b/Rezoom.SQL.Compiler/AggregateChecker.fs @@ -0,0 +1,150 @@ +/// Checks that aggregate expressions are used correctly: that is, aggregates are not mixed with non-aggregate +/// expressions of columns unless grouping by those columns. +module private Rezoom.SQL.Compiler.AggregateChecker +open System +open System.Collections.Generic +open Rezoom.SQL.Compiler.InferredTypes + +type private AggReference = + | Aggregate of SourceInfo + | ColumnOutsideAggregate of InfExpr + +let private columnOutside = function + | Aggregate _ -> None + | ColumnOutsideAggregate expr -> Some expr + +let rec private aggReferencesSelectCore (select : InfSelectCore) = + seq { + for col in select.Columns.Columns do + match col.Case with + | Column (ex, _) -> yield! aggReferences ex + | ColumnsWildcard + | TableColumnsWildcard _ + | ColumnNav _ -> bug "Typechecker should've eliminated these column cases" + match select.Where with + | None -> () + | Some where -> + for ref in aggReferences where do + match ref with + | Aggregate source -> + failAt source Error.aggregateInWhereClause + | _ -> () + yield! aggReferences where + } + +and private aggReferencesCompoundTerm (term : InfCompoundTerm) = + match term.Value with + | Values vs -> + seq { + for row in vs do + for col in row.Value do + yield! aggReferences col + } + | Select sel -> + aggReferencesSelectCore sel + +and private aggReferencesCompound (compound : InfCompoundExpr) = + match compound.Value with + | CompoundTerm term -> aggReferencesCompoundTerm term + | Union (expr, term) + | UnionAll (expr, term) + | Intersect (expr, term) + | Except (expr, term) -> + Seq.append (aggReferencesCompound expr) (aggReferencesCompoundTerm term) + +and private aggReferencesSelect (select : InfSelectStmt) = + let select = select.Value + seq { + yield! aggReferencesCompound select.Compound + match select.OrderBy with + | None -> () + | Some orderBy -> + for term in orderBy do yield! aggReferences term.By + match select.Limit with + | None -> () + | Some limit -> + yield! aggReferences limit.Limit + match limit.Offset with + | None -> () + | Some off -> yield! aggReferences off + } + +and private aggReferences (expr : InfExpr) = + match expr.Value with + | ExistsExpr _ + | LiteralExpr _ + | BindParameterExpr _ + | ScalarSubqueryExpr _ // scalar subqueries have been internally checked by typechecker + | RaiseExpr _ -> Seq.empty + | ColumnNameExpr _ -> Seq.singleton (ColumnOutsideAggregate expr) + | InExpr inex -> + seq { + yield! aggReferences inex.Input + match inex.Set.Value with + | InExpressions exs -> yield! Seq.collect aggReferences exs + | InSelect sel -> yield! aggReferencesSelect sel + | InTable _ | InParameter _ -> () + } + | CastExpr cast -> aggReferences cast.Expression + | CollateExpr collate -> aggReferences collate.Input + | FunctionInvocationExpr f -> + let mapping = ASTMapping((fun _ -> ()), fun _ -> ()) + match expr.Info.Function with + | Some funcInfo when mapping.FunctionInvocation(f).Arguments |> funcInfo.Aggregate |> Option.isSome -> + Seq.singleton (Aggregate expr.Source) + | _ -> + match f.Arguments with + | ArgumentWildcard -> Seq.empty + | ArgumentList (_, exprs) -> Seq.collect aggReferences exprs + | SimilarityExpr sim -> + seq { + yield! aggReferences sim.Input + yield! aggReferences sim.Pattern + match sim.Escape with + | Some escape -> + yield! aggReferences escape + | None -> () + } + | BinaryExpr bin -> Seq.append (aggReferences bin.Left) (aggReferences bin.Right) + | UnaryExpr un -> aggReferences un.Operand + | BetweenExpr bet -> + [ aggReferences bet.Input + aggReferences bet.Low + aggReferences bet.High + ] |> Seq.concat + | CaseExpr case -> + seq { + match case.Input with + | Some inp -> yield! aggReferences inp + | None -> () + for whenExpr, thenExpr in case.Cases do + yield! aggReferences whenExpr + yield! aggReferences thenExpr + match case.Else.Value with + | Some els -> yield! aggReferences els + | None -> () + } + +let check (select : InfSelectCore) = + let references = aggReferencesSelectCore select + match select.GroupBy with + | None -> + if references |> Seq.exists (function | Aggregate _ -> true | _ -> false) then + // If we have aggregates, but we're not grouping by anything, we better + // not have columns referenced outside the aggregates. + match references |> Seq.tryPick columnOutside with + | None -> () + | Some { Source = src } -> + failAt src Error.columnNotAggregated + | Some group -> + let legal = group.By |> HashSet + let havingReferences = + match group.Having with + | None -> Seq.empty + | Some having -> aggReferences having + let outside = Seq.append references havingReferences |> Seq.choose columnOutside + for outsideExpr in outside do + if not <| legal.Contains(outsideExpr) then + failAt outsideExpr.Source Error.columnNotGroupedBy + select + diff --git a/Rezoom.ADO.Test/AssemblyInfo.fs b/Rezoom.SQL.Compiler/AssemblyInfo.fs similarity index 86% rename from Rezoom.ADO.Test/AssemblyInfo.fs rename to Rezoom.SQL.Compiler/AssemblyInfo.fs index 9e7f696..7e8bac8 100644 --- a/Rezoom.ADO.Test/AssemblyInfo.fs +++ b/Rezoom.SQL.Compiler/AssemblyInfo.fs @@ -1,4 +1,4 @@ -namespace Rezoom.ADO.Test.AssemblyInfo +namespace Rezoom.SQL.AssemblyInfo open System.Reflection open System.Runtime.CompilerServices @@ -7,11 +7,11 @@ open System.Runtime.InteropServices // General Information about an assembly is controlled through the following // set of attributes. Change these attribute values to modify the information // associated with an assembly. -[] +[] [] [] [] -[] +[] [] [] [] @@ -22,7 +22,7 @@ open System.Runtime.InteropServices [] // The following GUID is for the ID of the typelib if this project is exposed to COM -[] +[] // Version information for an assembly consists of the following four values: // diff --git a/Rezoom.SQL.Compiler/Backend.fs b/Rezoom.SQL.Compiler/Backend.fs new file mode 100644 index 0000000..fb297a5 --- /dev/null +++ b/Rezoom.SQL.Compiler/Backend.fs @@ -0,0 +1,37 @@ +namespace Rezoom.SQL.Compiler +open System +open System.Data +open System.Data.Common +open System.Collections.Generic +open Rezoom.SQL.Mapping +open Rezoom.SQL.Compiler +open FSharp.Quotations + +type IParameterIndexer = + abstract member ParameterIndex : parameter : BindParameter -> int + +type ParameterTransform = + { ParameterType : DbType + ValueTransform : Quotations.Expr -> Quotations.Expr + } + static member Default(columnType : ColumnType) = + let transform (expr : Quotations.Expr) = + let ty = expr.Type + let asObj = Expr.Coerce(expr, typeof) + if ty.IsConstructedGenericType && ty.GetGenericTypeDefinition() = typedefof<_ option> then + let invokeValue = Expr.Coerce(Expr.PropertyGet(expr, ty.GetProperty("Value")), typeof) + <@@ if isNull %%asObj then box DBNull.Value else %%invokeValue @@> + else + <@@ if isNull %%asObj then box DBNull.Value else %%asObj @@> + let ty = columnType.DbType + { ParameterType = ty + ValueTransform = transform + } + +type IBackend = + abstract member InitialModel : Model + abstract member MigrationBackend : Quotations.Expr Migrations.IMigrationBackend> + abstract member ParameterTransform + : columnType : ColumnType -> ParameterTransform + abstract member ToCommandFragments + : indexer : IParameterIndexer * stmts : TTotalStmts -> CommandFragment IReadOnlyList diff --git a/Rezoom.SQL.Compiler/BackendUtilities.fs b/Rezoom.SQL.Compiler/BackendUtilities.fs new file mode 100644 index 0000000..c017911 --- /dev/null +++ b/Rezoom.SQL.Compiler/BackendUtilities.fs @@ -0,0 +1,122 @@ +module Rezoom.SQL.Compiler.BackendUtilities +open System +open System.Data +open System.Data.Common +open System.Text +open System.Globalization +open System.Collections.Generic +open Rezoom.SQL.Mapping +open Rezoom.SQL.Mapping.Migrations +open Rezoom.SQL.Compiler + +type Fragment = CommandFragment +type Fragments = Fragment seq + +let simplifyFragments (fragments : Fragments) = + seq { + let mutable hasWhitespace = false + let builder = StringBuilder() + for fragment in fragments do + match fragment with + | CommandText text -> + ignore <| builder.Append(text) + hasWhitespace <- text.EndsWith(" ") + | Whitespace -> + if not hasWhitespace then ignore <| builder.Append(' ') + hasWhitespace <- true + | Parameter _ + | LocalName _ -> + if builder.Length > 0 then + yield CommandText <| builder.ToString() + ignore <| builder.Clear() + yield fragment + hasWhitespace <- false + if builder.Length > 0 then + yield CommandText <| builder.ToString() + ignore <| builder.Clear() + } + +let ws = Whitespace +let text str = CommandText str + +let join separator (fragments : Fragments seq) = + seq { + let separator = CommandText separator + let mutable first = true + for element in fragments do + if not first then yield separator + else first <- false + yield! element + } + +let join1 separator sequence = join separator (sequence |> Seq.map Seq.singleton) + +type DbMigration(majorVersion : int, name : string) = + [] + member __.MajorVersion = majorVersion + [] + member __.Name = name + member __.ToTuple() = (majorVersion, name) + +type DefaultMigrationBackend(conn : DbConnection) = + abstract member Initialize : unit -> unit + abstract member GetMigrationsRun : unit -> (int * string) seq + abstract member RunMigration : string Migration -> unit + default __.Initialize() = + use cmd = conn.CreateCommand() + cmd.CommandText <- + """ + CREATE TABLE IF NOT EXISTS __RZSQL_MIGRATIONS + ( MajorVersion int + , Name varchar(256) + , UNIQUE (MajorVersion, Name) + ); + """ + ignore <| cmd.ExecuteNonQuery() + default __.GetMigrationsRun() = + use cmd = conn.CreateCommand() + cmd.CommandText <- + """ + SELECT MajorVersion, Name + FROM __RZSQL_MIGRATIONS + """ + use reader = cmd.ExecuteReader() + let entReader = CodeGeneration.ReaderTemplate.Template().CreateReader() + entReader.ProcessColumns(DataReader.columnMap(reader)) + let row = DataReader.DataReaderRow(reader) + while reader.Read() do + entReader.Read(row) + let migrationsRan = entReader.ToEntity() + migrationsRan + |> Seq.map (fun m -> m.ToTuple()) + default __.RunMigration(migration) = + use tx = conn.BeginTransaction() + do + use cmd = conn.CreateCommand() + cmd.CommandText <- migration.Source + ignore <| cmd.ExecuteNonQuery() + do + use cmd = conn.CreateCommand() + cmd.CommandText <- + """ + INSERT INTO __RZSQL_MIGRATIONS + VALUES (@major, @name) + """ + do + let major = cmd.CreateParameter() + major.DbType <- DbType.Int32 + major.ParameterName <- "@major" + major.Value <- box migration.MajorVersion + ignore <| cmd.Parameters.Add(major) + do + let name = cmd.CreateParameter() + name.DbType <- DbType.String + name.ParameterName <- "@name" + name.Value <- box migration.Name + ignore <| cmd.Parameters.Add(name) + ignore <| cmd.ExecuteNonQuery() + tx.Commit() + interface IMigrationBackend with + member this.Initialize() = this.Initialize() + member this.GetMigrationsRun() = this.GetMigrationsRun() + member this.RunMigration(migration) = this.RunMigration(migration) diff --git a/Rezoom.SQL.Compiler/CommandEffect.fs b/Rezoom.SQL.Compiler/CommandEffect.fs new file mode 100644 index 0000000..5f7ed0f --- /dev/null +++ b/Rezoom.SQL.Compiler/CommandEffect.fs @@ -0,0 +1,119 @@ +// A command is a series of SQL statements. +// This module analyzes the effects of commands, including the tables they update, the changes they make to the model, +// and the result sets they output. +namespace Rezoom.SQL.Compiler +open System +open System.Collections.Generic +open Rezoom.SQL.Compiler.InferredTypes + +type CommandEffectCacheInfo = + { Idempotent : bool + // schema name * table name + WriteTables : (Name * Name) IReadOnlyList + ReadTables : (Name * Name) IReadOnlyList + } + +type CommandEffect = + { Statements : TTotalStmt IReadOnlyList + Parameters : (BindParameter * ColumnType) IReadOnlyList + ModelChange : Model option + DestructiveUpdates : bool Lazy + CacheInfo : CommandEffectCacheInfo option Lazy // if we have any vendor stmts this is unknown + } + member this.ResultSets() = + this.Statements + |> Seq.collect (fun s -> s.SelectStmts()) + |> Seq.map (fun s -> s.Value.Info.Table.Query) + static member ParseSQL(descr: string, sql : string) : TotalStmts = + Parser.parseStatements descr sql |> toReadOnlyList + static member OfSQL(model : Model, stmts : TotalStmts) = + let builder = CommandEffectBuilder(model) + for stmt in stmts do + builder.AddTotalStmt(stmt) + builder.CommandEffect() + static member OfSQL(model : Model, descr : string, sql : string) = + catchSource descr sql <| fun () -> + let stmts = CommandEffect.ParseSQL(descr, sql) + CommandEffect.OfSQL(model, stmts) + +and private CommandEffectBuilder(model : Model) = + // shared throughout the whole command, since parameters are too. + let inference = TypeInferenceContext() :> ITypeInferenceContext + let inferredStmts = ResizeArray() + let mutable newModel = None + member private this.AddStmt(stmt : Stmt) = + let model = newModel |? model + let checker = TypeChecker(inference, InferredSelectScope.Root(model)) + let inferredStmt = checker.Stmt(stmt) + newModel <- ModelChange(model, inference).Stmt(inferredStmt) + inferredStmt + member this.AddTotalStmt(stmt : TotalStmt) = + match stmt with + | CoreStmt stmt -> this.AddStmt(stmt) |> CoreStmt |> inferredStmts.Add + | VendorStmt vendor -> + let model = newModel |? model + let checker = TypeChecker(inference, InferredSelectScope.Root(model)) + let frag = function + | VendorEmbeddedExpr e -> VendorEmbeddedExpr (checker.Expr(e)) + | VendorRaw str -> VendorRaw str + let checkedFrags = vendor.Fragments |> rmap frag + let checkedImaginary = vendor.ImaginaryStmts |> Option.map (rmap this.AddStmt) + { VendorName = vendor.VendorName + Fragments = checkedFrags + ImaginaryStmts = checkedImaginary + } |> VendorStmt |> inferredStmts.Add + + static member PerformsDestructiveUpdate(stmt : TStmt) = + match stmt with + | AlterTableStmt { Alteration = AddColumn _ } + | CreateIndexStmt _ + | CreateTableStmt _ + | SelectStmt _ + | BeginStmt + | CommitStmt + | RollbackStmt + | CreateViewStmt _ -> false + | AlterTableStmt { Alteration = RenameTo _ } + | DeleteStmt _ + | DropObjectStmt _ + | InsertStmt _ + | UpdateStmt _ -> true + + static member PerformsDestructiveUpdate(stmt : TTotalStmt) = + match stmt with + | CoreStmt core -> CommandEffectBuilder.PerformsDestructiveUpdate(core) + | VendorStmt { ImaginaryStmts = Some stmts } -> + stmts |> Seq.exists CommandEffectBuilder.PerformsDestructiveUpdate + | VendorStmt { ImaginaryStmts = None } -> false + + member this.CommandEffect() = + let mapping = concreteMapping inference + let stmts = inferredStmts |> Seq.map mapping.TotalStmt |> toReadOnlyList + let pars = + inference.Parameters + |> Seq.map (fun p -> p, inference.Concrete(inference.Variable(p))) + |> toReadOnlyList + let cacheInfo = + lazy ( + let vendorStmts = stmts |> Seq.choose (function | VendorStmt v -> Some v | _ -> None) + if vendorStmts |> Seq.forall (fun v -> Option.isSome v.ImaginaryStmts) then + let references = ReadWriteReferences.references (stmts |> Seq.collect (fun s -> s.CoreStmts())) + let toTuple (ref : SchemaTable) = ref.SchemaName, ref.TableName + let inline selectsIdempotent() = + stmts + |> Seq.collect (fun s -> s.SelectStmts()) + |> Seq.forall (fun s -> s.Value.Info.Idempotent) + { WriteTables = references.TablesWritten |> Seq.map toTuple |> toReadOnlyList + ReadTables = references.TablesRead |> Seq.map toTuple |> toReadOnlyList + Idempotent = references.TablesWritten.Count <= 0 && selectsIdempotent() + } |> Some + else + None + ) + let destructive = lazy (stmts |> Seq.exists CommandEffectBuilder.PerformsDestructiveUpdate) + { Statements = stmts + ModelChange = newModel + Parameters = pars + DestructiveUpdates = destructive + CacheInfo = cacheInfo + } \ No newline at end of file diff --git a/Rezoom.SQL.Compiler/Config.fs b/Rezoom.SQL.Compiler/Config.fs new file mode 100644 index 0000000..a555b5e --- /dev/null +++ b/Rezoom.SQL.Compiler/Config.fs @@ -0,0 +1,121 @@ +module Rezoom.SQL.Compiler.Config +open System +open FParsec + +type ConfigBackend = + | Identity // outputs Rezoom.SQL that can be parsed back + | SQLite + | TSQL + | PostgreSQL + | MySQL + member this.ToBackend() = + match this with + | Identity -> DefaultBackend() :> IBackend + | SQLite -> SQLite.SQLiteBackend() :> IBackend + | TSQL -> TSQL.TSQLBackend() :> IBackend + | _ -> failwithf "Unimplemented backend %A" this // TODO + +type ConfigOptionalStyle = + | CsStyle // optional value types get wrapped in Nullable, optional reference types untouched + | FsStyle // all optional types wrapped in FSharpOption + +type Config = + { /// Which backend to use. + Backend : ConfigBackend + /// Path to the migrations folder relative to the directory the config file resides in. + MigrationsPath : string + /// Connection string name to use at runtime. + ConnectionName : string + /// Type generation style for optionals. + Optionals : ConfigOptionalStyle + } + +let defaultConfig = + { Backend = Identity + MigrationsPath = "." + ConnectionName = "rzsql" + Optionals = FsStyle + } + +module private Parser = + open FParsec.Pipes + + let backend = + %% '"' + -- +.[ %% ci "SQLITE" -|> SQLite + %% [ ci "TSQL"; ci "MSSQL" ] -|> TSQL + %% ci "POSTGRES" -- zeroOrOne * ci "QL" -|> PostgreSQL + %% ci "MYSQL" -|> MySQL + %% ci "RZSQL" -|> Identity + ] + -- '"' + -|> id + + let optionals = + %% '"' + -- +.[ %% ci "C#" -|> CsStyle + %% ci "F#" -|> FsStyle + ] + -- '"' + -|> id + + let stringLiteral = + let escape = + anyOf "\"\\/bfnrt" + |>> function + | 'b' -> '\b' + | 'f' -> '\u000C' + | 'n' -> '\n' + | 'r' -> '\r' + | 't' -> '\t' + | c -> c + + let unicodeEscape = + %% 'u' + -- +.(qty.[4] * hex) + -|> fun hexes -> Int32.Parse(String(hexes)) |> char + + let escapedChar = %% '\\' -- +.[ escape; unicodeEscape ] -|> string + let normalChars = manySatisfy (function | '"' | '\\' -> false | _ -> true) + + %% '"' + -- +.stringsSepBy normalChars escapedChar + -- '"' + -|> id + + let prop (name : string) (parser : Parser<'a, 'u>) = + %% ci ("\"" + name + "\"") + -- spaces + -- ':' + -- spaces + -- +.parser + -- spaces + -|> id + + let property = + %[ + prop "BACKEND" (backend |>> fun backend config -> { config with Backend = backend }) + prop "MIGRATIONS" (stringLiteral |>> fun path config -> { config with MigrationsPath = path }) + prop "CONNECTIONNAME" (stringLiteral |>> fun conn config -> { config with ConnectionName = conn }) + prop "OPTIONALS" (optionals |>> fun opts config -> { config with Optionals = opts }) + ] + + let config : Parser = + let comma = %% ',' -- spaces -|> () + %% '{' + -- spaces + -- +.(qty.[0..] /. comma * property) + -- '}' + -- spaces + -- eof + -|> Seq.fold (|>) defaultConfig + +let parseConfig sourceDescription source = + match runParserOnString Parser.config () sourceDescription source with + | Success (statements, _, _) -> statements + | Failure (reason, err, _) -> + let sourceInfo = SourceInfo.OfPosition(translatePosition err.Position) + failAt sourceInfo reason + +let parseConfigFile path = + parseConfig path (IO.File.ReadAllText(path)) \ No newline at end of file diff --git a/Rezoom.SQL.Compiler/CoreParser.fs b/Rezoom.SQL.Compiler/CoreParser.fs new file mode 100644 index 0000000..87e0d80 --- /dev/null +++ b/Rezoom.SQL.Compiler/CoreParser.fs @@ -0,0 +1,1251 @@ +// Parses our typechecked subset of the SQL language. + +module private Rezoom.SQL.Compiler.CoreParser +open System +open System.Collections.Generic +open System.Globalization +open FParsec +open FParsec.Pipes +open FParsec.Pipes.Precedence +open Rezoom.SQL.Compiler + +/// Get the source position the parser is currently at. +let private sourcePosition = + %% +.p + -%> translatePosition + +/// Wraps any parser with source information. +let withSource (parser : Parser<'a, unit>) = + %% +.sourcePosition + -- +.parser + -- +.sourcePosition + -%> fun startPos value endPos -> + { WithSource.Source = { StartPosition = startPos; EndPosition = endPos } + Value = value + } + +/// A line comment begins with -- and continues through the end of the line. +let private lineComment = + %% "--" -- restOfLine true -|> () + +/// A block comment begins with /* and continues until a trailing */ is found. +/// Nested block comments are not allowed, so additional /* tokens found +/// after the first are ignored. +let private blockComment = + %% "/*" -- skipCharsTillString "*/" true Int32.MaxValue -|> () + +/// Where whitespace is expected, it can be one of... +let private whitespaceUnit = + %[ lineComment // a line comment + blockComment // a block comment + spaces1 // one or more whitespace characters + ] + +/// Optional whitespace: 0 or more whitespace units +let ws = skipMany whitespaceUnit + +/// Add optional trailing whitespace to a parser. +let inline tws parser = %parser .>> ws + +/// Required whitespace: 1 or more whitespace units +let ws1 = skipMany1 whitespaceUnit + +/// A name wrapped in double quotes (standard SQL). +let private quotedName = + let escapedQuote = + %% "\"\"" -|> "\"" // A pair of double quotes escapes a double quote character + let regularChars = + many1Satisfy ((<>) '"') // Any run of non-quote characters is literal + %% '"' -- +.([regularChars; escapedQuote] * qty.[0..]) -- '"' + -|> (String.Concat >> Name) // Glue together the parts of the string + +/// A name wrapped in square brackets (T-SQL style). +let private bracketedName = + let escapedBracket = + %% "]]" -|> "]" // A pair of right brackets escapes a right bracket character + let regularChars = + many1Satisfy ((<>) ']') // Any run of non-bracket characters is literal + %% '[' -- +.([regularChars; escapedBracket] * qty.[0..]) -- ']' + -|> (String.Concat >> Name) + +/// A name wrapped in backticks (MySQL style) +let private backtickedName = + let escapedTick = + %% "``" -|> "`" // A pair of backticks escapes a backtick character + let regularChars = + many1Satisfy ((<>) '`') // Any run of non-backtick characters is literal + %% '`' -- +.([regularChars; escapedTick] * qty.[0..]) -- '`' + -|> (String.Concat >> Name) + +let private sqlKeywords = + [ "ADD"; "ALL"; "ALTER"; + "AND"; "AS"; + "BETWEEN"; "CASE"; "CHECK"; "COLLATE"; + "COMMIT"; "CONFLICT"; "CONSTRAINT"; "CREATE"; "CROSS"; + "DEFAULT"; "DEFERRABLE"; "DELETE"; + "DISTINCT"; "DROP"; "ELSE"; "ESCAPE"; "EXCEPT"; + "EXISTS"; "FOREIGN"; "FROM"; + "FULL"; "GLOB"; "GROUP"; "HAVING"; "IN"; + "INNER"; "INSERT"; + "INTERSECT"; "INTO"; "IS"; "ISNULL"; "JOIN"; "LEFT"; + "LIMIT"; "NATURAL"; "NOT"; "NOTNULL"; "NULL"; + "ON"; "OR"; "ORDER"; "OUTER"; "PRIMARY"; + "REFERENCES"; + "RIGHT"; + "SELECT"; "SET"; "TABLE"; "THEN"; + "TO"; "TRANSACTION"; "UNION"; "UNIQUE"; "UPDATE"; "USING"; + "VALUES"; "WHEN"; "WHERE"; + // Note: we don't include TEMP in this list because it is a schema name. + ] |> fun kws -> + HashSet(kws, StringComparer.OrdinalIgnoreCase) + // Since SQL is case-insensitive, be sure to ignore case + // in this hash set. + +let private isInitialIdentifierCharacter c = + c = '_' + || c >= 'a' && c <= 'z' + || c >= 'A' && c <= 'Z' + +let private isFollowingIdentifierCharacter c = + isInitialIdentifierCharacter c + || c >= '0' && c <= '9' + || c = '$' + +let private unquotedNameOrKeyword = + many1Satisfy2 isInitialIdentifierCharacter isFollowingIdentifierCharacter + |>> Name + +/// A plain, unquoted name. +let private unquotedName = + unquotedNameOrKeyword >>=? fun ident -> + if sqlKeywords.Contains(ident.ToString()) then + fail (Error.reservedKeywordAsName ident) + else + preturn ident + +let name = + %[ quotedName + bracketedName + backtickedName + unquotedName + ] + +let private stringLiteral = + (let escapedQuote = + %% "''" -|> "'" // A pair of single quotes escapes a single quote character + let regularChars = + many1Satisfy ((<>) '\'') // Any run of non-quote characters is literal + %% '\'' -- +.([regularChars; escapedQuote] * qty.[0..]) -- '\'' + -|> String.Concat) + "string-literal" + +let private nameOrKeyword = + %[ quotedName + bracketedName + backtickedName + unquotedNameOrKeyword + ] + +let private objectName = + (%% +.sourcePosition + -- +.nameOrKeyword + -- ws + -- +.(zeroOrOne * (%% '.' -- ws -? +.nameOrKeyword -- ws -|> id)) + -- +.sourcePosition + -|> fun pos1 name1 name2 pos2 -> + let pos = { StartPosition = pos1; EndPosition = pos2 } + match name2 with + | None -> + { Source = pos; SchemaName = None; ObjectName = name1; Info = () } + | Some name2 -> + { Source = pos; SchemaName = Some name1; ObjectName = name2; Info = () }) + "object-name" + +let private columnName = + (qty.[1..3] / tws '.' * tws name + |> withSource + |>> fun { Value = names; Source = src } -> + match names.Count with + | 1 -> { Table = None; ColumnName = names.[0] } + | 2 -> + { Table = Some { Source = src; SchemaName = None; ObjectName = names.[0]; Info = () } + ColumnName = names.[1] + } + | 3 -> + { Table = Some { Source = src; SchemaName = Some names.[0]; ObjectName = names.[1]; Info = () } + ColumnName = names.[2] + } + | _ -> failwith "Unreachable") + "column-name" + +let private namedBindParameter = + %% '@' + -- +.unquotedNameOrKeyword + -|> fun name -> NamedParameter name + +let private bindParameter = namedBindParameter "bind-parameter" + +let private kw str = + %% ci str + -? notFollowedByL (satisfy isFollowingIdentifierCharacter) str + -- ws + -|> () + +let private nullLiteral = + %% kw "NULL" -|> NullLiteral + +let private booleanLiteral = + %[ %% kw "TRUE" -|> BooleanLiteral true + %% kw "FALSE" -|> BooleanLiteral false + ] + +let private blobLiteral = + let octet = + %% +.(qty.[2] * hex) + -|> fun pair -> Byte.Parse(String(pair), NumberStyles.HexNumber) + (%% ['x';'X'] + -? '\'' + -- +.(octet * qty.[0..]) + -- '\'' + -|> (Seq.toArray >> BlobLiteral)) + "blob-literal" + +let private dateTimeishLiteral = + let digit = digit |>> fun c -> int c - int '0' + let digits n = + qty.[n] * digit |>> Array.fold (fun acc next -> acc * 10 + next) 0 + let date = %% +.digits 4 -- '-' -- +.digits 2 -- '-' -- +.digits 2 -%> auto + let time = %% ci 'T' -- +.digits 2 -- ':' -- +.digits 2 -- ':' -- +.digits 2 -%> auto + let ms = + %% '.' -- +.(qty.[1..3] * digit) + -|> fun ds -> + let n = Seq.fold (fun acc next -> acc * 10 + next) 0 ds + let delta = ds.Count - 3 + if delta > 0 then n / pown 10 delta + elif delta < 0 then n * pown 10 (-delta) + else n + let offsetPart = + %% +.[ %% '+' -|> 1; %% '-' -|> -1 ] + -- +.digits 2 + -- ':' + -- +.digits 2 + -%> auto + let timePart = + %% +.time + -- +.(zeroOrOne * ms) + -- +.(zeroOrOne * offsetPart) + -%> auto + %% +.date + ?- +.(zeroOrOne * timePart) + -|> fun (year, month, day) time -> + match time with + | None -> DateTime(year, month, day, 0, 0, 0, DateTimeKind.Utc) |> DateTimeLiteral + | Some ((hour, minute, second), ms, offset) -> + let ms = ms |? 0 + let dateTime = DateTime(year, month, day, hour, minute, second, ms) + match offset with + | None -> + DateTime.SpecifyKind(dateTime, DateTimeKind.Utc) + |> DateTimeLiteral + | Some (sign, offsetHour, offsetMinute) -> + DateTimeOffset(dateTime, TimeSpan(offsetHour * sign, offsetMinute * sign, 0)) + |> DateTimeOffsetLiteral + +let private numericLiteral = + let options = + NumberLiteralOptions.AllowHexadecimal + ||| NumberLiteralOptions.AllowFraction + ||| NumberLiteralOptions.AllowFractionWOIntegerPart + ||| NumberLiteralOptions.AllowExponent + numberLiteral options "numeric-literal" >>= fun lit -> + if lit.IsInteger then + lit.String |> uint64 |> IntegerLiteral |> preturn + else if lit.IsHexadecimal then + fail "hexadecimal floats are not permitted" + else + lit.String |> float |> FloatLiteral |> preturn + +let private signedNumericLiteral = + let sign = + %[ %% '+' -|> 1 + %% '-' -|> -1 + preturn 0 + ] + %% +.sign + -- ws + -- +.numericLiteral + -|> fun sign value -> { Sign = sign; Value = value } + +let private literal = + %[ booleanLiteral + nullLiteral + blobLiteral + %% +.stringLiteral -|> StringLiteral + dateTimeishLiteral + %% +.numericLiteral -|> NumericLiteral + ] + +let private typeName = + let maxBound = %% '(' -- ws -- +.p -- ws -- ')' -- ws -%> id + %[ %% kw "STRING" -- +.(zeroOrOne * maxBound) -%> StringTypeName + %% kw "BINARY" -- +.(zeroOrOne * maxBound) -%> StringTypeName + %% kw "INT8" -%> IntegerTypeName Integer8 + %% kw "INT16" -%> IntegerTypeName Integer16 + %% kw "INT32" -%> IntegerTypeName Integer32 + %% kw "INT64" -%> IntegerTypeName Integer64 + %% kw "INT" -%> IntegerTypeName Integer32 + %% kw "FLOAT32" -%> FloatTypeName Float32 + %% kw "FLOAT64" -%> FloatTypeName Float64 + %% kw "FLOAT" -%> FloatTypeName Float64 + %% kw "DECIMAL" -%> DecimalTypeName + %% kw "BOOL" -%> BooleanTypeName + %% kw "DATETIME" -%> DateTimeTypeName + %% kw "DATETIMEOFFSET" -%> DateTimeOffsetTypeName + ] + +let private cast expr = + %% kw "CAST" + -- '(' + -- ws + -- +.expr + -- kw "AS" + -- +. typeName + -- ws + -- ')' + -|> fun ex typeName -> { Expression = ex; AsType = typeName } + +let private functionArguments (expr : Parser, unit>) = + %[ %% '*' -- ws -|> ArgumentWildcard + %% +.((%% kw "DISTINCT" -- ws -|> Distinct) * zeroOrOne) + -- +.(qty.[0..] / tws ',' * expr) + -|> fun distinct args -> ArgumentList (distinct, args.ToArray()) + ] + +let private functionInvocation expr = + %% +.nameOrKeyword + -- ws + -? '(' + -- ws + -- +.functionArguments expr + -- ')' + -|> fun name args -> { FunctionName = name; Arguments = args } + +let private case expr = + let whenClause = + %% kw "WHEN" + -- +.expr + -- kw "THEN" + -- +.expr + -%> auto + let elseClause = + %% kw "ELSE" + -- +.expr + -|> id + let whenForm = + %% +.(whenClause * qty.[1..]) + -- +.withSource (elseClause * zeroOrOne) + -- kw "END" + -|> fun cases els -> { Input = None; Cases = cases.ToArray(); Else = els } + let ofForm = + %% +.expr + -- +.whenForm + -|> fun ofExpr case -> { case with Input = Some ofExpr } + %% kw "CASE" + -- +.[ whenForm; ofForm ] + -|> id + +let expr, private exprImpl = createParserForwardedToRef, unit>() +let private selectStmt, private selectStmtImpl = createParserForwardedToRef, unit>() + +let private binary op e1 e2 = + { Expr.Value = BinaryExpr { BinaryExpr.Operator = op; Left = e1; Right = e2 } + Source = SourceInfo.Between(e1.Source, e2.Source) + Info = () + } + +let private unary op e1 = + { Expr.Value = UnaryExpr { UnaryExpr.Operator = op; Operand = e1 } + Source = e1.Source + Info = () + } + +let private tableInvocation = + let args = + %% '(' -- ws -- +.(qty.[0..] / tws ',' * expr) -- ')' -|> id + %% +.objectName + -- ws + -- +.(args * zeroOrOne) + -|> fun name args -> { Table = name; Arguments = args |> Option.map (fun r -> r.ToArray()) } + +let private collateOperator = + %% kw "COLLATE" + -- +.withSource name + -|> fun collation expr -> + { Expr.Value = CollateExpr { Input = expr; Collation = collation.Value } + Source = collation.Source + Info = () + } + +let private isOperator = + %% kw "IS" + -- +.(zeroOrOne * kw "NOT") + -|> function + | Some () -> binary IsNot + | None -> binary Is + +let private inOperator = + %% +.(zeroOrOne * kw "NOT") + -? +.withSource (kw "IN") + -- +.withSource + %[ %% '(' + -- ws + -- + +.[ + %% +.selectStmt -|> InSelect + %% +.(qty.[0..] / tws ',' * expr) -|> (fun exs -> exs.ToArray() |> InExpressions) + ] + -- ')' + -|> id + %% +.bindParameter -|> InParameter + %% +.tableInvocation -|> InTable + ] + -|> fun invert op inSet left -> + { Expr.Source = op.Source + Value = InExpr { Invert = Option.isSome invert; Input = left; Set = inSet } + Info = () + } + +let private similarityOperator = + let similar invert (op : SimilarityOperator WithSource) left right escape = + { Expr.Source = op.Source + Value = + { Invert = Option.isSome invert + Operator = op.Value + Input = left + Pattern = right + Escape = escape + } |> SimilarityExpr + Info = () + } + let op = + %[ %% kw "LIKE" -|> Like + %% kw "GLOB" -|> Glob + %% kw "MATCH" -|> Match + %% kw "REGEXP" -|> Regexp + ] |> withSource + %% +.(zeroOrOne * kw "NOT") + -? +.op + -|> similar + +let private notNullOperator = + %[ + kw "NOTNULL" + %% kw "NOT" -? kw "NULL" -|> () + ] + |> withSource + |>> fun op left -> + { Expr.Source = op.Source + Value = UnaryExpr { Operator = NotNull; Operand = left } + Info = () + } + +let private betweenOperator = + let between invert input low high = + { Invert = Option.isSome invert + Input = input + Low = low + High = high + } + %% +.(zeroOrOne * kw "NOT") + -? +.withSource (kw "BETWEEN") + -|> fun invert op input low high -> + { Expr.Source = op.Source + Value = BetweenExpr (between invert input low high) + Info = () + } + +let private raiseTrigger = + %% kw "RAISE" + -- '(' + -- ws + -- +.[ %% kw "IGNORE" -|> RaiseIgnore + %% kw "ROLLBACK" -- ',' -- ws -- +.stringLiteral -- ws -|> RaiseRollback + %% kw "ABORT" -- ',' -- ws -- +.stringLiteral -- ws -|> RaiseAbort + %% kw "FAIL" -- ',' -- ws -- +.stringLiteral -- ws -|> RaiseFail + ] + -- ')' + -|> RaiseExpr + +let private term (expr : Parser, unit>) = + let parenthesized = + %[ + %% +.selectStmt -|> ScalarSubqueryExpr + %% +.expr -|> fun e -> e.Value + ] + %% +.sourcePosition + -- +.[ + %% '(' -- ws -- +.parenthesized -- ')' -|> id + %% kw "EXISTS" -- ws -- '(' -- ws -- +.selectStmt -- ')' -|> ExistsExpr + %% +.literal -|> LiteralExpr + %% +.bindParameter -|> BindParameterExpr + %% +.cast expr -|> CastExpr + %% +.case expr -|> CaseExpr + raiseTrigger + %% +.functionInvocation expr -|> FunctionInvocationExpr + %% +.columnName -|> ColumnNameExpr + ] + -- +.sourcePosition + -%> fun startPos value endPos -> + { Expr.Value = value + Source = { StartPosition = startPos; EndPosition = endPos } + Info = () + } + +let private operators = [ + [ + postfixc collateOperator + ] + [ + prefix (kw "NOT") <| unary Not + prefix '~' <| unary BitNot + prefix '-' <| unary Negative + prefix '+' id + ] + [ + infixl "||" <| binary Concatenate + ] + [ + infixl '*' <| binary Multiply + infixl '/' <| binary Divide + infixl '%' <| binary Modulo + ] + [ + infixl '+' <| binary Add + infixl '-' <| binary Subtract + ] + [ + infixl "<<" <| binary BitShiftLeft + infixl ">>" <| binary BitShiftRight + infixl '&' <| binary BitAnd + infixl '|' <| binary BitOr + ] + [ + infixl ">=" <| binary GreaterThanOrEqual + infixl "<=" <| binary LessThanOrEqual + infixl (%% '<' -? notFollowedBy (skipChar '>') -|> ()) <| binary LessThan + infixl '>' <| binary GreaterThan + ] + [ + infixl "==" <| binary Equal + infixl "=" <| binary Equal + infixl "!=" <| binary NotEqual + infixl "<>" <| binary NotEqual + infixlc isOperator + ternaryolc similarityOperator (kw "ESCAPE") + postfix (kw "ISNULL") <| unary IsNull + postfixc notNullOperator + postfixc inOperator + ternarylc betweenOperator (kw "AND") + ] + [ + infixl (kw "AND") <| binary And + ] + [ + infixl (kw "OR") <| binary Or + ] +] + +do + exprImpl := + { Whitespace = ws + Term = term + Operators = operators + } |> Precedence.expression + +let private parenthesizedColumnNames = + %% '(' + -- ws + -- +.(qty.[0..] / tws ',' * tws (withSource name)) + -- ')' + -- ws + -|> fun vs -> vs.ToArray() + +let private commonTableExpression = + %% +.nameOrKeyword + -- ws + -- +.(zeroOrOne * withSource parenthesizedColumnNames) + -- kw "AS" + -- '(' + -- ws + -- +.selectStmt + -- ')' + -- ws + -|> fun table cols asSelect -> + { Name = table + ColumnNames = cols + AsSelect = asSelect + Info = () + } + +let private withClause = + %% kw "WITH" + -- +.(zeroOrOne * kw "RECURSIVE") + -- +.(qty.[1..] / tws ',' * commonTableExpression) + -|> fun recurs ctes -> + { Recursive = Option.isSome recurs; Tables = ctes.ToArray() } + +let private asAlias = + %% (zeroOrOne * kw "AS") + -? +.name + -|> id + +let private resultColumnNavCardinality = + %[ + %% kw "MANY" -|> NavMany + %% kw "OPTIONAL" -|> NavOptional + %% kw "ONE" -|> NavOne + ] + +let private resultColumnCase (resultColumns : Parser<_, unit>) = + let nav = + %% +.resultColumnNavCardinality + -? +.nameOrKeyword + -- ws + -- '(' + -- ws + -- +.resultColumns + -- ')' + -- ws + -|> fun cardinality name cols -> + { Cardinality = cardinality + Name = name + Columns = cols + } |> ColumnNav + %% +.[ + %% '*' -|> ColumnsWildcard + nav + %% +.name -- '.' -? '*' -|> TableColumnsWildcard + %% +.expr -- +.(asAlias * zeroOrOne) -|> fun ex alias -> Column (ex, alias) + ] -- ws -|> id + +let private resultColumns = + precursive <| fun resultColumns -> + let column = + %% +.withSource (resultColumnCase resultColumns) + -|> fun case -> + { ResultColumn.Case = case.Value + Source = case.Source + } + %% +.(qty.[1..] /. tws ',' * column) + -|> Seq.toArray + +let private selectColumns = + %% kw "SELECT" + -- +.[ %% kw "DISTINCT" -|> Some DistinctColumns + %% kw "ALL" -|> Some AllColumns + preturn None + ] + -- +.resultColumns + -|> fun distinct cols -> { Distinct = distinct; Columns = cols } + +let private indexHint = + %[ + %% kw "INDEXED" -- kw "BY" -- +.nameOrKeyword -- ws -|> IndexedBy + %% kw "NOT" -- kw "INDEXED" -|> NotIndexed + ] + +let private tableOrSubquery (tableExpr : Parser, unit>) = + let subterm = + %% +.selectStmt + -|> fun select alias -> TableOrSubquery { Table = Subquery select; Alias = alias; Info = () } + let by = + %[ %% +.indexHint -|> fun indexed table -> + TableOrSubquery { Table = Table (table, Some indexed); Alias = None; Info = () } + %% +.(asAlias * zeroOrOne) -- +.(indexHint * zeroOrOne) + -|> fun alias indexed table -> + TableOrSubquery { Table = Table (table, indexed); Alias = alias; Info = () } + ] + + %[ %% +.tableInvocation -- +.by -|> fun table by -> by table + %% '(' -- ws -- +.subterm -- ')' -- ws -- +.(asAlias * zeroOrOne) -|> (<|) + ] + +let private joinType = + %[ + %% kw "LEFT" -- (tws (kw "OUTER") * zeroOrOne) -|> LeftOuter + %% kw "INNER" -|> Inner + %% kw "CROSS" -|> Cross + %% ws -|> Inner + ] + +let private joinConstraint = + %[ + %% kw "ON" -- +.expr -- ws -|> JoinOn + preturn JoinUnconstrained + ] + +let private tableExpr = // parses table expr (with left-associative joins) + precursive <| fun tableExpr -> + let term = tableOrSubquery tableExpr |> withSource + let natural = %% kw "NATURAL" -|> () + let join = + %% +.( + %[ + %% ',' + -|> fun left right constr -> + { JoinType = Inner + LeftTable = left + RightTable = right + Constraint = constr + } |> Join + %% +.(natural * zeroOrOne) -- +.joinType -- kw "JOIN" + -|> fun natural join left right constr -> + let joinType = if Option.isSome natural then Natural join else join + { JoinType = joinType + LeftTable = left + RightTable = right + Constraint = constr + } |> Join + ] |> withSource) + -- ws + -- +.term + -- ws + -- +.joinConstraint + -|> fun f joinTo joinOn left -> { TableExpr.Source = f.Source; Value = f.Value left joinTo joinOn } + %% +.term + -- ws + -- +.(join * qty.[0..]) + -|> Seq.fold (|>) + +let private valuesClause = + let valuesRow = + %% '(' + -- ws + -- +.(qty.[0..] / tws ',' * expr) + -- ')' + -- ws + -|> fun vs -> vs.ToArray() + + %% kw "VALUES" + -- ws + -- +.(qty.[1..] / tws ',' * withSource valuesRow) + -- ws + -|> fun vs -> vs.ToArray() + +let private fromClause = + %% kw "FROM" + -- +.tableExpr + -|> id + +let private whereClause = + %% kw "WHERE" + -- +.expr + -|> id + +let private havingClause = + %% kw "HAVING" + -- +.expr + -|> id + +let private groupByClause = + %% kw "GROUP" + -- kw "BY" + -- +.(qty.[1..] / tws ',' * expr) + -- +.(zeroOrOne * havingClause) + -|> fun by having -> { By = by.ToArray(); Having = having } + +let private selectCore = + %% +.selectColumns + -- +.(fromClause * zeroOrOne) + -- +.(whereClause * zeroOrOne) + -- +.(groupByClause * zeroOrOne) + -|> fun cols table where groupBy -> + { Columns = cols + From = table + Where = where + GroupBy = groupBy + Info = () + } + +let private compoundTerm = + %% +.sourcePosition + -- +.[ %% +.valuesClause -|> Values + %% +.selectCore -|> Select + ] + -- +.sourcePosition + -|> fun pos1 term pos2 -> + { CompoundTerm.Source = { StartPosition = pos1; EndPosition = pos2 } + Value = term + Info = () + } + +let private compoundExpr = + let compoundOperation = + %[ %% kw "UNION" -- +.(zeroOrOne * kw "ALL") -|> function + | Some () -> fun left right -> UnionAll (left, right) + | None -> fun left right -> Union (left, right) + %% kw "INTERSECT" -|> fun left right -> Intersect (left, right) + %% kw "EXCEPT" -|> fun left right -> Except (left, right) + ] |> withSource + let compoundNext = + %% +.compoundOperation + -- +.compoundTerm + -|> fun f right left -> { CompoundExpr.Source = f.Source; Value = f.Value left right } + %% +.(compoundTerm |>> fun t -> { CompoundExpr.Source = t.Source; Value = CompoundTerm t }) + -- +.(compoundNext * qty.[0..]) + -|> Seq.fold (|>) + +let private orderDirection = + %[ + %% kw "DESC" -|> Descending + %% kw "ASC" -|> Ascending + preturn Ascending + ] + +let private orderingTerm = + %% +.expr + -- +.orderDirection + -- ws + -|> fun expr dir -> { By = expr; Direction = dir } + +let private orderBy = + %% kw "ORDER" + -- kw "BY" + -- +.(qty.[1..] / tws ',' * orderingTerm) + -|> fun by -> by.ToArray() + +let private limit = + let offset = + %% [%% ',' -- ws -|> (); kw "OFFSET"] + -- +.expr + -|> id + %% kw "LIMIT" + -- +.expr + -- +.(zeroOrOne * offset) + -|> fun limit offset -> { Limit = limit; Offset = offset } + +let selectStmtWithoutCTE = + %% +.withSource compoundExpr + -- +.(zeroOrOne * orderBy) + -- +.(zeroOrOne * limit) + -|> fun comp orderBy limit cte -> + { WithSource.Source = comp.Source + Value = + { With = cte + Compound = comp.Value + OrderBy = orderBy + Limit = limit + Info = () + } + } + +do + selectStmtImpl := + %% +.(zeroOrOne * withClause) + -? +.selectStmtWithoutCTE + -|> (|>) + +let private foreignKeyRule = + let eventRule = + %% kw "ON" + -- +.[ + %% kw "DELETE" -|> OnDelete + %% kw "UPDATE" -|> OnUpdate + ] + -- +.[ + %% kw "SET" -- +.[ %% kw "NULL" -|> SetNull; %% kw "DEFAULT" -|> SetDefault ] -|> id + %% kw "CASCADE" -|> Cascade + %% kw "RESTRICT" -|> Restrict + %% kw "NO" -- kw "ACTION" -|> NoAction + ] + -|> fun evt handler -> EventRule (evt, handler) + let matchRule = + %% kw "MATCH" + -- +.name + -- ws + -|> MatchRule + %[ eventRule; matchRule ] + + +let private foreignKeyDeferClause = + let initially = + %% kw "INITIALLY" -- +.[ %% kw "DEFERRED" -|> true; %% kw "IMMEDIATE" -|> false ] -|> id + %% +.(zeroOrOne * kw "NOT") + -? kw "DEFERRABLE" + -- +.(zeroOrOne * initially) + -|> fun not init -> { Deferrable = Option.isNone not; InitiallyDeferred = init } + +let private foreignKeyClause = + %% kw "REFERENCES" + -- +.objectName + -- +.parenthesizedColumnNames + -- +.(qty.[0..] * foreignKeyRule) + -- +.(zeroOrOne * foreignKeyDeferClause) + -|> fun table cols rules defer -> + { + ReferencesTable = table + ReferencesColumns = cols + Rules = rules.ToArray() + Defer = defer + } + +let private constraintName = + %% kw "CONSTRAINT" + -- +.name + -- ws + -|> id + +let private primaryKeyClause = + %% kw "PRIMARY" + -- kw "KEY" + -- +.orderDirection + -- ws + -- +.(zeroOrOne * tws (kw "AUTOINCREMENT")) + -|> fun dir auto -> + { + Order = dir + AutoIncrement = Option.isSome auto + } + +let private constraintType = + let signedToExpr (signed : SignedNumericLiteral WithSource) = + let expr = signed.Value.Value |> NumericLiteral |> LiteralExpr + let expr = { Expr.Source = signed.Source; Value = expr; Info = () } + if signed.Value.Sign < 0 then + { Expr.Source = expr.Source; Value = UnaryExpr { Operator = Negative; Operand = expr }; Info = () } + else expr + let defaultValue = + %[ + %% +.withSource signedNumericLiteral -|> signedToExpr + %% +.withSource literal -|> fun lit -> { Source = lit.Source; Value = LiteralExpr lit.Value; Info = () } + %% '(' -- ws -- +.expr -- ')' -|> id + // docs don't mention this, but it works + %% +.withSource name + -|> fun name -> + { Source = name.Source; Value = name.Value.ToString() |> StringLiteral |> LiteralExpr; Info = () } + ] + %[ + %% +.primaryKeyClause -|> PrimaryKeyConstraint + %% kw "NULL" -|> NullableConstraint + %% kw "UNIQUE" -|> UniqueConstraint + %% kw "DEFAULT" -- +.defaultValue -|> DefaultConstraint + %% kw "COLLATE" -- +.name -|> CollateConstraint + %% +.foreignKeyClause -|> ForeignKeyConstraint + ] + +let private columnConstraint = + %% +.(zeroOrOne * constraintName) + -- +.constraintType + -- ws + -|> fun name cty columnName -> + { Name = cty.DefaultName(columnName) + ColumnConstraintType = cty + } + +let private columnDef = + %% +.nameOrKeyword + -- ws + -- +.typeName + -- +.(columnConstraint * qty.[0..]) + -|> fun name typeName constraints -> + { Name = name + Type = typeName + Constraints = constraints |> Seq.map ((|>) name) |> Seq.toArray + } + +let private alterTableStmt = + let renameTo = + %% kw "RENAME" + -- kw "TO" + -- +.name + -|> RenameTo + let addColumn = + %% kw "ADD" + -- zeroOrOne * kw "COLUMN" + -- +.columnDef + -|> AddColumn + %% kw "ALTER" + -- kw "TABLE" + -- +.objectName + -- +.[ renameTo; addColumn ] + -|> fun table alteration -> { Table = table; Alteration = alteration } + +let private tableIndexConstraintType = + %[ + %% kw "PRIMARY" -- kw "KEY" -|> PrimaryKey + %% kw "UNIQUE" -|> Unique + ] + +let private indexedColumns = + %% '(' + -- ws + -- +.(qty.[1..] / tws ',' * (%% +.nameOrKeyword -- +.orderDirection -%> auto)) + -- ')' + -- ws + -|> fun vs -> vs.ToArray() + +let private tableIndexConstraint = + %% +.tableIndexConstraintType + -- +.indexedColumns + -|> fun cty cols -> + { Type = cty; IndexedColumns = cols } + +let private tableConstraintType = + let foreignKey = + %% kw "FOREIGN" + -- kw "KEY" + -- +.parenthesizedColumnNames + -- +.foreignKeyClause + -|> fun columns fk -> TableForeignKeyConstraint (columns, fk) + %[ + %% kw "CHECK" -- '(' -- ws -- +.expr -- ')' -|> TableCheckConstraint + foreignKey + %% +.tableIndexConstraint -|> TableIndexConstraint + ] + +let private tableConstraint = + %% +.(zeroOrOne * constraintName) + -- +.tableConstraintType + -- ws + -|> fun name cty -> + { Name = match name with | Some name -> name | None -> Name(cty.DefaultName()) + TableConstraintType = cty + } + +let private createTableDefinition = + let part = + %[ + %% +.tableConstraint -|> Choice1Of2 + %% +.columnDef -|> Choice2Of2 + ] + %% '(' + -- ws + -- +.(qty.[0..] /. tws ',' * part) + -- ')' + -- ws + -- +.(zeroOrOne * (%% kw "WITHOUT" -- kw "ROWID" -- ws -|> ())) + -|> fun parts without -> + { + Columns = + parts |> Seq.choose (function | Choice2Of2 cdef -> Some cdef | Choice1Of2 _ -> None) |> Seq.toArray + Constraints = + parts |> Seq.choose (function | Choice1Of2 ct -> Some ct | Choice2Of2 _ -> None) |> Seq.toArray + WithoutRowId = Option.isSome without + } + +let private createTableAs = + %[ %% kw "AS" -- +.selectStmt -|> CreateAsSelect + %% +.createTableDefinition -|> CreateAsDefinition + ] + +let private temporary = %(zeroOrOne * [kw "TEMPORARY"; kw "TEMP"]) + +let private createTableStmt = + %% kw "CREATE" + -- +.temporary + -? kw "TABLE" + -- +.objectName + -- +.createTableAs + -|> fun temp name createAs -> + { Temporary = Option.isSome temp + Name = name + As = createAs + } + +let private analyzeStmt = + %% kw "ANALYZE" + -- +.(zeroOrOne * objectName) + -|> id + +let private attachStmt = + %% kw "ATTACH" + -- zeroOrOne * kw "DATABASE" + -- +.expr + -- kw "AS" + -- +.nameOrKeyword + -|> fun ex schemaName -> ex, schemaName + +let private beginStmt = + %% kw "BEGIN" + -- zeroOrOne * kw "TRANSACTION" + -|> BeginStmt + +let private commitStmt = + %% [ kw "COMMIT"; kw "END" ] + -- zeroOrOne * kw "TRANSACTION" + -|> CommitStmt + +let private rollbackStmt = + %% kw "ROLLBACK" + -- zeroOrOne * kw "TRANSACTION" + -|> RollbackStmt + +let private createIndexStmt = + %% kw "CREATE" + -- +.(zeroOrOne * kw "UNIQUE") + -? kw "INDEX" + -- +.objectName + -- kw "ON" + -- +.objectName + -- +.indexedColumns + -- +.(zeroOrOne * (%% kw "WHERE" -- +.expr -|> id)) + -|> fun unique indexName tableName cols whereExpr -> + { Unique = Option.isSome unique + IndexName = indexName + TableName = tableName + IndexedColumns = cols + Where = whereExpr + } + +let private qualifiedTableName = + %% +.objectName + -- +.(zeroOrOne * indexHint) + -|> fun tableName hint -> + { TableName = tableName + IndexHint = hint + } + +let private deleteStmt = + %% kw "DELETE" + -- kw "FROM" + -- +.qualifiedTableName + -- +.(zeroOrOne * whereClause) + -- +.(zeroOrOne * orderBy) + -- +.(zeroOrOne * limit) + -|> fun fromTable where orderBy limit withClause -> + { With = withClause + DeleteFrom = fromTable + Where = where + OrderBy = orderBy + Limit = limit + } |> DeleteStmt + +let private updateOr = + %% kw "OR" + -- +.[ + %% kw "ROLLBACK" -|> UpdateOrRollback + %% kw "ABORT" -|> UpdateOrAbort + %% kw "REPLACE" -|> UpdateOrReplace + %% kw "FAIL" -|> UpdateOrFail + %% kw "IGNORE" -|> UpdateOrIgnore + ] + -|> id + +let private updateStmt = + let setColumn = + %% +.withSource name + -- ws + -- '=' + -- ws + -- +.expr + -|> fun name expr -> name, expr + %% kw "UPDATE" + -- +.(zeroOrOne * updateOr) + -- +.qualifiedTableName + -- kw "SET" + -- +.(qty.[1..] / tws ',' * setColumn) + -- +.(zeroOrOne * whereClause) + -- +.(zeroOrOne * orderBy) + -- +.(zeroOrOne * limit) + -|> fun updateOr table sets where orderBy limit withClause -> + { With = withClause + UpdateTable = table + Or = updateOr + Set = sets.ToArray() + Where = where + OrderBy = orderBy + Limit = limit + } |> UpdateStmt + +let private insertOr = + let orPart = + %% kw "OR" + -- +.[ + %% kw "REPLACE" -|> InsertOrReplace + %% kw "ROLLBACK" -|> InsertOrRollback + %% kw "ABORT" -|> InsertOrAbort + %% kw "FAIL" -|> InsertOrFail + %% kw "IGNORE" -|> InsertOrIgnore + ] + -|> id + %[ %% kw "REPLACE" -|> Some InsertOrReplace + %% kw "INSERT" -- +.(zeroOrOne * orPart) -|> id + ] + +let private insertStmt = + %% +.insertOr + -- kw "INTO" + -- +.objectName + -- +.(zeroOrOne * parenthesizedColumnNames) + -- +.[ + %% kw "DEFAULT" -- kw "VALUES" -|> None + %% +.selectStmt -|> Some + ] + -|> fun insert table cols data withClause -> + { With = withClause + Or = insert + InsertInto = table + Columns = cols + Data = data + } |> InsertStmt + +let private createViewStmt = + %% kw "CREATE" + -- +.temporary + -? kw "VIEW" + -- +.objectName + -- +.(zeroOrOne * parenthesizedColumnNames) + -- kw "AS" + -- +.selectStmt + -|> fun temp viewName cols asSelect -> + { Temporary = Option.isSome temp + ViewName = viewName + ColumnNames = cols + AsSelect = asSelect + } + +let private ifExists = + %[ %% kw "IF" -- kw "EXISTS" -|> true + preturn false + ] + +let private dropObjectType = + %[ %% kw "INDEX" -|> DropIndex + %% kw "TABLE" -|> DropTable + %% kw "VIEW" -|> DropView + ] + +let private dropObjectStmt = + %% kw "DROP" + -? +.dropObjectType + -- +.objectName + -|> fun dropType name -> + { Drop = dropType; ObjectName = name } + +let private cteStmt = + %% +.(zeroOrOne * withClause) + -- +.[ + deleteStmt + insertStmt + updateStmt + %% +.selectStmtWithoutCTE -|> + fun select withClause -> select withClause |> SelectStmt + ] + -|> (|>) + +let coreStmt = + %[ %% +.alterTableStmt -|> AlterTableStmt + %% +.createIndexStmt -|> CreateIndexStmt + %% +.createTableStmt -|> CreateTableStmt + %% +.createViewStmt -|> CreateViewStmt + %% +.dropObjectStmt -|> DropObjectStmt + cteStmt + beginStmt + commitStmt + rollbackStmt + ] + +let coreStmts = + %% ws + -- +.(qty.[0..] /. tws ';' * tws coreStmt) + -|> fun s -> s.ToArray() \ No newline at end of file diff --git a/Rezoom.SQL.Compiler/DefaultBackend.fs b/Rezoom.SQL.Compiler/DefaultBackend.fs new file mode 100644 index 0000000..b97cf44 --- /dev/null +++ b/Rezoom.SQL.Compiler/DefaultBackend.fs @@ -0,0 +1,35 @@ +namespace Rezoom.SQL.Compiler +open System +open System.Data +open System.Collections.Generic +open System.Globalization +open Rezoom.SQL.Compiler +open Rezoom.SQL.Compiler.BackendUtilities +open Rezoom.SQL.Compiler.Translators +open Rezoom.SQL.Mapping + +type DefaultBackend() = + static let initialModel = + let main, temp = Name("main"), Name("temp") + { Schemas = + [ Schema.Empty(main) + Schema.Empty(temp) + ] |> List.map (fun s -> s.SchemaName, s) |> Map.ofList + DefaultSchema = main + TemporarySchema = temp + Builtin = + { Functions = DefaultFunctions.extendedBy [||] + } + } + + interface IBackend with + member this.MigrationBackend = <@ fun conn -> DefaultMigrationBackend(conn) :> Migrations.IMigrationBackend @> + member this.InitialModel = initialModel + member this.ParameterTransform(columnType) = ParameterTransform.Default(columnType) + member this.ToCommandFragments(indexer, stmts) = + let translator = DefaultStatementTranslator(Name("RZSQL"), indexer) + translator.TotalStatements(stmts) + |> BackendUtilities.simplifyFragments + |> ResizeArray + :> _ IReadOnlyList + \ No newline at end of file diff --git a/Rezoom.SQL.Compiler/DefaultExprTranslator.fs b/Rezoom.SQL.Compiler/DefaultExprTranslator.fs new file mode 100644 index 0000000..5f93a7c --- /dev/null +++ b/Rezoom.SQL.Compiler/DefaultExprTranslator.fs @@ -0,0 +1,302 @@ +namespace Rezoom.SQL.Compiler.Translators +open Rezoom.SQL.Compiler +open Rezoom.SQL.Compiler.BackendUtilities +open Rezoom.SQL.Mapping + +type DefaultExprTranslator(statement : StatementTranslator, indexer : IParameterIndexer) = + inherit ExprTranslator(statement, indexer) + override __.Literal = upcast DefaultLiteralTranslator() + override __.Name(name) = + "\"" + name.Value.Replace("\"", "\"\"") + "\"" + |> text + override __.TypeName(name) = + (Seq.singleton << text) <| + match name with + | BooleanTypeName -> "BOOL" + | IntegerTypeName Integer8 -> "INT8" + | IntegerTypeName Integer16 -> "INT16" + | IntegerTypeName Integer32 -> "INT32" + | IntegerTypeName Integer64 -> "INT64" + | FloatTypeName Float32 -> "FLOAT32" + | FloatTypeName Float64 -> "FLOAT64" + | StringTypeName(Some size) -> "STRING(" + string size + ")" + | StringTypeName(None) -> "STRING" + | BinaryTypeName(Some size) -> "BINARY(" + string size + ")" + | BinaryTypeName(None) -> "BINARY" + | DecimalTypeName -> "DECIMAL" + | DateTimeTypeName -> "DATETIME" + | DateTimeOffsetTypeName -> "DATETIMEOFFSET" + override __.BinaryOperator op = + CommandText <| + match op with + | Concatenate -> "||" + | Multiply -> "*" + | Divide -> "/" + | Modulo -> "%" + | Add -> "+" + | Subtract -> "-" + | BitShiftLeft -> "<<" + | BitShiftRight -> ">>" + | BitAnd -> "&" + | BitOr -> "|" + | LessThan -> "<" + | LessThanOrEqual -> "<=" + | GreaterThan -> ">" + | GreaterThanOrEqual -> ">=" + | Equal -> "=" + | NotEqual -> "<>" + | Is -> "IS" + | IsNot -> "IS NOT" + | And -> "AND" + | Or -> "OR" + override __.UnaryOperator op = + CommandText <| + match op with + | Negative -> "-" + | Not -> "NOT" + | BitNot -> "~" + | NotNull -> "NOT NULL" + | IsNull -> "IS NULL" + override __.SimilarityOperator op = + CommandText <| + match op with + | Like -> "LIKE" + | Glob -> "GLOB" + | Match -> "MATCH" + | Regexp -> "REGEXP" + override __.BindParameter par = indexer.ParameterIndex(par) |> Parameter + override this.ObjectName name = + seq { + match name.SchemaName with + | Some schema -> + yield text (schema.Value + ".") + | None -> () + yield this.Name(name.ObjectName) + } + override this.ColumnName col = + seq { + match col.Table with + | Some tbl -> + yield! this.ObjectName(tbl) + yield text "." + | None -> () + yield this.Name(col.ColumnName) + } + override this.Cast(castExpr) = + seq { + yield text "CAST(" + yield! this.Expr(castExpr.Expression, FirstClassValue) + yield ws + yield text "AS" + yield ws + yield! this.TypeName(castExpr.AsType) + yield text ")" + } + override this.Collate(expr, collation) = + seq { + yield! this.Expr(expr) + yield ws + yield text "COLLATE" + yield ws + yield this.Name(collation) + } + override this.Invoke(func) = + seq { + yield text func.FunctionName.Value + yield text "(" + match func.Arguments with + | ArgumentWildcard -> yield text "*" + | ArgumentList (distinct, args) -> + match distinct with + | Some distinct -> + yield text "DISTINCT" + yield ws + | None -> () + yield! args |> Seq.map this.Expr |> join "," + yield text ")" + } + override this.Similarity(sim : TSimilarityExpr) = + seq { + yield! this.Expr(sim.Input) + yield ws + if sim.Invert then + yield text "NOT" + yield ws + yield this.SimilarityOperator(sim.Operator) + yield ws + yield! this.Expr(sim.Pattern) + match sim.Escape with + | None -> () + | Some escape -> + yield ws + yield text "ESCAPE" + yield ws + yield! this.Expr(escape) + } + override this.Binary(bin) = + let context = if bin.Operator.IsLogicalOperator then Predicate else FirstClassValue + seq { + yield! this.Expr(bin.Left, context) + yield ws + yield this.BinaryOperator(bin.Operator) + yield ws + yield! this.Expr(bin.Right, context) + } + override this.Unary(un) = + let context = if un.Operator.IsLogicalOperator then Predicate else FirstClassValue + match un.Operator with + | Negative + | Not + | BitNot -> + seq { + yield this.UnaryOperator(un.Operator) + yield ws + yield! this.Expr(un.Operand, context) + } + | NotNull + | IsNull -> + seq { + yield! this.Expr(un.Operand, context) + yield ws + yield this.UnaryOperator(un.Operator) + } + override this.Between(between) = + seq { + yield! this.Expr(between.Input) + yield ws + if between.Invert then + yield text "NOT" + yield ws + yield text "BETWEEN" + yield ws + yield! this.Expr(between.Low) + yield ws + yield text "AND" + yield ws + yield! this.Expr(between.High) + } + override this.Table(tbl) = + seq { + yield! this.ObjectName(tbl.Table) + match tbl.Arguments with + | None -> () + | Some args -> + yield text "(" + yield! args |> Seq.map this.Expr |> join "," + yield text ")" + } + override this.In(inex) = + seq { + yield! this.Expr(inex.Input, FirstClassValue) + yield ws + if inex.Invert then + yield text "NOT" + yield ws + yield text "IN" + yield ws + match inex.Set.Value with + | InExpressions exprs -> + yield text "(" + yield! exprs |> Seq.map this.Expr |> join "," + yield text ")" + | InSelect select -> + yield text "(" + yield! statement.Select(select) + yield text ")" + | InTable tbl -> + yield! this.Table(tbl) + | InParameter par -> + yield this.BindParameter(par) + } + override this.Case(case) = + seq { + yield text "CASE" + yield ws + let whenContext = + match case.Input with + | None -> Predicate + | Some _ -> FirstClassValue + match case.Input with + | None -> () + | Some input -> + yield! this.Expr(input, FirstClassValue) + yield ws + for input, output in case.Cases do + yield text "WHEN" + yield ws + yield! this.Expr(input, whenContext) + yield ws + yield text "THEN" + yield ws + yield! this.Expr(output, FirstClassValue) + match case.Else.Value with + | None -> () + | Some els -> + yield ws + yield text "ELSE" + yield ws + yield! this.Expr(els, FirstClassValue) + yield ws + yield text "END" + } + override this.Raise(raise) = + let raiseMsg ty msg = + seq { + yield text "RAISE(" + yield text ty + yield text "," + yield this.Literal.StringLiteral(msg) + yield text ")" + } + match raise with + | RaiseIgnore -> Seq.singleton (text "RAISE(IGNORE)") + | RaiseRollback msg -> raiseMsg "ROLLBACK" msg + | RaiseAbort msg -> raiseMsg "ABORT" msg + | RaiseFail msg -> raiseMsg "FAIL" msg + override this.Exists(subquery) = + seq { + yield text "EXISTS(" + yield! statement.Select(subquery) + yield text ")" + } + override this.ScalarSubquery(subquery) = + seq { + yield text "(" + yield! statement.Select(subquery) + yield text ")" + } + override __.NeedsParens(expr) = + match expr with + | LiteralExpr _ + | BindParameterExpr _ + | ColumnNameExpr _ + | CastExpr _ + | FunctionInvocationExpr _ + | ScalarSubqueryExpr _ -> false + | _ -> true + override this.Expr(expr, _) = + let needsParens = this.NeedsParens(expr.Value) + seq { + if needsParens then yield text "(" + yield! + match expr.Value with + | LiteralExpr lit -> this.Literal.Literal(lit) |> Seq.singleton + | BindParameterExpr bind -> this.BindParameter(bind) |> Seq.singleton + | ColumnNameExpr name -> this.ColumnName(name) + | CastExpr cast -> this.Cast(cast) + | CollateExpr { Input = expr; Collation = collation } -> this.Collate(expr, collation) + | FunctionInvocationExpr func -> this.Invoke(func) + | SimilarityExpr sim -> this.Similarity(sim) + | BinaryExpr bin -> this.Binary(bin) + | UnaryExpr un -> this.Unary(un) + | BetweenExpr between -> this.Between(between) + | InExpr inex -> this.In(inex) + | ExistsExpr select -> this.Exists(select) + | CaseExpr case -> this.Case(case) + | ScalarSubqueryExpr subquery -> this.ScalarSubquery(subquery) + | RaiseExpr raise -> this.Raise(raise) + if needsParens then yield text ")" + } + member this.Expr(expr) = this.Expr(expr, FirstClassValue) + + diff --git a/Rezoom.SQL.Compiler/DefaultFunctions.fs b/Rezoom.SQL.Compiler/DefaultFunctions.fs new file mode 100644 index 0000000..34620fc --- /dev/null +++ b/Rezoom.SQL.Compiler/DefaultFunctions.fs @@ -0,0 +1,37 @@ +module Rezoom.SQL.Compiler.DefaultFunctions +open System +open System.Data +open System.Collections.Generic +open System.Globalization +open Rezoom.SQL +open Rezoom.SQL.Mapping +open Rezoom.SQL.Compiler.BackendUtilities +open Rezoom.SQL.Compiler.Translators +open Rezoom.SQL.Compiler.FunctionDeclarations + +/// Functions that are supported by EVERY database. Surprisingly there aren't many of these. +let common = + [| func "abs" [ numeric (infect a') ] a' + func "coalesce" [ nullable a'; vararg (nullable a'); infect a' ] a' + |] + +/// Erased functions that can always be supported, even if the DB doesn't have ANY functions. +let builtins = + [| // Used to prevent queries from being assumed idempotent, even though they otherwise seem to be. + ErasedFunction(Name("impure"), infect a', a', idem = false) :> FunctionType + // Force its argument to be assumed nullable. This can be used to pick which variable is nullable + // in cases where we would otherwise make both nullable. + // For example `coalesce(@a + @b, 1)` could be written `coalesce(nullable(@a) + @b, 1)` so @b would + // not have to be assumed nullable. + erased "nullable" (nullable a') (nullable a') + // Ignore the inferred type (but not inferred nullability) of its argument. + // Lets you override the typechecker and treat values like whatever you feel they should be. + erased "unsafe_coerce" (infect scalar) scalar + |] + +let extendedBy backendFunctions = + Seq.concat + [| builtins + common + backendFunctions + |] |> mapBy (fun f -> f.FunctionName) \ No newline at end of file diff --git a/Rezoom.SQL.Compiler/DefaultLiteralTranslator.fs b/Rezoom.SQL.Compiler/DefaultLiteralTranslator.fs new file mode 100644 index 0000000..dcfa672 --- /dev/null +++ b/Rezoom.SQL.Compiler/DefaultLiteralTranslator.fs @@ -0,0 +1,42 @@ +namespace Rezoom.SQL.Compiler.Translators +open System +open System.Globalization +open Rezoom.SQL.Compiler +open Rezoom.SQL.Compiler.BackendUtilities +open Rezoom.SQL.Mapping + +type DefaultLiteralTranslator() = + inherit LiteralTranslator() + override __.NullLiteral = CommandText "NULL" + override __.BooleanLiteral b = CommandText <| if b then "TRUE" else "FALSE" + override __.IntegerLiteral i = CommandText (i.ToString(CultureInfo.InvariantCulture)) + override __.FloatLiteral f = CommandText (f.ToString("0.0##############", CultureInfo.InvariantCulture)) + override __.BlobLiteral(bytes) = + let hexPairs = bytes |> Array.map (fun b -> b.ToString("X2", CultureInfo.InvariantCulture)) + "x'" + String.Concat(hexPairs) + "'" + |> text + override __.StringLiteral(str) = + "'" + str.Replace("'", "''") + "'" + |> text + override __.DateTimeLiteral(dt) = + CommandText <| dt.ToString("yyyy'-'MM'-'dd'T'HH':'mm':'ss'.'fff") + override __.DateTimeOffsetLiteral(dt) = + CommandText <| dt.ToString("yyyy'-'MM'-'dd'T'HH':'mm':'ss'.'fffzzz") + override this.Literal literal = + match literal with + | NullLiteral -> this.NullLiteral + | BooleanLiteral t -> this.BooleanLiteral(t) + | StringLiteral str -> this.StringLiteral(str) + | BlobLiteral blob -> this.BlobLiteral(blob) + | DateTimeLiteral dt -> this.DateTimeLiteral(dt) + | DateTimeOffsetLiteral dt -> this.DateTimeOffsetLiteral(dt) + | NumericLiteral (IntegerLiteral i) -> this.IntegerLiteral(i) + | NumericLiteral (FloatLiteral f) -> this.FloatLiteral(f) + override this.SignedLiteral literal = + let literalValue = literal.Value |> NumericLiteral |> this.Literal + if literal.Sign >= 0 then Seq.singleton literalValue else + seq { + yield text "-" + yield literalValue + } + diff --git a/Rezoom.SQL.Compiler/DefaultStatementTranslator.fs b/Rezoom.SQL.Compiler/DefaultStatementTranslator.fs new file mode 100644 index 0000000..8705159 --- /dev/null +++ b/Rezoom.SQL.Compiler/DefaultStatementTranslator.fs @@ -0,0 +1,615 @@ +namespace Rezoom.SQL.Compiler.Translators +open Rezoom.SQL.Compiler +open Rezoom.SQL.Compiler.BackendUtilities +open Rezoom.SQL.Mapping + +type DefaultStatementTranslator(expectedVendorName : Name, indexer : IParameterIndexer) = + inherit StatementTranslator() + override this.Expr = upcast DefaultExprTranslator(this, indexer) + member this.Predicate(x) = this.Expr.Expr(x, Predicate) + member this.FirstClassValue(x) = this.Expr.Expr(x, FirstClassValue) + override __.OrderDirection(dir) = + match dir with + | Ascending -> text "ASC" + | Descending -> text "DESC" + override this.IndexHint(indexHint) = + seq { + match indexHint with + | NotIndexed -> + yield text "NOT INDEXED" + | (IndexedBy name) -> + yield text "INDEXED BY" + yield ws + yield this.Expr.Name(name) + } + member this.QualifiedTableName(qualified : TQualifiedTableName) = + seq { + yield! this.Expr.ObjectName(qualified.TableName) + match qualified.IndexHint with + | None -> () + | Some indexHint -> + yield ws + yield! this.IndexHint(indexHint) + } + override this.CTE(cte) = + seq { + yield this.Expr.Name(cte.Name) + yield ws + match cte.ColumnNames with + | None -> () + | Some names -> + yield text "(" + yield! names.Value |> Seq.map (srcValue >> this.Expr.Name) |> join1 ", " + yield text ") " + yield text "AS (" + yield! this.Select(cte.AsSelect) + yield text ")" + } + override this.With(withClause) = + seq { + yield text "WITH" + yield ws + if withClause.Recursive then + yield text "RECURSIVE" + yield ws + yield! withClause.Tables |> Seq.map this.CTE |> join "," + } + override this.Values(vals) = + seq { + yield text "VALUES" + yield! + vals |> Seq.map (fun row -> + seq { + yield text "(" + yield! row.Value |> Seq.map this.FirstClassValue |> join "," + yield text ")" + }) |> join "," + } + + override this.ResultColumn(expr, alias) = + seq { + yield! this.FirstClassValue(expr) + match alias with + | None -> () + | Some alias -> + yield ws + yield text "AS" + yield ws + yield this.Expr.Name(alias) + } + override this.ResultColumns(cols) = + seq { + match cols.Distinct with + | None + | Some AllColumns -> () + | Some DistinctColumns -> yield text "DISTINCT"; yield ws + yield! + seq { + for col in cols.Columns do + match col.Case with + | Column(expr, alias) -> + yield this.ResultColumn(expr, alias) + | ColumnNav _ -> + bug "Bug in typechecker: nav columns should've been expanded" + | _ -> + bug "Bug in typechecker: wildcards should've been expanded" + } |> join "," + } + override this.TableOrSubquery(tbl) = + seq { + match tbl.Table with + | Table (table, indexHint) -> + yield! this.Expr.Table(table) + match tbl.Alias with + | None -> () + | Some alias -> + yield ws + yield text "AS" + yield ws + yield this.Expr.Name(alias) + match indexHint with + | None -> () + | Some indexHint -> + yield ws + yield! this.IndexHint(indexHint) + | Subquery select -> + yield text "(" + yield! this.Select(select) + yield text ")" + match tbl.Alias with + | None -> () + | Some alias -> + yield ws + yield text "AS" + yield ws + yield this.Expr.Name(alias) + } + override this.TableExpr(texpr) = + match texpr.Value with + | TableOrSubquery tbl -> this.TableOrSubquery(tbl) + | Join join -> this.Join(join) + override __.JoinType(join) = + let rec joinText join = + match join with + | Inner -> "INNER JOIN" + | LeftOuter -> "LEFT OUTER JOIN" + | Cross -> "CROSS JOIN" + | Natural ty -> "NATURAL " + joinText ty + joinText join |> text + override this.Join(join) = + seq { + yield! this.TableExpr(join.LeftTable) + yield ws + yield this.JoinType(join.JoinType) + yield ws + yield! this.TableExpr(join.RightTable) + match join.Constraint with + | JoinOn expr -> + yield ws + yield text "ON" + yield ws + yield! this.Predicate(expr) + | JoinUnconstrained -> () + } + override this.SelectCore(select) = + seq { + yield text "SELECT" + yield ws + yield! this.ResultColumns(select.Columns) + match select.From with + | None -> () + | Some from -> + yield ws + yield text "FROM" + yield ws + yield! this.TableExpr(from) + match select.Where with + | None -> () + | Some where -> + yield ws + yield text "WHERE" + yield ws + yield! this.Predicate(where) + match select.GroupBy with + | None -> () + | Some groupBy -> + yield ws + yield text "GROUP BY" + yield ws + yield! groupBy.By |> Seq.map this.FirstClassValue |> join "," + match groupBy.Having with + | None -> () + | Some having -> + yield ws + yield text "HAVING" + yield ws + yield! this.Predicate(having) + } + override this.CompoundTerm(compound) = + match compound with + | Values vals -> this.Values(vals) + | Select select -> this.SelectCore(select) + override this.Compound(compound) = + let op name (expr : TCompoundExpr) (term : TCompoundTerm) = + seq { + yield! this.Compound(expr.Value) + yield ws + yield text name + yield ws + yield! this.CompoundTerm(term.Value) + } + match compound with + | CompoundTerm term -> this.CompoundTerm(term.Value) + | Union (expr, term) -> op "UNION" expr term + | UnionAll (expr, term) -> op "UNION ALL" expr term + | Intersect (expr, term) -> op "INTERSECT" expr term + | Except (expr, term) -> op "EXCEPT" expr term + override this.Limit(limit) = + seq { + yield text "LIMIT" + yield ws + yield! this.FirstClassValue(limit.Limit) + match limit.Offset with + | None -> () + | Some offset -> + yield ws + yield text "OFFSET" + yield ws + yield! this.FirstClassValue(offset) + } + override this.OrderingTerm(term) = + seq { + yield! this.FirstClassValue(term.By) + yield ws + yield this.OrderDirection(term.Direction) + } + override this.Select(select) = + let select = select.Value + seq { + match select.With with + | None -> () + | Some withClause -> + yield! this.With(withClause) + yield ws + yield! this.Compound(select.Compound.Value) + match select.OrderBy with + | None -> () + | Some orderBy -> + yield ws + yield text "ORDER BY" + yield ws + yield! orderBy |> Seq.map this.OrderingTerm |> join "," + match select.Limit with + | None -> () + | Some limit -> + yield ws + yield! this.Limit(limit) + } + override this.ForeignKeyRule(rule) = + seq { + match rule with + | MatchRule name -> + yield text "MATCH" + yield ws + yield this.Expr.Name(name) + | EventRule (evt, handler) -> + yield text "ON" + yield ws + yield + match evt with + | OnDelete -> text "DELETE" + | OnUpdate -> text "UPDATE" + yield ws + yield + match handler with + | SetNull -> text "SET NULL" + | SetDefault -> text "SET DEFAULT" + | Cascade -> text "CASCADE" + | Restrict -> text "RESTRICT" + | NoAction -> text "NO ACTION" + } + override this.ForeignKeyClause(clause) = + seq { + yield text "REFERENCES" + yield ws + yield! this.Expr.ObjectName(clause.ReferencesTable) + yield ws + yield text "(" + yield! clause.ReferencesColumns |> Seq.map (srcValue >> this.Expr.Name) |> join1 "," + yield text ")" + for rule in clause.Rules do + yield ws + yield! this.ForeignKeyRule(rule) + match clause.Defer with + | None -> () + | Some defer -> + if not defer.Deferrable then + yield text "NOT" + yield ws + yield text "DEFERRABLE" + match defer.InitiallyDeferred with + | None -> () + | Some initially -> + yield ws + yield text "INITIALLY" + yield ws + yield text (if initially then "DEFERRED" else "IMMEDIATE") + } + abstract member AutoIncrement : string + default __.AutoIncrement = "AUTOINCREMENT" + override this.ColumnConstraint(constr) = + seq { + yield text "CONSTRAINT" + yield ws + yield this.Expr.Name(constr.Name) + yield ws + match constr.ColumnConstraintType with + | NullableConstraint -> + yield text "NULL" + | PrimaryKeyConstraint pk -> + yield text "PRIMARY KEY" + yield ws + yield this.OrderDirection(pk.Order) + if pk.AutoIncrement then + yield ws + yield text this.AutoIncrement + | UniqueConstraint -> + yield text "UNIQUE" + | DefaultConstraint expr -> + yield text "DEFAULT(" + yield! this.FirstClassValue(expr) + yield text ")" + | CollateConstraint name -> + yield text "COLLATE" + yield ws + yield this.Expr.Name(name) + | ForeignKeyConstraint fk -> + yield! this.ForeignKeyClause(fk) + } + abstract member ColumnsNullableByDefault : bool + default __.ColumnsNullableByDefault = false + override this.ColumnDefinition(col) = + seq { + yield this.Expr.Name(col.Name) + yield ws + yield! this.Expr.TypeName(col.Type) + if this.ColumnsNullableByDefault && not col.Nullable then + yield! [| ws; text "CONSTRAINT"; ws; this.Expr.Name(col.Name + "_NOTNULL"); ws; text "NOT NULL" |] + for constr in col.Constraints do + yield ws + yield! this.ColumnConstraint(constr) + } + override this.CreateTableDefinition(create) = + seq { + yield text "(" + yield! create.Columns |> Seq.map this.ColumnDefinition |> join "," + yield text ")" + } + override this.CreateTable(create) = + seq { + yield text "CREATE" + yield ws + if create.Temporary then + yield text "TEMP" + yield ws + yield text "TABLE" + yield ws + yield! this.Expr.ObjectName(create.Name) + yield ws + match create.As with + | CreateAsSelect select -> + yield text "AS" + yield ws + yield! this.Select(select) + | CreateAsDefinition def -> + yield! this.CreateTableDefinition(def) + } + override this.AlterTable(alter) = + seq { + yield text "ALTER TABLE" + yield ws + yield! this.Expr.ObjectName(alter.Table) + yield ws + match alter.Alteration with + | RenameTo newName -> + yield text "RENAME TO" + yield ws + yield this.Expr.Name(newName) + | AddColumn columnDef -> + yield text "ADD COLUMN" + yield ws + yield! this.ColumnDefinition(columnDef) + } + override this.CreateView(create) = + seq { + yield text "CREATE" + yield ws + if create.Temporary then + yield text "TEMP" + yield ws + yield text "VIEW" + yield ws + yield! this.Expr.ObjectName(create.ViewName) + yield ws + match create.ColumnNames with + | None -> () + | Some names -> + yield text "(" + yield! names |> Seq.map (srcValue >> this.Expr.Name) |> join1 "," + yield text ")" + yield ws + yield text "AS" + yield ws + yield! this.Select(create.AsSelect) + } + override this.CreateIndex(create) = + seq { + yield text "CREATE" + yield ws + if create.Unique then + yield text "UNIQUE" + yield ws + yield text "INDEX" + yield ws + yield! this.Expr.ObjectName(create.IndexName) + yield ws + yield text "ON" + yield ws + yield! this.Expr.ObjectName(create.TableName) + yield text "(" + yield! + seq { + for name, dir in create.IndexedColumns -> + seq { + yield this.Expr.Name(name) + yield ws + yield this.OrderDirection(dir) + } + } |> join "," + yield text ")" + match create.Where with + | None -> () + | Some where -> + yield ws + yield text "WHERE" + yield ws + yield! this.Predicate(where) + } + override this.DropObject(drop) = + seq { + yield text "DROP" + yield ws + yield + match drop.Drop with + | DropIndex -> text "INDEX" + | DropTable -> text "TABLE" + | DropView -> text "VIEW" + yield ws + yield! this.Expr.ObjectName(drop.ObjectName) + } + override this.Insert(insert) = + seq { + match insert.With with + | None -> () + | Some withClause -> + yield! this.With(withClause) + yield ws + yield text "INSERT" + match insert.Or with + | None -> () + | Some insertOr -> + yield ws + yield + match insertOr with + | InsertOrRollback -> text "OR ROLLBACK" + | InsertOrAbort -> text "OR ABORT" + | InsertOrReplace -> text "OR REPLACE" + | InsertOrFail -> text "OR FAIL" + | InsertOrIgnore -> text "OR IGNORE" + yield ws + yield text "INTO" + yield ws + yield! this.Expr.ObjectName(insert.InsertInto) + match insert.Columns with + | None -> () + | Some columns -> + yield text "(" + yield! columns |> Seq.map (srcValue >> this.Expr.Name) |> join1 "," + yield text ")" + yield ws + match insert.Data with + | None -> yield text "DEFAULT VALUES" + | Some data -> + yield! this.Select(data) + } + override this.Update(update) = + seq { + match update.With with + | None -> () + | Some withClause -> + yield! this.With(withClause) + yield ws + yield text "UPDATE" + match update.Or with + | None -> () + | Some updateOr -> + yield ws + yield + match updateOr with + | UpdateOrRollback -> text "OR ROLLBACK" + | UpdateOrAbort -> text "OR ABORT" + | UpdateOrReplace -> text "OR REPLACE" + | UpdateOrFail -> text "OR FAIL" + | UpdateOrIgnore -> text "OR IGNORE" + yield ws + yield! this.QualifiedTableName(update.UpdateTable) + yield ws + yield text "SET" + yield ws + yield! + seq { + for name, value in update.Set -> + seq { + yield this.Expr.Name(name.Value) + yield ws + yield text "=" + yield ws + yield! this.FirstClassValue(value) + } + } |> join "," + match update.Where with + | None -> () + | Some where -> + yield ws + yield text "WHERE" + yield ws + yield! this.Predicate(where) + match update.OrderBy with + | None -> () + | Some orderBy -> + yield ws + yield text "ORDER BY" + yield ws + yield! orderBy |> Seq.map this.OrderingTerm |> join "," + match update.Limit with + | None -> () + | Some limit -> + yield ws + yield! this.Limit(limit) + } + override this.Delete(delete) = + seq { + match delete.With with + | None -> () + | Some withClause -> + yield! this.With(withClause) + yield ws + yield text "DELETE FROM" + yield ws + yield! this.QualifiedTableName(delete.DeleteFrom) + match delete.Where with + | None -> () + | Some where -> + yield ws + yield text "WHERE" + yield ws + yield! this.Predicate(where) + match delete.OrderBy with + | None -> () + | Some orderBy -> + yield ws + yield text "ORDER BY" + yield ws + yield! orderBy |> Seq.map this.OrderingTerm |> join "," + match delete.Limit with + | None -> () + | Some limit -> + yield ws + yield! this.Limit(limit) + } + override this.Begin = Seq.singleton (text "BEGIN") + override this.Commit = Seq.singleton (text "COMMIT") + override this.Rollback = Seq.singleton (text "ROLLBACK") + override this.Statement(stmt) = + match stmt with + | AlterTableStmt alter -> this.AlterTable(alter) + | CreateTableStmt create -> this.CreateTable(create) + | CreateViewStmt create -> this.CreateView(create) + | CreateIndexStmt create -> this.CreateIndex(create) + | DropObjectStmt drop -> this.DropObject(drop) + + | SelectStmt select -> this.Select(select) + | InsertStmt insert -> this.Insert(insert) + | UpdateStmt update -> this.Update(update) + | DeleteStmt delete -> this.Delete(delete) + + | BeginStmt -> this.Begin + | CommitStmt -> this.Commit + | RollbackStmt -> this.Rollback + override this.Statements(stmts) = + seq { + for stmt in stmts do + yield! this.Statement(stmt) + yield text ";" + } + override this.Vendor(vendor) = + if expectedVendorName <> vendor.VendorName.Value then + failAt vendor.VendorName.Source <| + Error.vendorMismatch vendor.VendorName.Value expectedVendorName + seq { + for fragment in vendor.Fragments do + match fragment with + | VendorRaw raw -> yield text raw + | VendorEmbeddedExpr expr -> yield! this.FirstClassValue(expr) + } + override this.TotalStatement(stmt) = + match stmt with + | CoreStmt stmt -> this.Statement(stmt) + | VendorStmt vendor -> this.Vendor(vendor) + override this.TotalStatements(stmts) = + seq { + for stmt in stmts do + yield! this.TotalStatement(stmt) + yield text ";" + } + diff --git a/Rezoom.SQL.Compiler/Error.fs b/Rezoom.SQL.Compiler/Error.fs new file mode 100644 index 0000000..2fa308c --- /dev/null +++ b/Rezoom.SQL.Compiler/Error.fs @@ -0,0 +1,90 @@ +module Rezoom.SQL.Compiler.Error + +let parseError msg = + sprintf "SQ000: %O" msg +let cannotUnify left right = + sprintf "SQ001: The types %O and %O cannot be unified" left right +let reservedKeywordAsName keyword = + sprintf "SQ002: Reserved keyword ``%O`` used as name" keyword +let noSuchFunction func = + sprintf "SQ003: No such function: ``%O``" func +let insufficientArguments func got expected = + sprintf "SQ004: Insufficient arguments to function ``%O`` (found %d, expected at least %d)" + func got expected +let excessiveArguments func got expected = + sprintf "SQ005: Too many arguments to function ``%O`` (found %d, expected at most %d)" + func got expected +let functionDoesNotPermitWildcard func = + sprintf "SQ006: Function ``%O`` cannot take a wildcard (*) argument" func +let jamesBond, jamesBondEasterEgg = + "SQ007: Expected martini shaken (found ``stirred``)", "CREATE VIEW TO A KILL" +let functionDoesNotPermitDistinct func = + sprintf "SQ008: Function ``%O`` cannot take a DISTINCT argument" func +let mismatchedColumnNameCount names cols = + sprintf "SQ009: %d columns named for a query for %d columns" names cols +let schemaNameInColumnReference name = + sprintf "SQ010: Unsupported schema name in column reference: ``%O``" name +let noSuchObject ty name = + sprintf "SQ011: No such %s: ``%O``" ty name +let noSuchTable name = noSuchObject "table" name +let objectNotATable name = + sprintf "SQ012: Object ``%O`` is not a table" name +let objectAlreadyExists name = + sprintf "SQ013: Object ``%O`` already exists" name +let objectIsNotA ty name = + sprintf "SQ014: Object ``%O`` is not a %s" name ty +let noSuchTableInFrom name = + sprintf "SQ015: No such table in FROM clause: ``%O``" name +let noSuchColumn name = + sprintf "SQ016: No such column: ``%O``" name +let noSuchColumnInFrom name = + sprintf "SQ017: No such column in FROM clause: ``%O``" name +let columnAlreadyExists name = + sprintf "SQ018: Column ``%O`` already exists" name +let noSuchColumnToSet tbl col = + sprintf "SQ019: No such column in table ``%O`` to set: ``%O``" tbl col +let noSuchSchema schema = + sprintf "SQ020: No such schema: ``%O``" schema +let ambiguousColumn name = + sprintf "SQ021: Ambiguous column: ``%O``" name +let ambiguousColumnBetween name tbl1 tbl2 = + sprintf "SQ022: Ambiguous column: ``%O`` (may refer to %O.%O or %O.%O)" + name tbl1 name tbl2 name +let tableNameAlreadyInScope name = + sprintf "SQ023: Table name already in scope: ``%O``" name +let columnReferenceWithoutFrom name = + sprintf "SQ024: Cannot reference column name ``%O`` in query without a FROM clause" name +let multipleColumnsForInSelect count = + sprintf "SQ025: Expected 1 column for IN(SELECT ...), but found %d" count +let multipleColumnsForScalarSubquery count = + sprintf "SQ026: Expected 1 column for scalar subquery, but found %d" count +let subqueryRequiresAnAlias = + sprintf "SQ027: This subquery must be given an alias" +let expressionRequiresAlias = + sprintf "SQ028: Expression-valued column requires an alias (what should the column name be?)" +let tableWildcardWithoutFromClause name = + sprintf "SQ029: SELECT statement must have a FROM clause to use ``%O.*``" name +let wildcardWithoutFromClause = + sprintf "SQ030: SELECT statement must have a FROM clause to use ``*``" +let navPropertyMissingKeys name = + sprintf "SQ031: The navigation property clause ``%O`` must contain at least one key column" name +let expectedKnownColumnCount got expected = + sprintf "SQ032: Expected %d columns in table, but found %d" expected got +let valuesRequiresKnownShape = + sprintf "SQ033: A VALUES() clause can only be used when column names are implied by the surrounding context" +let columnNotAggregated = + "SQ034: Can't reference column outside of an aggregate function" + + " because this query uses aggregate functions without a GROUP BY clause" +let columnNotGroupedBy = + "SQ035: Can't reference column outside of an aggregate function" + + " because the GROUP BY clause does not include this column" +let indexSchemasMismatch indexName tableName = + sprintf "SQ036: Can't create index ``%O`` in a different schema from its table ``%O``" indexName tableName +let vendorMismatch got expected = + sprintf "Vendor-specific code for ``%O`` cannot be compiled for backend ``%O``" got expected +let sameVendorDelimiters delim = + sprintf "SQ037: Opening and closing delimiters for vendor statement are identical ``%s``" delim +(* let exprMustBeNullable = + sprintf "SQ038: Expression is not nullable; but is required to be in this context" *) +let aggregateInWhereClause = + sprintf "SQ039: A WHERE clause cannot contain aggregates -- consider using a HAVING clause" \ No newline at end of file diff --git a/Rezoom.SQL.Compiler/ExprTypeChecker.fs b/Rezoom.SQL.Compiler/ExprTypeChecker.fs new file mode 100644 index 0000000..5b70d01 --- /dev/null +++ b/Rezoom.SQL.Compiler/ExprTypeChecker.fs @@ -0,0 +1,331 @@ +namespace Rezoom.SQL.Compiler +open System +open System.Collections.Generic +open Rezoom.SQL.Compiler.InferredTypes + +type IQueryTypeChecker = + abstract member Select : SelectStmt -> InfSelectStmt + abstract member CreateView : CreateViewStmt -> InfCreateViewStmt + +type private ExprTypeChecker(cxt : ITypeInferenceContext, scope : InferredSelectScope, queryChecker : IQueryTypeChecker) = + member this.Scope = scope + member this.ObjectName(objectName : ObjectName) = this.ObjectName(objectName, false) + member this.ObjectName(objectName : ObjectName, allowNotFound) : InfObjectName = + { SchemaName = objectName.SchemaName + ObjectName = objectName.ObjectName + Source = objectName.Source + Info = + if allowNotFound then Missing else + let inferView view = (concreteMapping cxt).CreateView(queryChecker.CreateView(view)) + match scope.ResolveObjectReference(objectName, inferView) with + | Found f -> f + | Ambiguous r + | NotFound r -> failAt objectName.Source r + } + + member this.ColumnName(source : SourceInfo, columnName : ColumnName) = + let tblAlias, tblInfo, name = scope.ResolveColumnReference(columnName) |> foundAt source + { Expr.Source = source + Value = + { Table = + match tblAlias with + | None -> None + | Some tblAlias -> + { Source = source + SchemaName = None + ObjectName = tblAlias + Info = TableLike tblInfo + } |> Some + ColumnName = columnName.ColumnName + } |> ColumnNameExpr + Info = name.Expr.Info + } + + member this.Literal(source : SourceInfo, literal : Literal) = + { Expr.Source = source + Value = LiteralExpr literal + Info = ExprInfo<_>.OfType(InferredType.OfLiteral(literal)) + } + + member this.BindParameter(source : SourceInfo, par : BindParameter) = + { Expr.Source = source + Value = BindParameterExpr par + Info = ExprInfo<_>.OfType(cxt.Variable(par)) + } + + member this.Binary(source : SourceInfo, binary : BinaryExpr) = + let left = this.Expr(binary.Left) + let right = this.Expr(binary.Right) + { Expr.Source = source + Value = + { Operator = binary.Operator + Left = left + Right = right + } |> BinaryExpr + Info = + { Type = cxt.Binary(source, binary.Operator, left.Info.Type, right.Info.Type) + Idempotent = left.Info.Idempotent && right.Info.Idempotent + Function = None + Column = None + } + } + + member this.Unary(source : SourceInfo, unary : UnaryExpr) = + let operand = this.Expr(unary.Operand) + { Expr.Source = source + Value = + { Operator = unary.Operator + Operand = operand + } |> UnaryExpr + Info = + { Type = cxt.Unary(source, unary.Operator, operand.Info.Type) + Idempotent = operand.Info.Idempotent + Function = None + Column = None + } + } + + member this.Cast(source : SourceInfo, cast : CastExpr) = + let input = this.Expr(cast.Expression) + let ty = InferredType.OfTypeName(cast.AsType, input.Info.Type) + { Expr.Source = source + Value = + { Expression = input + AsType = cast.AsType + } |> CastExpr + Info = + { Type = ty + Idempotent = input.Info.Idempotent + Function = None + Column = None + } + } + + member this.Collation(source : SourceInfo, collation : CollationExpr) = + let input = this.Expr(collation.Input) + ignore <| cxt.Unify(source, input.Info.Type, InferredType.String) + { Expr.Source = source + Value = + { Input = this.Expr(collation.Input) + Collation = collation.Collation + } |> CollateExpr + Info = + { Type = input.Info.Type + Idempotent = input.Info.Idempotent + Function = None + Column = None + } + } + + member this.FunctionArguments(args : FunctionArguments) = + match args with + | ArgumentWildcard -> ArgumentWildcard + | ArgumentList (distinct, args) -> + ArgumentList (distinct, args |> Array.map this.Expr) + + member this.FunctionInvocation(source : SourceInfo, func : FunctionInvocationExpr) = + match scope.Model.Builtin.Functions.TryFind(func.FunctionName) with + | None -> failAt source <| Error.noSuchFunction func.FunctionName + | Some funcType -> + let args, output = cxt.Function(source, funcType, this.FunctionArguments(func.Arguments)) + { Expr.Source = source + Value = + if funcType.Erased then + match args with + | ArgumentList (None, [| arg |]) -> arg.Value + | _ -> + bug <| sprintf "Bug in backend: erased function ``%O`` must take a single argument" + func.FunctionName + else + { FunctionName = func.FunctionName; Arguments = args } |> FunctionInvocationExpr + Info = + { Type = output + Idempotent = + funcType.Idempotent && + match args with + | ArgumentWildcard -> true + | ArgumentList (_, args) -> args |> Seq.forall (fun a -> a.Info.Idempotent) + Function = Some funcType + Column = None + } + } + + member this.Similarity(source : SourceInfo, sim : SimilarityExpr) = + let input = this.Expr(sim.Input) + let pattern = this.Expr(sim.Pattern) + let escape = Option.map this.Expr sim.Escape + let output = + let inputType = cxt.Unify(source, input.Info.Type, StringType) + let patternType = cxt.Unify(source, pattern.Info.Type, StringType) + match escape with + | None -> () + | Some escape -> ignore <| cxt.Unify(source, escape.Info.Type, StringType) + let unified = cxt.Unify(source, inputType, patternType) + InferredType.Dependent(unified, BooleanType) + { Expr.Source = source + Value = + { Invert = sim.Invert + Operator = sim.Operator + Input = input + Pattern = pattern + Escape = escape + } |> SimilarityExpr + Info = + { Type = output + Idempotent = input.Info.Idempotent && pattern.Info.Idempotent + Function = None + Column = None + } + } + + member this.Between(source : SourceInfo, between : BetweenExpr) = + let input = this.Expr(between.Input) + let low = this.Expr(between.Low) + let high = this.Expr(between.High) + { Expr.Source = source + Value = { Invert = between.Invert; Input = input; Low = low; High = high } |> BetweenExpr + Info = + { Type = cxt.Unify(source, [ input.Info.Type; low.Info.Type; high.Info.Type ]) + Idempotent = input.Info.Idempotent && low.Info.Idempotent && high.Info.Idempotent + Function = None + Column = None + } + } + + member this.TableInvocation(table : TableInvocation) = + { Table = this.ObjectName(table.Table) + Arguments = table.Arguments |> Option.map (rmap this.Expr) + } + + member this.In(source : SourceInfo, inex : InExpr) = + let input = this.Expr(inex.Input) + let set, idempotent = + match inex.Set.Value with + | InExpressions exprs -> + let exprs = exprs |> rmap this.Expr + let involvedInfos = + Seq.append (Seq.singleton input) exprs |> Seq.map (fun e -> e.Info) |> toReadOnlyList + ignore <| cxt.Unify(inex.Set.Source, involvedInfos |> Seq.map (fun e -> e.Type)) + InExpressions exprs, + (involvedInfos |> Seq.forall (fun i -> i.Idempotent)) + | InSelect select -> + let select = queryChecker.Select(select) + let columnCount = select.Value.Info.Columns.Count + if columnCount <> 1 then + failAt select.Source <| Error.multipleColumnsForInSelect columnCount + InSelect select, (input.Info.Idempotent && select.Value.Info.Idempotent) + | InTable table -> + let table = this.TableInvocation(table) + InTable table, input.Info.Idempotent + | InParameter par -> + cxt.UnifyList(inex.Set.Source, input.Info.Type.InferredType, par) + InParameter par, true + { Expr.Source = source + Value = + { Invert = inex.Invert + Input = this.Expr(inex.Input) + Set = { Source = inex.Set.Source; Value = set } + } |> InExpr + Info = + { Type = InferredType.Dependent(input.Info.Type, BooleanType) + Idempotent = input.Info.Idempotent + Function = None + Column = None + } + } + + member this.Case(source : SourceInfo, case : CaseExpr) = + let case = + { Input = Option.map this.Expr case.Input + Cases = + [| + for whenExpr, thenExpr in case.Cases -> + this.Expr(whenExpr), this.Expr(thenExpr) + |] + Else = + { Source = case.Else.Source + Value = Option.map this.Expr case.Else.Value + } + } + let outputType = + seq { + for _, thenExpr in case.Cases -> thenExpr.Info.Type + match case.Else.Value with + | None -> yield InferredType.OfLiteral(NullLiteral) + | Some els -> yield els.Info.Type + } |> fun s -> cxt.Unify(source, s) + cxt.Unify(source, + seq { + yield + match case.Input with + | None -> InferredType.Boolean + | Some input -> input.Info.Type + for whenExpr, _ in case.Cases -> whenExpr.Info.Type + }) |> ignore + let subExprs = + seq { + match case.Input with + | None -> () + | Some input -> yield input + for whenExpr, thenExpr in case.Cases do + yield whenExpr + yield thenExpr + match case.Else.Value with + | None -> () + | Some els -> yield els + } + { Expr.Source = source + Value = case |> CaseExpr + Info = + { Type = outputType + Idempotent = subExprs |> Seq.forall (fun e -> e.Info.Idempotent) + Function = None + Column = None + } + } + + member this.Exists(source : SourceInfo, exists : SelectStmt) = + let exists = queryChecker.Select(exists) + { Expr.Source = source + Value = ExistsExpr exists + Info = + { Type = InferredType.Boolean + Idempotent = exists.Value.Info.Idempotent + Function = None + Column = None + } + } + + member this.ScalarSubquery(source : SourceInfo, select : SelectStmt) = + let select = queryChecker.Select(select) + let tbl = select.Value.Info.Table.Query + if tbl.Columns.Count <> 1 then + failAt source <| Error.multipleColumnsForScalarSubquery tbl.Columns.Count + { Expr.Source = source + Value = ScalarSubqueryExpr select + Info = tbl.Columns.[0].Expr.Info + } + + member this.Expr(expr : Expr) : InfExpr = + let source = expr.Source + match expr.Value with + | LiteralExpr lit -> this.Literal(source, lit) + | BindParameterExpr par -> this.BindParameter(source, par) + | ColumnNameExpr name -> this.ColumnName(source, name) + | CastExpr cast -> this.Cast(source, cast) + | CollateExpr collation -> this.Collation(source, collation) + | FunctionInvocationExpr func -> this.FunctionInvocation(source, func) + | SimilarityExpr sim -> this.Similarity(source, sim) + | BinaryExpr bin -> this.Binary(source, bin) + | UnaryExpr un -> this.Unary(source, un) + | BetweenExpr between -> this.Between(source, between) + | InExpr inex -> this.In(source, inex) + | ExistsExpr select -> this.Exists(source, select) + | CaseExpr case -> this.Case(source, case) + | ScalarSubqueryExpr select -> this.ScalarSubquery(source, select) + | RaiseExpr raise -> { Source = source; Value = RaiseExpr raise; Info = ExprInfo<_>.OfType(InferredType.Scalar) } + + member this.Expr(expr : Expr, ty : CoreColumnType) = + let expr = this.Expr(expr) + ignore <| cxt.Unify(expr.Source, expr.Info.Type, ty) + expr \ No newline at end of file diff --git a/Rezoom.SQL.Compiler/FunctionDeclarations.fs b/Rezoom.SQL.Compiler/FunctionDeclarations.fs new file mode 100644 index 0000000..810fcbc --- /dev/null +++ b/Rezoom.SQL.Compiler/FunctionDeclarations.fs @@ -0,0 +1,98 @@ +module Rezoom.SQL.Compiler.FunctionDeclarations +open System.Collections.Generic +open Rezoom.SQL.Compiler + +let private argumentTypeVariable name = + { TypeConstraint = ScalarTypeClass + TypeVariable = Some (Name(name)) + ForceNullable = false + InfectNullable = false + VarArg = None + } + +let a' = argumentTypeVariable "a" +let b' = argumentTypeVariable "b" +let c' = argumentTypeVariable "c" +let d' = argumentTypeVariable "d" + +let constrained ty arg = + { arg with + TypeConstraint = + match arg.TypeConstraint.Unify(ty) with + | Ok t -> t + | Error e -> bug e + } + +let numeric ty = ty |> constrained NumericTypeClass +let stringish ty = ty |> constrained StringishTypeClass + +let inline private concrete ty = + { TypeConstraint = ty + TypeVariable = None + ForceNullable = false + InfectNullable = false + VarArg = None + } + +let scalar = concrete ScalarTypeClass +let boolean = concrete BooleanType +let string = concrete StringType +let num = concrete NumericTypeClass +let fractional = concrete FractionalTypeClass +let float64 = concrete (FloatType Float64) +let float32 = concrete (FloatType Float32) +let integral = concrete IntegralTypeClass +let int64 = concrete (IntegerType Integer64) +let int32 = concrete (IntegerType Integer32) +let int16 = concrete (IntegerType Integer16) +let int8 = concrete (IntegerType Integer8) +let binary = concrete BinaryType +let datetime = concrete DateTimeType +let datetimeoffset = concrete DateTimeOffsetType +let decimal = concrete DecimalType + +let nullable arg = + { arg with + ForceNullable = true + } + +let optional arg = + { arg with + VarArg = Some { MinArgCount = 0; MaxArgCount = Some 1 } + } + +let vararg arg = + { arg with + VarArg = Some { MinArgCount = 0; MaxArgCount = None } + } + +let varargN count arg = + { arg with + VarArg = Some { MinArgCount = 0; MaxArgCount = Some count } + } + +let infect arg = + { arg with + InfectNullable = true + } + +type NonAggregateFunction(name, args, ret, idem) = + inherit FunctionType(name, args, ret, idem) + override __.Aggregate(_) = None + +type AggregateFunction(name, args, ret, allowWildcard, allowDistinct) = + inherit FunctionType(name, args, ret, idem = true) + override __.Aggregate(_) = + Some { AllowWildcard = allowWildcard; AllowDistinct = allowWildcard } + +let inline proc name args ret = NonAggregateFunction(Name(name), List.toArray args, ret, idem = true) :> FunctionType +let inline func name args ret = NonAggregateFunction(Name(name), List.toArray args, ret, idem = false) :> FunctionType +let inline aggregate name args ret = AggregateFunction(Name(name), List.toArray args, ret, false, true) :> FunctionType +let inline aggregateW name args ret = AggregateFunction(Name(name), List.toArray args, ret, true, true) :> FunctionType + +type ErasedFunction(name, arg, ret, idem) = + inherit FunctionType(name, [| arg |], ret, idem) + override __.Erased = true + override __.Aggregate(_) = None + +let inline erased name arg ret = ErasedFunction(Name(name), arg, ret, idem = true) :> FunctionType \ No newline at end of file diff --git a/Rezoom.SQL.Compiler/InferredTypes.fs b/Rezoom.SQL.Compiler/InferredTypes.fs new file mode 100644 index 0000000..7828d49 --- /dev/null +++ b/Rezoom.SQL.Compiler/InferredTypes.fs @@ -0,0 +1,307 @@ +module private Rezoom.SQL.Compiler.InferredTypes +open Rezoom.SQL +open System +open System.Collections.Generic + +type TypeVariableId = int + +type CoreInferredType = + | TypeKnown of CoreColumnType + | TypeVariable of TypeVariableId + +type InferredNullable = + | NullableUnknown + | NullableKnown of bool + | NullableVariable of TypeVariableId + | NullableEither of InferredNullable * InferredNullable + | NullableDueToJoin of InferredNullable // outer joins make nulls that wouldn't otherwise happen + member this.JoinInducedNullabilityDepth() = + match this with + | NullableDueToJoin wrap -> 1 + wrap.JoinInducedNullabilityDepth() + | NullableEither (l, r) -> max (l.JoinInducedNullabilityDepth()) (r.JoinInducedNullabilityDepth()) + | _ -> 0 + /// Remove layers of nullability induced by an outer join. + member this.StripJoinInducedNullability(depth) = + if depth <= 0 then this else + match this with + | NullableEither (l, r) -> + NullableEither (l.StripJoinInducedNullability(depth), r.StripJoinInducedNullability(depth)) + | NullableDueToJoin n -> n.StripJoinInducedNullability(depth - 1) + | _ -> this + static member Any(nulls) = + nulls |> Seq.fold (fun l r -> InferredNullable.Either(l, r)) NullableUnknown + static member Either(left, right) = + match left, right with + | (NullableUnknown | NullableKnown false), x + | x, (NullableUnknown | NullableKnown false) -> x + | NullableKnown true as t, _ -> t + | _, (NullableKnown true as t) -> t + | NullableVariable x as v, NullableVariable y when x = y -> v + | l, r -> NullableEither(l, r) + member this.Simplify() = + match this with + | NullableUnknown -> NullableKnown false + | NullableKnown false + | NullableKnown true + | NullableVariable _ -> this + | NullableDueToJoin n -> NullableDueToJoin (n.Simplify()) + | NullableEither (l, r) -> + match l.Simplify(), r.Simplify() with + | NullableKnown true, _ + | _, NullableKnown true -> NullableKnown true + | NullableKnown false, x -> x + | x, NullableKnown false -> x + | l, r -> NullableEither(l, r) + +type InferredType = + { InferredType : CoreInferredType + InferredNullable : InferredNullable + } + member this.StripNullDueToJoin(depth) = + { this with InferredNullable = this.InferredNullable.StripJoinInducedNullability(depth) } + static member Of(col) = { InferredNullable = NullableKnown col.Nullable; InferredType = TypeKnown col.Type } + static member Of(core) = { InferredNullable = NullableUnknown; InferredType = TypeKnown core } + static member Float = InferredType.Of(FractionalTypeClass) + static member Integer = InferredType.Of(IntegralTypeClass) + static member Number = InferredType.Of(NumericTypeClass) + static member String = InferredType.Of(StringType) + static member Boolean = InferredType.Of(BooleanType) + static member DateTime = InferredType.Of(DateTimeType) + static member DateTimeOffset = InferredType.Of(DateTimeOffsetType) + static member Blob = InferredType.Of(BinaryType) + static member Scalar = InferredType.Of(ScalarTypeClass) + static member Dependent(ifNull : InferredType, outputType : CoreColumnType) = + { InferredNullable = ifNull.InferredNullable + InferredType = TypeKnown outputType + } + static member OfLiteral(literal : Literal) = + match literal with + | NullLiteral -> { InferredNullable = NullableKnown true; InferredType = TypeKnown ScalarTypeClass } + | BooleanLiteral _ -> InferredType.Boolean + | StringLiteral _ -> InferredType.String + | BlobLiteral _ -> InferredType.Blob + | NumericLiteral (IntegerLiteral _) -> InferredType.Number + | NumericLiteral (FloatLiteral _) -> InferredType.Float + | DateTimeLiteral _ -> InferredType.DateTime + | DateTimeOffsetLiteral _ -> InferredType.DateTimeOffset + static member OfTypeName(typeName : TypeName, inputType : InferredType) = + let affinity = CoreColumnType.OfTypeName(typeName) + { InferredNullable = inputType.InferredNullable + InferredType = TypeKnown affinity + } + +type InfExprType = ExprType +type InfExpr = Expr +type InfInExpr = InExpr +type InfCollationExpr = CollationExpr +type InfBetweenExpr = BetweenExpr +type InfSimilarityExpr = SimilarityExpr +type InfBinaryExpr = BinaryExpr +type InfUnaryExpr = UnaryExpr +type InfObjectName = ObjectName +type InfColumnName = ColumnName +type InfInSet = InSet +type InfCaseExpr = CaseExpr +type InfCastExpr = CastExpr +type InfFunctionArguments = FunctionArguments +type InfFunctionInvocationExpr = FunctionInvocationExpr + +type InfWithClause = WithClause +type InfCommonTableExpression = CommonTableExpression +type InfCompoundExprCore = CompoundExprCore +type InfCompoundExpr = CompoundExpr +type InfCompoundTermCore = CompoundTermCore +type InfCompoundTerm = CompoundTerm +type InfCreateTableDefinition = CreateTableDefinition +type InfCreateTableStmt = CreateTableStmt +type InfSelectCore = SelectCore +type InfJoinConstraint = JoinConstraint +type InfJoin = Join +type InfLimit = Limit +type InfGroupBy = GroupBy +type InfOrderingTerm = OrderingTerm +type InfResultColumn = ResultColumn +type InfResultColumns = ResultColumns +type InfTableOrSubquery = TableOrSubquery +type InfTableExprCore = TableExprCore +type InfTableExpr = TableExpr +type InfTableInvocation = TableInvocation +type InfSelectStmt = SelectStmt +type InfColumnConstraint = ColumnConstraint +type InfColumnDef = ColumnDef +type InfAlterTableStmt = AlterTableStmt +type InfAlterTableAlteration = AlterTableAlteration +type InfCreateIndexStmt = CreateIndexStmt +type InfTableIndexConstraintClause = TableIndexConstraintClause +type InfTableConstraint = TableConstraint +type InfCreateViewStmt = CreateViewStmt +type InfQualifiedTableName = QualifiedTableName +type InfDeleteStmt = DeleteStmt +type InfDropObjectStmt = DropObjectStmt +type InfUpdateStmt = UpdateStmt +type InfInsertStmt = InsertStmt +type InfStmt = Stmt +type InfVendorStmt = VendorStmt +type InfTotalStmt = TotalStmt + +type ITypeInferenceContext = + abstract member AnonymousVariable : unit -> CoreInferredType + abstract member Variable : BindParameter -> InferredType + /// Unify the two types (ensure they are compatible and add constraints) + /// and produce the most specific type. + abstract member Unify : SourceInfo * CoreInferredType * CoreInferredType -> CoreInferredType + abstract member UnifyList : SourceInfo * elem : CoreInferredType * list : BindParameter -> unit + abstract member ForceNullable : SourceInfo * InferredNullable -> unit + abstract member Concrete : InferredType -> ColumnType + abstract member Parameters : BindParameter seq + +type InferredQueryColumn() = + static member OfColumn(fromAlias : Name option, column : SchemaColumn) = + { Expr = + { Source = SourceInfo.Invalid + Info = { ExprInfo<_>.OfType(InferredType.Of(column.ColumnType)) with Column = Some column } + Value = ColumnNameExpr { ColumnName = column.ColumnName; Table = None } + } + ColumnName = column.ColumnName + FromAlias = fromAlias + } + +let foundAt source nameResolution = + match nameResolution with + | Found x -> x + | NotFound err + | Ambiguous err -> failAt source err + +let inferredOfTable (table : SchemaTable) = + { Columns = + table.Columns + |> Seq.map (function KeyValue(_, c) -> InferredQueryColumn.OfColumn(Some table.TableName, c)) + |> toReadOnlyList + } + +type InferredFromClause = + { /// The tables named in the "from" clause of the query, if any. + /// These are keyed on the alias of the table, if any, or the table name. + FromVariables : IReadOnlyDictionary + } + static member FromSingleObject(tableName : InfObjectName) = + let d = Dictionary() + d.Add(Name(""), tableName.Info) + { FromVariables = d :> IReadOnlyDictionary<_, _> + } + member this.ResolveTable(tableName : ObjectName) = + match tableName.SchemaName with + // We don't currently support referencing columns like "main.users.id". Use table aliases instead! + | Some schemaName -> Ambiguous <| Error.schemaNameInColumnReference tableName + | None -> + let succ, query = this.FromVariables.TryGetValue(tableName.ObjectName) + if succ then Found query + else NotFound <| Error.noSuchTableInFrom tableName.ObjectName + member this.ResolveColumnReference(name : ColumnName) = + match name.Table with + | None -> + let matches = + seq { + for KeyValue(tableAlias, objectInfo) in this.FromVariables do + let table = objectInfo.Table + match table.Query.ColumnByName(name.ColumnName) with + | Found column -> + yield Ok ((if tableAlias.Value = "" then None else Some tableAlias), table, column) + | NotFound _ -> () + | Ambiguous err -> yield Error err + } |> toReadOnlyList + if matches.Count = 1 then + match matches.[0] with + | Ok triple -> Found triple + | Error e -> Ambiguous e + elif matches.Count <= 0 then + NotFound <| Error.noSuchColumnInFrom name + else + Ambiguous <| Error.ambiguousColumn name + | Some tableName -> + match this.ResolveTable(tableName) with + | Found objectInfo -> + let table = objectInfo.Table + match table.Query.ColumnByName(name.ColumnName) with + | Found column -> Found (Some tableName.ObjectName, table, column) + | NotFound err -> NotFound err + | Ambiguous err -> Ambiguous err + | NotFound err -> NotFound err + | Ambiguous err -> Ambiguous err + +and InferredSelectScope = + { /// If this scope is that of a subquery, the parent query's scope can also be used + /// to resolve column and CTE names. + ParentScope : InferredSelectScope option + /// The model this select is running against. + /// This includes tables and views that are part of the database, and may be used to resolve + /// table names in the "from" clause of the query. + Model : Model + /// Any CTEs defined by the query. + /// These may be referenced in the "from" clause of the query. + CTEVariables : Map + FromClause : InferredFromClause option + SelectClause : InferredType QueryExprInfo option + } + + static member Root(model) = + { ParentScope = None + Model = model + CTEVariables = Map.empty + FromClause = None + SelectClause = None + } + + member private this.ResolveObjectReferenceBySchema + (schema : Schema, name : Name, inferView : CreateViewStmt -> TCreateViewStmt) = + match schema.Objects |> Map.tryFind name with + | Some (SchemaTable tbl) -> + { Table = TableReference tbl; Query = inferredOfTable(tbl) } |> TableLike |> Found + | Some (SchemaView view) -> + let def = inferView view.CreateDefinition + let query = def.AsSelect.Value.Info.Query.Map(InferredType.Of) + { Table = ViewReference(view, def); Query = query } |> TableLike |> Found + | None -> NotFound <| Error.noSuchTable name + + /// Resolve a reference to a table which may occur as part of a TableExpr. + /// This will resolve against the database model and CTEs, but not table aliases defined in the FROM clause. + member this.ResolveObjectReference(name : ObjectName, inferView) = + match name.SchemaName with + | None -> + match this.CTEVariables.TryFind(name.ObjectName) with + | Some cte -> { Table = CTEReference name.ObjectName; Query = cte } |> TableLike |> Found + | None -> + match this.ParentScope with + | Some parent -> parent.ResolveObjectReference(name, inferView) + | None -> + let schema = this.Model.Schemas.[this.Model.DefaultSchema] + this.ResolveObjectReferenceBySchema(schema, name.ObjectName, inferView) + | Some schema -> + let schema = this.Model.Schemas.[schema] + this.ResolveObjectReferenceBySchema(schema, name.ObjectName, inferView) + + /// Resolve a column reference, which may be qualified with a table alias. + /// This resolves against the tables referenced in the FROM clause, and the columns explicitly named + /// in the SELECT clause, if any. + member this.ResolveColumnReference(name : ColumnName) = + let findFrom() = + let thisLevel = + match this.FromClause with + | None -> NotFound <| Error.columnReferenceWithoutFrom name + | Some fromClause -> fromClause.ResolveColumnReference(name) + match this.ParentScope, thisLevel with + | Some parent, NotFound _ -> + parent.ResolveColumnReference(name) + | _ -> thisLevel + match name.Table, this.SelectClause with + | None, Some selected -> + match selected.ColumnByName(name.ColumnName) with + | Found column -> + Found (None, { Table = SelectClauseReference name.ColumnName; Query = selected }, column) + | Ambiguous reason -> Ambiguous reason + | NotFound _ -> findFrom() + | _ -> findFrom() + +let concreteMapping (inference : ITypeInferenceContext) = + ASTMapping + ((fun t -> t.Map(inference.Concrete)), fun e -> e.Map(inference.Concrete)) \ No newline at end of file diff --git a/Rezoom.SQL.Compiler/Model.fs b/Rezoom.SQL.Compiler/Model.fs new file mode 100644 index 0000000..ed4ab97 --- /dev/null +++ b/Rezoom.SQL.Compiler/Model.fs @@ -0,0 +1,302 @@ +namespace Rezoom.SQL.Compiler +open System +open System.Data +open System.Data.Common +open System.Collections.Generic + +type DatabaseBuiltin = + { Functions : Map + } + +type Model = + { Schemas : Map + DefaultSchema : Name + TemporarySchema : Name + Builtin : DatabaseBuiltin + } + member this.Schema(name : Name option) = + this.Schemas |> Map.tryFind (name |? this.DefaultSchema) + +and Schema = + { SchemaName : Name + Objects : Map + } + static member Empty(name) = + { SchemaName = name + Objects = Map.empty + } + member this.ContainsObject(name : Name) = this.Objects.ContainsKey(name) + +and SchemaObject = + | SchemaTable of SchemaTable + | SchemaView of SchemaView + +and SchemaIndex = + { SchemaName : Name + TableName : Name + IndexName : Name + Columns : Name Set + } + +and SchemaConstraint = + { SchemaName : Name + TableName : Name + ConstraintName : Name + Columns : Name Set + } + +and SchemaTable = + { SchemaName : Name + TableName : Name + Columns : Map + Indexes : Map + Constraints : Map + } + member this.WithAdditionalColumn(col : ColumnDef<_, _>) = + match this.Columns |> Map.tryFind col.Name with + | Some _ -> Error <| Error.columnAlreadyExists col.Name + | None -> + let isPrimaryKey = + col.Constraints + |> Seq.exists( + function | { ColumnConstraintType = PrimaryKeyConstraint _ } -> true | _ -> false) + let newCol = + { SchemaName = this.SchemaName + TableName = this.TableName + PrimaryKey = isPrimaryKey + ColumnName = col.Name + ColumnType = ColumnType.OfTypeName(col.Type, col.Nullable) + } + Ok { this with Columns = this.Columns |> Map.add newCol.ColumnName newCol } + static member OfCreateDefinition(schemaName, tableName, def : CreateTableDefinition<_, _>) = + let tablePkColumns = + seq { + for constr in def.Constraints do + match constr.TableConstraintType with + | TableIndexConstraint { Type = PrimaryKey; IndexedColumns = indexed } -> + for name, _ in indexed -> name + | _ -> () + } |> Set.ofSeq + let tableColumns = + seq { + for column in def.Columns -> + let isPrimaryKey = + tablePkColumns.Contains(column.Name) + || column.Constraints |> Seq.exists(function + | { ColumnConstraintType = PrimaryKeyConstraint _ } -> true + | _ -> false) + { SchemaName = schemaName + TableName = tableName + PrimaryKey = isPrimaryKey + ColumnName = column.Name + ColumnType = ColumnType.OfTypeName(column.Type, column.Nullable) + } + } + { SchemaName = schemaName + TableName = tableName + Columns = tableColumns |> mapBy (fun c -> c.ColumnName) + Indexes = Map.empty + Constraints = + seq { + for constr, names in def.AllConstraints() -> + constr, + { SchemaName = schemaName + TableName = tableName + ConstraintName = constr + Columns = names + } + } |> Map.ofSeq + } + +and SchemaColumn = + { SchemaName : Name + TableName : Name + ColumnName : Name + /// True if this column is part of the table's primary key. + PrimaryKey : bool + ColumnType : ColumnType + } + +and SchemaView = + { SchemaName : Name + ViewName : Name + CreateDefinition : CreateViewStmt + } + member this.Definition = this.CreateDefinition.AsSelect + +and ExprInfo<'t> = + { /// The inferred type of this expression. + Type : 't + /// Does this expression return the same value each time it's run? + Idempotent : bool + /// If this expression is a function call, the function that it calls. + Function : FunctionType option + /// If this expression accesses a column of a table in the schema, the column's information. + Column : SchemaColumn option + } + member this.PrimaryKey = + match this.Column with + | None -> false + | Some c -> c.PrimaryKey + static member OfType(t : 't) = + { Type = t + Idempotent = true + Function = None + Column = None + } + member this.Map(f : 't -> _) = + { Type = f this.Type + Idempotent = this.Idempotent + Function = this.Function + Column = this.Column + } + +and ColumnExprInfo<'t> = + { Expr : Expr<'t ObjectInfo, 't ExprInfo> + FromAlias : Name option // table alias this was selected from, if any + ColumnName : Name + } + member this.Map(f : 't -> _) = + { Expr = + let mapping = ASTMapping<'t ObjectInfo, 't ExprInfo, _, _>((fun t -> t.Map(f)), fun e -> e.Map(f)) + mapping.Expr(this.Expr) + FromAlias = this.FromAlias + ColumnName = this.ColumnName + } + +and QueryExprInfo<'t> = + { Columns : 't ColumnExprInfo IReadOnlyList } + member this.Idempotent = + this.Columns |> Seq.forall (fun e -> e.Expr.Info.Idempotent) + member this.ColumnsWithNames(names) = + let mine = this.Columns |> toDictionary (fun c -> c.ColumnName) + let filtered = + seq { + for { WithSource.Source = source; Value = name } in names do + let succ, found = mine.TryGetValue(name) + if succ then yield found + else failAt source <| Error.noSuchColumn name + } |> toReadOnlyList + { Columns = filtered } + member this.ColumnByName(name) = + let matches = + this.Columns + |> Seq.filter (fun c -> c.ColumnName = name) + |> Seq.truncate 2 + |> Seq.toList + match matches with + | [] -> NotFound <| Error.noSuchColumn name + | [ single ] -> Found single + | { FromAlias = Some a1 } :: { FromAlias = Some a2 } :: _ when a1 <> a2 -> + Ambiguous <| Error.ambiguousColumnBetween name a1 a2 + | _ -> Ambiguous <| Error.ambiguousColumn name + member this.RenameColumns(names : Name IReadOnlyList) = + if names.Count <> this.Columns.Count then + Error <| Error.mismatchedColumnNameCount names.Count this.Columns.Count + else + let newColumns = + (this.Columns, names) + ||> Seq.map2 (fun col newName -> { col with ColumnName = newName }) + |> toReadOnlyList + Ok { Columns = newColumns } + member this.Append(right : 't QueryExprInfo) = + { Columns = appendLists this.Columns right.Columns } + member this.Map(f : 't -> _) = + { Columns = this.Columns |> Seq.map (fun c -> c.Map(f)) |> toReadOnlyList } + +and TableReference = + | TableReference of SchemaTable + | ViewReference of SchemaView * TCreateViewStmt + | CTEReference of Name + | FromClauseReference of Name + | SelectClauseReference of Name + | SelectResults + | CompoundTermResults + +and TableLikeExprInfo<'t> = + { Table : TableReference + Query : QueryExprInfo<'t> + } + member this.Map(f : 't -> _) = + { Table = this.Table + Query = this.Query.Map(f) + } + +and ObjectInfo<'t> = + | TableLike of 't TableLikeExprInfo + | Index of SchemaIndex + | Missing + member this.Idempotent = + match this with + | TableLike t -> t.Query.Idempotent + | Index _ + | Missing -> true + member this.Table = + match this with + | TableLike t -> t + | other -> bug <| sprintf "Bug: expected table, but found reference to %A" other + member this.Query = this.Table.Query + member this.Columns = this.Query.Columns + member this.Map<'t1>(f : 't -> 't1) : ObjectInfo<'t1> = + match this with + | TableLike t -> TableLike (t.Map(f)) + | Index i -> Index i + | Missing -> Missing + + +and TSelectStmt = SelectStmt +and TCreateViewStmt = CreateViewStmt + +type TExprType = ExprType +type TExpr = Expr +type TInExpr = InExpr +type TCollationExpr = CollationExpr +type TBetweenExpr = BetweenExpr +type TSimilarityExpr = SimilarityExpr +type TBinaryExpr = BinaryExpr +type TUnaryExpr = UnaryExpr +type TObjectName = ObjectName +type TColumnName = ColumnName +type TInSet = InSet +type TCaseExpr = CaseExpr +type TCastExpr = CastExpr +type TFunctionInvocationExpr = FunctionInvocationExpr + +type TWithClause = WithClause +type TCommonTableExpression = CommonTableExpression +type TCompoundExprCore = CompoundExprCore +type TCompoundExpr = CompoundExpr +type TCompoundTermCore = CompoundTermCore +type TCompoundTerm = CompoundTerm +type TForeignKeyClause = ForeignKeyClause +type TCreateTableDefinition = CreateTableDefinition +type TCreateTableStmt = CreateTableStmt +type TSelectCore = SelectCore +type TJoinConstraint = JoinConstraint +type TJoin = Join +type TLimit = Limit +type TGroupBy = GroupBy +type TOrderingTerm = OrderingTerm +type TResultColumn = ResultColumn +type TResultColumns = ResultColumns +type TTableOrSubquery = TableOrSubquery +type TTableExprCore = TableExprCore +type TTableExpr = TableExpr +type TTableInvocation = TableInvocation + +type TColumnConstraint = ColumnConstraint +type TColumnDef = ColumnDef +type TAlterTableStmt = AlterTableStmt +type TAlterTableAlteration = AlterTableAlteration +type TCreateIndexStmt = CreateIndexStmt +type TTableIndexConstraintClause = TableIndexConstraintClause +type TTableConstraint = TableConstraint +type TQualifiedTableName = QualifiedTableName +type TDeleteStmt = DeleteStmt +type TDropObjectStmt = DropObjectStmt +type TUpdateStmt = UpdateStmt +type TInsertStmt = InsertStmt +type TStmt = Stmt +type TVendorStmt = VendorStmt +type TTotalStmt = TotalStmt +type TTotalStmts = TTotalStmt IReadOnlyList \ No newline at end of file diff --git a/Rezoom.SQL.Compiler/ModelChange.fs b/Rezoom.SQL.Compiler/ModelChange.fs new file mode 100644 index 0000000..393cc7e --- /dev/null +++ b/Rezoom.SQL.Compiler/ModelChange.fs @@ -0,0 +1,161 @@ +namespace Rezoom.SQL.Compiler +open System +open System.Collections.Generic +open Rezoom.SQL.Compiler.InferredTypes + +type private ModelChange(model : Model, inference : ITypeInferenceContext) = + member private this.CreateTableColumns(model, schemaName : Name, tableName : Name, asSelect : InfSelectStmt) = + let query = asSelect.Value.Info.Query + [| for column in query.Columns -> + { SchemaName = schemaName + TableName = tableName + ColumnName = column.ColumnName + PrimaryKey = column.Expr.Info.PrimaryKey + ColumnType = inference.Concrete(column.Expr.Info.Type) // unfortunate but necessary + } + |] + member private this.CreateTable(create : InfCreateTableStmt) = + let defaultSchema = if create.Temporary then model.TemporarySchema else model.DefaultSchema + let schema = create.Name.SchemaName |? defaultSchema + match model.Schemas.TryFind(schema) with + | None -> failAt create.Name.Source <| Error.noSuchSchema schema + | Some schema -> + let tableName = create.Name.ObjectName + match schema.Objects |> Map.tryFind tableName with + | Some _ -> failAt create.Name.Source <| Error.objectAlreadyExists create.Name + | None -> + let table = + match create.As with + | CreateAsSelect select -> + { SchemaName = schema.SchemaName + TableName = create.Name.ObjectName + Columns = + this.CreateTableColumns(model, schema.SchemaName, tableName, select) + |> mapBy (fun c -> c.ColumnName) + Indexes = Map.empty + Constraints = Map.empty + } + | CreateAsDefinition def -> + SchemaTable.OfCreateDefinition(schema.SchemaName, tableName, def) + let schema = + { schema with Objects = schema.Objects |> Map.add table.TableName (SchemaTable table) } + Some { model with Schemas = model.Schemas |> Map.add schema.SchemaName schema } + member this.AlterTable(alter : InfAlterTableStmt) = + match model.Schema(alter.Table.SchemaName) with + | None -> failAt alter.Table.Source <| Error.noSuchSchema alter.Table + | Some schema -> + let tblName = alter.Table.ObjectName + match schema.Objects |> Map.tryFind tblName with + | None -> failAt alter.Table.Source <| Error.noSuchTable alter.Table + | Some (SchemaTable tbl) -> + match alter.Alteration with + | RenameTo newName -> + match schema.Objects |> Map.tryFind newName with + | None -> + let objects = + schema.Objects |> Map.remove tblName |> Map.add newName (SchemaTable tbl) + let schema = { schema with Objects = objects } + Some { model with Schemas = model.Schemas |> Map.add schema.SchemaName schema } + | Some existing -> + failAt alter.Table.Source <| Error.objectAlreadyExists newName + | AddColumn col -> + let newTbl = tbl.WithAdditionalColumn(col) |> resultAt alter.Table.Source + let constraints = + col.Constraints + |> Seq.fold (fun state con -> + Map.add con.Name + { SchemaName = tbl.SchemaName + TableName = tbl.TableName + ConstraintName = con.Name + Columns = Set.singleton col.Name + } state) newTbl.Constraints + let newTbl = { newTbl with Constraints = constraints } + let schema = { schema with Objects = schema.Objects |> Map.add tbl.TableName (SchemaTable newTbl) } + Some { model with Schemas = model.Schemas |> Map.add schema.SchemaName schema } + | Some _ -> failAt alter.Table.Source <| Error.objectNotATable alter.Table + member this.CreateView(create : InfCreateViewStmt) = + let viewName = create.ViewName.ObjectName + match model.Schema(create.ViewName.SchemaName) with + | None -> failAt create.ViewName.Source <| Error.noSuchSchema create.ViewName + | Some schema -> + match schema.Objects |> Map.tryFind viewName with + | Some _ -> + failAt create.ViewName.Source <| Error.objectAlreadyExists create.ViewName + | None -> + let view = + { SchemaName = schema.SchemaName + ViewName = viewName + CreateDefinition = ASTMapping.Stripper().CreateView(create) + } |> SchemaView + let schema = { schema with Objects = schema.Objects |> Map.add create.ViewName.ObjectName view } + Some { model with Schemas = model.Schemas |> Map.add schema.SchemaName schema } + member this.DropObject(drop : InfDropObjectStmt) = + let objName = drop.ObjectName.ObjectName + let typeName = + match drop.Drop with + | DropTable -> "table" + | DropView -> "view" + | DropIndex -> "index" + match model.Schema(drop.ObjectName.SchemaName) with + | None -> failAt drop.ObjectName.Source <| Error.noSuchSchema objName + | Some schema -> + let dropped() = + let droppedSchema = { schema with Objects = schema.Objects |> Map.remove objName } + Some { model with Schemas = model.Schemas |> Map.add schema.SchemaName droppedSchema } + match schema.Objects |> Map.tryFind objName with + | None -> + failAt drop.ObjectName.Source <| Error.noSuchObject typeName objName + | Some o -> + match drop.Drop, o with + | DropTable, SchemaTable _ + | DropView, SchemaView _ -> + dropped() + | _ -> + failAt drop.ObjectName.Source <| Error.objectIsNotA typeName objName + member this.CreateIndex(create : InfCreateIndexStmt) = + match model.Schema(create.IndexName.SchemaName), model.Schema(create.TableName.SchemaName) with + | Some schema, Some tableSchema when schema.SchemaName = tableSchema.SchemaName -> + let table = + match schema.Objects |> Map.tryFind create.TableName.ObjectName with + | None -> failAt create.TableName.Source <| Error.noSuchTable create.TableName + | Some (SchemaTable table) -> table + | Some _ -> failAt create.TableName.Source <| Error.objectIsNotA "table" create.TableName + if schema.ContainsObject(create.IndexName.ObjectName) then + failAt create.IndexName.Source <| Error.objectAlreadyExists create.IndexName + let index = + { SchemaName = schema.SchemaName + TableName = table.TableName + IndexName = create.IndexName.ObjectName + Columns = create.IndexedColumns |> Seq.map fst |> Set.ofSeq + } + let table = + { table with + Indexes = Map.add index.IndexName index table.Indexes + } + let schema = + { schema with + Objects = schema.Objects |> Map.add table.TableName (SchemaTable table) + } + Some { model with Schemas = model.Schemas |> Map.add schema.SchemaName schema } + | Some _, Some _ -> + failAt create.IndexName.Source <| Error.indexSchemasMismatch create.IndexName create.TableName + | None, Some _ -> + failAt create.IndexName.Source <| Error.noSuchSchema create.IndexName + | _, None -> + failAt create.TableName.Source <| Error.noSuchSchema create.TableName + + member this.Stmt(stmt : InfStmt) = + match stmt with + | AlterTableStmt alter -> this.AlterTable(alter) + | CreateTableStmt create -> this.CreateTable(create) + | CreateViewStmt create -> this.CreateView(create) + | CreateIndexStmt create -> this.CreateIndex(create) + | DropObjectStmt drop -> this.DropObject(drop) + | BeginStmt + | CommitStmt + | DeleteStmt _ + | InsertStmt _ + | RollbackStmt + | SelectStmt _ + | UpdateStmt _ -> None + diff --git a/Rezoom.SQL.Compiler/Name.fs b/Rezoom.SQL.Compiler/Name.fs new file mode 100644 index 0000000..4f8a2a3 --- /dev/null +++ b/Rezoom.SQL.Compiler/Name.fs @@ -0,0 +1,37 @@ +namespace Rezoom.SQL.Compiler +open System +open System.Collections.Generic + +type Name(str : string) = + member inline private __.String = str + member inline private __.InlineEquals(other : Name) = + str.Equals(other.String, StringComparison.OrdinalIgnoreCase) + member inline private __.InlineCompareTo(other : Name) = + String.Compare(str, other.String, StringComparison.OrdinalIgnoreCase) + member this.Value = str + member this.Equals(other) = this.InlineEquals(other) + member this.CompareTo(other) = this.InlineCompareTo(other) + override __.ToString() = str + override this.Equals(other : obj) = + match other with + | :? Name as name -> this.InlineEquals(name) + | _ -> false + override this.GetHashCode() = + StringComparer.OrdinalIgnoreCase.GetHashCode(str) + interface IEquatable with + member this.Equals(name) = this.InlineEquals(name) + interface IComparable with + member this.CompareTo(other) = this.InlineCompareTo(other) + interface IComparable with + member this.CompareTo(other) = + match other with + | :? Name as name -> this.InlineCompareTo(name) + | _ -> invalidArg "other" "Argument is not a Name" + + static member op_Explicit(name : Name) = name.String + static member op_Explicit(name : string) = Name(name) + static member (+) (name : Name, str : string) = Name(name.String + str) + static member (+) (str : string, name : Name) = Name(str + name.String) + static member (+) (name1 : Name, name2 : Name) = Name(name1.String + name2.String) + + diff --git a/Rezoom.SQL.Compiler/Parser.fs b/Rezoom.SQL.Compiler/Parser.fs new file mode 100644 index 0000000..7b5363e --- /dev/null +++ b/Rezoom.SQL.Compiler/Parser.fs @@ -0,0 +1,169 @@ +// Parses all AST statements. + +module private Rezoom.SQL.Compiler.Parser +open System +open System.Collections.Generic +open System.Globalization +open FParsec +open FParsec.Pipes +open FParsec.Pipes.Precedence +open Rezoom.SQL.Compiler +open Rezoom.SQL.Compiler.CoreParser + +// Vendor statements allow embedding raw SQL written for a specific backend. +// Sometimes the set of SQL we can typecheck is insufficient. +// In their simplest form, vendor statements look like: +// VENDOR TSQL { +// -- literally any text could go in here +// -- as an example I'm using T-SQL's delete/join syntax which is not supported by the typechecked language. +// delete f from Foos f +// join Bars b on b.FooId = f.Id +// and f.Name like {@name} +// } + +// Notice that the vendor statement includes the name of the backend it was written for. +// This is intended to ease transitions to other backends, because when you change the targeted +// backend in your config file, the compiler will shriek and point to all the places you're using +// vendor-specific SQL for the old backend. + +// The delimiter can be any sequence of punctuation, in order to permit vendor statements +// that contain funky characters. For example, the above could also be written like: +// VENDOR TSQL {<[ +// delete f from Foos f +// join Bars b on b.FooId = f.Id +// and f.Name like {<[@name]>} +// ]>} + +// Notice: +// 1. The closing delimiter is a "flipped" version of the opening delimiter, which is a string reverse +// with some characters also switched out for their "inverse", e.g. [ and ]. +// 2. The same delimiter pairs are used to wrap parameters used within the vendor statement. +// We must separate parameters from the raw SQL text since when batching SQL statements, the parameter +// names are generated dynamically at runtime. + +// By default, since we can't statically determine anything about vendor statements, +// they have some unfortunate properties. +// 1. They are considered to invalidate the cache for all tables. (we assume the worst) +// 2. The parameters used with them get no type constraints added, so if they are not used +// in another statement, they will all have `obj` type. +// 3. They are asssumed to have no result sets of interest. + +// In order to avoid this, you should specify a second set of statements in the typechecked SQL dialect. +// These will never execute, and therefore need not match the actual semantics of the vendor statements, but +// the typechecker will use them to assume facts about the parameters and cache effects of the vendor statements. +// You do this with the IMAGINE clause. For example: +// VENDOR TSQL { +// select Name, count(*) over (partition by Name) as Count from Users group by Name +// } IMAGINE { +// -- inform the typechecker that we depend on the Users table and have a result set +// -- of type (Name string, Count int) +// select Name, 1 as Count from Users +// } + +let isDelimiterCharacter = + function + | '[' | ']' + | '{' | '}' + | '(' | ')' + | '<' | '>' + | '/' | '\\' + | ':' | '.' | ',' | '?' | '|' + | '!' | '@' | '#' | '$' | '%' + | '^' | '&' | '*' | '-' | '+' + | '=' | '~' -> true + | _ -> false + +let private flipChar = + function + | '[' -> ']' + | ']' -> '[' + | '{' -> '}' + | '}' -> '{' + | '(' -> ')' + | ')' -> '(' + | '<' -> '>' + | '>' -> '<' + | '/' -> '\\' + | '\\' -> '/' + | c -> c + +let private flipDelimiter (delim : string) = + delim.ToCharArray() + |> Array.rev + |> Array.map flipChar + |> String + +let private vendorStmtStart = + %% ci "VENDOR" + -- ws1 + -- +.withSource name + -- ws1 + -- +.many1Satisfy isDelimiterCharacter + -|> fun vendorName delim -> vendorName, delim + +type private Delimiter = + | OpenDelimiter + | CloseDelimiter + +let private vendorFragments openDelim closeDelim = + let delim = openDelim <|> closeDelim + let exprWithClose = ws >>. expr .>> ws .>> closeDelim + let onExpr str e next = + VendorRaw str + :: VendorEmbeddedExpr e + :: next + let self, selfRef = createParserForwardedToRef() + let onOpen str = + pipe2 exprWithClose self (onExpr str) + let onClose str = + preturn [ VendorRaw str ] + let onEither (str, delim) = + match delim with + | OpenDelimiter -> onOpen str + | CloseDelimiter -> onClose str + selfRef := + manyCharsTillApply anyChar delim + (fun str delim -> str, delim) + >>= onEither + self |>> List.toArray + +let private vendorStmt = + vendorStmtStart + >>= fun (vendorName, openDelim) -> + let closeDelim = flipDelimiter openDelim + if closeDelim = openDelim then + fail (Error.sameVendorDelimiters closeDelim) + else + let openDelim = pstring openDelim >>% OpenDelimiter + let closeDelim = pstring closeDelim >>% CloseDelimiter + let body = + vendorFragments openDelim closeDelim + |>> fun frags imaginary -> + { VendorName = vendorName + Fragments = frags + ImaginaryStmts = imaginary + } |> VendorStmt + let imaginary = + pstringCI "IMAGINE" + >>. ws1 + >>. openDelim + >>. coreStmts + .>> closeDelim + pipe2 (body .>> ws) (opt imaginary) (<|) + +let stmt = vendorStmt <|> (coreStmt |>> CoreStmt) + +let stmts = + %% ws + -- +.(qty.[0..] /. tws ';' * tws stmt) + -|> fun s -> s.ToArray() + +let parseStatements sourceDescription source = + if source = Error.jamesBondEasterEgg then failwith Error.jamesBond + match runParserOnString (stmts .>> eof) () sourceDescription source with + | Success (statements, _, _) -> statements + | Failure (_, err, _) -> + let sourceInfo = SourceInfo.OfPosition(translatePosition err.Position) + use writer = new System.IO.StringWriter() + err.WriteTo(writer, (fun _ _ _ _ -> ())) + failAt sourceInfo <| Error.parseError writer diff --git a/Rezoom.SQL.Compiler/ReadWriteReferences.fs b/Rezoom.SQL.Compiler/ReadWriteReferences.fs new file mode 100644 index 0000000..35aaccc --- /dev/null +++ b/Rezoom.SQL.Compiler/ReadWriteReferences.fs @@ -0,0 +1,212 @@ +module private Rezoom.SQL.Compiler.ReadWriteReferences +open System +open System.Collections.Generic +open Rezoom.SQL + +type private ReferenceType = + | ReadReference + | WriteReference + +type private ReferenceFinder() = + let referencedViews = HashSet() + let references = + { new IEqualityComparer with + member __.Equals(left, right) = + left.SchemaName = right.SchemaName + && left.TableName = right.TableName + member __.GetHashCode(tbl) = + (tbl.SchemaName, tbl.TableName).GetHashCode() + } |> Dictionary + let addReference table refType = + let succ, existing = references.TryGetValue(table) + let updated = + Set.add refType <| + if succ then existing + else Set.empty + references.[table] <- updated + member __.References = + seq { + for kv in references do + yield kv.Key, kv.Value + } + member this.ReferenceObject(reference : ReferenceType, name : TObjectName) = + match name.Info with + | TableLike { Table = TableReference schemaTable } -> + addReference schemaTable reference + | TableLike { Table = ViewReference(schemaView, createDef) } -> + if referencedViews.Add(schemaView.SchemaName, schemaView.ViewName) then + this.Select(createDef.AsSelect) + | _ -> () + member this.ReferenceColumn(reference : ReferenceType, column : TColumnName) = + match column.Table with + | None -> () + | Some tbl -> this.ReferenceObject(reference, tbl) + member this.Binary(binary : TBinaryExpr) = + this.Expr(binary.Left) + this.Expr(binary.Right) + member this.Unary(unary : TUnaryExpr) = this.Expr(unary.Operand) + member this.Cast(cast : TCastExpr) = this.Expr(cast.Expression) + member this.Collation(collation : TCollationExpr) = this.Expr(collation.Input) + member this.FunctionInvocation(func : TFunctionInvocationExpr) = + match func.Arguments with + | ArgumentList (_, exprs) -> + for expr in exprs do this.Expr(expr) + | _ -> () + member this.Similarity(sim : TSimilarityExpr) = + this.Expr(sim.Input) + this.Expr(sim.Pattern) + Option.iter this.Expr sim.Escape + member this.Between(between : TBetweenExpr) = + this.Expr(between.Input) + this.Expr(between.Low) + this.Expr(between.High) + member this.In(inex : TInExpr) = + this.Expr(inex.Input) + match inex.Set.Value with + | InExpressions exprs -> for expr in exprs do this.Expr(expr) + | InSelect select -> this.Select(select) + | InTable table -> this.TableInvocation(table) + | InParameter _ -> () + member this.Case(case : TCaseExpr) = + Option.iter this.Expr case.Input + for whenExpr, thenExpr in case.Cases do + this.Expr(whenExpr) + this.Expr(thenExpr) + Option.iter this.Expr case.Else.Value + member this.ExprType(expr : TExprType) : unit = + match expr with + | ColumnNameExpr name -> this.ReferenceColumn(ReadReference, name) + | CastExpr cast -> this.Cast(cast) + | CollateExpr collation -> this.Collation(collation) + | FunctionInvocationExpr func -> this.FunctionInvocation(func) + | SimilarityExpr sim -> this.Similarity(sim) + | BinaryExpr bin -> this.Binary(bin) + | UnaryExpr un -> this.Unary(un) + | BetweenExpr between -> this.Between(between) + | InExpr inex -> this.In(inex) + | ExistsExpr select -> this.Select(select) + | CaseExpr case -> this.Case(case) + | ScalarSubqueryExpr select -> this.Select(select) + | RaiseExpr _ + | LiteralExpr _ + | BindParameterExpr _ -> () + member this.Expr(expr : TExpr) = this.ExprType(expr.Value) + member this.TableInvocation(table : TTableInvocation) = + this.ReferenceObject(ReadReference, table.Table) + match table.Arguments with + | Some args -> for arg in args do this.Expr(arg) + | None -> () + member this.CTE(cte : TCommonTableExpression) = this.Select(cte.AsSelect) + member this.WithClause(withClause : TWithClause) = for table in withClause.Tables do this.CTE(table) + member this.OrderingTerm(orderingTerm : TOrderingTerm) = this.Expr(orderingTerm.By) + member this.Limit(limit : TLimit) = + this.Expr(limit.Limit) + Option.iter this.Expr limit.Offset + member this.ResultColumn(resultColumn : TResultColumn) = + match resultColumn.Case with + | Column (expr, _) -> this.Expr(expr) + | _ -> failwith "BUG: result column wildcards should've been expanded by now" + member this.ResultColumns(resultColumns : TResultColumns) = + for col in resultColumns.Columns do this.ResultColumn(col) + member this.TableOrSubquery(table : TTableOrSubquery) = + match table.Table with + | Table (tinvoc, _) -> this.TableInvocation(tinvoc) + | Subquery select -> this.Select(select) + member this.JoinConstraint(constr : TJoinConstraint) = + match constr with + | JoinOn expr -> this.Expr(expr) + | JoinUnconstrained -> () + member this.Join(join : TJoin) = + this.TableExpr(join.LeftTable) + this.TableExpr(join.RightTable) + this.JoinConstraint(join.Constraint) + member this.TableExpr(table : TTableExpr) = + match table.Value with + | TableOrSubquery sub -> this.TableOrSubquery(sub) + | Join join -> this.Join(join) + member this.GroupBy(groupBy : TGroupBy) = + for by in groupBy.By do this.Expr(by) + Option.iter this.Expr groupBy.Having + member this.SelectCore(select : TSelectCore) = + this.ResultColumns(select.Columns) + Option.iter this.TableExpr select.From + Option.iter this.Expr select.Where + Option.iter this.GroupBy select.GroupBy + member this.CompoundTerm(term : TCompoundTerm) = + match term.Value with + | Values rows -> + for row in rows do + for col in row.Value do + this.Expr(col) + | Select select -> this.SelectCore(select) + member this.Compound(compound : TCompoundExpr) = + match compound.Value with + | CompoundTerm term -> this.CompoundTerm(term) + | Union (expr, term) + | UnionAll (expr, term) + | Intersect (expr, term) + | Except (expr, term) -> + this.Compound(expr) + this.CompoundTerm(term) + member this.Select(select : TSelectStmt) : unit = + Option.iter this.WithClause select.Value.With + this.Compound(select.Value.Compound) + Option.iter this.Limit select.Value.Limit + match select.Value.OrderBy with + | Some terms -> for term in terms do this.OrderingTerm(term) + | None -> () + member this.Delete(delete : TDeleteStmt) = + this.ReferenceObject(WriteReference, delete.DeleteFrom.TableName) + Option.iter this.WithClause delete.With + Option.iter this.Expr delete.Where + Option.iter this.Limit delete.Limit + match delete.OrderBy with + | Some terms -> for term in terms do this.OrderingTerm(term) + | None -> () + member this.Insert(insert : TInsertStmt) = + this.ReferenceObject(WriteReference, insert.InsertInto) + Option.iter this.WithClause insert.With + Option.iter this.Select insert.Data + member this.Update(update : TUpdateStmt) = + this.ReferenceObject(WriteReference, update.UpdateTable.TableName) + Option.iter this.WithClause update.With + Option.iter this.Expr update.Where + Option.iter this.Limit update.Limit + for _, setTo in update.Set do this.Expr(setTo) + match update.OrderBy with + | Some terms -> for term in terms do this.OrderingTerm(term) + | None -> () + member this.Stmt(stmt : TStmt) = + match stmt with + | DeleteStmt delete -> this.Delete(delete) + | InsertStmt insert -> this.Insert(insert) + | SelectStmt select -> this.Select(select) + | UpdateStmt update -> this.Update(update) + | AlterTableStmt _ + | CreateIndexStmt _ + | CreateTableStmt _ + | CreateViewStmt _ + | DropObjectStmt _ + | BeginStmt + | CommitStmt + | RollbackStmt -> () + +type References = + { TablesRead : SchemaTable IReadOnlyList + TablesWritten : SchemaTable IReadOnlyList + } + +let references (stmts : TStmt seq) = + let finder = ReferenceFinder() + for stmt in stmts do finder.Stmt(stmt) + let tablesRead = ResizeArray() + let tablesWritten = ResizeArray() + let dependencyTargets = ResizeArray() + for table, set in finder.References do + for depTy in set do + match depTy with + | ReadReference -> tablesRead.Add(table) + | WriteReference -> tablesWritten.Add(table) + { TablesRead = tablesRead :> _ IReadOnlyList + TablesWritten = tablesWritten :> _ IReadOnlyList + } \ No newline at end of file diff --git a/Rezoom.SQL.Compiler/Rezoom.SQL.Compiler.fsproj b/Rezoom.SQL.Compiler/Rezoom.SQL.Compiler.fsproj new file mode 100644 index 0000000..2bb703a --- /dev/null +++ b/Rezoom.SQL.Compiler/Rezoom.SQL.Compiler.fsproj @@ -0,0 +1,128 @@ + + + + + Debug + AnyCPU + 2.0 + 87fcd04a-1f90-4d53-a428-cf5f5c532a22 + Library + Rezoom.SQL.Compiler + Rezoom.SQL.Compiler + v4.6 + 4.4.0.0 + true + Rezoom.SQL.Compiler + + + true + full + false + false + bin\Debug\ + DEBUG;TRACE + 3 + bin\Debug\Rezoom.SQL.Compiler.XML + + + pdbonly + true + true + bin\Release\ + TRACE + 3 + bin\Release\Rezoom.SQL.Compiler.XML + + + 11 + + + + + $(MSBuildExtensionsPath32)\..\Microsoft SDKs\F#\3.0\Framework\v4.0\Microsoft.FSharp.Targets + + + + + $(MSBuildExtensionsPath32)\Microsoft\VisualStudio\v$(VisualStudioVersion)\FSharp\Microsoft.FSharp.Targets + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + ..\packages\FParsec.1.0.2\lib\net40-client\FParsec.dll + True + + + ..\packages\FParsec-Pipes.0.3.1.0\lib\net45\FParsec-Pipes.dll + True + + + ..\packages\FParsec.1.0.2\lib\net40-client\FParsecCS.dll + True + + + + True + + + + + + + + + Rezoom.SQL.Mapping + {6b6a06c5-157a-4fe3-8b4c-2a1ae6a15333} + True + + + Rezoom + {d98acbeb-a039-4340-a7c5-6ed2b677268b} + True + + + + \ No newline at end of file diff --git a/Rezoom.SQL.Compiler/SQLite.fs b/Rezoom.SQL.Compiler/SQLite.fs new file mode 100644 index 0000000..98676b3 --- /dev/null +++ b/Rezoom.SQL.Compiler/SQLite.fs @@ -0,0 +1,135 @@ +namespace Rezoom.SQL.Compiler.SQLite +open System +open System.Data +open System.Collections.Generic +open System.Globalization +open Rezoom.SQL.Compiler +open Rezoom.SQL.Compiler.BackendUtilities +open Rezoom.SQL.Compiler.Translators +open Rezoom.SQL.Mapping + +type private SQLiteLiteral() = + inherit DefaultLiteralTranslator() + override __.BooleanLiteral(t) = CommandText <| if t then "1" else "0" + +type private SQLiteExpression(statement : StatementTranslator, indexer) = + inherit DefaultExprTranslator(statement, indexer) + let literal = SQLiteLiteral() + override __.Literal = upcast literal + override __.TypeName(name) = + (Seq.singleton << text) <| + match name with + | BooleanTypeName + | IntegerTypeName Integer8 + | IntegerTypeName Integer16 + | IntegerTypeName Integer32 + | IntegerTypeName Integer64 -> "INTEGER" + | FloatTypeName Float32 + | FloatTypeName Float64 -> "FLOAT" + | StringTypeName(_) -> "VARCHAR" + | BinaryTypeName(_) -> "BLOB" + | DecimalTypeName + | DateTimeTypeName + | DateTimeOffsetTypeName -> failwith <| sprintf "Unsupported type ``%A``" name + +type private SQLiteStatement(indexer : IParameterIndexer) as this = + inherit DefaultStatementTranslator(Name("SQLITE"), indexer) + let expr = SQLiteExpression(this :> StatementTranslator, indexer) + override __.Expr = upcast expr + override __.ColumnsNullableByDefault = true + +module private SQLiteFunctions = + open Rezoom.SQL.Compiler.FunctionDeclarations + let private minmax name = + { new FunctionType(Name(name), [| infect a'; vararg (infect a') |], a', idem = true) with + override __.Aggregate(arg) = + match arg with + | ArgumentWildcard -> None + | ArgumentList (_, exprs) -> + if exprs.Length = 1 then + Some { AllowWildcard = false; AllowDistinct = false } + else + None + } + let functions = + let numeric ty = ty |> constrained NumericTypeClass + [| // core functions from https://www.sqlite.org/lang_corefunc.html + proc "changes" [] int64 + func "char" [ vararg string ] string + func "glob" [ infect string; infect string ] boolean + func "hex" [ binary ] string + func "ifnull" [ nullable a'; infect a' ] a' + func "instr" [ infect (stringish a'); infect a' ] int64 + proc "last_insert_rowid" [] int64 + func "length" [ infect (stringish scalar) ] int64 + func "like" [ infect string; infect string; optional (infect string) ] boolean + func "likelihood" [ boolean; float64 ] boolean + func "likely" [ boolean ] boolean + // no load_extension + func "lower" [ infect string ] string + func "ltrim" [ infect string; optional (infect string) ] string + minmax "max" + minmax "min" + func "nullif" [ a'; a' ] (nullable a') + func "printf" [ infect string; vararg scalar ] string + func "quote" [ scalar ] string + proc "random" [] int64 + proc "randomblob" [] binary + func "replace" [ infect string; infect string; infect string ] string + func "round" [ infect float64; optional (infect integral) ] float64 + func "rtrim" [ infect string; optional (infect string) ] string + func "soundex" [ infect string ] string + func "sqlite_compileoption_get" [ integral ] string + func "sqlite_compileoption_used" [ infect string ] boolean + func "sqlite_source_id" [] string + func "sqlite_version" [] string + func "substr" [ infect string; infect integral; optional (infect integral) ] string + proc "total_changes" [] int64 + func "trim" [ infect string; optional (infect integral) ] string + func "typeof" [ scalar ] string + func "unicode" [ infect string ] int64 + func "unlikely" [ boolean ] boolean + func "upper" [ infect string ] string + func "zeroblob" [ integral ] binary + + // aggregate functions from https://www.sqlite.org/lang_aggfunc.html + aggregate "avg" [ numeric a' ] (nullable float64) + aggregateW "count" [ scalar ] int64 + aggregate "group_concat" [ infect string; optional string ] string + aggregate "sum" [ numeric a' ] a' + aggregate "total" [ numeric a' ] a' + + // date and time functions from https://www.sqlite.org/lang_datefunc.html + // for now we use strings to represent dates -- maybe should formalize this by using the datetime type + // even though its underlying representation will be a string + func "date" [ string; vararg string ] (nullable string) + func "time" [ string; vararg string ] (nullable string) + func "datetime" [ string; vararg string ] (nullable string) + func "julianday" [ string; vararg string ] (nullable string) + func "strftime" [ string; string; vararg string ] (nullable string) + |] |> DefaultFunctions.extendedBy + +type SQLiteBackend() = + static let initialModel = + let main, temp = Name("main"), Name("temp") + { Schemas = + [ Schema.Empty(main) + Schema.Empty(temp) + ] |> List.map (fun s -> s.SchemaName, s) |> Map.ofList + DefaultSchema = main + TemporarySchema = temp + Builtin = + { Functions = SQLiteFunctions.functions + } + } + interface IBackend with + member this.MigrationBackend = <@ fun conn -> DefaultMigrationBackend(conn) :> Migrations.IMigrationBackend @> + member this.InitialModel = initialModel + member this.ParameterTransform(columnType) = ParameterTransform.Default(columnType) + member this.ToCommandFragments(indexer, stmts) = + let translator = SQLiteStatement(indexer) + translator.TotalStatements(stmts) + |> BackendUtilities.simplifyFragments + |> ResizeArray + :> _ IReadOnlyList + \ No newline at end of file diff --git a/Rezoom.SQL.Compiler/SourceTypes.fs b/Rezoom.SQL.Compiler/SourceTypes.fs new file mode 100644 index 0000000..1efd55a --- /dev/null +++ b/Rezoom.SQL.Compiler/SourceTypes.fs @@ -0,0 +1,136 @@ +namespace Rezoom.SQL.Compiler +open System +open System.Collections.Generic + +/// The position in the source query that a syntactic element appeared. +type SourcePosition = + { Index : int + Line : int + Column : int + } + static member Invalid = + { Index = -1 + Line = -1 + Column = -1 + } + +type ParsingException(msg, pos : SourcePosition) = + inherit Exception(msg) + member this.Position = pos + +/// The span of (start, end) positions in the source file +/// that a syntactic element occupies. +type SourceInfo = + { StartPosition : SourcePosition + EndPosition : SourcePosition + } + static member Invalid = + { StartPosition = SourcePosition.Invalid + EndPosition = SourcePosition.Invalid + } + static member private ContextLength = 6 // words of context to show on each side + static member private ContextBefore(source : string) = + let mutable i = source.Length - 1 + let mutable inWord = false + let mutable boundaryCount = 0 + while i >= 0 && boundaryCount < SourceInfo.ContextLength do + if source.[i] = '\r' || source.[i] = '\n' then + boundaryCount <- SourceInfo.ContextLength + else + let inWordNow = Char.IsLetterOrDigit(source.[i]) + if inWord <> inWordNow && not inWordNow then + boundaryCount <- boundaryCount + 1 + inWord <- inWordNow + i <- i - 1 + i <- min (i + 1) (source.Length - 1) + source.Substring(i, source.Length - i) + static member private ContextAfter(source : string) = + let mutable i = 0 + let mutable inWord = false + let mutable boundaryCount = 0 + while i < source.Length && boundaryCount < SourceInfo.ContextLength do + if source.[i] = '\r' || source.[i] = '\n' then + boundaryCount <- SourceInfo.ContextLength + else + let inWordNow = Char.IsLetterOrDigit(source.[i]) + if inWord <> inWordNow && not inWordNow then + boundaryCount <- boundaryCount + 1 + inWord <- inWordNow + i <- i + 1 + i <- max 0 (i - 1) + source.Substring(0, i) + + static member private Emphasize(source : string) = + let trimmed = source.TrimEnd('\r', '\n', ' ', '\t') + let missing = source.Substring(trimmed.Length, source.Length - trimmed.Length) + " ⇨ " + trimmed + " ⇦ " + missing + member this.ShowInSource(source : string) = + if + this.StartPosition.Index < 0 + || this.EndPosition.Index < 0 + || this.StartPosition.Index >= int source.Length + || this.EndPosition.Index > int source.Length + then + "(no known source (possibly generated code))" + else + let context = 20 + let before = SourceInfo.ContextBefore(source.Substring(0, this.StartPosition.Index)) + let after = SourceInfo.ContextAfter(source.Substring(this.EndPosition.Index)) + let middle = source.Substring(this.StartPosition.Index, this.EndPosition.Index - this.StartPosition.Index) + before + SourceInfo.Emphasize(middle) + after + static member OfPosition(pos : SourcePosition) = + { StartPosition = pos + EndPosition = pos + } + static member Between(left : SourceInfo, right : SourceInfo) = + { StartPosition = min left.EndPosition right.EndPosition + EndPosition = max left.StartPosition right.StartPosition + } + +/// `'a` with the positions in source that it spanned. +[] +[] +type WithSource<'a> = + { /// The position in source of the syntactic element + Source : SourceInfo + /// The syntactic element + Value : 'a + } + member this.Map(f) = { Source = this.Source; Value = f this.Value } + member this.Equals(other) = EqualityComparer<'a>.Default.Equals(this.Value, other.Value) + override this.Equals(other) = + match other with + | :? WithSource<'a> as other -> this.Equals(other) + | _ -> false + override this.GetHashCode() = (box this.Value).GetHashCode() + interface IEquatable> with + member this.Equals(other) = this.Equals(other) + +type SourceInfoException(msg : string, pos : SourceInfo) = + inherit Exception(msg) + member this.SourceInfo = pos + +type SourceException(msg : string, pos : SourceInfo, source, fileName) = + inherit Exception + ( msg.TrimEnd('.') + "." + + Environment.NewLine + + fileName + + "(" + + string pos.StartPosition.Line + + "," + + string pos.StartPosition.Column + + "):" + Environment.NewLine + + pos.ShowInSource(source) + ) + member __.FileName = fileName + member __.Reason = msg + member __.FullSourceContext = source + member __.SourceContext = pos.ShowInSource(source) + +[] +module SourceInfoModule = + let inline catchSource fileName source f = + try f() + with + | :? SourceInfoException as exn -> + raise (SourceException(exn.Message, exn.SourceInfo, source, fileName)) \ No newline at end of file diff --git a/Rezoom.SQL.Compiler/TSQL.fs b/Rezoom.SQL.Compiler/TSQL.fs new file mode 100644 index 0000000..447c619 --- /dev/null +++ b/Rezoom.SQL.Compiler/TSQL.fs @@ -0,0 +1,472 @@ +namespace Rezoom.SQL.Compiler.TSQL +open System +open System.Data +open System.Data.Common +open System.Collections.Generic +open System.Globalization +open Rezoom.SQL.Compiler +open Rezoom.SQL.Compiler.BackendUtilities +open Rezoom.SQL.Compiler.Translators +open Rezoom.SQL.Mapping + +module private TSQLFunctions = + open Rezoom.SQL.Compiler.FunctionDeclarations + type CustomTranslator = ExprTranslator -> TFunctionInvocationExpr -> Fragments + let private noArgProc name ret = + proc name [] ret, Some <| fun _ _ -> [| text <| name.ToUpperInvariant() |] :> _ seq + let private atAtProc name ret = + proc name [] ret, Some <| fun _ _ -> [| text <| "@@" + name.ToUpperInvariant() |] :> _ seq + let private datePartWhitelist = + [| "year"; "yy"; "yyyy" + "quarter"; "qq"; "q" + "month"; "mm"; "m" + "dayofyear"; "dy"; "y" + "day"; "dd"; "d" + "week"; "wk"; "ww" + "weekday"; "dw" + "hour"; "hh" + "minute"; "mi"; "n" + "second"; "ss"; "s" + "millisecond"; "ms" + "microsecond"; "mcs" + "nanosecond"; "ns" + "tzoffset"; "tz" + "iso_week"; "isowk"; "isoww" + |] |> fun arr -> HashSet(arr, StringComparer.OrdinalIgnoreCase) + let private datePartFunc name otherArgs ret = + func name (string :: otherArgs) ret, + Some <| fun (expr : ExprTranslator) (invoc : TFunctionInvocationExpr) -> + seq { + yield text invoc.FunctionName.Value + yield text "(" + match invoc.Arguments with + | ArgumentList (None, args) when args.Length > 0 -> + match args.[0] with + | { Value = LiteralExpr (StringLiteral lit) } -> + if datePartWhitelist.Contains(lit) then + yield text lit + else + failAt args.[0].Source <| + sprintf "DATEPART argument must be one of %A" (List.ofSeq datePartWhitelist) + | _ -> + failAt args.[0].Source "DATEPART argument must be a string literal" + for i = 1 to args.Length - 1 do + yield text "," + yield! expr.Expr(args.[i], FirstClassValue) + | _ -> bug "Can't use datePartFunc with no args" + yield text ")" + } + let iifCustom = + func "iif" [ boolean; infect a'; infect a' ] a', + Some <| fun (expr : ExprTranslator) (invoc : TFunctionInvocationExpr) -> + match invoc.Arguments with + | ArgumentList (None, [| cond; ifTrue; ifFalse |]) -> + [| yield text "IIF(" + yield! expr.Expr(cond, Predicate) + yield text "," + yield! expr.Expr(ifTrue, FirstClassValue) + yield text "," + yield! expr.Expr(ifFalse, FirstClassValue) + yield text ")" + |] :> _ seq + | _ -> bug "Impossible arguments to iif" + let private aggregate name args ret = aggregate name args ret, None + let private aggregateW name args ret = aggregateW name args ret, None + let private func name args ret = func name args ret, None + let private proc name args ret = proc name args ret, None + let private i = integral + let private ii = infect i + let private date = datetime + let private specialFunctions = Dictionary() + let private addCustom (funcType : FunctionType, custom) = + match custom with + | None -> funcType + | Some custom -> + specialFunctions.[funcType.FunctionName] <- custom + funcType + let getCustom (funcName : Name) = + let succ, value = specialFunctions.TryGetValue(funcName) + if succ then Some value else None + let functions = + [| // aggregate functions + aggregate "avg" [ numeric a' ] (nullable a') + aggregateW "count" [ scalar ] int32 + aggregateW "count_big" [ scalar ] int64 + aggregate "grouping" [ scalar ] int8 + aggregate "grouping_id" [ vararg scalar ] int32 + aggregate "max" [ a' ] (nullable a') + aggregate "min" [ a' ] (nullable a') + aggregate "sum" [ numeric a' ] a' + aggregate "stdev" [ numeric scalar ] (nullable float64) + aggregate "stdevp" [ numeric scalar ] (nullable float64) + aggregate "var" [ numeric scalar ] (nullable float64) + aggregate "varp" [ numeric scalar ] (nullable float64) + // @@FUNCTIONNAME builtins + atAtProc "datefirst" int8 + atAtProc "dbts" binary + atAtProc "langid" int8 + atAtProc "language" string + atAtProc "lock_timeout" int32 + atAtProc "max_connections" int32 + atAtProc "max_precision" int8 + atAtProc "nestlevel" int32 + atAtProc "options" int32 + atAtProc "remserver" string + atAtProc "servername" string + atAtProc "servicename" string + atAtProc "spid" int8 + atAtProc "textsize" int32 + atAtProc "version" string + atAtProc "cursor_rows" int32 + atAtProc "fetch_status" int32 + atAtProc "identity" i + // identity + proc "scope_identity" [] i + // date/time functions from https://msdn.microsoft.com/en-us/library/ms186724.aspx + noArgProc "current_timestamp" datetime + proc "sysdatetime" [] datetime + proc "sysdatetimeoffset" [] datetimeoffset + proc "sysutcdatetime" [] datetime + proc "getdate" [] datetime + proc "getutcdate" [] datetime + datePartFunc "datename" [ infect datetime ] string + datePartFunc "dateadd" [ infect datetime ] string + datePartFunc "datediff" [ infect datetime; infect datetime ] int32 + datePartFunc "datediff_big" [ infect datetime; infect datetime ] int64 + datePartFunc "dateadd" [ infect i; infect datetime ] datetime + func "day" [ infect datetime ] i + func "month" [ infect datetime ] i + func "year" [ infect datetime ] i + func "datefromparts" [ ii; ii; ii ] date + func "datetime2fromparts" [ ii; ii; ii; ii; ii; ii; ii; ii ] datetime + func "datetimefromparts" [ ii; ii; ii; ii; ii; ii; ii ] datetime + func "datetimeoffsetfromparts" [ ii; ii; ii; ii; ii; ii; ii; ii; ii; ii ] datetimeoffset + func "smalldatetimefromparts" [ ii; ii; ii; ii; ii ] datetime + // math funcs from https://msdn.microsoft.com/en-us/library/ms177516.aspx + func "acos" [ infect fractional ] float64 + func "asin" [ infect fractional ] float64 + func "atan" [ infect fractional ] float64 + func "atn2" [ infect fractional; infect fractional ] float64 + func "ceiling" [ infect (numeric a') ] a' + func "cos" [ infect fractional] float64 + func "cot" [ infect fractional ] float64 + func "degrees" [ infect (numeric a') ] a' + func "exp" [ infect fractional ] float64 + func "floor" [ infect (numeric a') ] a' + func "log" [ infect num; infect (optional i) ] float64 + func "log10" [ infect num ] float64 + func "pi" [] float64 + func "power" [ infect (numeric a'); infect num ] a' + func "radians" [ infect (numeric a') ] a' + func "rand" [ infect (optional i) ] float64 + func "round" [ infect (numeric a'); infect i ] a' + func "sign" [ infect (numeric a') ] a' + func "sin" [ infect fractional ] float64 + func "sqrt" [ infect (numeric a') ] float64 + func "square" [ infect (numeric a') ] float64 + func "tan" [ infect fractional ] float64 + // JSON functions from https://msdn.microsoft.com/en-us/library/dn921900.aspx + func "isjson" [ infect string ] boolean + func "json_value" [ infect string; infect string ] string + func "json_query" [ infect string; infect string ] string + func "json_modify" [ infect string; infect string; infect string ] string + // logical funcs from https://msdn.microsoft.com/en-us/library/hh213226.aspx + func "choose" [ infect i; vararg (infect a') ] a' + iifCustom + // skip over "metadata functions" (for now) from https://msdn.microsoft.com/en-us/library/ms187812.aspx + // ... + // also "security functions" (for now) from https://msdn.microsoft.com/en-us/library/ms186236.aspx + // ... + // so onto string functions from https://msdn.microsoft.com/en-us/library/ms181984.aspx + func "ascii" [ infect string ] int32 + func "concat" [ string; string; vararg string ] string + func "format" [ infect scalar; infect string; optional (infect string) ] string + func "lower" [ infect string ] string + func "upper" [ infect string ] string + func "patindex" [ infect string; infect string ] integral + func "replicate" [ infect string; infect integral ] string + func "rtrim" [ infect string ] string + func "ltrim" [ infect string ] string + func "str" [ infect fractional; varargN 2 integral ] string + // func "string_split" [ infect string; infect string ] string_table // wait till we can do TVFs + func "translate" [ infect string; infect string; infect string ] string + func "char" [ infect integral ] string + func "concat_ws" [ infect string; scalar; scalar; vararg scalar ] string + func "left" [ infect string; infect integral ] string + func "right" [ infect string; infect integral ] string + func "quotename" [ infect string; optional (infect string) ] string + func "reverse" [ infect string ] string + func "soundex" [ infect string ] string + // func "string_agg" // wtf, how do we support this? it has its own special clause type... + func "stuff" [ infect (a' |> constrained StringishTypeClass); infect integral; infect integral; string ] a' + func "trim" [ infect string ] string // come on TSQL, "characters from"? cut it out... + func "charindex" [ infect string; infect string ; optional integral ] integral + func "difference" [ infect string; infect string ] int32 + func "len" [ infect string ] integral + func "nchar" [ infect integral ] string + func "replace" [ infect string; infect string; infect string ] string + func "space" [ infect integral ] string + func "string_escape" [ infect string; infect string ] string // TODO: enforce literal on 2nd arg? + func "substring" [ infect a' |> constrained StringishTypeClass; infect integral; infect integral ] a' + func "unicode" [ infect string ] int32 + // missing: system functions, system statistical functions, text and image functions + |] |> Array.map addCustom |> DefaultFunctions.extendedBy + +type private TSQLLiteral() = + inherit DefaultLiteralTranslator() + override __.BooleanLiteral(t) = CommandText <| if t then "1" else "0" + override __.BlobLiteral(bytes) = + let hexPairs = bytes |> Array.map (fun b -> b.ToString("X2", CultureInfo.InvariantCulture)) + "0x" + String.Concat(hexPairs) |> text + +type private TSQLExpression(statement : StatementTranslator, indexer) = + inherit DefaultExprTranslator(statement, indexer) + let literal = TSQLLiteral() + override __.Literal = upcast literal + override __.Name(name) = + "[" + name.Value.Replace("]", "]]") + "]" + |> text + override __.TypeName(name) = + (Seq.singleton << text) <| + match name with + | BooleanTypeName -> "BIT" + | IntegerTypeName Integer8 -> "TINYINT" + | IntegerTypeName Integer16 -> "SMALLINT" + | IntegerTypeName Integer32 -> "INT" + | IntegerTypeName Integer64 -> "BIGINT" + | FloatTypeName Float32 -> "FLOAT(24)" + | FloatTypeName Float64 -> "FLOAT(53)" + | StringTypeName(len) -> "NVARCHAR(" + string len + ")" + | BinaryTypeName(len) -> "VARBINARY(" + string len + ")" + | DecimalTypeName -> "NUMERIC(38, 19)" + | DateTimeTypeName -> "DATETIME2" + | DateTimeOffsetTypeName -> "DATETIMEOFFSET" + override __.BinaryOperator(op) = + CommandText <| + match op with + | Concatenate -> "+" + | Multiply -> "*" + | Divide -> "/" + | Modulo -> "%" + | Add -> "+" + | Subtract -> "-" + | BitAnd -> "&" + | BitOr -> "|" + | LessThan -> "<" + | LessThanOrEqual -> "<=" + | GreaterThan -> ">" + | GreaterThanOrEqual -> ">=" + | Equal -> "=" + | NotEqual -> "<>" + | And -> "AND" + | Or -> "OR" + | Is + | IsNot -> bug "should have been handled for TSQL before we got here" + | BitShiftLeft + | BitShiftRight -> failwithf "Not supported by TSQL: %A" op + override this.Binary(bin) = + match bin.Operator, bin.Right.Value with + | Is, LiteralExpr NullLiteral + | IsNot, LiteralExpr NullLiteral -> + seq { + yield! this.Expr(bin.Left, FirstClassValue) + yield ws + yield text "IS" + yield ws + if bin.Operator = IsNot then + yield text "NOT" + yield ws + yield text "NULL" + } + | Is, _ + | IsNot, _ -> + seq { + if bin.Operator = IsNot then + yield text "NOT" + yield ws + yield text "EXISTS(SELECT" + yield ws + yield! this.Expr(bin.Left, FirstClassValue) + yield ws + yield text "INTERSECT SELECT" + yield ws + yield! this.Expr(bin.Right, FirstClassValue) + yield text ")" + } + | _ -> base.Binary(bin) + override __.UnaryOperator(op) = + CommandText <| + match op with + | Negative -> "-" + | Not -> "NOT" + | NotNull -> "IS NOT NULL" + | IsNull -> "IS NULL" + | BitNot -> "~" + override __.SimilarityOperator(op) = + CommandText <| + match op with + | Like -> "LIKE" + | Glob + | Match + | Regexp -> failwithf "Not supported by TSQL: %A" op + /// Identifies expressions that are set up to use as predicates in T-SQL. + /// These expressions don't produce actual values. + /// For example, you can't `SELECT 1=1`, but you can do `SELECT 1 WHERE 1=1`. + /// Conversely, you can't `SELECT 1 WHERE tbl.BitColumn`, but you can do `SELECT tbl.BitColumn`. + static member private IsPredicateBoolean(expr : TExpr) = + expr.Info.Type.Type = BooleanType + && match expr.Value with + | SimilarityExpr _ + | BetweenExpr _ + | InExpr _ + | ExistsExpr _ + | BinaryExpr _ + | UnaryExpr _ -> true + | _ -> false + member private __.BaseExpr(expr, context) = base.Expr(expr, context) + override this.Expr(expr, context) = + match context with + | FirstClassValue -> + if TSQLExpression.IsPredicateBoolean(expr) then + seq { + yield text "CAST((CASE WHEN" + yield ws + yield! this.BaseExpr(expr, Predicate) + yield ws + yield text "THEN 1 ELSE 0 END) AS BIT)" + } + else + base.Expr(expr, context) + | Predicate -> + if TSQLExpression.IsPredicateBoolean(expr) then + base.Expr(expr, context) + else + seq { + yield text "((" + yield! this.BaseExpr(expr, FirstClassValue) + yield text ")<>0)" + } + override this.Invoke(func) = + match TSQLFunctions.getCustom func.FunctionName with + | Some custom -> custom (this :> ExprTranslator) func + | None -> base.Invoke(func) + +type private TSQLStatement(indexer : IParameterIndexer) as this = + inherit DefaultStatementTranslator(Name("TSQL"), indexer) + let expr = TSQLExpression(this :> StatementTranslator, indexer) + override __.Expr = upcast expr + override __.ColumnsNullableByDefault = true + member this.SelectCoreWithTop(select : TSelectCore, top) = + seq { + yield text "SELECT" + yield ws + match top with + | None -> () + | Some top -> + yield text "TOP" + yield ws + yield! this.FirstClassValue(top) + yield ws + yield! this.ResultColumns(select.Columns) + match select.From with + | None -> () + | Some from -> + yield ws + yield text "FROM" + yield ws + yield! this.TableExpr(from) + match select.Where with + | None -> () + | Some where -> + yield ws + yield text "WHERE" + yield ws + yield! this.Predicate(where) + match select.GroupBy with + | None -> () + | Some groupBy -> + yield ws + yield text "GROUP BY" + yield ws + yield! groupBy.By |> Seq.map this.FirstClassValue |> join "," + match groupBy.Having with + | None -> () + | Some having -> + yield ws + yield text "HAVING" + yield ws + yield! this.Predicate(having) + } + override this.SelectCore(select) = this.SelectCoreWithTop(select, None) + override this.Select(select) = + match select.Value.Limit with + | None -> base.Select(select) + | Some limit -> + // TSQL doesn't exactly support LIMIT so what shall we do? + match limit.Offset, select.Value.Compound.Value with + | None, CompoundTerm { Value = Select core } -> + // We can use TOP here + this.SelectCoreWithTop(core, Some limit.Limit) + | _ -> + this.Select(select) // Our override of LIMIT will turn this into an offset/fetch clause + override this.Limit(limit) = + seq { + yield text "OFFSET" + yield ws + match limit.Offset with + | Some offset -> + yield! this.FirstClassValue(offset) + | None -> + yield text "0" + yield ws + yield text "ROWS FETCH NEXT" + yield ws + yield! this.FirstClassValue(limit.Limit) + yield ws + yield text "ROWS ONLY" + } + override this.AutoIncrement = "IDENTITY(1,1)" + +type TSQLMigrationBackend(conn : DbConnection) = + inherit DefaultMigrationBackend(conn) + override __.Initialize() = + use cmd = conn.CreateCommand() + cmd.CommandText <- + """ + IF NOT EXISTS ( + SELECT * FROM sys.tables t + JOIN sys.schemas s ON t.schema_id = s.schema_id + WHERE s.name = 'dbo' and t.name = '__RZSQL_MIGRATIONS' + ) + CREATE TABLE __RZSQL_MIGRATIONS + ( MajorVersion int + , Name varchar(256) + , UNIQUE (MajorVersion, Name) + ); + """ + ignore <| cmd.ExecuteNonQuery() + +type TSQLBackend() = + static let initialModel = + let main, temp = Name("dbo"), Name("temp") + { Schemas = + [ Schema.Empty(main) + Schema.Empty(temp) + ] |> List.map (fun s -> s.SchemaName, s) |> Map.ofList + DefaultSchema = main + TemporarySchema = temp + Builtin = + { Functions = TSQLFunctions.functions + } + } + interface IBackend with + member this.MigrationBackend = <@ fun conn -> TSQLMigrationBackend(conn) :> Migrations.IMigrationBackend @> + member this.InitialModel = initialModel + member this.ParameterTransform(columnType) = ParameterTransform.Default(columnType) + member this.ToCommandFragments(indexer, stmts) = + let translator = TSQLStatement(indexer) + translator.TotalStatements(stmts) + |> BackendUtilities.simplifyFragments + |> ResizeArray + :> _ IReadOnlyList + \ No newline at end of file diff --git a/Rezoom.SQL.Compiler/Translators.fs b/Rezoom.SQL.Compiler/Translators.fs new file mode 100644 index 0000000..d703a96 --- /dev/null +++ b/Rezoom.SQL.Compiler/Translators.fs @@ -0,0 +1,101 @@ +namespace Rezoom.SQL.Compiler.Translators +open System +open System.Data +open System.Collections.Generic +open Rezoom.SQL.Compiler +open Rezoom.SQL.Mapping +open Rezoom.SQL.Compiler.BackendUtilities + +// MSSQL doesn't treat booleans as first-class values, and we don't want to have to rewrite the +// entire statement translator for it, so we pass this context around to hint to the ExprTranslator +// that it may need to fudge in a "CASE WHEN expr THEN 1 ELSE 0 END" to get a usable value. +type ExprTranslationContext = + /// The expression is expected to produce a first-class value + /// that can be passed to functions, returned from a select, etc. + | FirstClassValue + /// The expression is expected to produce a value suitable for a predicate like a "WHERE" clause or + /// condition within a "CASE" expression. + | Predicate + +[] +type LiteralTranslator() = + abstract member NullLiteral : Fragment + abstract member BooleanLiteral : t : bool -> Fragment + abstract member StringLiteral : str : string -> Fragment + abstract member BlobLiteral : bytes : byte array -> Fragment + abstract member IntegerLiteral : i : uint64 -> Fragment + abstract member FloatLiteral : f : float -> Fragment + abstract member DateTimeLiteral : dt : DateTime -> Fragment + abstract member DateTimeOffsetLiteral : dt : DateTimeOffset -> Fragment + abstract member Literal : literal : Literal -> Fragment + abstract member SignedLiteral : literal : SignedNumericLiteral -> Fragments + +[] +type StatementTranslator() = + abstract member Expr : ExprTranslator + abstract member OrderDirection : OrderDirection -> Fragment + abstract member IndexHint : IndexHint -> Fragments + abstract member CTE : cte : TCommonTableExpression -> Fragments + abstract member With : withClause : TWithClause -> Fragments + abstract member Values : vals : TExpr array WithSource array -> Fragments + abstract member ResultColumn : expr : TExpr * alias : Alias -> Fragments + abstract member ResultColumns : TResultColumns -> Fragments + abstract member TableOrSubquery : TTableOrSubquery -> Fragments + abstract member TableExpr : TTableExpr -> Fragments + abstract member JoinType : JoinType -> Fragment + abstract member Join : TJoin -> Fragments + abstract member SelectCore : select : TSelectCore -> Fragments + abstract member CompoundTerm : compound : TCompoundTermCore -> Fragments + abstract member Compound : compound : TCompoundExprCore -> Fragments + abstract member Limit : TLimit -> Fragments + abstract member OrderingTerm : TOrderingTerm -> Fragments + abstract member Select : select : TSelectStmt -> Fragments + abstract member ForeignKeyRule : rule : ForeignKeyRule -> Fragments + abstract member ForeignKeyClause : clause : TForeignKeyClause -> Fragments + abstract member ColumnConstraint : constr : TColumnConstraint -> Fragments + abstract member ColumnDefinition : col : TColumnDef -> Fragments + abstract member CreateTableDefinition : create : TCreateTableDefinition -> Fragments + abstract member CreateTable : create : TCreateTableStmt -> Fragments + abstract member AlterTable : alter : TAlterTableStmt -> Fragments + abstract member CreateView : create : TCreateViewStmt -> Fragments + abstract member CreateIndex : create : TCreateIndexStmt -> Fragments + abstract member DropObject : drop : TDropObjectStmt -> Fragments + abstract member Insert : insert : TInsertStmt -> Fragments + abstract member Update : update : TUpdateStmt -> Fragments + abstract member Delete : delete : TDeleteStmt -> Fragments + abstract member Begin : Fragments + abstract member Commit : Fragments + abstract member Rollback : Fragments + abstract member Statement : TStmt -> Fragments + abstract member Statements : TStmt seq -> Fragments + abstract member Vendor : TVendorStmt -> Fragments + abstract member TotalStatement : TTotalStmt -> Fragments + abstract member TotalStatements : TTotalStmt seq -> Fragments + +and [] ExprTranslator(statement : StatementTranslator, indexer : IParameterIndexer) = + abstract member Literal : LiteralTranslator + abstract member Name : name : Name -> Fragment + abstract member BinaryOperator : op : BinaryOperator -> Fragment + abstract member UnaryOperator : op : UnaryOperator -> Fragment + abstract member SimilarityOperator : op : SimilarityOperator -> Fragment + abstract member BindParameter : par : BindParameter -> Fragment + abstract member ObjectName : name : TObjectName -> Fragments + abstract member ColumnName : column : TColumnName -> Fragments + abstract member TypeName : TypeName -> Fragments + abstract member Cast : castExpr : TCastExpr -> Fragments + abstract member Collate : expr : TExpr * collation : Name -> Fragments + abstract member Invoke : func : TFunctionInvocationExpr -> Fragments + abstract member Similarity : sim : TSimilarityExpr -> Fragments + abstract member Binary : bin : TBinaryExpr -> Fragments + abstract member Unary : un : TUnaryExpr -> Fragments + abstract member Between : between : TBetweenExpr -> Fragments + abstract member Table : TTableInvocation -> Fragments + abstract member In : inex : TInExpr -> Fragments + abstract member Case : case : TCaseExpr -> Fragments + abstract member Raise : raise : Raise -> Fragments + abstract member Exists : subquery : TSelectStmt -> Fragments + abstract member ScalarSubquery : subquery : TSelectStmt -> Fragments + abstract member NeedsParens : TExprType -> bool + abstract member Expr : expr : TExpr * context : ExprTranslationContext -> Fragments + + diff --git a/Rezoom.SQL.Compiler/TypeChecker.fs b/Rezoom.SQL.Compiler/TypeChecker.fs new file mode 100644 index 0000000..13474b6 --- /dev/null +++ b/Rezoom.SQL.Compiler/TypeChecker.fs @@ -0,0 +1,602 @@ +namespace Rezoom.SQL.Compiler +open System +open System.Collections.Generic +open Rezoom.SQL.Compiler.InferredTypes + +type InferredQueryShape = InferredType QueryExprInfo +type SelfQueryShape = + // this thing is for when we know ahead of time what the column names of a select statement are supposed to be + // so we don't want to require that they all be aliased manually. + { CTEName : Name option + KnownShape : InferredQueryShape option + } + static member Known(known) = { CTEName = None; KnownShape = known } + static member Known(known) = SelfQueryShape.Known(Some known) + static member Unknown = { CTEName = None; KnownShape = None } + +type private TypeChecker(cxt : ITypeInferenceContext, scope : InferredSelectScope) as this = + let exprChecker = ExprTypeChecker(cxt, scope, this) + member this.ObjectName(name) = exprChecker.ObjectName(name) + member this.ObjectName(name, allowNotFound) = exprChecker.ObjectName(name, allowNotFound) + member this.Expr(expr, knownType) = exprChecker.Expr(expr, knownType) + member this.Expr(expr) = exprChecker.Expr(expr) + member this.Scope = scope + member this.WithScope(scope) = TypeChecker(cxt, scope) + + member private this.TableOrSubqueryScope(tsub : TableOrSubquery) = + match tsub.Table with + | Table (tinvoc, index) -> + tsub.Alias |? tinvoc.Table.ObjectName, this.ObjectName(tinvoc.Table).Info + | Subquery select -> + match tsub.Alias with + | None -> failAt select.Source Error.subqueryRequiresAnAlias + | Some alias -> alias, this.Select(select, SelfQueryShape.Unknown).Value.Info + + member private this.TableExprScope + (dict : Dictionary, texpr : TableExpr, outerDepth) = + let add name objectInfo = + if dict.ContainsKey(name) then + failAt texpr.Source <| Error.tableNameAlreadyInScope name + else + dict.Add(name, objectInfo) + match texpr.Value with + | TableOrSubquery tsub -> + let alias, objectInfo = this.TableOrSubqueryScope(tsub) + let objectInfo = + if outerDepth > 0 then + let nullable = NullableDueToJoin |> Seq.replicate outerDepth |> Seq.reduce (>>) + objectInfo.Map(fun t -> { t with InferredNullable = nullable t.InferredNullable }) + else objectInfo + add alias objectInfo + outerDepth + | Join join -> + let leftDepth = this.TableExprScope(dict, join.LeftTable, outerDepth) + let depthIncrement = if join.JoinType.IsOuter then 1 else 0 + this.TableExprScope(dict, join.RightTable, leftDepth + depthIncrement) + + member private this.TableExprScope(texpr : TableExpr) = + let dict = Dictionary() + ignore <| this.TableExprScope(dict, texpr, outerDepth = 0) + { FromVariables = dict } + + member private this.TableOrSubquery(tsub : TableOrSubquery) = + let tbl, info = + match tsub.Table with + | Table (tinvoc, index) -> + let invoke = exprChecker.TableInvocation(tinvoc) + Table (invoke, index), invoke.Table.Info + | Subquery select -> + let select = this.Select(select, SelfQueryShape.Unknown) + Subquery select, select.Value.Info + { Table = tbl + Alias = tsub.Alias + Info = info + } + + member private this.TableExpr(constraintChecker : TypeChecker, texpr : TableExpr) = + { TableExpr.Source = texpr.Source + Value = + match texpr.Value with + | TableOrSubquery tsub -> TableOrSubquery <| this.TableOrSubquery(tsub) + | Join join -> + { JoinType = join.JoinType + LeftTable = this.TableExpr(constraintChecker, join.LeftTable) + RightTable = this.TableExpr(constraintChecker, join.RightTable) + Constraint = + match join.Constraint with + | JoinOn e -> constraintChecker.Expr(e, BooleanType) |> JoinOn + | JoinUnconstrained -> JoinUnconstrained + } |> Join + } + + member this.TableExpr(texpr : TableExpr) = + let checker = TypeChecker(cxt, { scope with FromClause = Some <| this.TableExprScope(texpr) }) + checker, this.TableExpr(checker, texpr) + + member this.ResultColumn(aliasPrefix : Name option, resultColumn : ResultColumn) = + let qualify (tableAlias : Name) fromTable (col : _ ColumnExprInfo) = + { Expr.Source = resultColumn.Source + Value = + { ColumnName = col.ColumnName + Table = + { Source = resultColumn.Source + ObjectName = tableAlias + SchemaName = None + Info = fromTable + } |> Some + } |> ColumnNameExpr + Info = col.Expr.Info + }, + match aliasPrefix with + | None -> None + | Some prefix -> Some (prefix + col.ColumnName) + match resultColumn.Case with + | ColumnsWildcard -> + match scope.FromClause with + | None -> failAt resultColumn.Source Error.wildcardWithoutFromClause + | Some from -> + seq { + for KeyValue(tableAlias, fromTable) in from.FromVariables do + for col in fromTable.Table.Query.Columns do + yield qualify tableAlias fromTable col + } + | TableColumnsWildcard tbl -> + match scope.FromClause with + | None -> failAt resultColumn.Source <| Error.tableWildcardWithoutFromClause tbl + | Some from -> + let succ, fromTable = from.FromVariables.TryGetValue(tbl) + if not succ then failAt resultColumn.Source <| Error.noSuchTableInFrom tbl + fromTable.Table.Query.Columns |> Seq.map (qualify tbl fromTable) + | Column (expr, alias) -> + match aliasPrefix with + | None -> (this.Expr(expr), alias) |> Seq.singleton + | Some prefix -> + let expr = this.Expr(expr) + match implicitAlias (expr.Value, alias) with + | None -> (expr, None) |> Seq.singleton + | Some a -> (expr, Some (prefix + a)) |> Seq.singleton + | ColumnNav nav -> + this.ColumnNav(aliasPrefix, resultColumn, nav) + + member this.ColumnNav(aliasPrefix : Name option, resultColumn : ResultColumn, nav : ResultColumnNav) = + let subAliasPrefix = + let prev = + match aliasPrefix with + | Some prefix -> prefix.Value + | None -> "" + Some <| Name(prev + nav.Name.Value + nav.Cardinality.Separator) + let columns = + seq { + for column in nav.Columns do + let producedColumns = this.ResultColumn(subAliasPrefix, column) + yield column, producedColumns |> ResizeArray + } |> ResizeArray + let keyExprs = + seq { + for source, producedColumns in columns do + match source.Case with + | ColumnNav _ -> () // ignore sub-nav props + | _ -> + for expr, _ in producedColumns do + if expr.Info.PrimaryKey then yield expr + } |> ResizeArray + if keyExprs.Count <= 0 then + failAt resultColumn.Source <| Error.navPropertyMissingKeys nav.Name + else + let minDepthOfImmediateKey = + keyExprs + |> Seq.map (fun e -> e.Info.Type.InferredNullable.JoinInducedNullabilityDepth()) + |> Seq.min + columns + |> Seq.collect snd + |> Seq.map (fun (expr, alias) -> // remove nullability introduced by outer joins + { expr with + Info = { expr.Info with Type = expr.Info.Type.StripNullDueToJoin(minDepthOfImmediateKey) } + }, alias) + + member this.ResultColumns(resultColumns : ResultColumns, knownShape : InferredQueryShape option) = + let columns = + resultColumns.Columns + |> Seq.collect + (fun rc -> + this.ResultColumn(None, rc) + |> Seq.map (fun (expr, alias) -> { Source = rc.Source; Case = Column (expr, alias); })) + |> Seq.toArray + match knownShape with + | Some shape -> + if columns.Length <> shape.Columns.Count then + if columns.Length <= 0 then failwith "BUG: impossible, parser shouldn't have accepted this" + let source = columns.[columns.Length - 1].Source + failAt source <| Error.expectedKnownColumnCount columns.Length shape.Columns.Count + for i = 0 to columns.Length - 1 do + let selected, alias as selectedCol = columns.[i].Case.AssumeColumn() + let shape = shape.Columns.[i] + cxt.UnifyLeftKnown(selected.Source, shape.Expr.Info.Type, selected.Info.Type) + match implicitAlias (selected.Value, alias) with + | Some a when a = shape.ColumnName -> () + | _ -> + columns.[i] <- { columns.[i] with Case = Column(selected, Some shape.ColumnName) } + | None -> + for column in columns do + let selected, _ = column.Case.AssumeColumn() + ignore <| cxt.Unify(selected.Source, selected.Info.Type.InferredType, TypeKnown ScalarTypeClass) + { Distinct = resultColumns.Distinct + Columns = columns + } + + member this.GroupBy(groupBy : GroupBy) = + { By = groupBy.By |> rmap this.Expr + Having = groupBy.Having |> Option.map this.Expr + } + + member this.SelectCore(select : SelectCore, knownShape : InferredQueryShape option) = + let checker, from = + match select.From with + | None -> this, None + | Some from -> + let checker, texpr = this.TableExpr(from) + checker, Some texpr + let columns = checker.ResultColumns(select.Columns, knownShape) + let infoColumns = + seq { + for column in columns.Columns do + match column.Case with + | Column (expr, alias) -> + yield + { Expr = expr + FromAlias = None + ColumnName = + match implicitAlias (expr.Value, alias) with + | None -> failAt column.Source Error.expressionRequiresAlias + | Some alias -> alias + } + // typechecker should've eliminated alternatives + | _ -> bug "All wildcards must be expanded -- this is a typechecker bug" + } |> toReadOnlyList + { Columns = columns + From = from + Where = Option.map checker.Expr select.Where + GroupBy = Option.map checker.GroupBy select.GroupBy + Info = + { Table = SelectResults + Query = { Columns = infoColumns } + } |> TableLike + } |> AggregateChecker.check + + member this.CTE(cte : CommonTableExpression) = + let knownShape = cte.ColumnNames |> Option.map (fun n -> cxt.AnonymousQueryInfo(n.Value)) + let select = this.Select(cte.AsSelect, { KnownShape = knownShape; CTEName = Some cte.Name }) + { Name = cte.Name + ColumnNames = cte.ColumnNames + AsSelect = select + Info = select.Value.Info + } + + member this.WithClause(withClause : WithClause) = + let mutable scope = scope + let clause = + { Recursive = withClause.Recursive + Tables = + [| for cte in withClause.Tables -> + let cte = TypeChecker(cxt, scope).CTE(cte) + scope <- + { scope with + CTEVariables = scope.CTEVariables |> Map.add cte.Name cte.Info.Table.Query + } + cte + |] + } + TypeChecker(cxt, scope), clause + + member this.OrderingTerm(orderingTerm : OrderingTerm) = + { By = this.Expr(orderingTerm.By) + Direction = orderingTerm.Direction + } + + member this.Limit(limit : Limit) = + { Limit = this.Expr(limit.Limit, IntegerType Integer64) + Offset = limit.Offset |> Option.map (fun e -> this.Expr(e, IntegerType Integer64)) + } + + member this.CompoundTerm(term : CompoundTerm, knownShape : InferredQueryShape option) : InfCompoundTerm = + let info, value = + match term.Value, knownShape with + | Values vals, Some shape -> + let vals = vals |> rmap (fun w -> { WithSource.Value = rmap this.Expr w.Value; Source = w.Source }) + let columns = + seq { + for rowIndex, row in vals |> Seq.indexed do + if row.Value.Length <> shape.Columns.Count then + failAt row.Source <| Error.expectedKnownColumnCount row.Value.Length shape.Columns.Count + for colVal, colShape in Seq.zip row.Value shape.Columns do + cxt.UnifyLeftKnown(row.Source, colShape.Expr.Info.Type, colVal.Info.Type) + if rowIndex > 0 then () else + yield + { Expr = colVal + FromAlias = None + ColumnName = colShape.ColumnName + } + } |> toReadOnlyList + TableLike + { Table = CompoundTermResults + Query = { Columns = columns } + }, Values vals + | Values vals, None -> + failAt term.Source Error.valuesRequiresKnownShape + | Select select, knownShape -> + let select = this.SelectCore(select, knownShape) + select.Info, Select select + { Source = term.Source + Value = value + Info = info + } + + member this.Compound(compound : CompoundExpr, knownShape : InferredQueryShape option) : InfCompoundExpr = + let nested leftCompound rightTerm = + match knownShape with + | Some _ as shape -> + this.Compound(leftCompound, knownShape), this.CompoundTerm(rightTerm, knownShape) + | None -> + let leftCompound = this.Compound(leftCompound, None) + leftCompound, this.CompoundTerm(rightTerm, Some leftCompound.Value.Info.Query) + { CompoundExpr.Source = compound.Source + Value = + match compound.Value with + | CompoundTerm term -> CompoundTerm <| this.CompoundTerm(term, knownShape) + | Union (expr, term) -> Union <| nested expr term + | UnionAll (expr, term) -> UnionAll <| nested expr term + | Intersect (expr, term) -> Intersect <| nested expr term + | Except (expr, term) -> Except <| nested expr term + } + + member this.CompoundTop(compound : CompoundExpr, selfShape : SelfQueryShape) : InfCompoundExpr = + match selfShape.CTEName with + | None -> this.Compound(compound, selfShape.KnownShape) + | Some cteName -> + let nested leftCompound recursiveFinalTerm = + let leftCompound = this.Compound(leftCompound, selfShape.KnownShape) + let leftQuery = leftCompound.Value.Info.Query + let rightChecker = + { scope with + CTEVariables = scope.CTEVariables |> Map.add cteName leftQuery + } |> this.WithScope + leftCompound, rightChecker.CompoundTerm(recursiveFinalTerm, Some leftQuery) + { CompoundExpr.Source = compound.Source + Value = + match compound.Value with + | CompoundTerm term -> CompoundTerm <| this.CompoundTerm(term, selfShape.KnownShape) + | Union (expr, term) -> Union <| nested expr term + | UnionAll (expr, term) -> UnionAll <| nested expr term + | Intersect (expr, term) -> Intersect <| nested expr term + | Except (expr, term) -> Except <| nested expr term + } + + member this.Select(select : SelectStmt, selfShape : SelfQueryShape) : InfSelectStmt = + { Source = select.Source + Value = + let select = select.Value + let checker, withClause = + match select.With with + | None -> this, None + | Some withClause -> + let checker, withClause = this.WithClause(withClause) + checker, Some withClause + let compound = checker.CompoundTop(select.Compound, selfShape) + { With = withClause + Compound = compound + OrderBy = Option.map (rmap checker.OrderingTerm) select.OrderBy + Limit = Option.map checker.Limit select.Limit + Info = compound.Value.Info + } + } + + member this.ForeignKey(foreignKey, creating : CreateTableStmt option) = + let referencesTable, columnNames = + match creating with + | Some tbl when tbl.Name = foreignKey.ReferencesTable -> // self-referencing + this.ObjectName(foreignKey.ReferencesTable, allowNotFound = true), + match tbl.As with + | CreateAsDefinition cdef -> cdef.Columns |> Seq.map (fun c -> c.Name) + | CreateAsSelect _ -> bug "Self-referencing constraints can't exist in a CREATE AS SELECT" + | _ -> + let name = this.ObjectName(foreignKey.ReferencesTable) + name, name.Info.Query.Columns |> Seq.map (fun c -> c.ColumnName) + for { Source = source; Value = referenceName } in foreignKey.ReferencesColumns do + if not (Seq.contains referenceName columnNames) then + failAt source <| Error.noSuchColumn referenceName + { ReferencesTable = referencesTable + ReferencesColumns = foreignKey.ReferencesColumns + Rules = foreignKey.Rules + Defer = foreignKey.Defer + } + + member this.ColumnConstraint(constr : ColumnConstraint, creating : CreateTableStmt option) = + { Name = constr.Name + ColumnConstraintType = + match constr.ColumnConstraintType with + | NullableConstraint -> NullableConstraint + | PrimaryKeyConstraint clause -> PrimaryKeyConstraint clause + | UniqueConstraint -> UniqueConstraint + | DefaultConstraint def -> DefaultConstraint <| this.Expr(def) + | CollateConstraint name -> CollateConstraint name + | ForeignKeyConstraint foreignKey -> ForeignKeyConstraint <| this.ForeignKey(foreignKey, creating) + } + + member this.ColumnDef(cdef : ColumnDef, creating : CreateTableStmt option) = + { Name = cdef.Name + Type = cdef.Type + Constraints = cdef.Constraints |> rmap (fun con -> this.ColumnConstraint(con, creating)) + } + + member this.Alteration(tableName : InfObjectName, alteration : AlterTableAlteration) = + match alteration with + | RenameTo name -> RenameTo name + | AddColumn cdef -> + let fake = + resultAt tableName.Source <| + match tableName.Info.Table.Table with + | TableReference schemaTable -> schemaTable.WithAdditionalColumn(cdef) + | _ -> Error <| Error.objectNotATable tableName + let from = + InferredFromClause.FromSingleObject + ({ tableName with + Info = + { Table = TableReference fake + Query = inferredOfTable(fake) + } |> TableLike }) + let this = this.WithScope({ scope with FromClause = Some from }) + AddColumn <| this.ColumnDef(cdef, None) + + member this.CreateIndex(createIndex : CreateIndexStmt) = + let tableName = this.ObjectName(createIndex.TableName) + let checker = + this.WithScope({ scope with FromClause = Some <| InferredFromClause.FromSingleObject(tableName) }) + { Unique = createIndex.Unique + IndexName = this.ObjectName(createIndex.IndexName) + TableName = tableName + IndexedColumns = createIndex.IndexedColumns + Where = createIndex.Where |> Option.map checker.Expr + } + + member this.TableIndexConstraint(constr : TableIndexConstraintClause) = + { Type = constr.Type + IndexedColumns = constr.IndexedColumns + } + + member this.TableConstraint(constr : TableConstraint, creating) = + { Name = constr.Name + TableConstraintType = + match constr.TableConstraintType with + | TableIndexConstraint clause -> + TableIndexConstraint <| this.TableIndexConstraint(clause) + | TableForeignKeyConstraint (names, foreignKey) -> + TableForeignKeyConstraint (names, this.ForeignKey(foreignKey, creating)) + | TableCheckConstraint expr -> TableCheckConstraint <| this.Expr(expr) + } + + member this.CreateTableDefinition + (tableName : InfObjectName, createTable : CreateTableDefinition, creating : CreateTableStmt) = + let fake = + SchemaTable.OfCreateDefinition + ( tableName.SchemaName |? scope.Model.DefaultSchema + , tableName.ObjectName + , createTable + ) + let from = + InferredFromClause.FromSingleObject + ({ tableName with + Info = + { Table = TableReference fake + Query = inferredOfTable(fake) + } |> TableLike }) + let this = this.WithScope({ scope with FromClause = Some from }) + let creating = Some creating + let columns = createTable.Columns |> rmap (fun col -> this.ColumnDef(col, creating)) + { Columns = columns + Constraints = createTable.Constraints |> rmap (fun con -> this.TableConstraint(con, creating)) + WithoutRowId = createTable.WithoutRowId + } + + member this.CreateTable(createTable : CreateTableStmt) = + let name = this.ObjectName(createTable.Name, true) + { Temporary = createTable.Temporary + Name = name + As = + match createTable.As with + | CreateAsSelect select -> CreateAsSelect <| this.Select(select, SelfQueryShape.Unknown) + | CreateAsDefinition def -> CreateAsDefinition <| this.CreateTableDefinition(name, def, createTable) + } + + member this.CreateView(createView : CreateViewStmt) = + let knownShape = createView.ColumnNames |> Option.map cxt.AnonymousQueryInfo + { Temporary = createView.Temporary + ViewName = this.ObjectName(createView.ViewName, true) + ColumnNames = createView.ColumnNames + AsSelect = this.Select(createView.AsSelect, SelfQueryShape.Known(knownShape)) + } + + member this.QualifiedTableName(qualified : QualifiedTableName) = + { TableName = this.ObjectName(qualified.TableName) + IndexHint = qualified.IndexHint + } + + member this.Delete(delete : DeleteStmt) = + let checker, withClause = + match delete.With with + | None -> this, None + | Some withClause -> + let checker, withClause = this.WithClause(withClause) + checker, Some withClause + let deleteFrom = checker.QualifiedTableName(delete.DeleteFrom) + let checker = + checker.WithScope + ({ checker.Scope with FromClause = InferredFromClause.FromSingleObject(deleteFrom.TableName) |> Some }) + { With = withClause + DeleteFrom = deleteFrom + Where = Option.map checker.Expr delete.Where + OrderBy = Option.map (rmap checker.OrderingTerm) delete.OrderBy + Limit = Option.map checker.Limit delete.Limit + } + + member this.DropObject(drop : DropObjectStmt) = + { Drop = drop.Drop + ObjectName = this.ObjectName(drop.ObjectName) + } + + member this.Insert(insert : InsertStmt) = // TODO: verify that we insert into all cols w/o default values + let checker, withClause = + match insert.With with + | None -> this, None + | Some withClause -> + let checker, withClause = this.WithClause(withClause) + checker, Some withClause + let table = checker.ObjectName(insert.InsertInto) + let knownShape = + match insert.Columns with + | None -> table.Info.Query + | Some cols -> table.Info.Query.ColumnsWithNames(cols) + let columns = + knownShape.Columns + |> Seq.map (fun c -> { WithSource.Source = c.Expr.Source; Value = c.ColumnName }) + |> Seq.toArray + { With = withClause + Or = insert.Or + InsertInto = table + Columns = Some columns // we *must* specify these because our order might not match DB's + Data = insert.Data |> Option.map (fun data -> checker.Select(data, SelfQueryShape.Known(knownShape))) + } + + member this.Update(update : UpdateStmt) = + let checker, withClause = + match update.With with + | None -> this, None + | Some withClause -> + let checker, withClause = this.WithClause(withClause) + checker, Some withClause + let updateTable = checker.QualifiedTableName(update.UpdateTable) + let checker = + checker.WithScope + ({ checker.Scope with FromClause = InferredFromClause.FromSingleObject(updateTable.TableName) |> Some }) + let setColumns = + [| let cols = updateTable.TableName.Info.Query + for name, expr in update.Set do + match cols.ColumnByName(name.Value) with + | Found col -> + let expr = checker.Expr(expr) + cxt.UnifyLeftKnown(name.Source, col.Expr.Info.Type, expr.Info.Type) + yield name, expr + | _ -> + failAt name.Source <| Error.noSuchColumnToSet updateTable.TableName name.Value + |] + { With = withClause + UpdateTable = updateTable + Or = update.Or + Set = setColumns + Where = Option.map checker.Expr update.Where + OrderBy = Option.map (rmap checker.OrderingTerm) update.OrderBy + Limit = Option.map checker.Limit update.Limit + } + + member this.Stmt(stmt : Stmt) = + match stmt with + | AlterTableStmt alter -> + AlterTableStmt <| + let tbl = this.ObjectName(alter.Table) + { Table = tbl + Alteration = this.Alteration(tbl, alter.Alteration) + } + | CreateIndexStmt index -> CreateIndexStmt <| this.CreateIndex(index) + | CreateTableStmt createTable -> CreateTableStmt <| this.CreateTable(createTable) + | CreateViewStmt createView -> CreateViewStmt <| this.CreateView(createView) + | DeleteStmt delete -> DeleteStmt <| this.Delete(delete) + | DropObjectStmt drop -> DropObjectStmt <| this.DropObject(drop) + | InsertStmt insert -> InsertStmt <| this.Insert(insert) + | SelectStmt select -> SelectStmt <| this.Select(select, SelfQueryShape.Unknown) + | UpdateStmt update -> UpdateStmt <| this.Update(update) + | BeginStmt -> BeginStmt + | CommitStmt -> CommitStmt + | RollbackStmt -> RollbackStmt + + interface IQueryTypeChecker with + member this.Select(select) = this.Select(select, SelfQueryShape.Unknown) + member this.CreateView(view) = this.CreateView(view) \ No newline at end of file diff --git a/Rezoom.SQL.Compiler/TypeInferenceContext.fs b/Rezoom.SQL.Compiler/TypeInferenceContext.fs new file mode 100644 index 0000000..c303cf5 --- /dev/null +++ b/Rezoom.SQL.Compiler/TypeInferenceContext.fs @@ -0,0 +1,261 @@ +namespace Rezoom.SQL.Compiler +open System +open System.Collections.Generic +open Rezoom.SQL.Compiler.InferredTypes + +type private TypeInferenceVariable(id : TypeVariableId) = + let inferredType = TypeVariable id + let mutable currentType = AnyTypeClass + member __.Id = id + member __.InferredType = inferredType + member __.CurrentType = currentType + member __.Unify(source, core : CoreColumnType) = + let unified = currentType.Unify(core) |> resultAt source + currentType <- unified + +type private NullabilityVariable(id : TypeVariableId) = + let inferredNullable = NullableVariable id + let mutable currentNullable = NullableUnknown + member __.Id = id + member __.InferredNullable = inferredNullable + member __.CurrentNullable = currentNullable + member __.ForceNullable() = + currentNullable <- NullableKnown true + +type private VariableTracker<'var>(init : TypeVariableId -> 'var, id : 'var -> TypeVariableId) = + let variablesByParameter = Dictionary() + let variablesById = ResizeArray<'var>() + let getVar id = + if id < 0 || id >= variablesById.Count then bug "Type variable not found" + variablesById.[id] + member this.BoundParameters = variablesByParameter.Keys + member this.NextVar() = + let var = init variablesById.Count + variablesById.Add(var) + var + member this.GetVar(id : TypeVariableId) = getVar id + member this.BindVar(bindParameter : BindParameter) = + let succ, v = variablesByParameter.TryGetValue(bindParameter) + if succ then getVar v else + let var = this.NextVar() + variablesByParameter.[bindParameter] <- id var + var + member this.Replace(id : TypeVariableId, var : 'var) = + variablesById.[id] <- var + +type private TypeInferenceContext() = + let typeVariables = VariableTracker(TypeInferenceVariable, fun v -> v.Id) + let nullVariables= VariableTracker(NullabilityVariable, fun v -> v.Id) + let deferredNullables = ResizeArray() + static member UnifyColumnTypes(left : ColumnType, right : ColumnType) = + result { + let nullable = max left.Nullable right.Nullable + let! ty = left.Type.Unify(right.Type) + return { Nullable = nullable; Type = ty } + } + member this.AnonymousVariable() = typeVariables.NextVar().InferredType + member private this.Variable(bindParameter) = + { InferredType = typeVariables.BindVar(bindParameter).InferredType + InferredNullable = nullVariables.BindVar(bindParameter).InferredNullable + } + member this.Unify(source, left, right) = + match left, right with + | TypeKnown lk, TypeKnown rk -> + lk.Unify(rk) |> resultAt source |> TypeKnown + | TypeVariable varId, TypeKnown knownType + | TypeKnown knownType, TypeVariable varId -> + let tvar = typeVariables.GetVar(varId) + tvar.Unify(source, knownType) + tvar.InferredType + | TypeVariable leftId, TypeVariable rightId -> + let left, right = typeVariables.GetVar(leftId), typeVariables.GetVar(rightId) + left.Unify(source, right.CurrentType) + typeVariables.Replace(rightId, left) + left.InferredType + member this.UnifyList(source, elem, list) = + let var = typeVariables.BindVar(list) + match elem with + | TypeVariable varId -> + var.Unify(source, ListType (typeVariables.GetVar(varId)).CurrentType) + | TypeKnown knownType -> + var.Unify(source, ListType knownType) + member this.ForceNullable(source, nullable : InferredNullable) = + match nullable.Simplify() with + | NullableDueToJoin _ + | NullableUnknown + | NullableKnown _ -> () // even NullableKnown false is OK, we just want to force the NullableVariables + | NullableVariable id -> nullVariables.GetVar(id).ForceNullable() + | NullableEither _ -> + let rec allVars v = + match v with + | NullableUnknown + | NullableKnown true + | NullableKnown false + | NullableDueToJoin _ -> Seq.empty + | NullableVariable id -> Seq.singleton id + | NullableEither (l, r) -> Seq.append (allVars l) (allVars r) + deferredNullables.Add(ResizeArray(allVars nullable)) + member this.ResolveNullable(nullable) = + if deferredNullables.Count > 0 then + let triviallySatisfied r = + let t = NullableKnown true + r |> Seq.exists (fun v -> nullVariables.GetVar(v).CurrentNullable = t) + ignore <| deferredNullables.RemoveAll(fun r -> triviallySatisfied r) // remove trivially satisfied reqs + for vs in deferredNullables do // remaining vars must all be forced null + for v in vs do + nullVariables.GetVar(v).ForceNullable() + deferredNullables.Clear() + match nullable with + | NullableUnknown -> false + | NullableDueToJoin _ -> true + | NullableKnown t -> t + | NullableVariable id -> this.ResolveNullable(nullVariables.GetVar(id).CurrentNullable) + | NullableEither (l, r) -> this.ResolveNullable(l) || this.ResolveNullable(r) + member this.Concrete(inferred) = + { Nullable = this.ResolveNullable(inferred.InferredNullable) + Type = + match inferred.InferredType with + | TypeKnown t -> t + | TypeVariable id -> typeVariables.GetVar(id).CurrentType + } + interface ITypeInferenceContext with + member this.AnonymousVariable() = this.AnonymousVariable() + member this.Variable(parameter) = this.Variable(parameter) + member this.UnifyList(source, elem, list) = this.UnifyList(source, elem, list) + member this.Unify(source, left, right) = this.Unify(source, left, right) + member this.ForceNullable(source, nullable) = this.ForceNullable(source, nullable) + member this.Concrete(inferred) = this.Concrete(inferred) + member __.Parameters = typeVariables.BoundParameters :> _ seq + +[] +module private TypeInferenceExtensions = + type ITypeInferenceContext with + member typeInference.Unify(source : SourceInfo, left : InferredType, right : CoreColumnType) = + { left with + InferredType = typeInference.Unify(source, left.InferredType, TypeKnown right) + } + member typeInference.Unify(source : SourceInfo, left : InferredType, right : InferredType) = + { InferredType = typeInference.Unify(source, left.InferredType, right.InferredType) + InferredNullable = InferredNullable.Either(left.InferredNullable, right.InferredNullable) + } + member typeInference.Unify(source : SourceInfo, types : CoreInferredType seq) = + types + |> Seq.fold + (fun s next -> typeInference.Unify(source, s, next)) + InferredType.Scalar.InferredType + member typeInference.Unify(source : SourceInfo, types : InferredType seq) = + { InferredType = typeInference.Unify(source, types |> Seq.map (fun t -> t.InferredType)) + InferredNullable = InferredNullable.Any(types |> Seq.map (fun t -> t.InferredNullable)) + } + /// Unify a known type (e.g. from a table we're inserting into or a declared CTE) + /// with an inferred type. The inferred type is forced nullable if the known type is nullable. + member typeInference.UnifyLeftKnown(source : SourceInfo, left : InferredType, right : InferredType) = + ignore <| typeInference.Unify(source, left.InferredType, right.InferredType) + if left.InferredNullable = NullableKnown true then + typeInference.ForceNullable(source, right.InferredNullable) + member typeInference.Concrete(inferred) = typeInference.Concrete(inferred) + member typeInference.Binary(source, op, left, right) = + match op with + | Concatenate -> typeInference.Unify(source, [ left; right; InferredType.String ]) + | Multiply + | Divide + | Add + | Subtract -> typeInference.Unify(source, [ left; right; InferredType.Number ]) + | Modulo + | BitShiftLeft + | BitShiftRight + | BitAnd + | BitOr -> typeInference.Unify(source, [ left; right; InferredType.Integer ]) + | LessThan + | LessThanOrEqual + | GreaterThan + | GreaterThanOrEqual + | Equal + | NotEqual -> + let operandType = typeInference.Unify(source, left, right) + InferredType.Dependent(operandType, BooleanType) + | Is + | IsNot -> + let operandType = typeInference.Unify(source, left, right) + typeInference.ForceNullable(source, left.InferredNullable) + typeInference.ForceNullable(source, right.InferredNullable) + InferredType.Dependent(operandType, BooleanType) + | And + | Or -> typeInference.Unify(source, [ left; right; InferredType.Boolean ]) + member typeInference.Unary(source, op, operandType) = + match op with + | Negative + | BitNot -> typeInference.Unify(source, operandType, InferredType.Number) + | Not -> typeInference.Unify(source, operandType, InferredType.Boolean) + | IsNull + | NotNull -> + typeInference.ForceNullable(source, operandType.InferredNullable) + InferredType.Boolean + member typeInference.AnonymousQueryInfo(columnNames) = + { Columns = + seq { + for { WithSource.Source = source; Value = name } in columnNames -> + let tyVar = + { InferredType = typeInference.AnonymousVariable() + InferredNullable = NullableUnknown + } |> ExprInfo.OfType + { ColumnName = name + FromAlias = None + Expr = + { Value = ColumnNameExpr { Table = None; ColumnName = name } + Source = source + Info = tyVar + } + } + } |> toReadOnlyList + } + member typeInference.Function(source : SourceInfo, func : FunctionType, invoc : InfFunctionArguments) = + let functionVars = Dictionary() + let aggregate = func.Aggregate(invoc) + let term (termType : FunctionTermType) = + match termType.TypeVariable with + | None -> TypeKnown termType.TypeConstraint + | Some name -> + let succ, tvar = functionVars.TryGetValue(name) + let tvar = + if succ then tvar else + let avar = typeInference.AnonymousVariable() + functionVars.[name] <- avar + avar + typeInference.Unify(source, tvar, TypeKnown termType.TypeConstraint) + match invoc with + | ArgumentWildcard -> + match aggregate with + | Some aggregate when aggregate.AllowWildcard -> + ArgumentWildcard, + { InferredType = term func.Returns + InferredNullable = + if func.Returns.ForceNullable then NullableKnown true else NullableUnknown + } + | _ -> failAt source <| Error.functionDoesNotPermitWildcard func.FunctionName + | ArgumentList (distinct, args) as argumentList -> + if Option.isSome distinct then + match aggregate with + | Some aggregate when aggregate.AllowDistinct -> () + | _ -> failAt source <| Error.functionDoesNotPermitDistinct func.FunctionName + let nulls = ResizeArray() + func.ValidateArgs(source, args, (fun a -> a.Source), fun arg termTy -> + let term = term termTy + ignore <| typeInference.Unify(arg.Source, arg.Info.Type.InferredType, term) + if termTy.ForceNullable then + typeInference.ForceNullable(arg.Source, arg.Info.Type.InferredNullable) + if termTy.InfectNullable then + nulls.Add(arg.Info.Type.InferredNullable)) + let returnType = + { InferredType = term func.Returns + InferredNullable = + if func.Returns.ForceNullable then NullableKnown true else InferredNullable.Any(nulls) + } + argumentList, returnType + + let inline implicitAlias column = + match column with + | _, (Some _ as a) -> a + | ColumnNameExpr c, None -> Some c.ColumnName + | _ -> None + diff --git a/Rezoom.SQL.Compiler/TypeSystem.fs b/Rezoom.SQL.Compiler/TypeSystem.fs new file mode 100644 index 0000000..7eea8dc --- /dev/null +++ b/Rezoom.SQL.Compiler/TypeSystem.fs @@ -0,0 +1,199 @@ +namespace Rezoom.SQL.Compiler +open System +open System.Data +open System.Data.Common +open System.Collections.Generic + +type CoreColumnType = + | BooleanType + | StringType + | IntegerType of IntegerSize + | FloatType of FloatSize + | DecimalType + | BinaryType + | DateTimeType + | DateTimeOffsetType + | StringishTypeClass + | NumericTypeClass + | IntegralTypeClass + | FractionalTypeClass + | ScalarTypeClass + | AnyTypeClass + | ListType of CoreColumnType + member this.ParentType = + match this with + | IntegerType Integer8 -> IntegralTypeClass + | IntegerType Integer16 -> IntegerType Integer8 + | IntegerType Integer32 -> IntegerType Integer16 + | IntegerType Integer64 -> IntegerType Integer32 + | FloatType Float32 -> FractionalTypeClass + | FloatType Float64 -> FloatType Float32 + | DecimalType -> FractionalTypeClass + | StringType + | BinaryType -> StringishTypeClass + | IntegralTypeClass + | FractionalTypeClass -> NumericTypeClass + | BooleanType + | DateTimeType + | DateTimeOffsetType + | NumericTypeClass + | StringishTypeClass -> ScalarTypeClass + | ScalarTypeClass + | AnyTypeClass -> AnyTypeClass + | ListType element -> + let elementParent = element.ParentType + if elementParent = element then AnyTypeClass + else ListType elementParent + member this.HasAncestor(candidate) = + if this = candidate then true else + let parent = this.ParentType + if parent = this then false + else parent.HasAncestor(candidate) + member left.Unify(right) = + if left.HasAncestor(right) then + Ok left + elif right.HasAncestor(left) then + Ok right + else + Error <| Error.cannotUnify left right + override this.ToString() = + match this with + | BooleanType -> "BOOL" + | StringType -> "STRING" + | IntegerType Integer8 -> "INT8" + | IntegerType Integer16 -> "INT16" + | IntegerType Integer32 -> "INT" + | IntegerType Integer64 -> "INT64" + | FloatType Float32 -> "FLOAT32" + | FloatType Float64 -> "FLOAT64" + | DecimalType -> "DECIMAL" + | BinaryType -> "BINARY" + | DateTimeType -> "DATETIME" + | DateTimeOffsetType -> "DATETIMEOFFSET" + | FractionalTypeClass -> "" + | IntegralTypeClass -> "" + | NumericTypeClass -> "" + | StringishTypeClass -> "" + | ScalarTypeClass -> "" + | AnyTypeClass -> "" + | ListType t -> "[" + string t + "]" + static member OfTypeName(typeName : TypeName) = + match typeName with + | StringTypeName _ -> StringType + | BinaryTypeName _ -> BinaryType + | IntegerTypeName sz -> IntegerType sz + | FloatTypeName sz -> FloatType sz + | DecimalTypeName -> DecimalType + | BooleanTypeName -> BooleanType + | DateTimeTypeName -> DateTimeType + | DateTimeOffsetTypeName -> DateTimeOffsetType + +type ColumnType = + { Type : CoreColumnType + Nullable : bool + } + static member OfTypeName(typeName : TypeName, nullable) = + { Type = CoreColumnType.OfTypeName(typeName) + Nullable = nullable + } + member private ty.TypeInfo(useOptional) = + let nullify (clrType : Type) = + if ty.Nullable then + if useOptional then typedefof<_ option>.MakeGenericType(clrType) + elif clrType.IsValueType then typedefof<_ Nullable>.MakeGenericType(clrType) + else clrType + else + clrType + match ty.Type with + | IntegerType Integer8 -> DbType.SByte, nullify typeof + | IntegerType Integer16 -> DbType.Int16, nullify typeof + | IntegralTypeClass + | IntegerType Integer32 -> DbType.Int32, nullify typeof + | IntegerType Integer64 -> DbType.Int64, nullify typeof + | FloatType Float32 -> DbType.Single, nullify typeof + | FloatType Float64 -> DbType.Double, nullify typeof + | BooleanType -> DbType.Boolean, nullify typeof + | FractionalTypeClass + | NumericTypeClass + | DecimalType -> DbType.Decimal, nullify typeof + | DateTimeType -> DbType.DateTime, nullify typeof + | DateTimeOffsetType -> DbType.DateTimeOffset, nullify typeof + | StringType -> DbType.String, nullify typeof + | BinaryType -> DbType.Binary, nullify typeof + | StringishTypeClass + | ScalarTypeClass + | AnyTypeClass -> DbType.Object, nullify typeof + | ListType t -> + let dbType, clrType = { Type = t; Nullable = ty.Nullable }.TypeInfo(useOptional) + dbType, clrType.MakeArrayType() + member ty.CLRType(useOptional) = snd <| ty.TypeInfo(useOptional) + member ty.DbType = fst <| ty.TypeInfo(false) + override ty.ToString() = + string ty.Type + (if ty.Nullable then "?" else "") + +type FunctionTermType = + { TypeConstraint : CoreColumnType + TypeVariable : Name option + ForceNullable : bool + InfectNullable : bool + VarArg : FunctionTermVarArg option + } +and FunctionTermVarArg = + { MinArgCount : int + MaxArgCount : int option + } + +type AggregateType = + { AllowWildcard : bool + AllowDistinct : bool + } + +[] +type FunctionType + ( name : Name + , parameters : FunctionTermType IReadOnlyList + , returns : FunctionTermType + , idem : bool + ) = + do + let numVarArgs = parameters |> Seq.filter (fun p -> Option.isSome p.VarArg) |> Seq.truncate 2 |> Seq.length + if numVarArgs > 1 then bug <| sprintf "Can't have more than one vararg to a function (%O)" name + member __.FunctionName = name + member __.Parameters = parameters + member __.Returns = returns + member __.Idempotent = idem + /// Whether this function (of one argument) is erased when translated, i.e. `f(x)` becomes just `x`. + abstract member Erased : bool + default __.Erased = false + abstract member Aggregate : FunctionArguments<'t, 'e> -> AggregateType option + member __.MinimumParameters = + parameters |> Seq.sumBy (fun p -> match p.VarArg with | None -> 1 | Some v -> v.MinArgCount) + member internal this.ValidateArgs + ( source : SourceInfo + , argList : 'a IReadOnlyList + , argSource : 'a -> SourceInfo + , validate : 'a -> FunctionTermType -> unit + ) = + let mutable i = 0 + for par in parameters do + if i >= argList.Count then + failAt source <| Error.insufficientArguments name argList.Count this.MinimumParameters + match par.VarArg with + | None -> + validate (argList.[i]) par + i <- i + 1 + | Some varg -> + let start = i + // we can consume arguments until we get to this index + let indexOfLastVarArg = argList.Count - (parameters.Count - i) + let indexOfLastVarArg = + match varg.MaxArgCount with + | None -> indexOfLastVarArg + | Some maxCount -> min indexOfLastVarArg (i + maxCount - 1) + while i <= indexOfLastVarArg do + validate (argList.[i]) par + i <- i + 1 + if i - start < varg.MinArgCount then + failAt source <| Error.insufficientArguments name i this.MinimumParameters + if i < argList.Count then + failAt (argSource argList.[i]) <| Error.excessiveArguments name argList.Count (i - 1) \ No newline at end of file diff --git a/Rezoom.SQL.Compiler/UserModel.fs b/Rezoom.SQL.Compiler/UserModel.fs new file mode 100644 index 0000000..1f144aa --- /dev/null +++ b/Rezoom.SQL.Compiler/UserModel.fs @@ -0,0 +1,148 @@ +namespace Rezoom.SQL.Compiler +open System +open System.IO +open System.Text.RegularExpressions +open System.Collections.Generic +open Rezoom.SQL.Compiler +open Rezoom.SQL.Compiler.InferredTypes +open Rezoom.SQL.Mapping +open Rezoom.SQL.Mapping.Migrations + +module private UserModelLoader = + let private migrationPattern = + """ + ^V(? [0-9]+ ) + \. + (? \w+ ) + ( - (? \w+ ))? + \.SQL$ + """ |> fun pat -> Regex(pat, RegexOptions.IgnoreCase ||| RegexOptions.IgnorePatternWhitespace) + + let parseMigrationInfo path = + let rematch = migrationPattern.Match(path) + if not rematch.Success then None else + let majorVersion = rematch.Groups.["majorVersion"].Value |> int + let name = rematch.Groups.["name"].Value + let name2 = + let group = rematch.Groups.["name2"] + if group.Success then Some group.Value + else None + Some <| + match name2 with + | Some target -> + { ParentName = Some name + Name = target + MajorVersion = majorVersion + } + | None -> + { ParentName = None + Name = name + MajorVersion = majorVersion + } + + let loadMigrations migrationsFolder = + let builder = MigrationTreeListBuilder() + for path in Directory.GetFiles(migrationsFolder, "*.sql", SearchOption.AllDirectories) do + match parseMigrationInfo <| Path.GetFileName(path) with + | None -> () + | Some migrationName -> + let text = File.ReadAllText(path) + let parsed = CommandEffect.ParseSQL(path, text) + builder.Add(migrationName, parsed) + builder.ToTrees() + + let revalidateViews (model : Model) = + let inference = TypeInferenceContext() + let typeChecker = TypeChecker(inference, InferredSelectScope.Root(model)) + let concrete = concreteMapping inference + for KeyValue(_, schema) in model.Schemas do + for KeyValue(_, obj) in schema.Objects do + match obj with + | SchemaView view -> + let inferredDefinition = typeChecker.Select(view.CreateDefinition.AsSelect, SelfQueryShape.Unknown) + ignore <| concrete.Select(inferredDefinition) + | _ -> () + + let nextModel initialModel (migrationTrees : TotalStmts MigrationTree seq) = + let folder isRoot (model : Model) (migration : TotalStmts Migration) = + let effect = CommandEffect.OfSQL(model, migration.Source) + if not isRoot && effect.DestructiveUpdates.Value then + failwith <| sprintf + "The migration ``%s`` contains destructive statements. This requires a version bump." + migration.FileName + effect.Statements, effect.ModelChange |? model + let _, finalModel as pair = foldMigrations folder initialModel migrationTrees + revalidateViews finalModel + pair + + let stringizeMigrationTree (backend : IBackend) (migrationTrees : TTotalStmt IReadOnlyList MigrationTree seq) = + seq { + let indexer = + { new IParameterIndexer with + member __.ParameterIndex(par) = + failwith "Migrations cannot be parameterized" + } + for tree in migrationTrees -> + tree.Map(fun stmts -> + backend.ToCommandFragments(indexer, stmts) |> CommandFragment.Stringize) + } + + let tableIds (model : Model) = + seq { + let mutable i = 0 + for KeyValue(_, schema) in model.Schemas do + for KeyValue(_, obj) in schema.Objects do + match obj with + | SchemaTable tbl -> + yield (tbl.SchemaName, tbl.TableName), i + i <- i + 1 + | _ -> () + } |> Map.ofSeq + +open UserModelLoader + +type UserModel = + { ConnectionName : string + ConfigDirectory : string + Config : Config.Config + MigrationsDirectory : string + Backend : IBackend + Model : Model + TableIds : Map Lazy + Migrations : string MigrationTree IReadOnlyList + } + static member ConfigFileName = "rzsql.json" + static member Load(resolutionFolder : string, modelPath : string) = + let config, configDirectory = + if String.IsNullOrEmpty(modelPath) then // implicit based on location of dbconfig.json + let configPath = + Directory.GetFiles(resolutionFolder, "*.json", SearchOption.AllDirectories) + |> Array.tryFind (fun f -> f.EndsWith(UserModel.ConfigFileName, StringComparison.OrdinalIgnoreCase)) + match configPath with + | None -> Config.defaultConfig, resolutionFolder + | Some path -> + Config.parseConfigFile path, Path.GetDirectoryName(path) + else + let path = Path.Combine(resolutionFolder, modelPath) + if path.EndsWith(".json", StringComparison.OrdinalIgnoreCase) then + Config.parseConfigFile path, Path.GetDirectoryName(path) + else + let configPath = Path.Combine(path, UserModel.ConfigFileName) + if File.Exists(configPath) then + Config.parseConfigFile configPath, path + else + Config.defaultConfig, path + let migrationsDirectory = Path.Combine(configDirectory, config.MigrationsPath) |> Path.GetFullPath + let migrations = loadMigrations migrationsDirectory + let backend = config.Backend.ToBackend() + let migrations, model = nextModel backend.InitialModel migrations + let migrations = stringizeMigrationTree backend migrations |> toReadOnlyList + { ConnectionName = config.ConnectionName + MigrationsDirectory = migrationsDirectory + ConfigDirectory = Path.GetFullPath(configDirectory) + Config = config + Backend = backend + Model = model + TableIds = lazy tableIds model + Migrations = migrations + } \ No newline at end of file diff --git a/Rezoom.SQL.Compiler/Utilities.fs b/Rezoom.SQL.Compiler/Utilities.fs new file mode 100644 index 0000000..72a634c --- /dev/null +++ b/Rezoom.SQL.Compiler/Utilities.fs @@ -0,0 +1,146 @@ +[] +module Rezoom.SQL.Compiler.Utilities +open System +open System.Collections +open System.Collections.Generic + +let inline (|?) opt def = defaultArg opt def + +let inline rmap (f : 'a -> 'b) (list : 'a array) = Array.map f list + +let toReadOnlyList (values : 'a seq) = + ResizeArray(values) :> IReadOnlyList<_> + +let toDictionary (key : 'a -> 'k) (values : 'a seq) = + let d = Dictionary() + for value in values do + d.[key value] <- value + d + +let srcMap f (w : 'a WithSource) = w.Map(f) +let srcValue (w : 'a WithSource) = w.Value + +[] +let emptyDictionary<'k, 'v> = + { new IReadOnlyDictionary<'k, 'v> with + member __.ContainsKey(_) = false + member __.Count = 0 + member __.GetEnumerator() : IEnumerator> = Seq.empty.GetEnumerator() + member __.GetEnumerator() : IEnumerator = upcast Seq.empty.GetEnumerator() + member __.Item with get(k) = raise <| KeyNotFoundException() + member __.TryGetValue(k, v) = false + member __.Keys = Seq.empty + member __.Values = Seq.empty + } + +let inline bug msg = failwith msg + +let inline failAt (source : SourceInfo) (msg : string) = + raise (SourceInfoException(msg, source)) + +type NameResolution<'a> = + | Found of 'a + | NotFound of string + | Ambiguous of string + +type Result<'x, 'err> = + | Ok of 'x + | Error of 'err + +type ResultBuilder() = + member inline this.Zero() = Ok () + member inline this.Bind(result : Result<'x, 'err>, next : 'x -> Result<'y, 'err>) = + match result with + | Error err -> Error err + | Ok x -> next x + member inline this.Combine(first : Result<'x, 'err>, next : unit -> Result<'y, 'err>) = + match first with + | Error err -> Error err + | Ok _ -> next() + member inline this.Return(x) = Ok x + member inline this.ReturnFrom(x : Result<_, _>) = x + member inline __.Delay(x : unit -> 'x) = x + member inline __.Run(x : unit -> 'x) = x() + +let result = ResultBuilder() + +let resultAt source result = + match result with + | Ok x -> x + | Error err -> failAt source err + +let resultOk source result = resultAt source result |> ignore + +let appendLists (left : 'x IReadOnlyList) (right : 'x IReadOnlyList) = + { new IReadOnlyList<'x> with + member __.Count = left.Count + right.Count + member __.GetEnumerator() : 'x IEnumerator = (Seq.append left right).GetEnumerator() + member __.GetEnumerator() : IEnumerator = upcast (Seq.append left right).GetEnumerator() + member __.Item + with get (index) = + let leftCount = left.Count + if index >= leftCount then right.[index - leftCount] + else left.[index] + } + +type AmbiguousKeyException(msg) = + inherit Exception(msg) + +let appendDicts (left : IReadOnlyDictionary<'k, 'v>) (right : IReadOnlyDictionary<'k, 'v>) = + { new IReadOnlyDictionary<'k, 'v> with + member __.ContainsKey(key) = left.ContainsKey(key) || right.ContainsKey(key) + member __.Count = left.Count + right.Count + member __.GetEnumerator() : IEnumerator> = (Seq.append left right).GetEnumerator() + member __.GetEnumerator() : IEnumerator = upcast (Seq.append left right).GetEnumerator() + member __.Item + with get(k) = + let lsucc, lv = left.TryGetValue(k) + let rsucc, rv = right.TryGetValue(k) + if lsucc && rsucc then + raise <| AmbiguousKeyException(sprintf "Key %O is ambiguous" k) + else if lsucc then + lv + else if rsucc then + rv + else raise <| KeyNotFoundException() + member __.TryGetValue(k, v) = + let lsucc, lv = left.TryGetValue(k) + let rsucc, rv = right.TryGetValue(k) + if lsucc && rsucc then + raise <| AmbiguousKeyException(sprintf "Key %O is ambiguous" k) + else if lsucc then + v <- lv + true + else if rsucc then + v <- rv + true + else false + member __.Keys = Seq.append left.Keys right.Keys + member __.Values = Seq.append left.Values right.Values + } + +let rec private insertionsOf y xs = + match xs with + | [] -> Seq.singleton [y] + | x :: rest -> + seq { + yield y :: xs + for rest in insertionsOf y rest -> x :: rest + } + +let rec permutations (xs : 'a list) = + match xs with + | [] -> Seq.singleton [] + | x :: rest -> permutations rest |> Seq.collect (insertionsOf x) + +/// Translates from FParsec's position type to our own. +let internal translatePosition (pos : FParsec.Position) = + { Index = int pos.Index; Line = int pos.Line; Column = int pos.Column } + +let mapBy keyFunction sequence = + sequence |> Seq.map (fun x -> keyFunction x, x) |> Map.ofSeq + +let inline (+@+) x y = + let h1 = match box x with | null -> 0 | _ -> x.GetHashCode() + let h2 = match box y with | null -> 0 | _ -> y.GetHashCode() + ((h1 <<< 5) + h1) ^^^ h2 \ No newline at end of file diff --git a/Rezoom.SQL.Compiler/packages.config b/Rezoom.SQL.Compiler/packages.config new file mode 100644 index 0000000..b5cb8a6 --- /dev/null +++ b/Rezoom.SQL.Compiler/packages.config @@ -0,0 +1,5 @@ + + + + + \ No newline at end of file diff --git a/Rezoom.SQL.Mapping/AssemblyInfo.fs b/Rezoom.SQL.Mapping/AssemblyInfo.fs new file mode 100644 index 0000000..5e432a5 --- /dev/null +++ b/Rezoom.SQL.Mapping/AssemblyInfo.fs @@ -0,0 +1,41 @@ +namespace Rezoom.SQL.Mapping.AssemblyInfo + +open System.Reflection +open System.Runtime.CompilerServices +open System.Runtime.InteropServices + +// General Information about an assembly is controlled through the following +// set of attributes. Change these attribute values to modify the information +// associated with an assembly. +[] +[] +[] +[] +[] +[] +[] +[] + +// Setting ComVisible to false makes the types in this assembly not visible +// to COM components. If you need to access a type in this assembly from +// COM, set the ComVisible attribute to true on that type. +[] + +// The following GUID is for the ID of the typelib if this project is exposed to COM +[] + +// Version information for an assembly consists of the following four values: +// +// Major Version +// Minor Version +// Build Number +// Revision +// +// You can specify all the values or you can default the Build and Revision Numbers +// by using the '*' as shown below: +// [] +[] +[] + +do + () \ No newline at end of file diff --git a/Rezoom.SQL.Mapping/Blueprint.fs b/Rezoom.SQL.Mapping/Blueprint.fs new file mode 100644 index 0000000..3db5c09 --- /dev/null +++ b/Rezoom.SQL.Mapping/Blueprint.fs @@ -0,0 +1,109 @@ +namespace Rezoom.SQL.Mapping +open Rezoom.SQL.Mapping.CodeGeneration +open System +open System.Collections.Generic +open System.Reflection + +type Setter = + /// We initialize this column by passing it to the composite's constructor. + | SetConstructorParameter of ParameterInfo + /// We initialize this column by setting a field post-construction. + | SetField of FieldInfo + /// We initialize this column by setting a property post-construction. + | SetProperty of PropertyInfo + +type Getter = + | GetField of FieldInfo + | GetProperty of PropertyInfo + member this.MemberInfo = + match this with + | GetField f -> f :> MemberInfo + | GetProperty p -> p :> MemberInfo + +type Column = + { + ColumnId : ColumnId + /// The name of this column. This is the basename of the SQL column name that + /// will represent it. This should always be treated case-insensitively. + Name : string + /// The blueprint for this column's type. + Blueprint : Blueprint Lazy + /// The way to set this column when initializing an instance of the composite type. + Setter : Setter + /// The way to get this column's value (could be used for analyzing expression trees). + Getter : Getter option + /// The column on this column's type that points to this. + ReverseRelationship : Column option Lazy + } + member this.Output = this.Blueprint.Value.Output + +and Composite = + { + Output : Type + /// The constructor to use when instantiating this composite type. + /// All parameters must be supplied by columns. + Constructor : ConstructorInfo + /// The identity columns for this composite type, if any. + Identity : Column IReadOnlyList + /// All the columns of this composite type (including the identity, if any). + /// Indexed by name, case insensitive. + Columns : IReadOnlyDictionary + } + member this.ColumnByGetter(mem : MemberInfo) = + this.Columns.Values |> Seq.tryFind (fun col -> + match col.Getter with + | None -> false + | Some getter -> getter.MemberInfo = mem) + member this.TableName = this.Output.Name + member this.ReferencesQueryParent = + this.Columns.Values + |> Seq.exists (fun c -> c.ReverseRelationship.Value |> Option.isSome) + +and Primitive = + { + Output : Type + /// A method converting an object to the output type. + Converter : RowConversionMethod + } + +and BlueprintShape = + | Primitive of Primitive + | Composite of Composite + +and ElementBlueprint = + { + Shape : BlueprintShape + /// The element type this blueprint specifies how to construct. + Output : Type + } + member internal this.IsOne(roots : Type HashSet) = + match this.Shape with + | Primitive _ -> true + | Composite c -> + c.Columns.Values + |> Seq.forall (fun c -> + let blueprint = c.Blueprint.Value + roots.Contains(blueprint.Output) + || roots.Add(blueprint.Output) && blueprint.IsOne(roots)) + +and Cardinality = + | One of ElementBlueprint + /// Carries an element type blueprint and a method converting an ICollection> + /// to the target collection type. + | Many of ElementBlueprint * ConversionMethod + member this.Element = + match this with + | One elem -> elem + | Many (elem, _) -> elem + +and Blueprint = + { + Cardinality : Cardinality + /// The type (possibly a collectiont ype) this blueprint specifies how to construct. + Output : Type + } + member internal this.IsOne(roots : Type HashSet) = + match this.Cardinality with + | One e -> e.IsOne(roots) + | Many _ -> false + member this.IsOne() = this.IsOne(new HashSet<_>([| this.Output |])) diff --git a/Rezoom.SQL.Mapping/BlueprintAttributes.fs b/Rezoom.SQL.Mapping/BlueprintAttributes.fs new file mode 100644 index 0000000..8bc08e0 --- /dev/null +++ b/Rezoom.SQL.Mapping/BlueprintAttributes.fs @@ -0,0 +1,27 @@ +namespace Rezoom.SQL.Mapping +open System + +/// Marks a constructor as being the one to use when creating entities from blueprints. +[] +[] +type BlueprintConstructorAttribute() = + inherit Attribute() + +/// Marks a property as being part of the primary key of its composite type. +[] +[] +type BlueprintKeyAttribute() = + inherit Attribute() + +/// Indicates that a property is represented with a different column name than its own member name. +[] +[] +type BlueprintColumnNameAttribute(name : string) = + inherit Attribute() + member __.Name = name + +/// Indicates that a class has no key properties and should not be de-duplicated. +[] +[] +type BlueprintNoKeyAttribute() = + inherit Attribute() \ No newline at end of file diff --git a/Rezoom.SQL.Mapping/BlueprintModule.fs b/Rezoom.SQL.Mapping/BlueprintModule.fs new file mode 100644 index 0000000..62c862a --- /dev/null +++ b/Rezoom.SQL.Mapping/BlueprintModule.fs @@ -0,0 +1,244 @@ +[] +module Rezoom.SQL.Mapping.Blueprint +open Rezoom.SQL.Mapping.CodeGeneration +open LicenseToCIL +open System +open System.Collections +open System.Collections.Generic +open System.ComponentModel +open System.Reflection +open System.Text.RegularExpressions + +let private blueprintCache = new Dictionary() + +let private ciDictionary keyValues = + let dictionary = new Dictionary(StringComparer.OrdinalIgnoreCase) + for key, value in keyValues do + dictionary.[key] <- value // overwrite duplicates, last wins + dictionary + +/// Get the constructor that the blueprint for `ty` should use. +/// This is simply the constructor with the most parameters, +/// unless there is a constructor with `[]`, +/// in which case that one will be used. +let private pickConstructor (ty : Type) = + let constructors = ty.GetConstructors() + if Array.isEmpty constructors then failwithf "Type %O has no public constructors" ty + let constructorsWithInfo = + constructors + |> Array.map (fun cons -> + let hasAttr = not << isNull <| cons.GetCustomAttribute() + cons, cons.GetParameters(), hasAttr) + let attributed = + constructorsWithInfo + |> Seq.filter (fun (_, _, a) -> a) + |> Seq.truncate 2 + |> Seq.toList + match attributed with + | [] -> + constructorsWithInfo + |> Array.maxBy (fun (_, p, _) -> p.Length) + |> fun (cons, pars, _) -> cons, pars + | [(cons, pars, _)] -> cons, pars + | multiple -> + failwithf "Type %O has %d constructors with [] applied. Cannot disambiguate constructor." + ty + (List.length multiple) + +/// Pick, in order of most to least preferred: +/// - the column whose getter is annotated with [] +/// - the column named "ID" +/// - the column named "{TypeName}ID" +let private pickIdentity (ty : Type) (cols : IReadOnlyDictionary) = + let noIdentity = ty.GetCustomAttribute() + if isNull noIdentity then + let attributed = + seq { + for col in cols.Values do + match col.Getter with + | None -> () + | Some getter -> + let attr = getter.MemberInfo.GetCustomAttribute() + if not (isNull attr) then yield col + } |> Seq.toArray + match attributed with + | [||] -> + let succ, id = cols.TryGetValue("ID") + if succ then [| id |] else + let succ, id = cols.TryGetValue(ty.Name + "ID") + if succ then [| id |] else + Array.empty + | identity -> identity + else Array.empty + +let private swapParentChild (me : string) (them : string) (name : string) = + let swapper (m : Match) = + if m.Value.Equals("PARENT", StringComparison.OrdinalIgnoreCase) then "CHILD" + elif m.Value.Equals("CHILD", StringComparison.OrdinalIgnoreCase) then "PARENT" + elif m.Value.Equals(them, StringComparison.OrdinalIgnoreCase) then me + elif m.Value.Equals(me, StringComparison.OrdinalIgnoreCase) then them + else failwith "Impossible" + let re = Regex("PARENT|CHILD|" + Regex.Escape(me) + "|" + Regex.Escape(them), RegexOptions.IgnoreCase) + re.Replace(name, swapper) + +let private pickReverseRelationship (ty : Type) (columnName : string) (neighbor : Blueprint) = + match neighbor.Cardinality with + | One { Shape = Composite composite } -> + let swapped = swapParentChild ty.Name composite.Output.Name columnName + composite.Columns.Values + |> Seq.choose (fun manyCol -> + if manyCol.Name.IndexOf(swapped, StringComparison.OrdinalIgnoreCase) >= 0 then + match manyCol.Blueprint.Value.Cardinality with + | Many (manyElem, _) when manyElem.Output = ty -> Some manyCol + | _ -> None + else None) + |> Seq.tryHead + | Many ({ Shape = Composite composite }, _) -> + composite.Columns.Values + |> Seq.filter (fun oneCol -> composite.Output <> ty || oneCol.Name <> columnName) + |> Seq.choose (fun oneCol -> + match oneCol.ReverseRelationship.Value with + | Some manyCol when + manyCol.Name.Equals(columnName, StringComparison.OrdinalIgnoreCase) -> + match oneCol.Blueprint.Value.Cardinality with + | One elem when elem.Output = ty -> Some oneCol + | _ -> None + | _ -> None) + |> Seq.tryHead + | _ -> None + +let private pickName (name : string) (getter : Getter option) = + match getter with + | None -> name + | Some getter -> + let columnNameAttr = getter.MemberInfo.GetCustomAttribute() + if isNull columnNameAttr then name + else columnNameAttr.Name + +let rec private compositeShapeOfType ty = + let ctor, pars = pickConstructor ty + let props = + ty.GetProperties() |> Array.filter (fun p -> p.CanRead) + let fields = + ty.GetFields() + let gettersByName = + seq { // order is important: we want to prefer props over fields + for field in fields do + yield field.Name, (field.FieldType, GetField field) + for prop in props do + yield prop.Name, (prop.PropertyType, GetProperty prop) + } |> ciDictionary + let settersByName = + seq { // order is important: we want to prefer constructor pars over props over fields + for field in fields do + yield field.Name, (field.FieldType, SetField field) + for prop in props do + if prop.CanWrite then + yield prop.Name, (prop.PropertyType, SetProperty prop) + for par in pars do + yield par.Name, (par.ParameterType, SetConstructorParameter par) + } |> ciDictionary + let columns = + seq { + for index, KeyValue(name, (setterTy, setter)) in settersByName |> Seq.indexed -> + let succ, getter = gettersByName.TryGetValue(name) + let getter = + if not succ then None else + let getterTy, getter = getter + if getterTy.IsAssignableFrom(setterTy) then Some getter + else None + let blueprint = lazy ofType setterTy + let name = pickName name getter + name, { + ColumnId = index + Name = name + Blueprint = blueprint + Setter = setter + Getter = getter + ReverseRelationship = + lazy pickReverseRelationship ty name blueprint.Value + } + } |> List.ofSeq |> ciDictionary + { Output = ty + Constructor = ctor + Identity = pickIdentity ty columns + Columns = columns + } + +and private cardinalityOfType (ty : Type) = + // If our type is an interface, choose a concrete representative instead. + let ty = CollectionConverters.representativeForInterface ty + if ty.IsConstructedGenericType && ty.GetGenericTypeDefinition() = typedefof<_ option> then + // Sadly must special-case this since option doesn't implement IEnumerable + let elemTy = ty.GetGenericArguments().[0] + match CollectionConverters.converter ty null elemTy with + | None -> failwith "Can't handle optional" + | Some converter -> + Many (elementOfType elemTy, converter) + else + let ifaces = ty.GetInterfaces() + // For this to be a collection, it must implement IEnumerable. + if ifaces |> Array.contains (typeof) |> not then One (elementOfType ty) else + // Ok, really it needs to be a generic IEnumerable *of* something... + let possible = + ifaces + |> Seq.filter + (fun iface -> + iface.IsConstructedGenericType + && iface.GetGenericTypeDefinition() = typedefof<_ seq>) + |> Seq.truncate 2 + |> Seq.toList + match possible with + | [] -> One (elementOfType ty) + | [ienum] -> + // Also, we need to figure out some way to construct it. + let elemTy = + match ienum.GetGenericArguments() with + | [|e|] -> e + | _ -> failwith "Cannot run in bizzare universe where IEnumerable doesn't have one generic arg." + match CollectionConverters.converter ty ienum elemTy with + | None -> One (elementOfType ty) + | Some converter -> Many (elementOfType elemTy, converter) + | multiple -> + failwithf "Type %O has %d IEnumerable implementations. This confuses us." + ty + (List.length multiple) + +and private primitiveShapeOfType (ty : Type) = + PrimitiveConverters.converter ty + |> Option.map (fun converter -> { Output = ty; Converter = converter }) + +and private elementOfType (ty : Type) = + let shape = + match primitiveShapeOfType ty with + | Some p -> Primitive p + | None -> Composite (compositeShapeOfType ty) + { + Shape = shape + Output = ty + } + +and private ofTypeRaw (ty : Type) = + match primitiveShapeOfType ty with + | Some p -> + { + Cardinality = + { + Shape = Primitive p + Output = ty + } |> One + Output = ty + } + | None -> + { + Cardinality = cardinalityOfType ty + Output = ty + } + +and ofType ty = + lock blueprintCache <| fun () -> + let succ, existing = blueprintCache.TryGetValue(ty) + if succ then existing else + let blueprint = ofTypeRaw ty + blueprintCache.[ty] <- blueprint + blueprint diff --git a/Rezoom.SQL.Mapping/CILHelpers.fs b/Rezoom.SQL.Mapping/CILHelpers.fs new file mode 100644 index 0000000..86efce9 --- /dev/null +++ b/Rezoom.SQL.Mapping/CILHelpers.fs @@ -0,0 +1,21 @@ +[] +module private Rezoom.SQL.Mapping.CodeGeneration.CILHelpers +open LicenseToCIL +open LicenseToCIL.Stack +open LicenseToCIL.Ops + +let generalize (op : Op) : Op<'x S, 'x S> = + cil { + yield pretend + yield op + yield pretend + } + +let generalize2 (op : Op) : Op<'x S S, 'x S> = + cil { + yield pretend + yield op + yield pretend + } + + diff --git a/Rezoom.SQL.Mapping/CollectionConverters.fs b/Rezoom.SQL.Mapping/CollectionConverters.fs new file mode 100644 index 0000000..ed9a9a7 --- /dev/null +++ b/Rezoom.SQL.Mapping/CollectionConverters.fs @@ -0,0 +1,77 @@ +module Rezoom.SQL.Mapping.CodeGeneration.CollectionConverters +open Rezoom.SQL.Mapping +open LicenseToCIL +open System +open System.Collections.Generic + +let private interfaceRepresentatives : IDictionary Type> = + [| + [| + typedefof> + typedefof> + typedefof> + |], fun (elementTy : Type) -> elementTy.MakeArrayType() + [| + typedefof> + typedefof> + |], fun (elementTy : Type) -> typedefof<_ ResizeArray>.MakeGenericType(elementTy) + |] + |> Seq.collect + (fun (ifaces, representative) -> ifaces |> Seq.map (fun i -> i, representative)) + |> dict + +let representativeForInterface (ty : Type) = + if not ty.IsInterface then ty else + if not ty.IsConstructedGenericType then ty else + let def = ty.GetGenericTypeDefinition() + let args = ty.GetGenericArguments() + match args with + | [|elemTy|] -> + let succ, repr = interfaceRepresentatives.TryGetValue(def) + if not succ then ty + else repr elemTy + | _ -> ty + +type Converters<'elem> = + static member ToArray(collection : 'elem EntityReader ICollection) = + let arr = Array.zeroCreate collection.Count + let mutable i = 0 + for reader in collection do + arr.[i] <- reader.ToEntity() + i <- i + 1 + arr + static member ToResizeArray(collection : 'elem EntityReader ICollection) = + let resizeArr = new ResizeArray<'elem>(collection.Count) + for reader in collection do + resizeArr.Add(reader.ToEntity()) + resizeArr + static member ToList(collection : 'elem EntityReader ICollection) = + collection |> Seq.map (fun r -> r.ToEntity()) |> List.ofSeq + static member ToOption(collection : 'elem EntityReader ICollection) = + if collection.Count > 1 then + failwithf + "Multiple %ss found in results where a single optional %s was expected" + typeof<'elem>.Name + typeof<'elem>.Name + elif collection.Count <= 0 then + None + else + let reader = collection |> Seq.head + Some <| reader.ToEntity() + +let converter (ty : Type) (ienum : Type) (elem : Type) : ConversionMethod option = + let converter = typedefof>.MakeGenericType(elem) + let specializedMethod = + converter.GetMethods() + |> Array.tryFind (fun m -> m.ReturnType = ty) + match specializedMethod with + | Some m -> Some (Ops.call1 m) + | None -> + // fall back to passing the type an IEnumerable + let constructorOfIEnum = ty.GetConstructor([|ienum|]) + if isNull constructorOfIEnum then None else + let toArray = converter.GetMethod("ToArray") + cil { + yield Ops.call1 toArray + yield Ops.newobj1 constructorOfIEnum + } |> Some \ No newline at end of file diff --git a/Rezoom.SQL.Mapping/ColumnMap.fs b/Rezoom.SQL.Mapping/ColumnMap.fs new file mode 100644 index 0000000..97c287f --- /dev/null +++ b/Rezoom.SQL.Mapping/ColumnMap.fs @@ -0,0 +1,101 @@ +namespace Rezoom.SQL.Mapping +open System +open System.Collections.Generic + +type ColumnType = + | Invalid = 0s + | Object = 1s // whatever it is goes through boxing + | String = 2s + | Byte = 3s + | Int16 = 4s + | Int32 = 5s + | Int64 = 6s + | SByte = 7s + | UInt16 = 8s + | UInt32 = 9s + | UInt64 = 10s + | Single = 11s + | Double = 12s + | Decimal = 13s + | DateTime = 14s + +[] +type ColumnInfo = + // must be mutable to be able to access with ldfld from generated code + val mutable public Index : int16 + val mutable public Type : ColumnType + new (index, rowValueType) = { Index = index; Type = rowValueType } + + static member IndexField = typeof.GetField("Index") + static member TypeField = typeof.GetField("Type") + member this.CLRType = + match this.Type with + | ColumnType.Invalid -> typeof + | ColumnType.Object -> typeof + | ColumnType.String -> typeof + | ColumnType.Byte -> typeof + | ColumnType.Int16 -> typeof + | ColumnType.Int32 -> typeof + | ColumnType.Int64 -> typeof + | ColumnType.SByte -> typeof + | ColumnType.UInt16 -> typeof + | ColumnType.UInt32 -> typeof + | ColumnType.UInt64 -> typeof + | ColumnType.Single -> typeof + | ColumnType.Double -> typeof + | ColumnType.Decimal -> typeof + | ColumnType.DateTime -> typeof + | _ -> invalidArg "type" "Unknown column type" + +[] +type ColumnMap(columns, subMaps) = + static let columnMethod = typeof.GetMethod("Column") + static let primaryColumnMethod = typeof.GetMethod("PrimaryColumn") + static let subMapMethod = typeof.GetMethod("SubMap") + new() = + let columns = Dictionary(StringComparer.OrdinalIgnoreCase) + let subMaps = Dictionary(StringComparer.OrdinalIgnoreCase) + ColumnMap(columns, subMaps) + member private this.GetOrCreateSubMap(name) = + let succ, sub = subMaps.TryGetValue(name) + if succ then sub else + let sub = ColumnMap() + subMaps.[name] <- sub + sub + member private this.SetColumn(name, info) = + columns.[name] <- info + member private this.Load(columnNames : (string * ColumnType) array) = + for i = 0 to columnNames.Length - 1 do + let mutable current = this + let name, rowValueType = columnNames.[i] + let path = name.Split('.', '$') + if path.Length > 1 then + current <- this + for j = 0 to path.Length - 2 do + current <- current.GetOrCreateSubMap(path.[j]) + current.SetColumn(Array.last path, ColumnInfo(int16 i, rowValueType)) + member this.Column(name) = + let succ, info = columns.TryGetValue(name) + if succ then info else ColumnInfo(-1s, ColumnType.Invalid) + member this.PrimaryColumn() = + columns.Values |> Seq.head + member this.SubMap(name) = + let succ, map = subMaps.TryGetValue(name) + if succ then map + else + let succ, info = columns.TryGetValue(name) + if succ then + let cols = Dictionary() + cols.[name] <- info + ColumnMap(cols, Dictionary()) + else null + member this.SubMaps = subMaps :> _ seq + member this.Columns = columns :> _ seq + static member Parse(columnNames) = + let map = ColumnMap() + map.Load(columnNames) + map + + static member internal PrimaryColumnMethod = primaryColumnMethod + static member internal ColumnMethod = columnMethod + static member internal SubMapMethod = subMapMethod diff --git a/Rezoom.SQL.Mapping/Command.fs b/Rezoom.SQL.Mapping/Command.fs new file mode 100644 index 0000000..0824002 --- /dev/null +++ b/Rezoom.SQL.Mapping/Command.fs @@ -0,0 +1,222 @@ +namespace Rezoom.SQL.Mapping +open System +open System.Data +open System.Collections.Generic +open Rezoom +open Rezoom.SQL.Mapping.CodeGeneration + +type CommandFragment = + /// A name which should be localized to this command for batching. + /// For example, if the command creates a temp table, the real name should be chosen dynamically + /// so it doesn't break when the command is batched with others that create the same-named temp table. + | LocalName of string + /// Chunk of raw SQL text. + | CommandText of string + /// References parameter by index. + | Parameter of int + /// At least one unit of whitespace. + | Whitespace + /// Converts a sequence of fragments *without parameters* to a string. + static member Stringize(fragments : CommandFragment seq) = + seq { + for fragment in fragments do + match fragment with + | LocalName name -> yield name + | CommandText text -> yield text + | Whitespace -> yield " " + | Parameter i -> yield ("@P" + string i) + } |> String.concat "" + +[] +type ResultSetProcessor() = + /// Start processing a result set. + abstract member BeginResultSet : IDataReader -> unit + /// Process a single row of the result set. + abstract member ProcessRow : unit -> unit + /// Obtain the result object after processing *all* result sets. + abstract member ObjectGetResult : unit -> obj + +[] +type ResultSetProcessor<'output>() = + inherit ResultSetProcessor() + abstract member GetResult : unit -> 'output + override this.ObjectGetResult() = this.GetResult() |> box + +type CommandData = + { ConnectionName : string + Identity : string + Fragments : CommandFragment IReadOnlyList + DependencyMask : BitMask + InvalidationMask : BitMask + Cacheable : bool + ResultSetCount : int option + } + +type CommandCategory = CommandCategory of connectionName : string + +[] +type Command(data : CommandData, parameters : (obj * DbType) IReadOnlyList) = + let category = CommandCategory data.ConnectionName + let cacheInfo = + { new CacheInfo() with + override __.Category = upcast category + override __.Identity = upcast data.Identity + override __.DependencyMask = data.DependencyMask + override __.InvalidationMask = data.InvalidationMask + override __.Cacheable = data.Cacheable + } + member __.ConnectionName = data.ConnectionName + member __.CacheInfo = cacheInfo + member __.Fragments = data.Fragments + member __.Parameters = parameters + /// The number of result sets this command will return, if it can be statically determined. + member __.ResultSetCount = data.ResultSetCount + + abstract member ObjectResultSetProcessor : unit -> ResultSetProcessor + +/// Represents multiple result sets as the output from a single command. +[] +type ResultSets() = + abstract member AllResultSets : obj seq + +type ResultSets<'a, 'b>(a : 'a, b : 'b) = + inherit ResultSets() + member __.ResultSet1 = a + member __.ResultSet2 = b + override __.AllResultSets = + Seq.ofArray [| box a; box b |] + +type ResultSets<'a, 'b, 'c>(a : 'a, b : 'b, c : 'c) = + inherit ResultSets() + member __.ResultSet1 = a + member __.ResultSet2 = b + member __.ResultSet3 = c + override __.AllResultSets = + Seq.ofArray [| box a; box b; box c |] + +type ResultSets<'a, 'b, 'c, 'd>(a : 'a, b : 'b, c : 'c, d : 'd) = + inherit ResultSets() + member __.ResultSet1 = a + member __.ResultSet2 = b + member __.ResultSet3 = c + member __.ResultSet4 = d + override __.AllResultSets = + Seq.ofArray [| box a; box b; box c; box d |] + +/// A command which can be expected to produce `'output` when run. +[] +type Command<'output>(data, parameters) = + inherit Command(data, parameters) + abstract member WithConnectionName : connectionName : string -> Command<'output> + abstract member ResultSetProcessor : unit -> ResultSetProcessor<'output> + override this.ObjectResultSetProcessor() = upcast this.ResultSetProcessor() + +type private ResultSetProcessor0<'a>() = + inherit ResultSetProcessor<'a>() + override __.BeginResultSet(_) = () + override __.ProcessRow() = () + override __.ObjectGetResult() = upcast Unchecked.defaultof<'a> + override __.GetResult() = Unchecked.defaultof<'a> + +type private Command0(data, parameters) = + inherit Command(data, parameters) + override __.WithConnectionName(connectionName) = + upcast Command0({ data with ConnectionName = connectionName}, parameters) + override __.ResultSetProcessor() = upcast ResultSetProcessor0() + +type private ResultSetProcessor1<'a>() = + inherit ResultSetProcessor<'a>() + let reader = ReaderTemplate<'a>.Template().CreateReader() + let mutable row = Unchecked.defaultof + let result = lazy reader.ToEntity() + override __.BeginResultSet(dataReader) = + reader.ProcessColumns(DataReader.columnMap(dataReader)) + row <- DataReader.DataReaderRow(dataReader) + override __.ProcessRow() = + reader.Read(row) + override __.GetResult() = result.Value + +type private Command1<'a>(data, parameters) = + inherit Command<'a>(data, parameters) + override __.WithConnectionName(connectionName) = + upcast Command1({ data with ConnectionName = connectionName}, parameters) + override __.ResultSetProcessor() = upcast ResultSetProcessor1<'a>() + +type private MultiResultSetProcessor(readers : EntityReader list) = + let mutable row = Unchecked.defaultof + let mutable readers = readers + let mutable first = true + member __.BeginResultSet(dataReader : IDataReader) = + if not first then + readers <- List.tail readers + else + first <- false + (List.head readers).ProcessColumns(DataReader.columnMap(dataReader)) + row <- DataReader.DataReaderRow(dataReader) + member __.ProcessRow() = + (List.head readers).Read(row) + +type private ResultSetProcessor2<'a, 'b>() = + inherit ResultSetProcessor>() + let aReader = ReaderTemplate<'a>.Template().CreateReader() + let bReader = ReaderTemplate<'b>.Template().CreateReader() + let proc = MultiResultSetProcessor([ aReader; bReader ]) + let result = lazy ResultSets<'a, 'b>(aReader.ToEntity(), bReader.ToEntity()) + override __.BeginResultSet(dataReader) = proc.BeginResultSet(dataReader) + override __.ProcessRow() = proc.ProcessRow() + override __.GetResult() = result.Value + +type private Command2<'a, 'b>(data, parameters) = + inherit Command>(data, parameters) + override __.WithConnectionName(connectionName) = + upcast Command2({ data with ConnectionName = connectionName}, parameters) + override __.ResultSetProcessor() = upcast ResultSetProcessor2<'a, 'b>() + +type private ResultSetProcessor3<'a, 'b, 'c>() = + inherit ResultSetProcessor>() + let aReader = ReaderTemplate<'a>.Template().CreateReader() + let bReader = ReaderTemplate<'b>.Template().CreateReader() + let cReader = ReaderTemplate<'c>.Template().CreateReader() + let proc = MultiResultSetProcessor([ aReader; bReader; cReader ]) + let result = lazy ResultSets<'a, 'b, 'c>(aReader.ToEntity(), bReader.ToEntity(), cReader.ToEntity()) + override __.BeginResultSet(dataReader) = proc.BeginResultSet(dataReader) + override __.ProcessRow() = proc.ProcessRow() + override __.GetResult() = result.Value + +type private Command3<'a, 'b, 'c>(data, parameters) = + inherit Command>(data, parameters) + override __.WithConnectionName(connectionName) = + upcast Command3({ data with ConnectionName = connectionName}, parameters) + override __.ResultSetProcessor() = upcast ResultSetProcessor3<'a, 'b, 'c>() + +type private ResultSetProcessor4<'a, 'b, 'c, 'd>() = + inherit ResultSetProcessor>() + let aReader = ReaderTemplate<'a>.Template().CreateReader() + let bReader = ReaderTemplate<'b>.Template().CreateReader() + let cReader = ReaderTemplate<'c>.Template().CreateReader() + let dReader = ReaderTemplate<'d>.Template().CreateReader() + let proc = MultiResultSetProcessor([ aReader; bReader; cReader; dReader ]) + let result = + lazy ResultSets<'a, 'b, 'c, 'd> + (aReader.ToEntity(), bReader.ToEntity(), cReader.ToEntity(), dReader.ToEntity()) + override __.BeginResultSet(dataReader) = proc.BeginResultSet(dataReader) + override __.ProcessRow() = proc.ProcessRow() + override __.GetResult() = result.Value + +type private Command4<'a, 'b, 'c, 'd>(data, parameters) = + inherit Command>(data, parameters) + override __.WithConnectionName(connectionName) = + upcast Command4({ data with ConnectionName = connectionName}, parameters) + override __.ResultSetProcessor() = upcast ResultSetProcessor4<'a, 'b, 'c, 'd>() + +type CommandConstructor() = + static member Command0(data, parameters) = + Command0(data, parameters) :> _ Command + static member Command1<'a>(data, parameters) = + Command1<'a>(data, parameters) :> _ Command + static member Command2<'a, 'b>(data, parameters) = + Command2<'a, 'b>(data, parameters) :> _ Command + static member Command3<'a, 'b, 'c>(data, parameters) = + Command3<'a, 'b, 'c>(data, parameters) :> _ Command + static member Command4<'a, 'b, 'c, 'd>(data, parameters) = + Command4<'a, 'b, 'c, 'd>(data, parameters) :> _ Command \ No newline at end of file diff --git a/Rezoom.SQL.Mapping/CommandBatch.fs b/Rezoom.SQL.Mapping/CommandBatch.fs new file mode 100644 index 0000000..b3617dd --- /dev/null +++ b/Rezoom.SQL.Mapping/CommandBatch.fs @@ -0,0 +1,164 @@ +namespace Rezoom.SQL.Mapping +open System +open System.Data +open System.Data.Common +open System.Collections.Generic +open System.Text +open System.Threading +open System.Threading.Tasks +open FSharp.Control.Tasks.ContextInsensitive + +type private CommandBatchBuilder(conn : DbConnection) = + let maxParameters = + match conn.GetType().Namespace with + | "System.Data.SqlClient" -> 2100 // SQL server + | "System.Data.OracleClient" + | "Oracle.DataAccess.Client" -> 2000 // Oracle + | "Npgsql" -> 10000 // Postgres -- can support more but it's probably a bad idea + | "MySql.Data.MySqlClient" -> 10000 // MySQL -- can support more but it's probably a bad idea + | _ -> 999 // SQLite default, and probably the lowest of any DB + static let terminatorColumn i = "RZSQL_TERMINATOR_" + string i + static let terminator i = ";--'*/;SELECT NULL AS " + terminatorColumn i + static let parameterName i = "@RZSQL_" + string i + static let parameterNameArray i j = "@RZSQL_" + string i + "_" + string j + static let localName i name = "RZSQL_" + name + "_" + string i + let commands = ResizeArray() + let mutable parameterCount = 0 + let mutable evaluating = false + + let addCommand (builder : StringBuilder) (dbCommand : DbCommand) (commandIndex : int) (command : Command) = + let parameterOffset = dbCommand.Parameters.Count + let addParam name dbType value = + let dbParam = dbCommand.CreateParameter() + dbParam.ParameterName <- name + dbParam.DbType <- dbType + dbParam.Value <- value + ignore <| dbCommand.Parameters.Add(dbParam) + for i, (parameterValue, parameterType) in command.Parameters |> Seq.indexed do + match parameterValue with + | :? Array as arr -> + let mutable j = 0 + for elem in arr do + addParam (parameterNameArray (parameterOffset + i) j) parameterType elem + j <- j + 1 + | _ -> + addParam (parameterName (parameterOffset + i)) parameterType parameterValue + for fragment in command.Fragments do + let fragmentString = + match fragment with + | LocalName name -> localName commandIndex name + | CommandText str -> str + | Parameter i -> + match command.Parameters.[i] |> fst with + | :? Array as arr -> + let parNames = + seq { + for j = 0 to arr.Length - 1 do yield parameterNameArray (parameterOffset + i) j + } + "(" + String.concat "," parNames + ")" + | _ -> parameterName (parameterOffset + i) + | Whitespace -> " " + ignore <| builder.Append(fragmentString) + match command.ResultSetCount with + | Some _ -> () // no need to add terminator statement + | None when commandIndex + 1 >= commands.Count -> () + | None -> + builder.Append(terminator commandIndex) |> ignore + let buildCommand (dbCommand : DbCommand) = + let builder = StringBuilder() + for commandIndex, command in commands |> Seq.indexed do + addCommand builder dbCommand commandIndex command + dbCommand.CommandText <- builder.ToString() + + member __.BatchCommand(cmd : Command) = + let mutable count = 0 + for par in cmd.Parameters do + match fst par with + | :? Array as arr -> count <- count + arr.Length + | _ -> count <- count + 1 + if parameterCount + count > maxParameters then + Nullable() + else + let index = commands.Count + commands.Add(cmd) + parameterCount <- parameterCount + count + Nullable(index) + + member __.Evaluate() = + if evaluating then failwith "Already evaluating command" + else evaluating <- true + task { + use dbCommand = conn.CreateCommand() + buildCommand dbCommand + use! reader = dbCommand.ExecuteReaderAsync() + let reader = reader : DbDataReader + let processed = ResizeArray() + for i = 0 to commands.Count - 1 do + let cmd = commands.[i] + let processor = cmd.ObjectResultSetProcessor() + let mutable resultSetCount = match cmd.ResultSetCount with | Some 0 -> -1 | _ -> 0 + while resultSetCount >= 0 do + processor.BeginResultSet(reader) + let mutable hasRows = true + while hasRows do + let! hasRow = reader.ReadAsync() + if hasRow then + processor.ProcessRow() + else + hasRows <- false + resultSetCount <- resultSetCount + 1 + let! hasNextResult = reader.NextResultAsync() + match cmd.ResultSetCount with + | None -> // check for terminator + if not hasNextResult || reader.FieldCount = 1 && reader.GetName(0) = terminatorColumn i then + resultSetCount <- -1 + else + let! hasNextResult = reader.NextResultAsync() + if not hasNextResult then + resultSetCount <- -1 + | Some count -> + if resultSetCount = count then + resultSetCount <- -1 + elif not hasNextResult then + failwithf + "Command claimed it would produce %d result sets, but only yielded %d" + count resultSetCount + processed.Add(processor.ObjectGetResult()) + return processed + } + +type CommandBatch(conn : DbConnection) = + let builders = ResizeArray() + let evaluation = + lazy + task { + let arr = Array.zeroCreate builders.Count + for i = 0 to builders.Count - 1 do + let! resultSets = builders.[i].Evaluate() + arr.[i] <- resultSets + return arr + } + do + builders.Add(CommandBatchBuilder(conn)) + member __.Batch(cmd : #Command<'a>) = + let inline retrieveResult builderIndex resultsIndex = + fun (token : CancellationToken) -> + evaluation.Value.ContinueWith + ( (fun (t : _ ResizeArray array Task) -> + t.Result.[builderIndex].[resultsIndex] |> Unchecked.unbox : 'a) + , TaskContinuationOptions.ExecuteSynchronously + ) + let builderIndex = builders.Count - 1 + let resultsIndex = builders.[builderIndex].BatchCommand(cmd) + if resultsIndex.HasValue then + retrieveResult builderIndex resultsIndex.Value + else + let next = CommandBatchBuilder(conn) + let builderIndex = builderIndex + 1 + let resultsIndex = next.BatchCommand(cmd) + builders.Add(next) + if resultsIndex.HasValue then + retrieveResult builderIndex resultsIndex.Value + else + failwith "Command has too many parameters to run" + diff --git a/Rezoom.SQL.Mapping/CompositeColumnGenerator.fs b/Rezoom.SQL.Mapping/CompositeColumnGenerator.fs new file mode 100644 index 0000000..d565f61 --- /dev/null +++ b/Rezoom.SQL.Mapping/CompositeColumnGenerator.fs @@ -0,0 +1,115 @@ +namespace Rezoom.SQL.Mapping.CodeGeneration +open Rezoom.SQL.Mapping +open LicenseToCIL +open LicenseToCIL.Stack +open LicenseToCIL.Ops +open System +open System.Collections.Generic +open System.Reflection +open System.Reflection.Emit + +type private CompositeColumnGenerator(builder, column, composite : Composite) = + inherit EntityReaderColumnGenerator(builder) + let output = column.Blueprint.Value.Output + let staticTemplate = Generation.readerTemplateGeneric.MakeGenericType(output) + let entTemplate = typedefof<_ EntityReaderTemplate>.MakeGenericType(output) + let entReaderType = typedefof<_ EntityReader>.MakeGenericType(output) + let requiresSelf = composite.ReferencesQueryParent + let mutable entReader = null + override __.DefineConstructor() = + entReader <- builder.DefineField("_c_r_" + column.Name, entReaderType, FieldAttributes.Private) + zero + override __.DefineProcessColumns() = + cil { + let! ncase = deflabel // if submap is null + let! sub = tmplocal typeof + yield ldarg 1 // column map + yield ldstr column.Name + yield call2 ColumnMap.SubMapMethod + yield dup + yield stloc sub + yield brfalse's ncase + yield cil { + yield dup + yield call0 (staticTemplate.GetMethod("Template")) + yield callvirt1 (entTemplate.GetMethod("CreateReader")) + yield dup + yield ldloc sub + yield callvirt2'void Generation.processColumnsMethod + yield stfld entReader + } + yield mark ncase + } + override __.DefineImpartKnowledgeToNext() = + cil { + let! ncase = deflabel + yield ldarg 0 + yield ldfld entReader + yield brfalse's ncase + yield cil { + let! newReader = tmplocal entReaderType + yield ldarg 1 + yield castclass builder + yield ldarg 0 + yield ldfld entReader + yield call0 (staticTemplate.GetMethod("Template")) + yield callvirt1 (entTemplate.GetMethod("CreateReader")) + yield dup + yield stloc newReader + yield callvirt2'void (entReaderType.GetMethod("ImpartKnowledgeToNext")) + yield ldloc newReader + yield stfld entReader + } + yield mark ncase + } + override __.DefineRead(skipOnes) = + cil { + let! ncase = deflabel + yield dup + yield ldfld entReader + yield brfalse's ncase + yield cil { + yield dup + yield ldfld entReader + yield ldarg 1 + yield callvirt2'void Generation.readMethod + } + yield mark ncase + } + override __.DefineSetReverse() = + if column.ReverseRelationship.Value |> Option.isNone then zero else + cil { + let! skip = deflabel + yield ldarg 1 + yield ldc'i4 column.ColumnId + yield bne'un's skip + yield cil { + yield dup + yield ldarg 2 + yield castclass composite.Output + yield newobj1 (typedefof<_ ObjectEntityReader>.MakeGenericType(output).GetConstructor([|output|])) + yield stfld entReader + } + yield mark skip + } + override __.RequiresSelfReferenceToPush = requiresSelf + override __.DefinePush(self) = + cil { + let! ncase = deflabel + yield ldarg 0 + yield ldfld entReader + yield dup + yield brfalse's ncase + yield cil { + match column.ReverseRelationship.Value with + | None -> () + | Some rev -> + yield dup + yield ldc'i4 rev.ColumnId + yield ldloc self + if output.IsValueType then yield box'val output + yield callvirt3'void Generation.setReverseMethod + yield callvirt1 (entReaderType.GetMethod("ToEntity")) + } + yield mark ncase + } diff --git a/Rezoom.SQL.Mapping/Converters.fs b/Rezoom.SQL.Mapping/Converters.fs new file mode 100644 index 0000000..b6859cb --- /dev/null +++ b/Rezoom.SQL.Mapping/Converters.fs @@ -0,0 +1,10 @@ +namespace Rezoom.SQL.Mapping.CodeGeneration +open LicenseToCIL +open LicenseToCIL.Stack + +/// A conversion that assumes an obj is on the stack, and pushes a value of whatever type is being +/// converted to (depends on the context in which you see the conversion). +type ConversionMethod = Op + +/// Takes `Row` and `ColumnInfo` and pushes a value of whatever type if being converted to. +type RowConversionMethod = Op \ No newline at end of file diff --git a/Rezoom.SQL.Mapping/DataReader.fs b/Rezoom.SQL.Mapping/DataReader.fs new file mode 100644 index 0000000..cbdc35e --- /dev/null +++ b/Rezoom.SQL.Mapping/DataReader.fs @@ -0,0 +1,49 @@ +module Rezoom.SQL.Mapping.DataReader +open System +open System.Data + +let private columnTypes = + [| + typeof, ColumnType.String + typeof, ColumnType.Byte + typeof, ColumnType.Int16 + typeof, ColumnType.Int32 + typeof, ColumnType.Int64 + typeof, ColumnType.SByte + typeof, ColumnType.UInt16 + typeof, ColumnType.UInt32 + typeof, ColumnType.UInt64 + typeof, ColumnType.Single + typeof, ColumnType.Double + typeof, ColumnType.Decimal + typeof, ColumnType.DateTime + |] |> dict + +let columnType (ty : Type) = + let succ, colTy = columnTypes.TryGetValue(ty) + if succ then colTy else ColumnType.Object + +let columnMap (reader : IDataReader) = + let cols = Array.zeroCreate reader.FieldCount + for i = 0 to reader.FieldCount - 1 do + cols.[i] <- reader.GetName(i), columnType (reader.GetFieldType(i)) + ColumnMap.Parse(cols) + +type DataReaderRow(reader : IDataReader) = + inherit Row() + override __.IsNull(i) = reader.IsDBNull(int i) + override __.GetObject(i) = reader.GetValue(int i) + override __.GetString(i) = reader.GetString(int i) + override __.GetByte(i) = reader.GetByte(int i) + override __.GetInt16(i) = reader.GetInt16(int i) + override __.GetInt32(i) = reader.GetInt32(int i) + override __.GetInt64(i) = reader.GetInt64(int i) + override __.GetSByte(i) = reader.GetValue(int i) |> Convert.ToSByte + override __.GetUInt16(i) = reader.GetValue(int i) |> Convert.ToUInt16 + override __.GetUInt32(i) = reader.GetValue(int i) |> Convert.ToUInt32 + override __.GetUInt64(i) = reader.GetValue(int i) |> Convert.ToUInt64 + override __.GetSingle(i) = reader.GetFloat(int i) + override __.GetDouble(i) = reader.GetDouble(int i) + override __.GetDecimal(i) = reader.GetDecimal(int i) + override __.GetDateTime(i) = reader.GetDateTime(int i) + diff --git a/Rezoom.SQL.Mapping/EntityReader.fs b/Rezoom.SQL.Mapping/EntityReader.fs new file mode 100644 index 0000000..f4df10b --- /dev/null +++ b/Rezoom.SQL.Mapping/EntityReader.fs @@ -0,0 +1,27 @@ +namespace Rezoom.SQL.Mapping + +type ColumnId = int + +[] +type EntityReader() = + abstract member ProcessColumns : ColumnMap -> unit + abstract member Read : Row -> unit + abstract member SetReverse : ColumnId * obj -> unit + +[] +type EntityReader<'ent>() = + inherit EntityReader() + abstract member ImpartKnowledgeToNext : EntityReader<'ent> -> unit + abstract member ToEntity : unit -> 'ent + +type ObjectEntityReader<'ent>(ent : 'ent) = + inherit EntityReader<'ent>() + override __.ImpartKnowledgeToNext(_) = () + override __.ProcessColumns(_) = () + override __.Read(_) = () + override __.SetReverse(_, _) = () + override __.ToEntity() = ent + +[] +type EntityReaderTemplate<'ent>() = + abstract member CreateReader : unit -> 'ent EntityReader \ No newline at end of file diff --git a/Rezoom.SQL.Mapping/EntityReaderColumnGenerator.fs b/Rezoom.SQL.Mapping/EntityReaderColumnGenerator.fs new file mode 100644 index 0000000..284ac2d --- /dev/null +++ b/Rezoom.SQL.Mapping/EntityReaderColumnGenerator.fs @@ -0,0 +1,37 @@ +namespace Rezoom.SQL.Mapping.CodeGeneration +open Rezoom.SQL.Mapping +open LicenseToCIL +open LicenseToCIL.Stack +open LicenseToCIL.Ops +open System +open System.Collections.Generic +open System.Reflection +open System.Reflection.Emit + +type 'x THIS = 'x S +type 'x ENT = 'x S + +[] +type private EntityReaderColumnGenerator(builder : TypeBuilder) = + abstract member DefineConstructor : unit -> Op + abstract member DefineProcessColumns : unit -> Op + abstract member DefineImpartKnowledgeToNext : unit -> Op + abstract member DefineRead : skipOnes : Label -> Op + abstract member DefineSetReverse : unit -> Op + default __.DefineSetReverse() = zero + abstract member RequiresSelfReferenceToPush : bool + default __.RequiresSelfReferenceToPush = false + abstract member DefinePush : selfReference : Local -> Op<'x, 'x S> + +module private Generation = + // We'll need to reference this type in various column generator implementations, + // but don't want to use typedefof<_> and introduce explicit mutual recursion because + // that would require that we put all the implementations in one file. D: + let readerTemplateGeneric = + Assembly.GetExecutingAssembly().GetType("Rezoom.SQL.Mapping.CodeGeneration.ReaderTemplate`1") + let processColumnsMethod = + typeof.GetMethod("ProcessColumns") + let readMethod = + typeof.GetMethod("Read") + let setReverseMethod = + typeof.GetMethod("SetReverse") \ No newline at end of file diff --git a/Rezoom.SQL.Mapping/Integration.fs b/Rezoom.SQL.Mapping/Integration.fs new file mode 100644 index 0000000..ba7f475 --- /dev/null +++ b/Rezoom.SQL.Mapping/Integration.fs @@ -0,0 +1,119 @@ +[] +module Rezoom.SQL.Mapping.Integration +open System +open System.Configuration +open System.Collections.Generic +open System.Data +open System.Data.Common +open System.Threading +open Rezoom + +[] +type ConnectionProvider() = + abstract member Open : name : string -> DbConnection + abstract member BeginTransaction : DbConnection -> DbTransaction + default __.BeginTransaction(conn) = conn.BeginTransaction() + +type DefaultConnectionProvider() = + inherit ConnectionProvider() + override __.Open(name) = + let connectionStrings = ConfigurationManager.ConnectionStrings + if isNull connectionStrings then + failwith "No element in config" + let connectionString = connectionStrings.[name] + if isNull connectionString then + failwith "No connection string by the expected name" + let provider = DbProviderFactories.GetFactory(connectionString.ProviderName) + let conn = provider.CreateConnection() + conn.ConnectionString <- connectionString.ConnectionString + conn.Open() + conn + +type private ExecutionLocalConnections(provider : ConnectionProvider) = + let connections = Dictionary() + member __.GetConnection(name) = + let succ, tuple = connections.TryGetValue(name) + if succ then fst tuple else + let conn = provider.Open(name) + let tran = provider.BeginTransaction(conn) + connections.Add(name, (conn, tran)) + conn + member __.Dispose(state) = + let mutable exn = null + for conn, tran in connections.Values do + try + match state with + | ExecutionSuccess -> tran.Commit() + | ExecutionFault -> // don't explicitly rollback, tran.Dispose() should handle it + try + tran.Dispose() + finally + conn.Dispose() + with + | e -> + if isNull exn then exn <- e + else exn <- AggregateException(exn, e) + connections.Clear() + if not (isNull exn) then raise exn + // don't implement IDisposable because we need exec. state to know how to end transactions + +type private ExecutionLocalConnectionsFactory() = + inherit ServiceFactory() + override __.ServiceLifetime = ServiceLifetime.ExecutionLocal + override __.CreateService(cxt) = + let provider = + match cxt.Configuration.TryGetConfig() with + | None -> DefaultConnectionProvider() :> ConnectionProvider + | Some provider -> provider + ExecutionLocalConnections(provider) + override __.DisposeService(state, svc) = svc.Dispose(state) + +type private StepLocalBatches(conns : ExecutionLocalConnections) = + let batches = Dictionary() + member __.GetBatch(name) = + let succ, batch = batches.TryGetValue(name) + if succ then batch else + let conn = conns.GetConnection(name) + let batch = CommandBatch(conn) + batches.Add(name, batch) + batch + +type private StepLocalBatchesFactory() = + inherit ServiceFactory() + override __.ServiceLifetime = ServiceLifetime.StepLocal + override __.CreateService(cxt) = StepLocalBatches(cxt.GetService()) + override __.DisposeService(_, _) = () + +type private CommandErrandArgument(parameters : (obj * DbType) IReadOnlyList) = + member __.Parameters = parameters + member __.Equals(other : CommandErrandArgument) = + Seq.forall2 (=) parameters other.Parameters + override __.GetHashCode() = + let mutable h = 0 + for par, _ in parameters do + h <- ((h <<< 5) + h) ^^^ par.GetHashCode() + h + override this.Equals(other : obj) = + match other with + | :? CommandErrandArgument as other -> this.Equals(other) + | _ -> false + +type CommandErrand<'a>(command : Command<'a>) = + inherit AsynchronousErrand<'a>() + let cacheArgument = CommandErrandArgument(command.Parameters) + override __.CacheInfo = command.CacheInfo + override __.CacheArgument = box cacheArgument + override __.SequenceGroup = null + override __.Prepare(cxt) = + let batches = cxt.GetService() + batches.GetBatch(command.ConnectionName).Batch(command) + override __.ToString() = + let all = CommandFragment.Stringize(command.Fragments) + let truncate = 80 + if all.Length < truncate then all else all.Substring(0, truncate - 3) + "..." + +type Command<'a> with + member this.ExecutePlan() = + CommandErrand(this) |> Plan.ofErrand + member this.ExecuteAsync(conn : DbConnection) = + CommandBatch(conn).Batch(this)(CancellationToken()) diff --git a/Rezoom.SQL.Mapping/ManyColumnGenerator.fs b/Rezoom.SQL.Mapping/ManyColumnGenerator.fs new file mode 100644 index 0000000..c61bb5d --- /dev/null +++ b/Rezoom.SQL.Mapping/ManyColumnGenerator.fs @@ -0,0 +1,101 @@ +namespace Rezoom.SQL.Mapping.CodeGeneration +open Rezoom.SQL.Mapping +open LicenseToCIL +open LicenseToCIL.Stack +open LicenseToCIL.Ops +open System +open System.Collections.Generic +open System.Reflection +open System.Reflection.Emit + +type private ManyColumnGenerator + ( builder + , column : Column option + , element : ElementBlueprint + , conversion : ConversionMethod + ) = + inherit EntityReaderColumnGenerator(builder) + let elemTy = element.Output + let staticTemplate = Generation.readerTemplateGeneric.MakeGenericType(elemTy) + let entTemplate = typedefof<_ EntityReaderTemplate>.MakeGenericType(elemTy) + let elemReaderTy = typedefof<_ EntityReader>.MakeGenericType(elemTy) + let listTy = typedefof<_ ResizeArray>.MakeGenericType(elemReaderTy) + let mutable entList = null + let mutable refReader = null + override __.DefineConstructor() = + let name = defaultArg (column |> Option.map (fun c -> c.Name)) "self" + entList <- builder.DefineField("_m_l_" + name, listTy, FieldAttributes.Private) + refReader <- builder.DefineField("_m_r_" + name, elemReaderTy, FieldAttributes.Private) + cil { + yield ldarg 0 + yield newobj0 (listTy.GetConstructor(Type.EmptyTypes)) + yield stfld entList + } + override __.DefineProcessColumns() = + cil { + let! skip = deflabel + yield ldarg 1 // col map + match column with + | Some column -> + yield ldstr column.Name + yield call2 ColumnMap.SubMapMethod + | None -> () + let! sub = tmplocal typeof + yield dup + yield stloc sub // col map + yield brfalse's skip + yield cil { + yield dup // this + yield call0 (staticTemplate.GetMethod("Template")) // this, template + yield callvirt1 (entTemplate.GetMethod("CreateReader")) // this, reader + yield dup // this, reader, reader + yield ldloc sub // this, reader, reader, submap + yield callvirt2'void Generation.processColumnsMethod // this, reader + yield stfld refReader // _ + } + yield mark skip + } + override __.DefineImpartKnowledgeToNext() = + cil { + yield ldarg 1 // that + yield ldarg 0 // that, this + yield ldfld refReader // that, oldReader + yield call0 (staticTemplate.GetMethod("Template")) // that, oldReader, template + yield callvirt1 (entTemplate.GetMethod("CreateReader")) // that, oldReader, newReader + let! newReader = deflocal elemReaderTy + yield dup + yield stloc newReader + // that, oldReader, newReader + yield callvirt2'void (elemReaderTy.GetMethod("ImpartKnowledgeToNext")) + // that + yield ldloc newReader + yield stfld refReader + } + override __.DefineRead(_) = + cil { + let! entReader = deflocal elemReaderTy + yield dup + yield ldfld refReader + yield ldarg 0 + yield ldfld entList // refReader, list + yield call0 (staticTemplate.GetMethod("Template")) + yield callvirt1 (entTemplate.GetMethod("CreateReader")) + yield dup // refReader, list, entReader, entReader + yield stloc entReader // refReader, list, entReader + yield call2'void (listTy.GetMethod("Add", [| elemReaderTy |])) + // refReader + yield ldloc entReader // refReader, entReader + yield callvirt2'void (elemReaderTy.GetMethod("ImpartKnowledgeToNext")) + // () + yield ldloc entReader // entReader + yield ldarg 1 // row + yield callvirt2'void Generation.readMethod // entReader.Read(row) + } + override __.RequiresSelfReferenceToPush = false + override __.DefinePush(self) = + cil { + let! ncase = deflabel + yield ldarg 0 + yield ldfld entList + yield generalize conversion + } diff --git a/Rezoom.SQL.Mapping/ManyEntityColumnGenerator.fs b/Rezoom.SQL.Mapping/ManyEntityColumnGenerator.fs new file mode 100644 index 0000000..319b248 --- /dev/null +++ b/Rezoom.SQL.Mapping/ManyEntityColumnGenerator.fs @@ -0,0 +1,321 @@ +namespace Rezoom.SQL.Mapping.CodeGeneration +open Rezoom.SQL.Mapping +open LicenseToCIL +open LicenseToCIL.Stack +open LicenseToCIL.Ops +open System +open System.Collections.Generic +open System.Reflection +open System.Reflection.Emit + +type ManyColumnGeneratorCode<'a> = + // may be called in code generated by ManyColumnGenerator + static member SetReverse(collection : 'a EntityReader ICollection, columnId : ColumnId, parent : obj) = + for reader in collection do + reader.SetReverse(columnId, parent) + +[] +[] +type FastTuple<'a, 'b>(item1 : 'a, item2 : 'b) = + struct + static let equalityA = EqualityComparer<'a>.Default + static let equalityB = EqualityComparer<'b>.Default + static let comparerA = Comparer<'a>.Default + static let comparerB = Comparer<'b>.Default + member __.Item1 = item1 + member __.Item2 = item2 + member this.Equals(other : FastTuple<'a, 'b>) = + equalityA.Equals(item1, other.Item1) + && equalityB.Equals(item2, other.Item2) + member this.CompareTo(other : FastTuple<'a, 'b>) = + let a = comparerA.Compare(item1, other.Item1) + if a <> 0 then a + else comparerB.Compare(item2, other.Item2) + override this.Equals(other : obj) = + match other with + | :? FastTuple<'a, 'b> as other -> this.Equals(other) + | _ -> false + override this.GetHashCode() = + let h1 = equalityA.GetHashCode(item1) + ((h1 <<< 5) + h1) ^^^ equalityB.GetHashCode(item2) + interface IEquatable> with + member this.Equals(other) = this.Equals(other) + interface IComparable> with + member this.CompareTo(other) = this.CompareTo(other) + end + +type private KeyColumns = + { + Type : Type + ColumnInfoFields : TypeBuilder -> string -> obj + ProcessColumns : Local -> obj -> Op // this, this -> this + ImpartToNext : obj -> Op // this, that -> this + Read : Local -> Label -> obj -> Op + } + static member private GetPrimitiveConverter(column : Column) = + match column.Blueprint.Value.Cardinality with + | Many _ -> failwith "Collection types are not supported as keys" + | One { Shape = Primitive prim } -> prim.Converter + | One { Shape = Composite _ } -> + failwith <| + "Composite types are not supported as keys." + + " Consider using KeyAttribute on multiple primitive columns instead." + static member TupleTypeDef(length : int) = + match length with + | 2 -> typedefof> + | 3 -> typedefof<_ * _ * _> + | 4 -> typedefof<_ * _ * _ * _> + | 5 -> typedefof<_ * _ * _ * _ * _> + | 6 -> typedefof<_ * _ * _ * _ * _ * _> + | 7 -> typedefof<_ * _ * _ * _ * _ * _ * _> + | 8 -> typedefof<_ * _ * _ * _ * _ * _ * _ * _> + | 9 -> typedefof<_ * _ * _ * _ * _ * _ * _ * _ * _> + | length -> failwithf "Unsupported length: can't use %d columns as identity" length + static member Of(column : Column) = + { Type = column.Blueprint.Value.Output + ColumnInfoFields = fun builder name -> + builder.DefineField("_m_key_" + name, typeof, FieldAttributes.Private) |> box + ProcessColumns = fun subMap infoFields -> + let infoField = infoFields |> Unchecked.unbox : FieldInfo + cil { + yield ldloc subMap // this, col map + yield ldstr column.Name + yield call2 ColumnMap.ColumnMethod + yield stfld infoField + } + ImpartToNext = fun infoFields -> + let infoField = infoFields |> Unchecked.unbox : FieldInfo + cil { + yield ldarg 0 + yield ldfld infoField + yield stfld infoField + } + Read = fun keyLocal skip infoFields -> + let converter = KeyColumns.GetPrimitiveConverter(column) + let infoField = infoFields |> Unchecked.unbox : FieldInfo + cil { + yield ldarg 1 // row + yield ldarg 0 // row, this + yield ldfld infoField // row, colinfo + yield ldfld (typeof.GetField("Index")) // row, index + yield callvirt2 (typeof.GetMethod("IsNull")) // isnull + yield brtrue skip + yield ldarg 1 // row + yield ldarg 0 // row, this + yield ldfld infoField // row, colinfo + yield generalize2 converter // id + yield stloc keyLocal + } + } + static member Of(columns : Column IReadOnlyList) = + if columns.Count < 1 then failwith "Collections of types without identity are not supported" + if columns.Count = 1 then KeyColumns.Of(columns.[0]) else + let types = [| for column in columns -> column.Output |] + let tupleType = KeyColumns.TupleTypeDef(columns.Count).MakeGenericType(types) + let ctor = tupleType.GetConstructor(types) + { Type = tupleType + ColumnInfoFields = fun builder name -> + [| for column in columns -> + builder.DefineField + ("_m_key_" + name + "_" + column.Name, typeof, FieldAttributes.Private) + |] |> box + ProcessColumns = fun subMap infoFields -> + let infoFields = infoFields |> Unchecked.unbox : FieldInfo array + cil { + for column, infoField in Seq.zip columns infoFields do + yield dup + yield ldloc subMap // this, col map + yield ldstr column.Name + yield call2 ColumnMap.ColumnMethod + yield stfld infoField + yield pop + } + ImpartToNext = fun infoFields -> + let infoFields = infoFields |> Unchecked.unbox : FieldInfo array + cil { + for infoField in infoFields do + yield dup + yield ldarg 0 + yield ldfld infoField + yield stfld infoField + yield pop + } + Read = fun keyLocal skip infoFields -> + let infoFields = infoFields |> Unchecked.unbox : FieldInfo array + cil { + let locals = new ResizeArray<_>() + for column, infoField in Seq.zip columns infoFields do + let! local = deflocal column.Output + locals.Add(local) + let converter = KeyColumns.GetPrimitiveConverter(column) + yield ldarg 1 // row + yield ldarg 0 // row, this + yield ldfld infoField // row, colinfo + yield ldfld (typeof.GetField("Index")) // row, index + yield callvirt2 (typeof.GetMethod("IsNull")) // isnull + yield brtrue skip + yield ldarg 1 // row + yield ldarg 0 // row, this + yield ldfld infoField // row, colinfo + yield generalize2 converter // id + yield stloc local + for local in locals do + yield ldloc local + yield pretend + yield newobj'x ctor + yield stloc keyLocal + } + } + +type private ManyEntityColumnGenerator + ( builder + , column : Column option + , element : ElementBlueprint + , conversion : ConversionMethod + ) = + inherit EntityReaderColumnGenerator(builder) + let composite = + match element.Shape with + | Composite c -> c + | Primitive _ -> failwith "Collections of primitives are not supported" + let keyColumns = KeyColumns.Of(composite.Identity) + let elemTy = element.Output + let staticTemplate = Generation.readerTemplateGeneric.MakeGenericType(elemTy) + let entTemplate = typedefof<_ EntityReaderTemplate>.MakeGenericType(elemTy) + let elemReaderTy = typedefof<_ EntityReader>.MakeGenericType(elemTy) + let dictTy = typedefof>.MakeGenericType(keyColumns.Type, elemReaderTy) + let requiresSelf = composite.ReferencesQueryParent + let mutable entDict = null + let mutable refReader = null + let mutable keyInfo = null + override __.DefineConstructor() = + let name = defaultArg (column |> Option.map (fun c -> c.Name)) "self" + keyInfo <- keyColumns.ColumnInfoFields builder name + entDict <- builder.DefineField("_m_d_" + name, dictTy, FieldAttributes.Private) + refReader <- builder.DefineField("_m_r_" + name, elemReaderTy, FieldAttributes.Private) + cil { + yield ldarg 0 + yield newobj0 (dictTy.GetConstructor(Type.EmptyTypes)) + yield stfld entDict + } + override __.DefineProcessColumns() = + cil { + let! skip = deflabel + yield ldarg 1 // col map + match column with + | Some column -> + yield ldstr column.Name + yield call2 ColumnMap.SubMapMethod + | None -> () + let! sub = tmplocal typeof + yield dup + yield stloc sub // col map + yield brfalse's skip + yield dup // this + yield keyColumns.ProcessColumns sub keyInfo + yield cil { + yield dup // this + yield call0 (staticTemplate.GetMethod("Template")) // this, template + yield callvirt1 (entTemplate.GetMethod("CreateReader")) // this, reader + yield dup // this, reader, reader + yield ldloc sub // this, reader, reader, submap + yield callvirt2'void Generation.processColumnsMethod // this, reader + yield stfld refReader // _ + } + yield mark skip + } + override __.DefineImpartKnowledgeToNext() = + cil { + yield ldarg 1 + yield castclass builder + yield keyColumns.ImpartToNext keyInfo + + let! nread = deflabel + let! exit = deflabel + yield dup + yield ldfld refReader + yield brfalse's nread + yield cil { + yield ldarg 1 // that + yield ldarg 0 // that, this + yield ldfld refReader // that, oldReader + yield call0 (staticTemplate.GetMethod("Template")) // that, oldReader, template + yield callvirt1 (entTemplate.GetMethod("CreateReader")) // that, oldReader, newReader + let! newReader = deflocal elemReaderTy + yield dup + yield stloc newReader + // that, oldReader, newReader + yield callvirt2'void (elemReaderTy.GetMethod("ImpartKnowledgeToNext")) + // that + yield ldloc newReader + yield stfld refReader + yield br's exit + } + yield mark nread + yield cil { + yield ldarg 1 + yield ldnull + yield stfld refReader + } + yield mark exit + } + override __.DefineRead(_) = + cil { + let! skip = deflabel + yield dup + yield ldfld refReader + yield brfalse skip + yield cil { + let! keyLocal = tmplocal keyColumns.Type + yield keyColumns.Read keyLocal skip keyInfo + + let! entReader = tmplocal elemReaderTy + yield dup + yield ldfld entDict + yield ldloc keyLocal + yield ldloca entReader + yield call3 (dictTy.GetMethod("TryGetValue")) + let! readRow = deflabel + yield brtrue's readRow + + yield dup + yield ldfld entDict + yield ldloc keyLocal + yield call0 (staticTemplate.GetMethod("Template")) + yield callvirt1 (entTemplate.GetMethod("CreateReader")) + yield dup + yield stloc entReader + yield call3'void (dictTy.GetMethod("Add", [| keyColumns.Type; elemReaderTy |])) + yield dup + yield ldfld refReader + yield ldloc entReader + yield callvirt2'void (elemReaderTy.GetMethod("ImpartKnowledgeToNext")) + + yield mark readRow + yield ldloc entReader + yield ldarg 1 // row + yield callvirt2'void Generation.readMethod + } + yield mark skip + } + override __.RequiresSelfReferenceToPush = requiresSelf + override __.DefinePush(self) = + cil { + let! ncase = deflabel + yield ldarg 0 + yield ldfld entDict + yield call1 (dictTy.GetProperty("Values").GetGetMethod()) + match column with + | None -> () + | Some col -> + match col.ReverseRelationship.Value with + | None -> () + | Some rev -> + let setReverse = + typedefof<_ ManyColumnGeneratorCode>.MakeGenericType(elemTy).GetMethod("SetReverse") + yield dup + yield ldc'i4 rev.ColumnId + yield ldloc self + yield call3'void setReverse + yield generalize conversion + } diff --git a/Rezoom.SQL.Mapping/Migrations.fs b/Rezoom.SQL.Mapping/Migrations.fs new file mode 100644 index 0000000..b8c7ae5 --- /dev/null +++ b/Rezoom.SQL.Mapping/Migrations.fs @@ -0,0 +1,193 @@ +module Rezoom.SQL.Mapping.Migrations +open System +open System.Collections.Generic +open FSharp.Quotations + +type MigrationFileName = + { MajorVersion : int + ParentName : string option + Name : string + } + override this.ToString() = + match this.ParentName with + | None -> sprintf "V%d.%s" this.MajorVersion this.Name + | Some parent -> sprintf "V%d.%s-%s" this.MajorVersion parent this.Name + +type Migration<'src> = + { MajorVersion : int + Name : string + Source : 'src + } + member this.FileName = "V" + string this.MajorVersion + "." + this.Name + +let private quotationizeMigration (migration : string Migration) = + <@@ { MajorVersion = %%Expr.Value(migration.MajorVersion) + Name = %%Expr.Value(migration.Name) + Source = %%Expr.Value(migration.Source) + } : string Migration @@> + +type MigrationTree<'src> = + { Node : 'src Migration + Children : 'src MigrationTree IReadOnlyList + } + member this.Map(f) = + { Node = + { MajorVersion = this.Node.MajorVersion + Name = this.Node.Name + Source = f this.Node.Source + } + Children = this.Children |> Seq.map (fun t -> t.Map(f)) |> ResizeArray + } + member this.Migrations() = + seq { + yield this.Node + for child in this.Children do + yield! child.Migrations() + } + +let rec quotationizeMigrationTree (tree : string MigrationTree) = + let children = + Expr.NewArray(typeof, + [ for child in tree.Children -> + quotationizeMigrationTree child + ]) + let children = Expr.Coerce(children, typeof) + <@@ { Node = %%quotationizeMigration tree.Node + Children = %%children + } : string MigrationTree @@> + +let foldMigrations + (folder : bool -> 'acc -> 's1 Migration -> 's2 * 'acc) + (acc : 'acc) + (migrationTrees : 's1 MigrationTree seq) = + let mutable acc = acc + let rec mapFold root tree = + let s2, acc2 = folder root acc tree.Node + acc <- acc2 + { Node = + { MajorVersion = tree.Node.MajorVersion + Name = tree.Node.Name + Source = s2 + } + Children = tree.Children |> Seq.map (mapFold false) |> ResizeArray + } + let trees = [ for tree in migrationTrees -> mapFold true tree ] + trees, acc + +type private MigrationTreeBuilderNode<'src> = + { mutable Source : 'src option + Name : string + Children : 'src MigrationTreeBuilderNode ResizeArray + } + +type private MigrationTreeBuilder<'src>(majorVersionNumber) = + let migrations = Dictionary() + let mutable root = None + member __.ToTree() = + match root with + | None -> + failwithf "No root migration for V%d" majorVersionNumber + | Some (root, rootName) -> + let rec toTree (node : 'src MigrationTreeBuilderNode) = + { Node = + { MajorVersion = majorVersionNumber + Name = node.Name + Source = + match node.Source with + | None -> + failwithf "No source for migration V%d.%s" + majorVersionNumber node.Name + | Some src -> src + } + Children = + node.Children |> Seq.map toTree |> ResizeArray + } + toTree root + member __.Add(name : MigrationFileName, source : 'src) = + let succ, self = migrations.TryGetValue(name.Name) + let self = + if succ then + if Option.isSome self.Source then + failwithf "Multiple sources given for migration %O" name + self.Source <- Some source + self + else + let newNode = + { Source = Some source + Name = name.Name + Children = ResizeArray() + } + migrations.[name.Name] <- newNode + newNode + match name.ParentName with + | None -> + match root with + | Some (node, rootName) -> + failwithf "Multiple root migrations given (%O, %O)" rootName name + | None -> + root <- Some (self, name) + | Some parentName -> + let succ, parent = migrations.TryGetValue(parentName) + if succ then + parent.Children.Add(self) + else + let parent = + { Source = None + Name = name.Name + Children = ResizeArray([|self|]) + } + migrations.[parentName] <- parent + +type MigrationTreeListBuilder<'src>() = + let majorVersions = Dictionary() + member __.Add(name : MigrationFileName, source : 'src) = + let succ, found = majorVersions.TryGetValue(name.MajorVersion) + let found = + if succ then found else + let builder = MigrationTreeBuilder(name.MajorVersion) + majorVersions.[name.MajorVersion] <- builder + builder + found.Add(name, source) + member __.ToTrees() = + majorVersions + |> Seq.sortBy (fun v -> v.Key) + |> Seq.map (fun v -> v.Value.ToTree()) + |> ResizeArray + +type IMigrationBackend = + abstract member Initialize : unit -> unit + abstract member GetMigrationsRun : unit -> (int * string) seq + abstract member RunMigration : string Migration -> unit + +type MigrationConfig = + { AllowMigrationsFromOlderMajorVersions : bool + LogMigrationRan : string Migration -> unit + } + +let runMigrations config (backend : IMigrationBackend) (migrationTrees : string MigrationTree seq) = + backend.Initialize() + let already = HashSet(backend.GetMigrationsRun()) + let currentMajorVersion = + already + |> Seq.map fst + |> Seq.sortByDescending id + |> Seq.tryHead + let currentMajorVersion = + match currentMajorVersion with + | Some version -> version + | None -> Int32.MinValue + for migrationTree in migrationTrees do + for migration in migrationTree.Migrations() do + let pair = migration.MajorVersion, migration.Name + if not <| already.Contains(pair) then + if migration.MajorVersion < currentMajorVersion + && not config.AllowMigrationsFromOlderMajorVersions then + failwith <| + sprintf "Can't run migration V%d.%s because database has a newer major version (V%d)" + migration.MajorVersion migration.Name + currentMajorVersion + else + backend.RunMigration(migration) + config.LogMigrationRan migration + ignore <| already.Add(pair) // actually we don't need this but ok + diff --git a/Rezoom.SQL.Mapping/PrimitiveColumnGenerator.fs b/Rezoom.SQL.Mapping/PrimitiveColumnGenerator.fs new file mode 100644 index 0000000..32545ac --- /dev/null +++ b/Rezoom.SQL.Mapping/PrimitiveColumnGenerator.fs @@ -0,0 +1,66 @@ +namespace Rezoom.SQL.Mapping.CodeGeneration +open Rezoom.SQL.Mapping +open LicenseToCIL +open LicenseToCIL.Stack +open LicenseToCIL.Ops +open System +open System.Collections.Generic +open System.Reflection +open System.Reflection.Emit + +type private PrimitiveColumnGenerator(builder, column, primitive : Primitive) = + inherit EntityReaderColumnGenerator(builder) + let output = column.Blueprint.Value.Output + let mutable colValue = null + let mutable colInfo = null + let mutable found = null + override __.DefineConstructor() = + found <- builder.DefineField("_p_f_" + column.Name, typeof, FieldAttributes.Private) + colInfo <- builder.DefineField("_p_i_" + column.Name, typeof, FieldAttributes.Private) + colValue <- builder.DefineField("_p_" + column.Name, output, FieldAttributes.Private) + zero + override __.DefineProcessColumns() = + cil { + yield dup + yield ldarg 1 // column map + yield ldstr column.Name + yield call2 ColumnMap.ColumnMethod + yield stfld colInfo + } + override __.DefineImpartKnowledgeToNext() = + cil { + yield ldarg 1 + yield castclass builder + yield ldarg 0 + yield ldfld colInfo + yield stfld colInfo + } + override __.DefineRead(skipOnes) = + cil { + let! skip = deflabel + yield dup + yield ldfld found + yield brtrue skipOnes + yield dup + yield ldfld colInfo + yield ldfld ColumnInfo.IndexField + yield ldc'i4 0 + yield blt's skip + yield cil { + yield ldarg 1 // row + yield ldarg 0 // this + yield ldfld colInfo // row, index + yield generalize2 primitive.Converter + yield stfld colValue + yield ldarg 0 + yield dup + yield ldc'i4 1 + yield stfld found + } + yield mark skip + } + override __.DefinePush(_) = + cil { + yield ldarg 0 // this + yield ldfld colValue + } \ No newline at end of file diff --git a/Rezoom.SQL.Mapping/PrimitiveConverters.fs b/Rezoom.SQL.Mapping/PrimitiveConverters.fs new file mode 100644 index 0000000..f230b89 --- /dev/null +++ b/Rezoom.SQL.Mapping/PrimitiveConverters.fs @@ -0,0 +1,307 @@ +module Rezoom.SQL.Mapping.CodeGeneration.PrimitiveConverters +open Rezoom.SQL.Mapping +open LicenseToCIL +open LicenseToCIL.Ops +open System +open System.Collections.Generic +open System.Reflection +open System.Reflection.Emit + +let inline private toNumeric + (row : Row) + (col : ColumnInfo) + fromObj + fromString + fromByte + fromInt16 + fromInt32 + fromInt64 + fromSByte + fromUInt16 + fromUInt32 + fromUInt64 + fromSingle + fromDouble + fromDecimal = + match col.Type with + | ColumnType.Object -> row.GetObject(col.Index) |> fromObj + | ColumnType.String -> row.GetString(col.Index) |> fromString + | ColumnType.Byte -> row.GetByte(col.Index) |> fromByte + | ColumnType.Int16 -> row.GetInt16(col.Index) |> fromInt16 + | ColumnType.Int32 -> row.GetInt32(col.Index) |> fromInt32 + | ColumnType.Int64 -> row.GetInt64(col.Index) |> fromInt64 + | ColumnType.SByte -> row.GetSByte(col.Index) |> fromSByte + | ColumnType.UInt16 -> row.GetUInt16(col.Index) |> fromUInt16 + | ColumnType.UInt32 -> row.GetUInt32(col.Index) |> fromUInt32 + | ColumnType.UInt64 -> row.GetUInt64(col.Index) |> fromUInt64 + | ColumnType.Single -> row.GetSingle(col.Index) |> fromSingle + | ColumnType.Double -> row.GetDouble(col.Index) |> fromDouble + | ColumnType.Decimal -> row.GetDecimal(col.Index) |> fromDecimal + | x -> failwithf "Invalid column type %A for numeric" x + +type Converters = + static member ToObject(row : Row, col : ColumnInfo) = row.GetObject(col.Index) + static member ToString(row : Row, col : ColumnInfo) = + match col.Type with + | ColumnType.String -> row.GetString(col.Index) + | _ -> + match row.GetObject(col.Index) with + | null -> null + | o -> Convert.ToString(o) + static member ToByteArray(row : Row, col : ColumnInfo) = + row.GetObject(col.Index) + |> Unchecked.unbox : byte array + static member ToByte(row : Row, col : ColumnInfo) : byte = + toNumeric row col + Convert.ToByte + byte byte byte byte + byte byte byte byte + byte byte byte byte + static member ToInt16(row : Row, col : ColumnInfo) : int16 = + toNumeric row col + Convert.ToInt16 + int16 int16 int16 int16 + int16 int16 int16 int16 + int16 int16 int16 int16 + static member ToInt32(row : Row, col : ColumnInfo) : int32 = + toNumeric row col + Convert.ToInt32 + int32 int32 int32 int32 + int32 int32 int32 int32 + int32 int32 int32 int32 + static member ToInt64(row : Row, col : ColumnInfo) : int64 = + toNumeric row col + Convert.ToInt64 + int64 int64 int64 int64 + int64 int64 int64 int64 + int64 int64 int64 int64 + static member ToSByte(row : Row, col : ColumnInfo) : sbyte = + toNumeric row col + Convert.ToSByte + sbyte sbyte sbyte sbyte + sbyte sbyte sbyte sbyte + sbyte sbyte sbyte sbyte + static member ToUInt16(row : Row, col : ColumnInfo) : uint16 = + toNumeric row col + Convert.ToUInt16 + uint16 uint16 uint16 uint16 + uint16 uint16 uint16 uint16 + uint16 uint16 uint16 uint16 + static member ToUInt32(row : Row, col : ColumnInfo) : uint32 = + toNumeric row col + Convert.ToUInt32 + uint32 uint32 uint32 uint32 + uint32 uint32 uint32 uint32 + uint32 uint32 uint32 uint32 + static member ToUInt64(row : Row, col : ColumnInfo) : uint64 = + toNumeric row col + Convert.ToUInt64 + uint64 uint64 uint64 uint64 + uint64 uint64 uint64 uint64 + uint64 uint64 uint64 uint64 + static member ToSingle(row : Row, col : ColumnInfo) : single = + toNumeric row col + Convert.ToSingle + single single single single + single single single single + single single single single + static member ToDouble(row : Row, col : ColumnInfo) : double = + toNumeric row col + Convert.ToDouble + double double double double + double double double double + double double double double + static member ToDecimal(row : Row, col : ColumnInfo) : decimal = + toNumeric row col + Convert.ToDecimal + decimal decimal decimal decimal + decimal decimal decimal decimal + decimal decimal decimal decimal + static member ToDateTime(row : Row, col : ColumnInfo) : DateTime = + match col.Type with + | ColumnType.DateTime -> row.GetDateTime(col.Index) + | ColumnType.Object -> Convert.ToDateTime(row.GetObject(col.Index)) + | x -> failwithf "Invalid column type %A for DateTime" x + +let private convertersByType = + let methods = typeof.GetMethods() + methods + |> Seq.filter + (fun m -> + let parTypes = m.GetParameters() |> Array.map (fun p -> p.ParameterType) + parTypes = [|typeof; typeof|]) + |> Seq.map + (fun m -> m.ReturnType, m) + |> dict + +let private columnIndexField = typeof.GetField("Index") +let private columnTypeField = typeof.GetField("Type") +let private rowIsNullMethod = typeof.GetMethod("IsNull") +let private rowGetStringMethod = typeof.GetMethod("GetString") +let private stringTrimMethod = typeof.GetMethod("Trim", Type.EmptyTypes) + +let private storeInstructions= + [ + typeof, stind'i1 + typeof, stind'i1 + typeof, stind'i2 + typeof, stind'i2 + typeof, stind'i4 + typeof, stind'i4 + typeof, stind'i8 + typeof, stind'i8 + ] |> dict + +let private enumTryParser (delTy) (enumTy : Type) = + let underlying = enumTy.GetEnumUnderlyingType() + let loadValue = + if obj.ReferenceEquals(underlying, typeof) then fun o -> ldc'i8 (Unchecked.unbox o) + elif obj.ReferenceEquals(underlying, typeof) then fun o -> ldc'i8 (int64 (Unchecked.unbox o : uint64)) + elif obj.ReferenceEquals(underlying, typeof) then fun o -> ldc'i4 (int (Unchecked.unbox o : uint32)) + else fun (o : obj) -> ldc'i4 (Convert.ToInt32(o)) + let storeValue = storeInstructions.[underlying] + let names = Enum.GetNames(enumTy) + let values = Enum.GetValues(enumTy) + let pairs = + seq { + for i = 0 to names.Length - 1 do + yield names.[i], values.GetValue(i) + } + let dynamicMethod = + DynamicMethod + ( "TryParse" + enumTy.Name + , typeof + , [| typeof; enumTy.MakeByRefType() |] + , typeof + ) + (cil { + yield ldarg 0 + yield call1 stringTrimMethod + yield StringSwitch.insensitive + [| for name, value in pairs -> + name, + cil { + yield ldarg 1 + yield loadValue value + yield storeValue + yield ldc'i4 1 + yield ret + } + |] zero + yield ldc'i4 0 + yield ret + }) null (IL(dynamicMethod.GetILGenerator())) |> ignore + dynamicMethod.CreateDelegate(delTy) + +type EnumTryParserDelegate<'enum> = delegate of string * 'enum byref -> bool + +type EnumTryParser<'enum>() = + static let parser = + enumTryParser typeof> typeof<'enum> + |> Unchecked.unbox : EnumTryParserDelegate<'enum> + static member TryParse(str : string, enum : 'enum byref) = + parser.Invoke(str, &enum) + +let rec converter (ty : Type) : RowConversionMethod option = + let succ, meth = convertersByType.TryGetValue(ty) + if succ then + Some (Ops.call2 meth) + elif ty.IsEnum then + match converter (ty.GetEnumUnderlyingType()) with + | None -> None + | Some converter -> + cil { + let! colInfo = tmplocal typeof + let! parsed = tmplocal ty + let! skipParse = deflabel + let! exit = deflabel + yield dup // row, col, col + yield stloc colInfo // row, col + yield ldfld columnTypeField // row, type + yield ldc'i4 (int ColumnType.String) // row, type, string + yield bne'un's skipParse // row + yield dup // row, row + yield ldloc colInfo // row, row, col + yield ldfld columnIndexField // row, row, i + yield callvirt2 rowGetStringMethod // row, string + yield ldloca parsed // row, string, &parsed + yield call2 <| typedefof<_ EnumTryParser>.MakeGenericType(ty).GetMethod("TryParse") // row, succ + yield brfalse's skipParse // row + yield pop + yield ldloc parsed + yield br's exit + yield mark skipParse + yield ldloc colInfo + yield converter + yield mark exit + } |> Some + else genericConverter ty + +and genericConverter (ty : Type) : RowConversionMethod option = + if ty.IsConstructedGenericType then + let def = ty.GetGenericTypeDefinition() + if def = typedefof<_ Nullable> then + match ty.GetGenericArguments() with + | [| nTy |] -> + match converter nTy with + | None -> None + | Some innerConverter -> + cil { + let! colInfo = tmplocal typeof + let! ncase = deflabel + let! exit = deflabel + yield stloc colInfo // row + yield dup // row, row + yield ldloc colInfo // row, row, col + yield ldfld columnIndexField // row, row, index + yield Ops.callvirt2 rowIsNullMethod // row, isnull + yield brtrue's ncase + yield cil { + yield ldloc colInfo + yield innerConverter + yield newobj1 (ty.GetConstructor([| nTy |])) + yield br's exit + } + yield mark ncase + yield cil { + yield pop + let! empty = tmplocal ty + yield ldloca empty + yield initobj ty + yield ldloc empty + } + yield mark exit + } |> Some + | _ -> failwith "Cannot function in world where Nullable doesn't have one type argument." + elif def = typedefof<_ option> then + match ty.GetGenericArguments() with + | [| nTy |] -> + match converter nTy with + | None -> None + | Some innerConverter -> + cil { + let! colInfo = tmplocal typeof + let! ncase = deflabel + let! exit = deflabel + yield stloc colInfo // row + yield dup // row, row + yield ldloc colInfo // row, row, col + yield ldfld columnIndexField // row, row, index + yield Ops.callvirt2 rowIsNullMethod // row, isnull + yield brtrue's ncase + yield cil { + yield ldloc colInfo + yield innerConverter + yield newobj1 (ty.GetConstructor([| nTy |])) + yield br's exit + } + yield mark ncase + yield cil { + yield pop + yield ldnull // None + } + yield mark exit + } |> Some + | _ -> failwith "Cannot function in world where FSharpOption doesn't have one type argument." + else None + else None \ No newline at end of file diff --git a/Rezoom.SQL.Mapping/Rezoom.SQL.Mapping.fsproj b/Rezoom.SQL.Mapping/Rezoom.SQL.Mapping.fsproj new file mode 100644 index 0000000..8bebb43 --- /dev/null +++ b/Rezoom.SQL.Mapping/Rezoom.SQL.Mapping.fsproj @@ -0,0 +1,104 @@ + + + + + Debug + AnyCPU + 2.0 + 6b6a06c5-157a-4fe3-8b4c-2a1ae6a15333 + Library + Rezoom.SQL.Mapping + Rezoom.SQL.Mapping + v4.6 + 4.4.0.0 + true + Rezoom.SQL.Mapping + + + true + full + false + false + bin\Debug\ + DEBUG;TRACE + 3 + bin\Debug\Rezoom.SQL.Mapping.XML + + + pdbonly + true + true + bin\Release\ + TRACE + 3 + bin\Release\Rezoom.SQL.Mapping.XML + + + 11 + + + + + $(MSBuildExtensionsPath32)\..\Microsoft SDKs\F#\3.0\Framework\v4.0\Microsoft.FSharp.Targets + + + + + $(MSBuildExtensionsPath32)\Microsoft\VisualStudio\v$(VisualStudioVersion)\FSharp\Microsoft.FSharp.Targets + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + ..\packages\LicenseToCIL.0.2.2\lib\net46\LicenseToCIL.dll + True + + + + True + + + + + + + + Rezoom + {d98acbeb-a039-4340-a7c5-6ed2b677268b} + True + + + + \ No newline at end of file diff --git a/Rezoom.SQL.Mapping/Row.fs b/Rezoom.SQL.Mapping/Row.fs new file mode 100644 index 0000000..0e67dcd --- /dev/null +++ b/Rezoom.SQL.Mapping/Row.fs @@ -0,0 +1,40 @@ +namespace Rezoom.SQL.Mapping +open System +open System.Data + +[] +type Row() = + abstract member IsNull : int16 -> bool + abstract member GetObject : int16 -> obj + abstract member GetString : int16 -> string + abstract member GetByte : int16 -> byte + abstract member GetInt16 : int16 -> int16 + abstract member GetInt32 : int16 -> int32 + abstract member GetInt64 : int16 -> int64 + abstract member GetSByte : int16 -> sbyte + abstract member GetUInt16 : int16 -> uint16 + abstract member GetUInt32 : int16 -> uint32 + abstract member GetUInt64 : int16 -> uint64 + abstract member GetSingle : int16 -> single + abstract member GetDouble : int16 -> double + abstract member GetDecimal : int16 -> decimal + abstract member GetDateTime : int16 -> DateTime + +type ObjectRow([] row : obj array) = + inherit Row() + override __.IsNull(i) = isNull (row.[int i]) + override __.GetObject(i) = row.[int i] + override __.GetString(i) = row.[int i] |> Unchecked.unbox + override __.GetByte(i) = row.[int i] |> Unchecked.unbox + override __.GetInt16(i) = row.[int i] |> Unchecked.unbox + override __.GetInt32(i) = row.[int i] |> Unchecked.unbox + override __.GetInt64(i) = row.[int i] |> Unchecked.unbox + override __.GetSByte(i) = row.[int i] |> Unchecked.unbox + override __.GetUInt16(i) = row.[int i] |> Unchecked.unbox + override __.GetUInt32(i) = row.[int i] |> Unchecked.unbox + override __.GetUInt64(i) = row.[int i] |> Unchecked.unbox + override __.GetSingle(i) = row.[int i] |> Unchecked.unbox + override __.GetDouble(i) = row.[int i] |> Unchecked.unbox + override __.GetDecimal(i) = row.[int i] |> Unchecked.unbox + override __.GetDateTime(i) = row.[int i] |> Unchecked.unbox + diff --git a/Rezoom.SQL.Mapping/StaticEntityReaderTemplate.fs b/Rezoom.SQL.Mapping/StaticEntityReaderTemplate.fs new file mode 100644 index 0000000..8ba14c2 --- /dev/null +++ b/Rezoom.SQL.Mapping/StaticEntityReaderTemplate.fs @@ -0,0 +1,301 @@ +namespace Rezoom.SQL.Mapping.CodeGeneration +open Rezoom.SQL.Mapping +open LicenseToCIL +open LicenseToCIL.Stack +open LicenseToCIL.Ops +open System +open System.Collections.Generic +open System.Reflection +open System.Reflection.Emit + +type private EntityReaderBuilder = + { + Ctor : E S * IL + ProcessColumns : E S * IL + ImpartKnowledge : E S * IL + Read : E S * IL + SetReverse : E S * IL + ToEntity : E S * IL + } + +type private StaticEntityReaderTemplate = + static member ColumnGenerator(builder, column) = + match column.Blueprint.Value.Cardinality with + | One { Shape = Primitive p } -> + PrimitiveColumnGenerator(builder, column, p) :> EntityReaderColumnGenerator + | One { Shape = Composite c } -> + CompositeColumnGenerator(builder, column, c) :> EntityReaderColumnGenerator + | Many (element, conversion) -> + match element.Shape with + | Composite c when c.Identity.Count > 0 -> + ManyEntityColumnGenerator(builder, Some column, element, conversion) :> EntityReaderColumnGenerator + | _ -> + ManyColumnGenerator(builder, Some column, element, conversion) :> EntityReaderColumnGenerator + + static member ImplementPrimitive(builder : TypeBuilder, ty : Type, primitive : Primitive, readerBuilder) = + let info = builder.DefineField("_i", typeof, FieldAttributes.Private) + let value = builder.DefineField("_v", ty, FieldAttributes.Private) + readerBuilder.Ctor ||> ret'void |> ignore + readerBuilder.ProcessColumns ||> + cil { + yield ldarg 0 + yield ldarg 1 + yield call1 ColumnMap.PrimaryColumnMethod + yield stfld info + yield ret'void + } |> ignore + readerBuilder.ImpartKnowledge ||> + cil { + yield ldarg 1 + yield castclass builder + yield ldarg 0 + yield ldfld info + yield stfld info + yield ret'void + } |> ignore + readerBuilder.Read ||> + cil { + yield ldarg 0 + yield ldarg 1 + yield ldarg 0 + yield ldfld info + yield generalize2 primitive.Converter + yield stfld value + yield ret'void + } |> ignore + readerBuilder.SetReverse ||> ret'void |> ignore + readerBuilder.ToEntity ||> + cil { + yield ldarg 0 + yield ldfld value + yield ret + } |> ignore + + static member ImplementMany(builder : TypeBuilder, element : ElementBlueprint, conversion, readerBuilder) = + let generator = + match element.Shape with + | Composite c when c.Identity.Count > 0 -> + ManyEntityColumnGenerator(builder, None, element, conversion) :> EntityReaderColumnGenerator + | _ -> + ManyColumnGenerator(builder, None, element, conversion) :> EntityReaderColumnGenerator + readerBuilder.Ctor ||> + cil { + yield ldarg 0 + yield generator.DefineConstructor() + yield pop + yield ret'void + } |> ignore + readerBuilder.ProcessColumns ||> + cil { + yield ldarg 0 + yield generator.DefineProcessColumns() + yield pop + yield ret'void + } |> ignore + readerBuilder.ImpartKnowledge ||> + cil { + yield ldarg 0 + yield generator.DefineImpartKnowledgeToNext() + yield pop + yield ret'void + } |> ignore + readerBuilder.Read ||> + cil { + let! lbl = deflabel + yield ldarg 0 + yield generator.DefineRead(lbl) + yield mark lbl + yield pop + yield ret'void + } |> ignore + readerBuilder.SetReverse ||> + cil { + yield ldarg 0 + yield generator.DefineSetReverse() + yield pop + yield ret'void + } |> ignore + readerBuilder.ToEntity ||> + cil { + let! self = deflocal builder + yield generator.DefinePush(self) + yield ret + } |> ignore + + static member ImplementComposite(builder, composite : Composite, readerBuilder) = + let columns = + [| for column in composite.Columns.Values -> + column, StaticEntityReaderTemplate.ColumnGenerator(builder, column) + |] + readerBuilder.Ctor ||> + cil { + yield ldarg 0 + for _, column in columns do + yield column.DefineConstructor() + yield pop + yield ret'void + } |> ignore + readerBuilder.ProcessColumns ||> + cil { + yield ldarg 0 + for _, column in columns do + yield column.DefineProcessColumns() + yield pop + yield ret'void + } |> ignore + readerBuilder.ImpartKnowledge ||> + cil { + yield ldarg 0 + for _, column in columns do + yield column.DefineImpartKnowledgeToNext() + yield pop + yield ret'void + } |> ignore + readerBuilder.Read ||> + cil { + let! skipOnes = deflabel + let! skipAll = deflabel + yield ldarg 0 + let ones, others = columns |> Array.partition (fun (b, _) -> b.Blueprint.Value.IsOne()) + for _, column in ones do + yield column.DefineRead(skipOnes) + yield mark skipOnes + for _, column in others do + yield column.DefineRead(skipAll) + yield mark skipAll + yield pop + yield ret'void + } |> ignore + readerBuilder.SetReverse ||> + cil { + yield ldarg 0 + for _, column in columns do + yield column.DefineSetReverse() + yield pop + yield ret'void + } |> ignore + let constructorColumns = + seq { + for blue, column in columns do + match blue.Setter with + | SetConstructorParameter paramInfo -> + yield paramInfo.Position, column + | _ -> () + } |> Seq.sortBy fst |> Seq.map snd |> Seq.toArray + readerBuilder.ToEntity ||> + cil { + let! self = deflocal builder + if constructorColumns |> Array.exists (fun c -> c.RequiresSelfReferenceToPush) then + let uninit = + typeof.GetMethod("GetUninitializedObject") + yield ldtoken composite.Output + yield call1 (typeof.GetMethod("GetTypeFromHandle")) + yield call1 uninit + yield castclass composite.Output + yield dup + yield stloc self + yield dup + for column in constructorColumns do + yield column.DefinePush(self) + yield pretend + yield (fun st il -> + il.Generator.Emit(OpCodes.Call, composite.Constructor) + null) + else + for column in constructorColumns do + yield column.DefinePush(self) + yield pretend + yield newobj'x composite.Constructor + if composite.ReferencesQueryParent then + yield dup + yield stloc self + for blue, column in columns do + match blue.Setter with + | SetField field -> + yield dup + yield column.DefinePush(self) + yield stfld field + | SetProperty prop -> + yield dup + yield column.DefinePush(self) + let meth = prop.GetSetMethod() + yield (if meth.IsVirtual then callvirt2'void else call2'void) meth + | _ -> () + yield ret + } |> ignore + static member ImplementReader(blueprint : Blueprint, builder : TypeBuilder) = + let readerTy = typedefof<_ EntityReader>.MakeGenericType(blueprint.Output) + let methodAttrs = MethodAttributes.Public ||| MethodAttributes.Virtual + let readerBuilder = + { + Ctor = + Stack.empty, IL(builder + .DefineConstructor(MethodAttributes.Public, CallingConventions.HasThis, Type.EmptyTypes) + .GetILGenerator()) + ImpartKnowledge = + Stack.empty, IL(builder + .DefineMethod("ImpartKnowledgeToNext", methodAttrs, typeof, [| readerTy |]) + .GetILGenerator()) + ProcessColumns = + Stack.empty, IL(builder + .DefineMethod("ProcessColumns", methodAttrs, typeof, [| typeof |]) + .GetILGenerator()) + Read = Stack.empty, IL(builder + .DefineMethod("Read", methodAttrs, typeof, [| typeof |]).GetILGenerator()) + SetReverse = + Stack.empty, IL(builder + .DefineMethod("SetReverse", methodAttrs, typeof, [| typeof; typeof |]) + .GetILGenerator()) + ToEntity = Stack.empty, IL(builder + .DefineMethod("ToEntity", methodAttrs, blueprint.Output, Type.EmptyTypes).GetILGenerator()) + } + match blueprint.Cardinality with + | One { Shape = Primitive primitive } -> + StaticEntityReaderTemplate.ImplementPrimitive(builder, blueprint.Output, primitive, readerBuilder) + | One { Shape = Composite composite } -> + StaticEntityReaderTemplate.ImplementComposite(builder, composite, readerBuilder) + | Many (element, conversion) -> + StaticEntityReaderTemplate.ImplementMany(builder, element, conversion, readerBuilder) + builder.CreateType() + +type ReaderTemplate<'ent>() = + static let entType = typeof<'ent> + static let template = + let moduleBuilder = + let assembly = AssemblyName("Readers." + entType.Name + "." + Guid.NewGuid().ToString("N")) + let appDomain = Threading.Thread.GetDomain() + let assemblyBuilder = appDomain.DefineDynamicAssembly(assembly, AssemblyBuilderAccess.Run) + assemblyBuilder.DefineDynamicModule(assembly.Name) + let readerBaseType = typedefof<_ EntityReader>.MakeGenericType(entType) + let readerType = + let builder = + moduleBuilder.DefineType + ( entType.Name + "Reader" + , TypeAttributes.Public ||| TypeAttributes.AutoClass ||| TypeAttributes.AnsiClass + , readerBaseType + ) + StaticEntityReaderTemplate.ImplementReader(Blueprint.ofType entType, builder) + let templateType = + let builder = + moduleBuilder.DefineType + ( entType.Name + "Template" + , TypeAttributes.Public ||| TypeAttributes.AutoClass ||| TypeAttributes.AnsiClass + , typedefof<_ EntityReaderTemplate>.MakeGenericType(entType) + ) + ignore <| builder.DefineDefaultConstructor(MethodAttributes.Public) + let meth = + builder.DefineMethod + ( "CreateReader" + , MethodAttributes.Public ||| MethodAttributes.Virtual + , readerBaseType + , Type.EmptyTypes + ) + (Stack.empty, IL(meth.GetILGenerator())) ||> + cil { + yield newobj0 (readerType.GetConstructor(Type.EmptyTypes)) + yield ret + } |> ignore + builder.CreateType() + Activator.CreateInstance(templateType) + |> Unchecked.unbox : 'ent EntityReaderTemplate + static member Template() = template \ No newline at end of file diff --git a/Rezoom.SQL.Mapping/packages.config b/Rezoom.SQL.Mapping/packages.config new file mode 100644 index 0000000..1beb59d --- /dev/null +++ b/Rezoom.SQL.Mapping/packages.config @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/Rezoom.SQL.Provider.Test/Rezoom.SQL.Provider.Test.sln b/Rezoom.SQL.Provider.Test/Rezoom.SQL.Provider.Test.sln new file mode 100644 index 0000000..9cd2641 --- /dev/null +++ b/Rezoom.SQL.Provider.Test/Rezoom.SQL.Provider.Test.sln @@ -0,0 +1,22 @@ + +Microsoft Visual Studio Solution File, Format Version 12.00 +# Visual Studio 14 +VisualStudioVersion = 14.0.25420.1 +MinimumVisualStudioVersion = 10.0.40219.1 +Project("{F2A71F9B-5D33-465A-A702-920D77279786}") = "Rezoom.SQL.Provider.Test", "Rezoom.SQL.Provider.Test\Rezoom.SQL.Provider.Test.fsproj", "{0C6AF2D8-42BC-46F2-A9D6-BBEED08E6965}" +EndProject +Global + GlobalSection(SolutionConfigurationPlatforms) = preSolution + Debug|Any CPU = Debug|Any CPU + Release|Any CPU = Release|Any CPU + EndGlobalSection + GlobalSection(ProjectConfigurationPlatforms) = postSolution + {0C6AF2D8-42BC-46F2-A9D6-BBEED08E6965}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {0C6AF2D8-42BC-46F2-A9D6-BBEED08E6965}.Debug|Any CPU.Build.0 = Debug|Any CPU + {0C6AF2D8-42BC-46F2-A9D6-BBEED08E6965}.Release|Any CPU.ActiveCfg = Release|Any CPU + {0C6AF2D8-42BC-46F2-A9D6-BBEED08E6965}.Release|Any CPU.Build.0 = Release|Any CPU + EndGlobalSection + GlobalSection(SolutionProperties) = preSolution + HideSolutionNode = FALSE + EndGlobalSection +EndGlobal diff --git a/Rezoom.SQL.Provider.Test/Rezoom.SQL.Provider.Test/App.config b/Rezoom.SQL.Provider.Test/Rezoom.SQL.Provider.Test/App.config new file mode 100644 index 0000000..453ff05 --- /dev/null +++ b/Rezoom.SQL.Provider.Test/Rezoom.SQL.Provider.Test/App.config @@ -0,0 +1,18 @@ + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/Rezoom.SQL.Provider.Test/Rezoom.SQL.Provider.Test/AssemblyInfo.fs b/Rezoom.SQL.Provider.Test/Rezoom.SQL.Provider.Test/AssemblyInfo.fs new file mode 100644 index 0000000..a6fbfd0 --- /dev/null +++ b/Rezoom.SQL.Provider.Test/Rezoom.SQL.Provider.Test/AssemblyInfo.fs @@ -0,0 +1,41 @@ +namespace Rezoom.SQL.Provider.Test.AssemblyInfo + +open System.Reflection +open System.Runtime.CompilerServices +open System.Runtime.InteropServices + +// General Information about an assembly is controlled through the following +// set of attributes. Change these attribute values to modify the information +// associated with an assembly. +[] +[] +[] +[] +[] +[] +[] +[] + +// Setting ComVisible to false makes the types in this assembly not visible +// to COM components. If you need to access a type in this assembly from +// COM, set the ComVisible attribute to true on that type. +[] + +// The following GUID is for the ID of the typelib if this project is exposed to COM +[] + +// Version information for an assembly consists of the following four values: +// +// Major Version +// Minor Version +// Build Number +// Revision +// +// You can specify all the values or you can default the Build and Revision Numbers +// by using the '*' as shown below: +// [] +[] +[] + +do + () \ No newline at end of file diff --git a/Rezoom.SQL.Provider.Test/Rezoom.SQL.Provider.Test/Program.fs b/Rezoom.SQL.Provider.Test/Rezoom.SQL.Provider.Test/Program.fs new file mode 100644 index 0000000..0f34e54 --- /dev/null +++ b/Rezoom.SQL.Provider.Test/Rezoom.SQL.Provider.Test/Program.fs @@ -0,0 +1,70 @@ +open System +open System.IO +open System.Data.SQLite +open Rezoom +open Rezoom.Execution +open Rezoom.SQL.Provider +open Rezoom.SQL.Mapping +open Rezoom.SQL.Mapping.Migrations + +type DataModel = SQLModel + +type GetChildTodos = SQL<""" + select * from ActiveTodos where ParentId is @parentId +"""> + +type MakeChildTodo = SQL<""" + insert into Todos(ParentId, Heading, Paragraph, DeactivatedUtc) + values (null, @heading, @paragraph, null); + select last_insert_rowid() as id; +"""> + +let cmdPlan (cmd : string) = + plan { + match cmd with + | "ls" -> + let! children = GetChildTodos.Command(None).ExecutePlan() + printfn "%d children" children.Count + for child in children do + printfn "%s: %s" child.Heading (defaultArg child.Paragraph "None") + | "mk" -> + let! result = MakeChildTodo.Command("test heading", Some "test para").ExecutePlan() + printfn "Created TODO %d" result.[0].id + | "throw" -> + failwith "unhandled exn" + | _ -> + printfn "Unrecognized command ``%s``" cmd + } + +let buildPlan (cmd : string) = + plan { + let cmds = cmd.Split(';') + for cmd in batch cmds do + do! cmdPlan (cmd.Trim()) + } + +let migrate() = + let dbname = "test.db" + if not <| File.Exists(dbname) then + SQLiteConnection.CreateFile(dbname) + use conn = new SQLiteConnection("data source=" + dbname) + conn.Open() + let config = + { AllowMigrationsFromOlderMajorVersions = false + LogMigrationRan = fun _ -> () + } + DataModel.Migrate(config, conn) + +[] +let main argv = + migrate() + let execute plan = Execution.execute ExecutionConfig.Default plan + while true do + let cmd = Console.ReadLine() + try + let plan = buildPlan cmd + (execute plan).Wait() + with + | exn -> + printfn "Plan failed with exn: %O" exn + 0 // return an integer exit code diff --git a/Rezoom.ADO.Test/Rezoom.ADO.Test.fsproj b/Rezoom.SQL.Provider.Test/Rezoom.SQL.Provider.Test/Rezoom.SQL.Provider.Test.fsproj similarity index 62% rename from Rezoom.ADO.Test/Rezoom.ADO.Test.fsproj rename to Rezoom.SQL.Provider.Test/Rezoom.SQL.Provider.Test/Rezoom.SQL.Provider.Test.fsproj index 21eadcd..ce51125 100644 --- a/Rezoom.ADO.Test/Rezoom.ADO.Test.fsproj +++ b/Rezoom.SQL.Provider.Test/Rezoom.SQL.Provider.Test/Rezoom.SQL.Provider.Test.fsproj @@ -1,18 +1,18 @@ - + Debug AnyCPU 2.0 - 39b23d3a-43d2-442f-9b83-b68ece642fad - Library - Rezoom.ADO.Test - Rezoom.ADO.Test + 0c6af2d8-42bc-46f2-a9d6-bbeed08e6965 + Exe + Rezoom.SQL.Provider.Test + Rezoom.SQL.Provider.Test v4.6 - 4.4.0.0 true - Rezoom.ADO.Test + 4.4.0.0 + Rezoom.SQL.Provider.Test @@ -24,7 +24,9 @@ bin\Debug\ DEBUG;TRACE 3 - bin\Debug\Rezoom.ADO.Test.XML + AnyCPU + bin\Debug\Rezoom.SQL.Provider.Test.XML + true pdbonly @@ -33,7 +35,9 @@ bin\Release\ TRACE 3 - bin\Release\Rezoom.ADO.Test.XML + AnyCPU + bin\Release\Rezoom.SQL.Provider.Test.XML + true 11 @@ -51,50 +55,66 @@ - - - - This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}. - - - - - + + + + - + + ..\packages\EntityFramework.6.0.0\lib\net45\EntityFramework.dll + True + + + ..\packages\EntityFramework.6.0.0\lib\net45\EntityFramework.SqlServer.dll + True + + + ..\..\Rezoom.SQL.Provider\bin\Debug\FParsec.dll + + + ..\..\Rezoom.SQL.Provider\bin\Debug\FParsec-Pipes.dll + + + ..\..\Rezoom.SQL.Provider\bin\Debug\LicenseToCIL.dll + True + + ..\..\Rezoom.SQL.Provider\bin\Debug\Rezoom.dll + + + ..\..\Rezoom.SQL.Provider\bin\Debug\Rezoom.SQL.Compiler.dll + + + ..\..\Rezoom.SQL.Provider\bin\Debug\Rezoom.SQL.Mapping.dll + + + ..\..\Rezoom.SQL.Provider\bin\Debug\Rezoom.SQL.Provider.dll + + - ..\packages\System.Data.SQLite.Core.1.0.101.0\lib\net46\System.Data.SQLite.dll + ..\packages\System.Data.SQLite.Core.1.0.104.0\lib\net46\System.Data.SQLite.dll True - - Rezoom.ADO - {13bb08a8-8135-4630-beab-1f35d660b52b} - True - - - Rezoom - {d98acbeb-a039-4340-a7c5-6ed2b677268b} - True - - - Rezoom.Execution - {9db721d3-da97-4be3-b60b-9b7a682e803e} - True - + + + + This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}. + + + - + \ No newline at end of file diff --git a/Rezoom.SQL.Provider.Test/Rezoom.SQL.Provider.Test/V1.initial.sql b/Rezoom.SQL.Provider.Test/Rezoom.SQL.Provider.Test/V1.initial.sql new file mode 100644 index 0000000..aaffa38 --- /dev/null +++ b/Rezoom.SQL.Provider.Test/Rezoom.SQL.Provider.Test/V1.initial.sql @@ -0,0 +1,10 @@ +create table ToDos + ( Id int primary key autoincrement + , ParentId int null references ToDos(Id) + , Heading string(256) + , Paragraph string(512) null + , DeactivatedUtc string(64) null + ); + +create view ActiveToDos as + select * from ToDos where DeactivatedUtc is null; \ No newline at end of file diff --git a/Rezoom.SQL.Provider.Test/Rezoom.SQL.Provider.Test/packages.config b/Rezoom.SQL.Provider.Test/Rezoom.SQL.Provider.Test/packages.config new file mode 100644 index 0000000..2f0b4a2 --- /dev/null +++ b/Rezoom.SQL.Provider.Test/Rezoom.SQL.Provider.Test/packages.config @@ -0,0 +1,5 @@ + + + + + \ No newline at end of file diff --git a/Rezoom.SQL.Provider.Test/Rezoom.SQL.Provider.Test/rzsql.json b/Rezoom.SQL.Provider.Test/Rezoom.SQL.Provider.Test/rzsql.json new file mode 100644 index 0000000..5cd1080 --- /dev/null +++ b/Rezoom.SQL.Provider.Test/Rezoom.SQL.Provider.Test/rzsql.json @@ -0,0 +1,4 @@ +{ + "backend": "sqlite", + "optionals": "f#" +} diff --git a/Rezoom.SQL.Provider/AssemblyInfo.fs b/Rezoom.SQL.Provider/AssemblyInfo.fs new file mode 100644 index 0000000..45443bb --- /dev/null +++ b/Rezoom.SQL.Provider/AssemblyInfo.fs @@ -0,0 +1,41 @@ +namespace Rezoom.SQL.Provider.AssemblyInfo + +open System.Reflection +open System.Runtime.CompilerServices +open System.Runtime.InteropServices + +// General Information about an assembly is controlled through the following +// set of attributes. Change these attribute values to modify the information +// associated with an assembly. +[] +[] +[] +[] +[] +[] +[] +[] + +// Setting ComVisible to false makes the types in this assembly not visible +// to COM components. If you need to access a type in this assembly from +// COM, set the ComVisible attribute to true on that type. +[] + +// The following GUID is for the ID of the typelib if this project is exposed to COM +[] + +// Version information for an assembly consists of the following four values: +// +// Major Version +// Minor Version +// Build Number +// Revision +// +// You can specify all the values or you can default the Build and Revision Numbers +// by using the '*' as shown below: +// [] +[] +[] + +do + () \ No newline at end of file diff --git a/Rezoom.SQL.Provider/CompileTimeColumnMap.fs b/Rezoom.SQL.Provider/CompileTimeColumnMap.fs new file mode 100644 index 0000000..5ed147b --- /dev/null +++ b/Rezoom.SQL.Provider/CompileTimeColumnMap.fs @@ -0,0 +1,37 @@ +namespace Rezoom.SQL.Provider +open System +open System.Collections.Generic +open Rezoom.SQL.Compiler + +/// Same mapping as Rezoom.SQL.Mapping.ColumnMap, but carries more metadata about the columns +/// known from Rezoom.SQL. +type private CompileTimeColumnMap() = + let columns = Dictionary(StringComparer.OrdinalIgnoreCase) + let subMaps = Dictionary(StringComparer.OrdinalIgnoreCase) + member private this.GetOrCreateSubMap(name) = + let succ, sub = subMaps.TryGetValue(name) + if succ then sub else + let sub = CompileTimeColumnMap() + subMaps.[name] <- sub + sub + member private this.SetColumn(name, info) = + columns.[name] <- info + // TODO: use inline functions to have a single implementation for this load logic. + // It's gross duplicating it between ColumnMap and CompileTimeColumnMap. + member private this.Load(columns : ColumnType ColumnExprInfo IReadOnlyList) = + for i = 0 to columns.Count - 1 do + let mutable current = this + let column = columns.[i] + let path = column.ColumnName.Value.Split('.', '$') + if path.Length > 1 then + current <- this + for j = 0 to path.Length - 2 do + current <- current.GetOrCreateSubMap(path.[j]) + current.SetColumn(Array.last path, (int16 i, column)) + member this.HasSubMaps = subMaps.Count > 0 + member this.SubMaps = subMaps :> _ seq + member this.Columns = columns :> _ seq + static member Parse(columns) = + let map = CompileTimeColumnMap() + map.Load(columns) + map diff --git a/Rezoom.SQL.Provider/ProvidedTypes-LICENSE.md b/Rezoom.SQL.Provider/ProvidedTypes-LICENSE.md new file mode 100644 index 0000000..b7620f2 --- /dev/null +++ b/Rezoom.SQL.Provider/ProvidedTypes-LICENSE.md @@ -0,0 +1,202 @@ +Copyright 2011-2012, Tomas Petricek (http://tomasp.net) + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +------------------------------------------------------------ + +Apache License, Version 2.0 +=========================== + +Apache License +Version 2.0, January 2004 +http://www.apache.org/licenses/ + +### TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + +**1. Definitions.** + + - "License" shall mean the terms and conditions for use, reproduction, +and distribution as defined by Sections 1 through 9 of this document. + + - "Licensor" shall mean the copyright owner or entity authorized by +the copyright owner that is granting the License. + + - "Legal Entity" shall mean the union of the acting entity and all +other entities that control, are controlled by, or are under common +control with that entity. For the purposes of this definition, +"control" means (i) the power, direct or indirect, to cause the +direction or management of such entity, whether by contract or +otherwise, or (ii) ownership of fifty percent (50%) or more of the +outstanding shares, or (iii) beneficial ownership of such entity. + + - "You" (or "Your") shall mean an individual or Legal Entity +exercising permissions granted by this License. + + - "Source" form shall mean the preferred form for making modifications, +including but not limited to software source code, documentation +source, and configuration files. + + - "Object" form shall mean any form resulting from mechanical +transformation or translation of a Source form, including but +not limited to compiled object code, generated documentation, +and conversions to other media types. + + - "Work" shall mean the work of authorship, whether in Source or +Object form, made available under the License, as indicated by a +copyright notice that is included in or attached to the work +(an example is provided in the Appendix below). + + - "Derivative Works" shall mean any work, whether in Source or Object +form, that is based on (or derived from) the Work and for which the +editorial revisions, annotations, elaborations, or other modifications +represent, as a whole, an original work of authorship. For the purposes +of this License, Derivative Works shall not include works that remain +separable from, or merely link (or bind by name) to the interfaces of, +the Work and Derivative Works thereof. + + - "Contribution" shall mean any work of authorship, including +the original version of the Work and any modifications or additions +to that Work or Derivative Works thereof, that is intentionally +submitted to Licensor for inclusion in the Work by the copyright owner +or by an individual or Legal Entity authorized to submit on behalf of +the copyright owner. For the purposes of this definition, "submitted" +means any form of electronic, verbal, or written communication sent +to the Licensor or its representatives, including but not limited to +communication on electronic mailing lists, source code control systems, +and issue tracking systems that are managed by, or on behalf of, the +Licensor for the purpose of discussing and improving the Work, but +excluding communication that is conspicuously marked or otherwise +designated in writing by the copyright owner as "Not a Contribution." + + - "Contributor" shall mean Licensor and any individual or Legal Entity +on behalf of whom a Contribution has been received by Licensor and +subsequently incorporated within the Work. + +**2. Grant of Copyright License.** +Subject to the terms and conditions of +this License, each Contributor hereby grants to You a perpetual, +worldwide, non-exclusive, no-charge, royalty-free, irrevocable +copyright license to reproduce, prepare Derivative Works of, +publicly display, publicly perform, sublicense, and distribute the +Work and such Derivative Works in Source or Object form. + +**3. Grant of Patent License.** +Subject to the terms and conditions of +this License, each Contributor hereby grants to You a perpetual, +worldwide, non-exclusive, no-charge, royalty-free, irrevocable +(except as stated in this section) patent license to make, have made, +use, offer to sell, sell, import, and otherwise transfer the Work, +where such license applies only to those patent claims licensable +by such Contributor that are necessarily infringed by their +Contribution(s) alone or by combination of their Contribution(s) +with the Work to which such Contribution(s) was submitted. If You +institute patent litigation against any entity (including a +cross-claim or counterclaim in a lawsuit) alleging that the Work +or a Contribution incorporated within the Work constitutes direct +or contributory patent infringement, then any patent licenses +granted to You under this License for that Work shall terminate +as of the date such litigation is filed. + +**4. Redistribution.** +You may reproduce and distribute copies of the +Work or Derivative Works thereof in any medium, with or without +modifications, and in Source or Object form, provided that You +meet the following conditions: + + - You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + - You must cause any modified files to carry prominent notices + stating that You changed the files; and + + - You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + - If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + +You may add Your own copyright statement to Your modifications and +may provide additional or different license terms and conditions +for use, reproduction, or distribution of Your modifications, or +for any such Derivative Works as a whole, provided Your use, +reproduction, and distribution of the Work otherwise complies with +the conditions stated in this License. + +**5. Submission of Contributions.** +Unless You explicitly state otherwise, +any Contribution intentionally submitted for inclusion in the Work +by You to the Licensor shall be under the terms and conditions of +this License, without any additional terms or conditions. +Notwithstanding the above, nothing herein shall supersede or modify +the terms of any separate license agreement you may have executed +with Licensor regarding such Contributions. + +**6. Trademarks.** +This License does not grant permission to use the trade +names, trademarks, service marks, or product names of the Licensor, +except as required for reasonable and customary use in describing the +origin of the Work and reproducing the content of the NOTICE file. + +**7. Disclaimer of Warranty.** +Unless required by applicable law or +agreed to in writing, Licensor provides the Work (and each +Contributor provides its Contributions) on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +implied, including, without limitation, any warranties or conditions +of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A +PARTICULAR PURPOSE. You are solely responsible for determining the +appropriateness of using or redistributing the Work and assume any +risks associated with Your exercise of permissions under this License. + +**8. Limitation of Liability.** +In no event and under no legal theory, +whether in tort (including negligence), contract, or otherwise, +unless required by applicable law (such as deliberate and grossly +negligent acts) or agreed to in writing, shall any Contributor be +liable to You for damages, including any direct, indirect, special, +incidental, or consequential damages of any character arising as a +result of this License or out of the use or inability to use the +Work (including but not limited to damages for loss of goodwill, +work stoppage, computer failure or malfunction, or any and all +other commercial damages or losses), even if such Contributor +has been advised of the possibility of such damages. + +**9. Accepting Warranty or Additional Liability.** +While redistributing +the Work or Derivative Works thereof, You may choose to offer, +and charge a fee for, acceptance of support, warranty, indemnity, +or other liability obligations and/or rights consistent with this +License. However, in accepting such obligations, You may act only +on Your own behalf and on Your sole responsibility, not on behalf +of any other Contributor, and only if You agree to indemnify, +defend, and hold each Contributor harmless for any liability +incurred by, or claims asserted against, such Contributor by reason +of your accepting any such warranty or additional liability. diff --git a/Rezoom.SQL.Provider/ProvidedTypes.fs b/Rezoom.SQL.Provider/ProvidedTypes.fs new file mode 100644 index 0000000..96525e3 --- /dev/null +++ b/Rezoom.SQL.Provider/ProvidedTypes.fs @@ -0,0 +1,3032 @@ +// Copyright (c) Microsoft Corporation 2005-2012. +// This sample code is provided "as is" without warranty of any kind. +// We disclaim all warranties, either express or implied, including the +// warranties of merchantability and fitness for a particular purpose. + +// This file contains a set of helper types and methods for providing types in an implementation +// of ITypeProvider. + +// This code has been modified and is appropriate for use in conjunction with the F# 3.0-4.0 releases + +namespace ProviderImplementation.ProvidedTypes + +open System +open System.Text +open System.IO +open System.Reflection +open System.Linq.Expressions +open System.Collections.Generic +open Microsoft.FSharp.Quotations +open Microsoft.FSharp.Quotations.Patterns +open Microsoft.FSharp.Quotations.DerivedPatterns +open Microsoft.FSharp.Core.CompilerServices + +#if NO_GENERATIVE +#else +open System.Reflection.Emit +#endif + +//-------------------------------------------------------------------------------- +// UncheckedQuotations + +// The FSharp.Core 2.0 - 4.0 (4.0.0.0 - 4.4.0.0) quotations implementation is overly strict in that it doesn't allow +// generation of quotations for cross-targeted FSharp.Core. Below we define a series of Unchecked methods +// implemented via reflection hacks to allow creation of various nodes when using a cross-targets FSharp.Core and +// mscorlib.dll. +// +// - Most importantly, these cross-targeted quotations can be provided to the F# compiler by a type provider. +// They are generally produced via the AssemblyReplacer.fs component through a process of rewriting design-time quotations that +// are not cross-targeted. +// +// - However, these quotation values are a bit fragile. Using existing FSharp.Core.Quotations.Patterns +// active patterns on these quotation nodes will generally work correctly. But using ExprShape.RebuildShapeCombination +// on these new nodes will not succed, nor will operations that build new quotations such as Expr.Call. +// Instead, use the replacement provided in this module. +// +// - Likewise, some operations in these quotation values like "expr.Type" may be a bit fragile, possibly returning non cross-targeted types in +// the result. However those operations are not used by the F# compiler. +[] +module internal UncheckedQuotations = + + let qTy = typeof.Assembly.GetType("Microsoft.FSharp.Quotations.ExprConstInfo") + assert (qTy <> null) + let pTy = typeof.Assembly.GetType("Microsoft.FSharp.Quotations.PatternsModule") + assert (pTy<> null) + + // These are handles to the internal functions that create quotation nodes of different sizes. Although internal, + // these function names have been stable since F# 2.0. + let mkFE0 = pTy.GetMethod("mkFE0", BindingFlags.Static ||| BindingFlags.Public ||| BindingFlags.NonPublic) + assert (mkFE0 <> null) + let mkFE1 = pTy.GetMethod("mkFE1", BindingFlags.Static ||| BindingFlags.Public ||| BindingFlags.NonPublic) + assert (mkFE1 <> null) + let mkFE2 = pTy.GetMethod("mkFE2", BindingFlags.Static ||| BindingFlags.Public ||| BindingFlags.NonPublic) + assert (mkFE2 <> null) + let mkFEN = pTy.GetMethod("mkFEN", BindingFlags.Static ||| BindingFlags.Public ||| BindingFlags.NonPublic) + assert (mkFEN <> null) + + // These are handles to the internal tags attached to quotation nodes of different sizes. Although internal, + // these function names have been stable since F# 2.0. + let newDelegateOp = qTy.GetMethod("NewNewDelegateOp", BindingFlags.Static ||| BindingFlags.Public ||| BindingFlags.NonPublic) + assert (newDelegateOp <> null) + let instanceCallOp = qTy.GetMethod("NewInstanceMethodCallOp", BindingFlags.Static ||| BindingFlags.Public ||| BindingFlags.NonPublic) + assert (instanceCallOp <> null) + let staticCallOp = qTy.GetMethod("NewStaticMethodCallOp", BindingFlags.Static ||| BindingFlags.Public ||| BindingFlags.NonPublic) + assert (staticCallOp <> null) + let newObjectOp = qTy.GetMethod("NewNewObjectOp", BindingFlags.Static ||| BindingFlags.Public ||| BindingFlags.NonPublic) + assert (newObjectOp <> null) + let newArrayOp = qTy.GetMethod("NewNewArrayOp", BindingFlags.Static ||| BindingFlags.Public ||| BindingFlags.NonPublic) + assert (newArrayOp <> null) + let appOp = qTy.GetMethod("get_AppOp", BindingFlags.Static ||| BindingFlags.Public ||| BindingFlags.NonPublic) + assert (appOp <> null) + let instancePropGetOp = qTy.GetMethod("NewInstancePropGetOp", BindingFlags.Static ||| BindingFlags.Public ||| BindingFlags.NonPublic) + assert (instancePropGetOp <> null) + let staticPropGetOp = qTy.GetMethod("NewStaticPropGetOp", BindingFlags.Static ||| BindingFlags.Public ||| BindingFlags.NonPublic) + assert (staticPropGetOp <> null) + let instancePropSetOp = qTy.GetMethod("NewInstancePropSetOp", BindingFlags.Static ||| BindingFlags.Public ||| BindingFlags.NonPublic) + assert (instancePropSetOp <> null) + let staticPropSetOp = qTy.GetMethod("NewStaticPropSetOp", BindingFlags.Static ||| BindingFlags.Public ||| BindingFlags.NonPublic) + assert (staticPropSetOp <> null) + let instanceFieldGetOp = qTy.GetMethod("NewInstanceFieldGetOp", BindingFlags.Static ||| BindingFlags.Public ||| BindingFlags.NonPublic) + assert (instanceFieldGetOp <> null) + let staticFieldGetOp = qTy.GetMethod("NewStaticFieldGetOp", BindingFlags.Static ||| BindingFlags.Public ||| BindingFlags.NonPublic) + assert (staticFieldGetOp <> null) + let instanceFieldSetOp = qTy.GetMethod("NewInstanceFieldSetOp", BindingFlags.Static ||| BindingFlags.Public ||| BindingFlags.NonPublic) + assert (instanceFieldSetOp <> null) + let staticFieldSetOp = qTy.GetMethod("NewStaticFieldSetOp", BindingFlags.Static ||| BindingFlags.Public ||| BindingFlags.NonPublic) + assert (staticFieldSetOp <> null) + let tupleGetOp = qTy.GetMethod("NewTupleGetOp", BindingFlags.Static ||| BindingFlags.Public ||| BindingFlags.NonPublic) + assert (tupleGetOp <> null) + let letOp = qTy.GetMethod("get_LetOp", BindingFlags.Static ||| BindingFlags.Public ||| BindingFlags.NonPublic) + assert (letOp <> null) + + type Microsoft.FSharp.Quotations.Expr with + + static member NewDelegateUnchecked (ty: Type, vs: Var list, body: Expr) = + let e = List.foldBack (fun v acc -> Expr.Lambda(v,acc)) vs body + let op = newDelegateOp.Invoke(null, [| box ty |]) + mkFE1.Invoke(null, [| box op; box e |]) :?> Expr + + static member NewObjectUnchecked (cinfo: ConstructorInfo, args : Expr list) = + let op = newObjectOp.Invoke(null, [| box cinfo |]) + mkFEN.Invoke(null, [| box op; box args |]) :?> Expr + + static member NewArrayUnchecked (elementType: Type, elements : Expr list) = + let op = newArrayOp.Invoke(null, [| box elementType |]) + mkFEN.Invoke(null, [| box op; box elements |]) :?> Expr + + static member CallUnchecked (minfo: MethodInfo, args : Expr list) = + let op = staticCallOp.Invoke(null, [| box minfo |]) + mkFEN.Invoke(null, [| box op; box args |]) :?> Expr + + static member CallUnchecked (obj: Expr, minfo: MethodInfo, args : Expr list) = + let op = instanceCallOp.Invoke(null, [| box minfo |]) + mkFEN.Invoke(null, [| box op; box (obj::args) |]) :?> Expr + + static member ApplicationUnchecked (f: Expr, x: Expr) = + let op = appOp.Invoke(null, [| |]) + mkFE2.Invoke(null, [| box op; box f; box x |]) :?> Expr + + static member PropertyGetUnchecked (pinfo: PropertyInfo, args : Expr list) = + let op = staticPropGetOp.Invoke(null, [| box pinfo |]) + mkFEN.Invoke(null, [| box op; box args |]) :?> Expr + + static member PropertyGetUnchecked (obj: Expr, pinfo: PropertyInfo, ?args : Expr list) = + let args = defaultArg args [] + let op = instancePropGetOp.Invoke(null, [| box pinfo |]) + mkFEN.Invoke(null, [| box op; box (obj::args) |]) :?> Expr + + static member PropertySetUnchecked (pinfo: PropertyInfo, value: Expr, ?args : Expr list) = + let args = defaultArg args [] + let op = staticPropSetOp.Invoke(null, [| box pinfo |]) + mkFEN.Invoke(null, [| box op; box (args@[value]) |]) :?> Expr + + static member PropertySetUnchecked (obj: Expr, pinfo: PropertyInfo, value: Expr, args : Expr list) = + let op = instancePropSetOp.Invoke(null, [| box pinfo |]) + mkFEN.Invoke(null, [| box op; box (obj::(args@[value])) |]) :?> Expr + + static member FieldGetUnchecked (pinfo: FieldInfo) = + let op = staticFieldGetOp.Invoke(null, [| box pinfo |]) + mkFE0.Invoke(null, [| box op; |]) :?> Expr + + static member FieldGetUnchecked (obj: Expr, pinfo: FieldInfo) = + let op = instanceFieldGetOp.Invoke(null, [| box pinfo |]) + mkFE1.Invoke(null, [| box op; box obj |]) :?> Expr + + static member FieldSetUnchecked (pinfo: FieldInfo, value: Expr) = + let op = staticFieldSetOp.Invoke(null, [| box pinfo |]) + mkFE1.Invoke(null, [| box op; box value |]) :?> Expr + + static member FieldSetUnchecked (obj: Expr, pinfo: FieldInfo, value: Expr) = + let op = instanceFieldSetOp.Invoke(null, [| box pinfo |]) + mkFE2.Invoke(null, [| box op; box obj; box value |]) :?> Expr + + static member TupleGetUnchecked (e: Expr, n:int) = + let op = tupleGetOp.Invoke(null, [| box e.Type; box n |]) + mkFE1.Invoke(null, [| box op; box e |]) :?> Expr + + static member LetUnchecked (v:Var, e: Expr, body:Expr) = + let lam = Expr.Lambda(v,body) + let op = letOp.Invoke(null, [| |]) + mkFE2.Invoke(null, [| box op; box e; box lam |]) :?> Expr + + type Shape = Shape of (Expr list -> Expr) + + let (|ShapeCombinationUnchecked|ShapeVarUnchecked|ShapeLambdaUnchecked|) e = + match e with + | NewObject (cinfo, args) -> + ShapeCombinationUnchecked (Shape (function args -> Expr.NewObjectUnchecked (cinfo, args)), args) + | NewArray (ty, args) -> + ShapeCombinationUnchecked (Shape (function args -> Expr.NewArrayUnchecked (ty, args)), args) + | NewDelegate (t, vars, expr) -> + ShapeCombinationUnchecked (Shape (function [expr] -> Expr.NewDelegateUnchecked (t, vars, expr) | _ -> invalidArg "expr" "invalid shape"), [expr]) + | TupleGet (expr, n) -> + ShapeCombinationUnchecked (Shape (function [expr] -> Expr.TupleGetUnchecked (expr, n) | _ -> invalidArg "expr" "invalid shape"), [expr]) + | Application (f, x) -> + ShapeCombinationUnchecked (Shape (function [f; x] -> Expr.ApplicationUnchecked (f, x) | _ -> invalidArg "expr" "invalid shape"), [f; x]) + | Call (objOpt, minfo, args) -> + match objOpt with + | None -> ShapeCombinationUnchecked (Shape (function args -> Expr.CallUnchecked (minfo, args)), args) + | Some obj -> ShapeCombinationUnchecked (Shape (function (obj::args) -> Expr.CallUnchecked (obj, minfo, args) | _ -> invalidArg "expr" "invalid shape"), obj::args) + | PropertyGet (objOpt, pinfo, args) -> + match objOpt with + | None -> ShapeCombinationUnchecked (Shape (function args -> Expr.PropertyGetUnchecked (pinfo, args)), args) + | Some obj -> ShapeCombinationUnchecked (Shape (function (obj::args) -> Expr.PropertyGetUnchecked (obj, pinfo, args) | _ -> invalidArg "expr" "invalid shape"), obj::args) + | PropertySet (objOpt, pinfo, args, value) -> + match objOpt with + | None -> ShapeCombinationUnchecked (Shape (function (value::args) -> Expr.PropertySetUnchecked (pinfo, value, args) | _ -> invalidArg "expr" "invalid shape"), value::args) + | Some obj -> ShapeCombinationUnchecked (Shape (function (obj::value::args) -> Expr.PropertySetUnchecked (obj, pinfo, value, args) | _ -> invalidArg "expr" "invalid shape"), obj::value::args) + | FieldGet (objOpt, pinfo) -> + match objOpt with + | None -> ShapeCombinationUnchecked (Shape (function _ -> Expr.FieldGetUnchecked (pinfo)), []) + | Some obj -> ShapeCombinationUnchecked (Shape (function [obj] -> Expr.FieldGetUnchecked (obj, pinfo) | _ -> invalidArg "expr" "invalid shape"), [obj]) + | FieldSet (objOpt, pinfo, value) -> + match objOpt with + | None -> ShapeCombinationUnchecked (Shape (function [value] -> Expr.FieldSetUnchecked (pinfo, value) | _ -> invalidArg "expr" "invalid shape"), [value]) + | Some obj -> ShapeCombinationUnchecked (Shape (function [obj;value] -> Expr.FieldSetUnchecked (obj, pinfo, value) | _ -> invalidArg "expr" "invalid shape"), [obj; value]) + | Let (var, value, body) -> + ShapeCombinationUnchecked (Shape (function [value;Lambda(var, body)] -> Expr.LetUnchecked(var, value, body) | _ -> invalidArg "expr" "invalid shape"), [value; Expr.Lambda(var, body)]) + | TupleGet (expr, i) -> + ShapeCombinationUnchecked (Shape (function [expr] -> Expr.TupleGetUnchecked (expr, i) | _ -> invalidArg "expr" "invalid shape"), [expr]) + | ExprShape.ShapeCombination (comb,args) -> + ShapeCombinationUnchecked (Shape (fun args -> ExprShape.RebuildShapeCombination(comb, args)), args) + | ExprShape.ShapeVar v -> ShapeVarUnchecked v + | ExprShape.ShapeLambda (v, e) -> ShapeLambdaUnchecked (v,e) + + let RebuildShapeCombinationUnchecked (Shape comb,args) = comb args + +//-------------------------------------------------------------------------------- +// The quotation simplifier +// +// This is invoked for each quotation specified by the type provider, just before it is +// handed to the F# compiler, allowing a broader range of +// quotations to be accepted. Specifically accept: +// +// - NewTuple nodes (for generative type providers) +// - TupleGet nodes (for generative type providers) +// - array and list values as constants +// - PropertyGet and PropertySet nodes +// - Application, NewUnionCase, NewRecord, UnionCaseTest nodes +// - Let nodes (defining "byref" values) +// - LetRecursive nodes +// +// Additionally, a set of code optimizations are applied to generated code: +// - inlineRightPipe +// - optimizeCurriedApplications +// - inlineValueBindings + +type QuotationSimplifier(isGenerated: bool) = + + let rec transExpr q = + match q with + // convert NewTuple to the call to the constructor of the Tuple type (only for generated types) + | NewTuple(items) when isGenerated -> + let rec mkCtor args ty = + let ctor, restTyOpt = Reflection.FSharpValue.PreComputeTupleConstructorInfo ty + match restTyOpt with + | None -> Expr.NewObject(ctor, List.map transExpr args) + | Some restTy -> + let curr = [for a in Seq.take 7 args -> transExpr a] + let rest = List.ofSeq (Seq.skip 7 args) + Expr.NewObject(ctor, curr @ [mkCtor rest restTy]) + let tys = [| for e in items -> e.Type |] + let tupleTy = Reflection.FSharpType.MakeTupleType tys + transExpr (mkCtor items tupleTy) + // convert TupleGet to the chain of PropertyGet calls (only for generated types) + | TupleGet(e, i) when isGenerated -> + let rec mkGet ty i (e : Expr) = + let pi, restOpt = Reflection.FSharpValue.PreComputeTuplePropertyInfo(ty, i) + let propGet = Expr.PropertyGet(e, pi) + match restOpt with + | None -> propGet + | Some (restTy, restI) -> mkGet restTy restI propGet + transExpr (mkGet e.Type i (transExpr e)) + | Value(value, ty) -> + if value <> null then + let tyOfValue = value.GetType() + transValue(value, tyOfValue, ty) + else q + // Eliminate F# property gets to method calls + | PropertyGet(obj,propInfo,args) -> + match obj with + | None -> transExpr (Expr.CallUnchecked(propInfo.GetGetMethod(),args)) + | Some o -> transExpr (Expr.CallUnchecked(transExpr o,propInfo.GetGetMethod(),args)) + // Eliminate F# property sets to method calls + | PropertySet(obj,propInfo,args,v) -> + match obj with + | None -> transExpr (Expr.CallUnchecked(propInfo.GetSetMethod(),args@[v])) + | Some o -> transExpr (Expr.CallUnchecked(transExpr o,propInfo.GetSetMethod(),args@[v])) + // Eliminate F# function applications to FSharpFunc<_,_>.Invoke calls + | Application(f,e) -> + transExpr (Expr.CallUnchecked(transExpr f, f.Type.GetMethod "Invoke", [ e ]) ) + | NewUnionCase(ci, es) -> + transExpr (Expr.CallUnchecked(Reflection.FSharpValue.PreComputeUnionConstructorInfo ci, es) ) + | NewRecord(ci, es) -> + transExpr (Expr.NewObjectUnchecked(Reflection.FSharpValue.PreComputeRecordConstructorInfo ci, es) ) + | UnionCaseTest(e,uc) -> + let tagInfo = Reflection.FSharpValue.PreComputeUnionTagMemberInfo uc.DeclaringType + let tagExpr = + match tagInfo with + | :? PropertyInfo as tagProp -> + transExpr (Expr.PropertyGet(e,tagProp) ) + | :? MethodInfo as tagMeth -> + if tagMeth.IsStatic then transExpr (Expr.Call(tagMeth, [e])) + else transExpr (Expr.Call(e,tagMeth,[])) + | _ -> failwith "unreachable: unexpected result from PreComputeUnionTagMemberInfo" + let tagNumber = uc.Tag + transExpr <@@ (%%(tagExpr) : int) = tagNumber @@> + + // Explicitly handle weird byref variables in lets (used to populate out parameters), since the generic handlers can't deal with byrefs. + // + // The binding must have leaves that are themselves variables (due to the limited support for byrefs in expressions) + // therefore, we can perform inlining to translate this to a form that can be compiled + | Let(v,vexpr,bexpr) when v.Type.IsByRef -> transLetOfByref v vexpr bexpr + + // Eliminate recursive let bindings (which are unsupported by the type provider API) to regular let bindings + | LetRecursive(bindings, expr) -> transLetRec bindings expr + + // Handle the generic cases + | ShapeLambdaUnchecked(v,body) -> Expr.Lambda(v, transExpr body) + | ShapeCombinationUnchecked(comb,args) -> RebuildShapeCombinationUnchecked(comb,List.map transExpr args) + | ShapeVarUnchecked _ -> q + + and transLetRec bindings expr = + // This uses a "lets and sets" approach, converting something like + // let rec even = function + // | 0 -> true + // | n -> odd (n-1) + // and odd = function + // | 0 -> false + // | n -> even (n-1) + // X + // to something like + // let even = ref Unchecked.defaultof<_> + // let odd = ref Unchecked.defaultof<_> + // even := function + // | 0 -> true + // | n -> !odd (n-1) + // odd := function + // | 0 -> false + // | n -> !even (n-1) + // X' + // where X' is X but with occurrences of even/odd substituted by !even and !odd (since now even and odd are references) + // Translation relies on typedefof<_ ref> - does this affect ability to target different runtime and design time environments? + let vars = List.map fst bindings + let vars' = vars |> List.map (fun v -> Quotations.Var(v.Name, typedefof<_ ref>.MakeGenericType(v.Type))) + + // "init t" generates the equivalent of <@ ref Unchecked.defaultof @> + let init (t:Type) = + let r = match <@ ref 1 @> with Call(None, r, [_]) -> r | _ -> failwith "Extracting MethodInfo from <@ 1 @> failed" + let d = match <@ Unchecked.defaultof<_> @> with Call(None, d, []) -> d | _ -> failwith "Extracting MethodInfo from <@ Unchecked.defaultof<_> @> failed" + Expr.Call(r.GetGenericMethodDefinition().MakeGenericMethod(t), [Expr.Call(d.GetGenericMethodDefinition().MakeGenericMethod(t),[])]) + + // deref v generates the equivalent of <@ !v @> + // (so v's type must be ref) + let deref (v:Quotations.Var) = + let m = match <@ !(ref 1) @> with Call(None, m, [_]) -> m | _ -> failwith "Extracting MethodInfo from <@ !(ref 1) @> failed" + let tyArgs = v.Type.GetGenericArguments() + Expr.Call(m.GetGenericMethodDefinition().MakeGenericMethod(tyArgs), [Expr.Var v]) + + // substitution mapping a variable v to the expression <@ !v' @> using the corresponding new variable v' of ref type + let subst = + let map = + vars' + |> List.map deref + |> List.zip vars + |> Map.ofList + fun v -> Map.tryFind v map + + let expr' = expr.Substitute(subst) + + // maps variables to new variables + let varDict = List.zip vars vars' |> dict + + // given an old variable v and an expression e, returns a quotation like <@ v' := e @> using the corresponding new variable v' of ref type + let setRef (v:Quotations.Var) e = + let m = match <@ (ref 1) := 2 @> with Call(None, m, [_;_]) -> m | _ -> failwith "Extracting MethodInfo from <@ (ref 1) := 2 @> failed" + Expr.Call(m.GetGenericMethodDefinition().MakeGenericMethod(v.Type), [Expr.Var varDict.[v]; e]) + + // Something like + // <@ + // v1 := e1' + // v2 := e2' + // ... + // expr' + // @> + // Note that we must substitute our new variable dereferences into the bound expressions + let body = + bindings + |> List.fold (fun b (v,e) -> Expr.Sequential(setRef v (e.Substitute subst), b)) expr' + + // Something like + // let v1 = ref Unchecked.defaultof + // let v2 = ref Unchecked.defaultof + // ... + // body + vars + |> List.fold (fun b v -> Expr.LetUnchecked(varDict.[v], init v.Type, b)) body + |> transExpr + + + and transLetOfByref v vexpr bexpr = + match vexpr with + | Sequential(e',vexpr') -> + (* let v = (e'; vexpr') in bexpr => e'; let v = vexpr' in bexpr *) + Expr.Sequential(e', transLetOfByref v vexpr' bexpr) + |> transExpr + | IfThenElse(c,b1,b2) -> + (* let v = if c then b1 else b2 in bexpr => if c then let v = b1 in bexpr else let v = b2 in bexpr *) + // + // Note, this duplicates "bexpr" + Expr.IfThenElse(c, transLetOfByref v b1 bexpr, transLetOfByref v b2 bexpr) + |> transExpr + | Var _ -> + (* let v = v1 in bexpr => bexpr[v/v1] *) + bexpr.Substitute(fun v' -> if v = v' then Some vexpr else None) + |> transExpr + | _ -> + failwith (sprintf "Unexpected byref binding: %A = %A" v vexpr) + + and transValueArray (o : Array, ty : Type) = + let elemTy = ty.GetElementType() + let converter = getValueConverterForType elemTy + let elements = [ for el in o -> converter el ] + Expr.NewArrayUnchecked(elemTy, elements) + + and transValueList(o, ty : Type, nil, cons) = + let converter = getValueConverterForType (ty.GetGenericArguments().[0]) + o + |> Seq.cast + |> List.ofSeq + |> fun l -> List.foldBack(fun o s -> Expr.NewUnionCase(cons, [ converter(o); s ])) l (Expr.NewUnionCase(nil, [])) + |> transExpr + + and getValueConverterForType (ty : Type) = + if ty.IsArray then + fun (v : obj) -> transValueArray(v :?> Array, ty) + elif ty.IsGenericType && ty.GetGenericTypeDefinition() = typedefof<_ list> then + let nil, cons = + let cases = Reflection.FSharpType.GetUnionCases(ty) + let a = cases.[0] + let b = cases.[1] + if a.Name = "Empty" then a,b + else b,a + + fun v -> transValueList (v :?> System.Collections.IEnumerable, ty, nil, cons) + else + fun v -> Expr.Value(v, ty) + + and transValue (v : obj, tyOfValue : Type, expectedTy : Type) = + let converter = getValueConverterForType tyOfValue + let r = converter v + if tyOfValue <> expectedTy then Expr.Coerce(r, expectedTy) + else r + +#if NO_GENERATIVE +#else + // TODO: this works over FSharp.Core 4.4.0.0 types. These types need to be retargeted to the target runtime. + let getFastFuncType (args : list) resultType = + let types = + [| for arg in args -> arg.Type + yield resultType |] + let fastFuncTy = + match List.length args with + | 2 -> typedefof>.MakeGenericType(types) + | 3 -> typedefof>.MakeGenericType(types) + | 4 -> typedefof>.MakeGenericType(types) + | 5 -> typedefof>.MakeGenericType(types) + | _ -> invalidArg "args" "incorrect number of arguments" + fastFuncTy.GetMethod("Adapt") + + let (===) a b = LanguagePrimitives.PhysicalEquality a b + + let traverse f = + let rec fallback e = + match e with + | Let(v, value, body) -> + let fixedValue = f fallback value + let fixedBody = f fallback body + if fixedValue === value && fixedBody === body then + e + else + Expr.Let(v, fixedValue, fixedBody) + | ShapeVarUnchecked _ -> e + | ShapeLambdaUnchecked(v, body) -> + let fixedBody = f fallback body + if fixedBody === body then + e + else + Expr.Lambda(v, fixedBody) + | ShapeCombinationUnchecked(shape, exprs) -> + let exprs1 = List.map (f fallback) exprs + if List.forall2 (===) exprs exprs1 then + e + else + RebuildShapeCombinationUnchecked(shape, exprs1) + fun e -> f fallback e + + let RightPipe = <@@ (|>) @@> + let inlineRightPipe expr = + let rec loop expr = traverse loopCore expr + and loopCore fallback orig = + match orig with + | SpecificCall RightPipe (None, _, [operand; applicable]) -> + let fixedOperand = loop operand + match loop applicable with + | Lambda(arg, body) -> + let v = Quotations.Var("__temp", operand.Type) + let ev = Expr.Var v + + let fixedBody = loop body + Expr.Let(v, fixedOperand, fixedBody.Substitute(fun v1 -> if v1 = arg then Some ev else None)) + | fixedApplicable -> Expr.Application(fixedApplicable, fixedOperand) + | x -> fallback x + loop expr + + let inlineValueBindings e = + let map = Dictionary(HashIdentity.Reference) + let rec loop expr = traverse loopCore expr + and loopCore fallback orig = + match orig with + | Let(id, (Value(_) as v), body) when not id.IsMutable -> + map.[id] <- v + let fixedBody = loop body + map.Remove(id) |> ignore + fixedBody + | ShapeVarUnchecked v -> + match map.TryGetValue v with + | true, e -> e + | _ -> orig + | x -> fallback x + loop e + + + let optimizeCurriedApplications expr = + let rec loop expr = traverse loopCore expr + and loopCore fallback orig = + match orig with + | Application(e, arg) -> + let e1 = tryPeelApplications e [loop arg] + if e1 === e then + orig + else + e1 + | x -> fallback x + and tryPeelApplications orig args = + let n = List.length args + match orig with + | Application(e, arg) -> + let e1 = tryPeelApplications e ((loop arg)::args) + if e1 === e then + orig + else + e1 + | Let(id, applicable, (Lambda(_) as body)) when n > 0 -> + let numberOfApplication = countPeelableApplications body id 0 + if numberOfApplication = 0 then orig + elif n = 1 then Expr.Application(applicable, List.head args) + elif n <= 5 then + let resultType = + applicable.Type + |> Seq.unfold (fun t -> + if not t.IsGenericType then None else + let args = t.GetGenericArguments() + if args.Length <> 2 then None else + Some (args.[1], args.[1]) + ) + |> Seq.toArray + |> (fun arr -> arr.[n - 1]) + + let adaptMethod = getFastFuncType args resultType + let adapted = Expr.Call(adaptMethod, [loop applicable]) + let invoke = adapted.Type.GetMethod("Invoke", [| for arg in args -> arg.Type |]) + Expr.Call(adapted, invoke, args) + else + (applicable, args) ||> List.fold (fun e a -> Expr.Application(e, a)) + | _ -> + orig + and countPeelableApplications expr v n = + match expr with + // v - applicable entity obtained on the prev step + // \arg -> let v1 = (f arg) in rest ==> f + | Lambda(arg, Let(v1, Application(Var f, Var arg1), rest)) when v = f && arg = arg1 -> countPeelableApplications rest v1 (n + 1) + // \arg -> (f arg) ==> f + | Lambda(arg, Application(Var f, Var arg1)) when v = f && arg = arg1 -> n + | _ -> n + loop expr +#endif + + member __.TranslateExpression q = transExpr q + + member __.TranslateQuotationToCode qexprf (paramNames: string[]) (argExprs: Expr[]) = + // Use the real variable names instead of indices, to improve output of Debug.fs + // Add let bindings for arguments to ensure that arguments will be evaluated + let vars = argExprs |> Array.mapi (fun i e -> Quotations.Var(paramNames.[i], e.Type)) + let expr = qexprf ([for v in vars -> Expr.Var v]) + + let pairs = Array.zip argExprs vars + let expr = Array.foldBack (fun (arg, var) e -> Expr.LetUnchecked(var, arg, e)) pairs expr +#if NO_GENERATIVE +#else + let expr = + if isGenerated then + let e1 = inlineRightPipe expr + let e2 = optimizeCurriedApplications e1 + let e3 = inlineValueBindings e2 + e3 + else + expr +#endif + + transExpr expr + +//------------------------------------------------------------------------------------------------- +// Generate IL code from quotations + + +#if NO_GENERATIVE +#else + +type internal ExpectedStackState = + | Empty = 1 + | Address = 2 + | Value = 3 + +type CodeGenerator(assemblyMainModule: ModuleBuilder, uniqueLambdaTypeName, + implicitCtorArgsAsFields: FieldBuilder list, + transType: Type -> Type, + transField: FieldInfo -> FieldInfo, + transMethod: MethodInfo -> MethodInfo, + transCtor: ConstructorInfo -> ConstructorInfo, + isLiteralEnumField: FieldInfo -> bool, + ilg: ILGenerator, locals:Dictionary, parameterVars) = + + let TypeBuilderInstantiationType = + let runningOnMono = try System.Type.GetType("Mono.Runtime") <> null with e -> false + let typeName = if runningOnMono then "System.Reflection.MonoGenericClass" else "System.Reflection.Emit.TypeBuilderInstantiation" + typeof.Assembly.GetType(typeName) + + // TODO: this works over FSharp.Core 4.4.0.0 types and methods. These types need to be retargeted to the target runtime. + + let GetTypeFromHandleMethod() = typeof.GetMethod("GetTypeFromHandle") + let LanguagePrimitivesType() = typedefof>.Assembly.GetType("Microsoft.FSharp.Core.LanguagePrimitives") + let ParseInt32Method() = LanguagePrimitivesType().GetMethod "ParseInt32" + let DecimalConstructor() = typeof.GetConstructor([| typeof; typeof; typeof; typeof; typeof |]) + let DateTimeConstructor() = typeof.GetConstructor([| typeof; typeof |]) + let DateTimeOffsetConstructor() = typeof.GetConstructor([| typeof; typeof |]) + let TimeSpanConstructor() = typeof.GetConstructor([|typeof|]) + + let isEmpty s = (s = ExpectedStackState.Empty) + let isAddress s = (s = ExpectedStackState.Address) + let rec emitLambda(callSiteIlg : ILGenerator, v : Quotations.Var, body : Expr, freeVars : seq, locals : Dictionary<_, LocalBuilder>, parameters) = + let lambda = assemblyMainModule.DefineType(uniqueLambdaTypeName(), TypeAttributes.Class) + let baseType = typedefof>.MakeGenericType(v.Type, body.Type) + lambda.SetParent(baseType) + let ctor = lambda.DefineDefaultConstructor(MethodAttributes.Public) + let decl = baseType.GetMethod "Invoke" + let paramTypes = [| for p in decl.GetParameters() -> p.ParameterType |] + let invoke = lambda.DefineMethod("Invoke", MethodAttributes.Virtual ||| MethodAttributes.Final ||| MethodAttributes.Public, decl.ReturnType, paramTypes) + lambda.DefineMethodOverride(invoke, decl) + + // promote free vars to fields + let fields = ResizeArray() + for v in freeVars do + let f = lambda.DefineField(v.Name, v.Type, FieldAttributes.Assembly) + fields.Add(v, f) + + let lambdaLocals = Dictionary() + + let ilg = invoke.GetILGenerator() + for (v, f) in fields do + let l = ilg.DeclareLocal(v.Type) + ilg.Emit(OpCodes.Ldarg_0) + ilg.Emit(OpCodes.Ldfld, f) + ilg.Emit(OpCodes.Stloc, l) + lambdaLocals.[v] <- l + + let expectedState = if (invoke.ReturnType = typeof) then ExpectedStackState.Empty else ExpectedStackState.Value + let lambadParamVars = [| Quotations.Var("this", lambda); v|] + let codeGen = CodeGenerator(assemblyMainModule, uniqueLambdaTypeName, implicitCtorArgsAsFields, transType, transField, transMethod, transCtor, isLiteralEnumField, ilg, lambdaLocals, lambadParamVars) + codeGen.EmitExpr (expectedState, body) + ilg.Emit(OpCodes.Ret) + + lambda.CreateType() |> ignore + + callSiteIlg.Emit(OpCodes.Newobj, ctor) + for (v, f) in fields do + callSiteIlg.Emit(OpCodes.Dup) + match locals.TryGetValue v with + | true, loc -> + callSiteIlg.Emit(OpCodes.Ldloc, loc) + | false, _ -> + let index = parameters |> Array.findIndex ((=) v) + callSiteIlg.Emit(OpCodes.Ldarg, index) + callSiteIlg.Emit(OpCodes.Stfld, f) + + and emitExpr expectedState expr = + let pop () = ilg.Emit(OpCodes.Pop) + let popIfEmptyExpected s = if isEmpty s then pop() + let emitConvIfNecessary t1 = + if t1 = typeof then + ilg.Emit(OpCodes.Conv_I2) + elif t1 = typeof then + ilg.Emit(OpCodes.Conv_U2) + elif t1 = typeof then + ilg.Emit(OpCodes.Conv_I1) + elif t1 = typeof then + ilg.Emit(OpCodes.Conv_U1) + + /// emits given expression to corresponding IL + match expr with + | ForIntegerRangeLoop(loopVar, first, last, body) -> + // for(loopVar = first..last) body + let lb = + match locals.TryGetValue loopVar with + | true, lb -> lb + | false, _ -> + let lb = ilg.DeclareLocal(transType loopVar.Type) + locals.Add(loopVar, lb) + lb + + // loopVar = first + emitExpr ExpectedStackState.Value first + ilg.Emit(OpCodes.Stloc, lb) + + let before = ilg.DefineLabel() + let after = ilg.DefineLabel() + + ilg.MarkLabel before + ilg.Emit(OpCodes.Ldloc, lb) + + emitExpr ExpectedStackState.Value last + ilg.Emit(OpCodes.Bgt, after) + + emitExpr ExpectedStackState.Empty body + + // loopVar++ + ilg.Emit(OpCodes.Ldloc, lb) + ilg.Emit(OpCodes.Ldc_I4_1) + ilg.Emit(OpCodes.Add) + ilg.Emit(OpCodes.Stloc, lb) + + ilg.Emit(OpCodes.Br, before) + ilg.MarkLabel(after) + + | NewArray(elementTy, elements) -> + ilg.Emit(OpCodes.Ldc_I4, List.length elements) + ilg.Emit(OpCodes.Newarr, transType elementTy) + + elements + |> List.iteri (fun i el -> + ilg.Emit(OpCodes.Dup) + ilg.Emit(OpCodes.Ldc_I4, i) + emitExpr ExpectedStackState.Value el + ilg.Emit(OpCodes.Stelem, transType elementTy)) + + popIfEmptyExpected expectedState + + | WhileLoop(cond, body) -> + let before = ilg.DefineLabel() + let after = ilg.DefineLabel() + + ilg.MarkLabel before + emitExpr ExpectedStackState.Value cond + ilg.Emit(OpCodes.Brfalse, after) + emitExpr ExpectedStackState.Empty body + ilg.Emit(OpCodes.Br, before) + + ilg.MarkLabel after + + | Var v -> + if isEmpty expectedState then () else + + // Try to interpret this as a method parameter + let methIdx = parameterVars |> Array.tryFindIndex (fun p -> p = v) + match methIdx with + | Some idx -> + ilg.Emit((if isAddress expectedState then OpCodes.Ldarga else OpCodes.Ldarg), idx) + | None -> + + // Try to interpret this as an implicit field in a class + let implicitCtorArgFieldOpt = implicitCtorArgsAsFields |> List.tryFind (fun f -> f.Name = v.Name) + match implicitCtorArgFieldOpt with + | Some ctorArgField -> + ilg.Emit(OpCodes.Ldarg_0) + ilg.Emit(OpCodes.Ldfld, ctorArgField) + | None -> + + // Try to interpret this as a local + match locals.TryGetValue v with + | true, localBuilder -> + ilg.Emit((if isAddress expectedState then OpCodes.Ldloca else OpCodes.Ldloc), localBuilder.LocalIndex) + | false, _ -> + failwith "unknown parameter/field" + + | Coerce (arg,ty) -> + // castClass may lead to observable side-effects - InvalidCastException + emitExpr ExpectedStackState.Value arg + let argTy = transType arg.Type + let targetTy = transType ty + if argTy.IsValueType && not targetTy.IsValueType then + ilg.Emit(OpCodes.Box, argTy) + elif not argTy.IsValueType && targetTy.IsValueType then + ilg.Emit(OpCodes.Unbox_Any, targetTy) + // emit castclass if + // - targettype is not obj (assume this is always possible for ref types) + // AND + // - HACK: targettype is TypeBuilderInstantiationType + // (its implementation of IsAssignableFrom raises NotSupportedException so it will be safer to always emit castclass) + // OR + // - not (argTy :> targetTy) + elif targetTy <> typeof && (TypeBuilderInstantiationType.Equals(targetTy.GetType()) || not (targetTy.IsAssignableFrom(argTy))) then + ilg.Emit(OpCodes.Castclass, targetTy) + + popIfEmptyExpected expectedState + + | SpecificCall <@ (-) @>(None, [t1; t2; _], [a1; a2]) -> + assert(t1 = t2) + emitExpr ExpectedStackState.Value a1 + emitExpr ExpectedStackState.Value a2 + if t1 = typeof then + ilg.Emit(OpCodes.Call, typeof.GetMethod "op_Subtraction") + else + ilg.Emit(OpCodes.Sub) + emitConvIfNecessary t1 + + popIfEmptyExpected expectedState + + | SpecificCall <@ (/) @> (None, [t1; t2; _], [a1; a2]) -> + assert (t1 = t2) + emitExpr ExpectedStackState.Value a1 + emitExpr ExpectedStackState.Value a2 + if t1 = typeof then + ilg.Emit(OpCodes.Call, typeof.GetMethod "op_Division") + else + match Type.GetTypeCode t1 with + | TypeCode.UInt32 + | TypeCode.UInt64 + | TypeCode.UInt16 + | TypeCode.Byte + | _ when t1 = typeof -> ilg.Emit (OpCodes.Div_Un) + | _ -> ilg.Emit(OpCodes.Div) + + emitConvIfNecessary t1 + + popIfEmptyExpected expectedState + + | SpecificCall <@ int @>(None, [sourceTy], [v]) -> + emitExpr ExpectedStackState.Value v + match Type.GetTypeCode(sourceTy) with + | TypeCode.String -> + ilg.Emit(OpCodes.Call, ParseInt32Method()) + | TypeCode.Single + | TypeCode.Double + | TypeCode.Int64 + | TypeCode.UInt64 + | TypeCode.UInt16 + | TypeCode.Char + | TypeCode.Byte + | _ when sourceTy = typeof || sourceTy = typeof -> + ilg.Emit(OpCodes.Conv_I4) + | TypeCode.Int32 + | TypeCode.UInt32 + | TypeCode.Int16 + | TypeCode.SByte -> () // no op + | _ -> failwith "TODO: search for op_Explicit on sourceTy" + + | SpecificCall <@ LanguagePrimitives.IntrinsicFunctions.GetArray @> (None, [ty], [arr; index]) -> + // observable side-effect - IndexOutOfRangeException + emitExpr ExpectedStackState.Value arr + emitExpr ExpectedStackState.Value index + if isAddress expectedState then + ilg.Emit(OpCodes.Readonly) + ilg.Emit(OpCodes.Ldelema, transType ty) + else + ilg.Emit(OpCodes.Ldelem, transType ty) + + popIfEmptyExpected expectedState + + | SpecificCall <@ LanguagePrimitives.IntrinsicFunctions.GetArray2D @> (None, _ty, arr::indices) + | SpecificCall <@ LanguagePrimitives.IntrinsicFunctions.GetArray3D @> (None, _ty, arr::indices) + | SpecificCall <@ LanguagePrimitives.IntrinsicFunctions.GetArray4D @> (None, _ty, arr::indices) -> + + let meth = + let name = if isAddress expectedState then "Address" else "Get" + arr.Type.GetMethod(name) + + // observable side-effect - IndexOutOfRangeException + emitExpr ExpectedStackState.Value arr + for index in indices do + emitExpr ExpectedStackState.Value index + + if isAddress expectedState then + ilg.Emit(OpCodes.Readonly) + + ilg.Emit(OpCodes.Call, meth) + + popIfEmptyExpected expectedState + + + | FieldGet (None,field) when isLiteralEnumField field -> + if expectedState <> ExpectedStackState.Empty then + emitExpr expectedState (Expr.Value(field.GetRawConstantValue(), field.FieldType.GetEnumUnderlyingType())) + + | FieldGet (objOpt,field) -> + objOpt |> Option.iter (fun e -> + let s = if e.Type.IsValueType then ExpectedStackState.Address else ExpectedStackState.Value + emitExpr s e) + let field = transField field + if field.IsStatic then + ilg.Emit(OpCodes.Ldsfld, field) + else + ilg.Emit(OpCodes.Ldfld, field) + + | FieldSet (objOpt,field,v) -> + objOpt |> Option.iter (fun e -> + let s = if e.Type.IsValueType then ExpectedStackState.Address else ExpectedStackState.Value + emitExpr s e) + emitExpr ExpectedStackState.Value v + let field = transField field + if field.IsStatic then + ilg.Emit(OpCodes.Stsfld, field) + else + ilg.Emit(OpCodes.Stfld, field) + + | Call (objOpt,meth,args) -> + objOpt |> Option.iter (fun e -> + let s = if e.Type.IsValueType then ExpectedStackState.Address else ExpectedStackState.Value + emitExpr s e) + + for pe in args do + emitExpr ExpectedStackState.Value pe + + // Handle the case where this is a generic method instantiated at a type being compiled + let mappedMeth = + if meth.IsGenericMethod then + let args = meth.GetGenericArguments() |> Array.map transType + let gmd = meth.GetGenericMethodDefinition() |> transMethod + gmd.GetGenericMethodDefinition().MakeGenericMethod args + elif meth.DeclaringType.IsGenericType then + let gdty = transType (meth.DeclaringType.GetGenericTypeDefinition()) + let gdtym = gdty.GetMethods() |> Seq.find (fun x -> x.Name = meth.Name) + assert (gdtym <> null) // ?? will never happen - if method is not found - KeyNotFoundException will be raised + let dtym = + match transType meth.DeclaringType with + | :? TypeBuilder as dty -> TypeBuilder.GetMethod(dty, gdtym) + | dty -> MethodBase.GetMethodFromHandle(meth.MethodHandle, dty.TypeHandle) :?> _ + + assert (dtym <> null) + dtym + else + transMethod meth + match objOpt with + | Some obj when mappedMeth.IsAbstract || mappedMeth.IsVirtual -> + if obj.Type.IsValueType then ilg.Emit(OpCodes.Constrained, transType obj.Type) + ilg.Emit(OpCodes.Callvirt, mappedMeth) + | _ -> + ilg.Emit(OpCodes.Call, mappedMeth) + + let returnTypeIsVoid = mappedMeth.ReturnType = typeof + match returnTypeIsVoid, (isEmpty expectedState) with + | false, true -> + // method produced something, but we don't need it + pop() + | true, false when expr.Type = typeof -> + // if we need result and method produce void and result should be unit - push null as unit value on stack + ilg.Emit(OpCodes.Ldnull) + | _ -> () + + | NewObject (ctor,args) -> + for pe in args do + emitExpr ExpectedStackState.Value pe + let meth = transCtor ctor + ilg.Emit(OpCodes.Newobj, meth) + + popIfEmptyExpected expectedState + + | Value (obj, _ty) -> + let rec emitC (v:obj) = + match v with + | :? string as x -> ilg.Emit(OpCodes.Ldstr, x) + | :? int8 as x -> ilg.Emit(OpCodes.Ldc_I4, int32 x) + | :? uint8 as x -> ilg.Emit(OpCodes.Ldc_I4, int32 (int8 x)) + | :? int16 as x -> ilg.Emit(OpCodes.Ldc_I4, int32 x) + | :? uint16 as x -> ilg.Emit(OpCodes.Ldc_I4, int32 (int16 x)) + | :? int32 as x -> ilg.Emit(OpCodes.Ldc_I4, x) + | :? uint32 as x -> ilg.Emit(OpCodes.Ldc_I4, int32 x) + | :? int64 as x -> ilg.Emit(OpCodes.Ldc_I8, x) + | :? uint64 as x -> ilg.Emit(OpCodes.Ldc_I8, int64 x) + | :? char as x -> ilg.Emit(OpCodes.Ldc_I4, int32 x) + | :? bool as x -> ilg.Emit(OpCodes.Ldc_I4, if x then 1 else 0) + | :? float32 as x -> ilg.Emit(OpCodes.Ldc_R4, x) + | :? float as x -> ilg.Emit(OpCodes.Ldc_R8, x) +#if FX_NO_GET_ENUM_UNDERLYING_TYPE +#else + | :? System.Enum as x when x.GetType().GetEnumUnderlyingType() = typeof -> ilg.Emit(OpCodes.Ldc_I4, unbox v) +#endif + | :? Type as ty -> + ilg.Emit(OpCodes.Ldtoken, transType ty) + ilg.Emit(OpCodes.Call, GetTypeFromHandleMethod()) + | :? decimal as x -> + let bits = System.Decimal.GetBits x + ilg.Emit(OpCodes.Ldc_I4, bits.[0]) + ilg.Emit(OpCodes.Ldc_I4, bits.[1]) + ilg.Emit(OpCodes.Ldc_I4, bits.[2]) + do + let sign = (bits.[3] &&& 0x80000000) <> 0 + ilg.Emit(if sign then OpCodes.Ldc_I4_1 else OpCodes.Ldc_I4_0) + do + let scale = byte ((bits.[3] >>> 16) &&& 0x7F) + ilg.Emit(OpCodes.Ldc_I4_S, scale) + ilg.Emit(OpCodes.Newobj, DecimalConstructor()) + | :? DateTime as x -> + ilg.Emit(OpCodes.Ldc_I8, x.Ticks) + ilg.Emit(OpCodes.Ldc_I4, int x.Kind) + ilg.Emit(OpCodes.Newobj, DateTimeConstructor()) + | :? DateTimeOffset as x -> + ilg.Emit(OpCodes.Ldc_I8, x.Ticks) + ilg.Emit(OpCodes.Ldc_I8, x.Offset.Ticks) + ilg.Emit(OpCodes.Newobj, TimeSpanConstructor()) + ilg.Emit(OpCodes.Newobj, DateTimeOffsetConstructor()) + | null -> ilg.Emit(OpCodes.Ldnull) + | _ -> failwithf "unknown constant '%A' in generated method" v + if isEmpty expectedState then () + else emitC obj + + | Let(v,e,b) -> + let lb = ilg.DeclareLocal (transType v.Type) + locals.Add (v, lb) + emitExpr ExpectedStackState.Value e + ilg.Emit(OpCodes.Stloc, lb.LocalIndex) + emitExpr expectedState b + + | Sequential(e1, e2) -> + emitExpr ExpectedStackState.Empty e1 + emitExpr expectedState e2 + + | IfThenElse(cond, ifTrue, ifFalse) -> + let ifFalseLabel = ilg.DefineLabel() + let endLabel = ilg.DefineLabel() + + emitExpr ExpectedStackState.Value cond + + ilg.Emit(OpCodes.Brfalse, ifFalseLabel) + + emitExpr expectedState ifTrue + ilg.Emit(OpCodes.Br, endLabel) + + ilg.MarkLabel(ifFalseLabel) + emitExpr expectedState ifFalse + + ilg.Emit(OpCodes.Nop) + ilg.MarkLabel(endLabel) + + | TryWith(body, _filterVar, _filterBody, catchVar, catchBody) -> + + let stres, ldres = + if isEmpty expectedState then ignore, ignore + else + let local = ilg.DeclareLocal (transType body.Type) + let stres = fun () -> ilg.Emit(OpCodes.Stloc, local) + let ldres = fun () -> ilg.Emit(OpCodes.Ldloc, local) + stres, ldres + + let exceptionVar = ilg.DeclareLocal(transType catchVar.Type) + locals.Add(catchVar, exceptionVar) + + let _exnBlock = ilg.BeginExceptionBlock() + + emitExpr expectedState body + stres() + + ilg.BeginCatchBlock(transType catchVar.Type) + ilg.Emit(OpCodes.Stloc, exceptionVar) + emitExpr expectedState catchBody + stres() + ilg.EndExceptionBlock() + + ldres() + + | VarSet(v,e) -> + emitExpr ExpectedStackState.Value e + match locals.TryGetValue v with + | true, localBuilder -> + ilg.Emit(OpCodes.Stloc, localBuilder.LocalIndex) + | false, _ -> + failwith "unknown parameter/field in assignment. Only assignments to locals are currently supported by TypeProviderEmit" + | Lambda(v, body) -> + emitLambda(ilg, v, body, expr.GetFreeVars(), locals, parameterVars) + popIfEmptyExpected expectedState + | n -> + failwith (sprintf "unknown expression '%A' in generated method" n) + + member __.EmitExpr (expectedState, expr) = emitExpr expectedState expr + +#endif + +[] +module internal Misc = + + + let nonNull str x = if x=null then failwith ("Null in " + str) else x + + let notRequired opname item = + let msg = sprintf "The operation '%s' on item '%s' should not be called on provided type, member or parameter" opname item + System.Diagnostics.Debug.Assert (false, msg) + raise (System.NotSupportedException msg) + + let mkParamArrayCustomAttributeData() = +#if FX_NO_CUSTOMATTRIBUTEDATA + { new IProvidedCustomAttributeData with +#else + { new CustomAttributeData() with +#endif + member __.Constructor = typeof.GetConstructors().[0] + member __.ConstructorArguments = upcast [| |] + member __.NamedArguments = upcast [| |] } + +#if FX_NO_CUSTOMATTRIBUTEDATA + let CustomAttributeTypedArgument(ty,v) = + { new IProvidedCustomAttributeTypedArgument with + member x.ArgumentType = ty + member x.Value = v } + let CustomAttributeNamedArgument(memb,arg:IProvidedCustomAttributeTypedArgument) = + { new IProvidedCustomAttributeNamedArgument with + member x.MemberInfo = memb + member x.ArgumentType = arg.ArgumentType + member x.TypedValue = arg } + type CustomAttributeData = Microsoft.FSharp.Core.CompilerServices.IProvidedCustomAttributeData +#endif + + let mkEditorHideMethodsCustomAttributeData() = +#if FX_NO_CUSTOMATTRIBUTEDATA + { new IProvidedCustomAttributeData with +#else + { new CustomAttributeData() with +#endif + member __.Constructor = typeof.GetConstructors().[0] + member __.ConstructorArguments = upcast [| |] + member __.NamedArguments = upcast [| |] } + + let mkAllowNullLiteralCustomAttributeData value = +#if FX_NO_CUSTOMATTRIBUTEDATA + { new IProvidedCustomAttributeData with +#else + { new CustomAttributeData() with +#endif + member __.Constructor = typeof.GetConstructors().[0] + member __.ConstructorArguments = upcast [| CustomAttributeTypedArgument(typeof, value) |] + member __.NamedArguments = upcast [| |] } + + /// This makes an xml doc attribute w.r.t. an amortized computation of an xml doc string. + /// It is important that the text of the xml doc only get forced when poking on the ConstructorArguments + /// for the CustomAttributeData object. + let mkXmlDocCustomAttributeDataLazy(lazyText: Lazy) = +#if FX_NO_CUSTOMATTRIBUTEDATA + { new IProvidedCustomAttributeData with +#else + { new CustomAttributeData() with +#endif + member __.Constructor = typeof.GetConstructors().[0] + member __.ConstructorArguments = upcast [| CustomAttributeTypedArgument(typeof, lazyText.Force()) |] + member __.NamedArguments = upcast [| |] } + + let mkXmlDocCustomAttributeData(s:string) = mkXmlDocCustomAttributeDataLazy (lazy s) + + let mkDefinitionLocationAttributeCustomAttributeData(line:int,column:int,filePath:string) = +#if FX_NO_CUSTOMATTRIBUTEDATA + { new IProvidedCustomAttributeData with +#else + { new CustomAttributeData() with +#endif + member __.Constructor = typeof.GetConstructors().[0] + member __.ConstructorArguments = upcast [| |] + member __.NamedArguments = + upcast [| CustomAttributeNamedArgument(typeof.GetProperty("FilePath"), CustomAttributeTypedArgument(typeof, filePath)); + CustomAttributeNamedArgument(typeof.GetProperty("Line"), CustomAttributeTypedArgument(typeof, line)) ; + CustomAttributeNamedArgument(typeof.GetProperty("Column"), CustomAttributeTypedArgument(typeof, column)) + |] } + let mkObsoleteAttributeCustomAttributeData(message:string, isError: bool) = +#if FX_NO_CUSTOMATTRIBUTEDATA + { new IProvidedCustomAttributeData with +#else + { new CustomAttributeData() with +#endif + member __.Constructor = typeof.GetConstructors() |> Array.find (fun x -> x.GetParameters().Length = 2) + member __.ConstructorArguments = upcast [|CustomAttributeTypedArgument(typeof, message) ; CustomAttributeTypedArgument(typeof, isError) |] + member __.NamedArguments = upcast [| |] } + + let mkReflectedDefinitionCustomAttributeData() = +#if FX_NO_CUSTOMATTRIBUTEDATA + { new IProvidedCustomAttributeData with +#else + { new CustomAttributeData() with +#endif + member __.Constructor = typeof.GetConstructors().[0] + member __.ConstructorArguments = upcast [| |] + member __.NamedArguments = upcast [| |] } + + type CustomAttributesImpl() = + let customAttributes = ResizeArray() + let mutable hideObjectMethods = false + let mutable nonNullable = false + let mutable obsoleteMessage = None + let mutable xmlDocDelayed = None + let mutable xmlDocAlwaysRecomputed = None + let mutable hasParamArray = false + let mutable hasReflectedDefinition = false + + // XML doc text that we only compute once, if any. This must _not_ be forced until the ConstructorArguments + // property of the custom attribute is foced. + let xmlDocDelayedText = + lazy + (match xmlDocDelayed with None -> assert false; "" | Some f -> f()) + + // Custom atttributes that we only compute once + let customAttributesOnce = + lazy + [| if hideObjectMethods then yield mkEditorHideMethodsCustomAttributeData() + if nonNullable then yield mkAllowNullLiteralCustomAttributeData false + match xmlDocDelayed with None -> () | Some _ -> customAttributes.Add(mkXmlDocCustomAttributeDataLazy xmlDocDelayedText) + match obsoleteMessage with None -> () | Some s -> customAttributes.Add(mkObsoleteAttributeCustomAttributeData s) + if hasParamArray then yield mkParamArrayCustomAttributeData() + if hasReflectedDefinition then yield mkReflectedDefinitionCustomAttributeData() + yield! customAttributes |] + + member __.AddDefinitionLocation(line:int,column:int,filePath:string) = customAttributes.Add(mkDefinitionLocationAttributeCustomAttributeData(line, column, filePath)) + member __.AddObsolete(message : string, isError) = obsoleteMessage <- Some (message,isError) + member __.HasParamArray with get() = hasParamArray and set(v) = hasParamArray <- v + member __.HasReflectedDefinition with get() = hasReflectedDefinition and set(v) = hasReflectedDefinition <- v + member __.AddXmlDocComputed xmlDocFunction = xmlDocAlwaysRecomputed <- Some xmlDocFunction + member __.AddXmlDocDelayed xmlDocFunction = xmlDocDelayed <- Some xmlDocFunction + member __.AddXmlDoc xmlDoc = xmlDocDelayed <- Some (fun () -> xmlDoc) + member __.HideObjectMethods with set v = hideObjectMethods <- v + member __.NonNullable with set v = nonNullable <- v + member __.AddCustomAttribute(attribute) = customAttributes.Add(attribute) + member __.GetCustomAttributesData() = + [| yield! customAttributesOnce.Force() + match xmlDocAlwaysRecomputed with None -> () | Some f -> customAttributes.Add(mkXmlDocCustomAttributeData (f())) |] + :> IList<_> + + + let adjustTypeAttributes attributes isNested = + let visibilityAttributes = + match attributes &&& TypeAttributes.VisibilityMask with + | TypeAttributes.Public when isNested -> TypeAttributes.NestedPublic + | TypeAttributes.NotPublic when isNested -> TypeAttributes.NestedAssembly + | TypeAttributes.NestedPublic when not isNested -> TypeAttributes.Public + | TypeAttributes.NestedAssembly + | TypeAttributes.NestedPrivate + | TypeAttributes.NestedFamORAssem + | TypeAttributes.NestedFamily + | TypeAttributes.NestedFamANDAssem when not isNested -> TypeAttributes.NotPublic + | a -> a + (attributes &&& ~~~TypeAttributes.VisibilityMask) ||| visibilityAttributes + + + +type ProvidedStaticParameter(parameterName:string,parameterType:Type,?parameterDefaultValue:obj) = + inherit System.Reflection.ParameterInfo() + + let customAttributesImpl = CustomAttributesImpl() + + member __.AddXmlDocDelayed xmlDocFunction = customAttributesImpl.AddXmlDocDelayed xmlDocFunction + member __.AddXmlDocComputed xmlDocFunction = customAttributesImpl.AddXmlDocComputed xmlDocFunction + member __.AddXmlDoc xmlDoc = customAttributesImpl.AddXmlDoc xmlDoc + + override __.RawDefaultValue = defaultArg parameterDefaultValue null + override __.Attributes = if parameterDefaultValue.IsNone then enum 0 else ParameterAttributes.Optional + override __.Position = 0 + override __.ParameterType = parameterType + override __.Name = parameterName + + override __.GetCustomAttributes(_inherit) = ignore(_inherit); notRequired "GetCustomAttributes" parameterName + override __.GetCustomAttributes(_attributeType, _inherit) = notRequired "GetCustomAttributes" parameterName + +type ProvidedParameter(name:string,parameterType:Type,?isOut:bool,?optionalValue:obj) = + inherit System.Reflection.ParameterInfo() + let customAttributesImpl = CustomAttributesImpl() + let isOut = defaultArg isOut false + member __.IsParamArray with get() = customAttributesImpl.HasParamArray and set(v) = customAttributesImpl.HasParamArray <- v + member __.IsReflectedDefinition with get() = customAttributesImpl.HasReflectedDefinition and set(v) = customAttributesImpl.HasReflectedDefinition <- v + override __.Name = name + override __.ParameterType = parameterType + override __.Attributes = (base.Attributes ||| (if isOut then ParameterAttributes.Out else enum 0) + ||| (match optionalValue with None -> enum 0 | Some _ -> ParameterAttributes.Optional ||| ParameterAttributes.HasDefault)) + override __.RawDefaultValue = defaultArg optionalValue null + member __.HasDefaultParameterValue = Option.isSome optionalValue + member __.GetCustomAttributesDataImpl() = customAttributesImpl.GetCustomAttributesData() +#if FX_NO_CUSTOMATTRIBUTEDATA +#else + override __.GetCustomAttributesData() = customAttributesImpl.GetCustomAttributesData() +#endif + +type ProvidedConstructor(parameters : ProvidedParameter list) = + inherit ConstructorInfo() + let parameters = parameters |> List.map (fun p -> p :> ParameterInfo) + let mutable baseCall = None + + let mutable declaringType = null : System.Type + let mutable invokeCode = None : option Expr> + let mutable isImplicitCtor = false + let mutable ctorAttributes = MethodAttributes.Public ||| MethodAttributes.RTSpecialName + let nameText () = sprintf "constructor for %s" (if declaringType=null then "" else declaringType.FullName) + let isStatic() = ctorAttributes.HasFlag(MethodAttributes.Static) + + let customAttributesImpl = CustomAttributesImpl() + member __.IsTypeInitializer + with get() = isStatic() && ctorAttributes.HasFlag(MethodAttributes.Private) + and set(v) = + let typeInitializerAttributes = MethodAttributes.Static ||| MethodAttributes.Private + ctorAttributes <- if v then ctorAttributes ||| typeInitializerAttributes else ctorAttributes &&& ~~~typeInitializerAttributes + + member __.AddXmlDocComputed xmlDocFunction = customAttributesImpl.AddXmlDocComputed xmlDocFunction + member __.AddXmlDocDelayed xmlDocFunction = customAttributesImpl.AddXmlDocDelayed xmlDocFunction + member __.AddXmlDoc xmlDoc = customAttributesImpl.AddXmlDoc xmlDoc + member __.AddObsoleteAttribute (message,?isError) = customAttributesImpl.AddObsolete (message,defaultArg isError false) + member __.AddDefinitionLocation(line,column,filePath) = customAttributesImpl.AddDefinitionLocation(line, column, filePath) + member __.GetCustomAttributesDataImpl() = customAttributesImpl.GetCustomAttributesData() +#if FX_NO_CUSTOMATTRIBUTEDATA +#else + override __.GetCustomAttributesData() = customAttributesImpl.GetCustomAttributesData() +#endif + + member __.DeclaringTypeImpl + with set x = + if declaringType<>null then failwith (sprintf "ProvidedConstructor: declaringType already set on '%s'" (nameText())); + declaringType <- x + + member __.InvokeCode + with set (q:Expr list -> Expr) = + match invokeCode with + | None -> invokeCode <- Some q + | Some _ -> failwith (sprintf "ProvidedConstructor: code already given for '%s'" (nameText())) + + member __.BaseConstructorCall + with set (d:Expr list -> (ConstructorInfo * Expr list)) = + match baseCall with + | None -> baseCall <- Some d + | Some _ -> failwith (sprintf "ProvidedConstructor: base call already given for '%s'" (nameText())) + + member __.GetInvokeCodeInternal isGenerated = + match invokeCode with + | Some f -> + // FSharp.Data change: use the real variable names instead of indices, to improve output of Debug.fs + let paramNames = + parameters + |> List.map (fun p -> p.Name) + |> List.append (if not isGenerated || isStatic() then [] else ["this"]) + |> Array.ofList + QuotationSimplifier(isGenerated).TranslateQuotationToCode f paramNames + | None -> failwith (sprintf "ProvidedConstructor: no invoker for '%s'" (nameText())) + + member __.GetBaseConstructorCallInternal isGenerated = + match baseCall with + | Some f -> Some(fun ctorArgs -> let c,baseCtorArgExprs = f ctorArgs in c, List.map (QuotationSimplifier(isGenerated).TranslateExpression) baseCtorArgExprs) + | None -> None + + member __.IsImplicitCtor with get() = isImplicitCtor and set v = isImplicitCtor <- v + + // Implement overloads + override __.GetParameters() = parameters |> List.toArray + override __.Attributes = ctorAttributes + override __.Name = if isStatic() then ".cctor" else ".ctor" + override __.DeclaringType = declaringType |> nonNull "ProvidedConstructor.DeclaringType" + override __.IsDefined(_attributeType, _inherit) = true + + override __.Invoke(_invokeAttr, _binder, _parameters, _culture) = notRequired "Invoke" (nameText()) + override __.Invoke(_obj, _invokeAttr, _binder, _parameters, _culture) = notRequired "Invoke" (nameText()) + override __.ReflectedType = notRequired "ReflectedType" (nameText()) + override __.GetMethodImplementationFlags() = notRequired "GetMethodImplementationFlags" (nameText()) + override __.MethodHandle = notRequired "MethodHandle" (nameText()) + override __.GetCustomAttributes(_inherit) = notRequired "GetCustomAttributes" (nameText()) + override __.GetCustomAttributes(_attributeType, _inherit) = notRequired "GetCustomAttributes" (nameText()) + +type ProvidedMethod(methodName: string, parameters: ProvidedParameter list, returnType: Type) = + inherit System.Reflection.MethodInfo() + let argParams = parameters |> List.map (fun p -> p :> ParameterInfo) + + // State + let mutable declaringType : Type = null + let mutable methodAttrs = MethodAttributes.Public + let mutable invokeCode = None : option Expr> + let mutable staticParams = [ ] + let mutable staticParamsApply = None + let isStatic() = methodAttrs.HasFlag(MethodAttributes.Static) + let customAttributesImpl = CustomAttributesImpl() + + member __.AddXmlDocComputed xmlDocFunction = customAttributesImpl.AddXmlDocComputed xmlDocFunction + member __.AddXmlDocDelayed xmlDocFunction = customAttributesImpl.AddXmlDocDelayed xmlDocFunction + member __.AddXmlDoc xmlDoc = customAttributesImpl.AddXmlDoc xmlDoc + member __.AddObsoleteAttribute (message,?isError) = customAttributesImpl.AddObsolete (message,defaultArg isError false) + member __.AddDefinitionLocation(line,column,filePath) = customAttributesImpl.AddDefinitionLocation(line, column, filePath) + member __.AddCustomAttribute(attribute) = customAttributesImpl.AddCustomAttribute(attribute) + member __.GetCustomAttributesDataImpl() = customAttributesImpl.GetCustomAttributesData() +#if FX_NO_CUSTOMATTRIBUTEDATA +#else + override __.GetCustomAttributesData() = customAttributesImpl.GetCustomAttributesData() +#endif + + member __.SetMethodAttrs m = methodAttrs <- m + member __.AddMethodAttrs m = methodAttrs <- methodAttrs ||| m + member __.DeclaringTypeImpl with set x = declaringType <- x // check: not set twice + member __.IsStaticMethod + with get() = isStatic() + and set x = if x then methodAttrs <- methodAttrs ||| MethodAttributes.Static + else methodAttrs <- methodAttrs &&& (~~~ MethodAttributes.Static) + + member __.InvokeCode + with set (q:Expr list -> Expr) = + match invokeCode with + | None -> invokeCode <- Some q + | Some _ -> failwith (sprintf "ProvidedConstructor: code already given for %s on type %s" methodName (if declaringType=null then "" else declaringType.FullName)) + + + /// Abstract a type to a parametric-type. Requires "formal parameters" and "instantiation function". + member __.DefineStaticParameters(staticParameters : list, apply : (string -> obj[] -> ProvidedMethod)) = + staticParams <- staticParameters + staticParamsApply <- Some apply + + /// Get ParameterInfo[] for the parametric type parameters (//s GetGenericParameters) + member __.GetStaticParameters() = [| for p in staticParams -> p :> ParameterInfo |] + + /// Instantiate parametrics type + member __.ApplyStaticArguments(mangledName:string, args:obj[]) = + if staticParams.Length>0 then + if staticParams.Length <> args.Length then + failwith (sprintf "ProvidedTypeDefinition: expecting %d static parameters but given %d for method %s" staticParams.Length args.Length methodName) + match staticParamsApply with + | None -> failwith "ProvidedTypeDefinition: DefineStaticParameters was not called" + | Some f -> f mangledName args + else + failwith (sprintf "ProvidedTypeDefinition: static parameters supplied but not expected for method %s" methodName) + + member __.GetInvokeCodeInternal isGenerated = + match invokeCode with + | Some f -> + // FSharp.Data change: use the real variable names instead of indices, to improve output of Debug.fs + let paramNames = + parameters + |> List.map (fun p -> p.Name) + |> List.append (if isStatic() then [] else ["this"]) + |> Array.ofList + QuotationSimplifier(isGenerated).TranslateQuotationToCode f paramNames + | None -> failwith (sprintf "ProvidedMethod: no invoker for %s on type %s" methodName (if declaringType=null then "" else declaringType.FullName)) + + // Implement overloads + override __.GetParameters() = argParams |> Array.ofList + override __.Attributes = methodAttrs + override __.Name = methodName + override __.DeclaringType = declaringType |> nonNull "ProvidedMethod.DeclaringType" + override __.IsDefined(_attributeType, _inherit) : bool = true + override __.MemberType = MemberTypes.Method + override __.CallingConvention = + let cc = CallingConventions.Standard + let cc = if not (isStatic()) then cc ||| CallingConventions.HasThis else cc + cc + override __.ReturnType = returnType + override __.ReturnParameter = null // REVIEW: Give it a name and type? + override __.ToString() = "Method " + methodName + + // These don't have to return fully accurate results - they are used + // by the F# Quotations library function SpecificCall as a pre-optimization + // when comparing methods + override __.MetadataToken = hash declaringType + hash methodName + override __.MethodHandle = RuntimeMethodHandle() + + override __.ReturnTypeCustomAttributes = notRequired "ReturnTypeCustomAttributes" methodName + override __.GetBaseDefinition() = notRequired "GetBaseDefinition" methodName + override __.GetMethodImplementationFlags() = notRequired "GetMethodImplementationFlags" methodName + override __.Invoke(_obj, _invokeAttr, _binder, _parameters, _culture) = notRequired "Invoke" methodName + override __.ReflectedType = notRequired "ReflectedType" methodName + override __.GetCustomAttributes(_inherit) = notRequired "GetCustomAttributes" methodName + override __.GetCustomAttributes(_attributeType, _inherit) = notRequired "GetCustomAttributes" methodName + + +type ProvidedProperty(propertyName: string, propertyType: Type, ?parameters: ProvidedParameter list) = + inherit System.Reflection.PropertyInfo() + // State + + let parameters = defaultArg parameters [] + let mutable declaringType = null + let mutable isStatic = false + let mutable getterCode = None : option Expr> + let mutable setterCode = None : option Expr> + + let hasGetter() = getterCode.IsSome + let hasSetter() = setterCode.IsSome + + // Delay construction - to pick up the latest isStatic + let markSpecialName (m:ProvidedMethod) = m.AddMethodAttrs(MethodAttributes.SpecialName); m + let getter = lazy (ProvidedMethod("get_" + propertyName,parameters,propertyType,IsStaticMethod=isStatic,DeclaringTypeImpl=declaringType,InvokeCode=getterCode.Value) |> markSpecialName) + let setter = lazy (ProvidedMethod("set_" + propertyName,parameters @ [ProvidedParameter("value",propertyType)],typeof,IsStaticMethod=isStatic,DeclaringTypeImpl=declaringType,InvokeCode=setterCode.Value) |> markSpecialName) + + let customAttributesImpl = CustomAttributesImpl() + member __.AddXmlDocComputed xmlDocFunction = customAttributesImpl.AddXmlDocComputed xmlDocFunction + member __.AddXmlDocDelayed xmlDocFunction = customAttributesImpl.AddXmlDocDelayed xmlDocFunction + member __.AddXmlDoc xmlDoc = customAttributesImpl.AddXmlDoc xmlDoc + member __.AddObsoleteAttribute (message,?isError) = customAttributesImpl.AddObsolete (message,defaultArg isError false) + member __.AddDefinitionLocation(line,column,filePath) = customAttributesImpl.AddDefinitionLocation(line, column, filePath) + member __.GetCustomAttributesDataImpl() = customAttributesImpl.GetCustomAttributesData() + member __.AddCustomAttribute attribute = customAttributesImpl.AddCustomAttribute attribute +#if FX_NO_CUSTOMATTRIBUTEDATA +#else + override __.GetCustomAttributesData() = customAttributesImpl.GetCustomAttributesData() +#endif + + member __.DeclaringTypeImpl with set x = declaringType <- x // check: not set twice + + member __.IsStatic + with get() = isStatic + and set x = isStatic <- x + + member __.GetterCode + with set (q:Expr list -> Expr) = + if not getter.IsValueCreated then getterCode <- Some q else failwith "ProvidedProperty: getter MethodInfo has already been created" + + member __.SetterCode + with set (q:Expr list -> Expr) = + if not (setter.IsValueCreated) then setterCode <- Some q else failwith "ProvidedProperty: setter MethodInfo has already been created" + + // Implement overloads + override __.PropertyType = propertyType + override __.SetValue(_obj, _value, _invokeAttr, _binder, _index, _culture) = notRequired "SetValue" propertyName + override __.GetAccessors _nonPublic = notRequired "nonPublic" propertyName + override __.GetGetMethod _nonPublic = if hasGetter() then getter.Force() :> MethodInfo else null + override __.GetSetMethod _nonPublic = if hasSetter() then setter.Force() :> MethodInfo else null + override __.GetIndexParameters() = [| for p in parameters -> upcast p |] + override __.Attributes = PropertyAttributes.None + override __.CanRead = hasGetter() + override __.CanWrite = hasSetter() + override __.GetValue(_obj, _invokeAttr, _binder, _index, _culture) : obj = notRequired "GetValue" propertyName + override __.Name = propertyName + override __.DeclaringType = declaringType |> nonNull "ProvidedProperty.DeclaringType" + override __.MemberType : MemberTypes = MemberTypes.Property + + override __.ReflectedType = notRequired "ReflectedType" propertyName + override __.GetCustomAttributes(_inherit) = notRequired "GetCustomAttributes" propertyName + override __.GetCustomAttributes(_attributeType, _inherit) = notRequired "GetCustomAttributes" propertyName + override __.IsDefined(_attributeType, _inherit) = notRequired "IsDefined" propertyName + +type ProvidedEvent(eventName:string,eventHandlerType:Type) = + inherit System.Reflection.EventInfo() + // State + + let mutable declaringType = null + let mutable isStatic = false + let mutable adderCode = None : option Expr> + let mutable removerCode = None : option Expr> + + // Delay construction - to pick up the latest isStatic + let markSpecialName (m:ProvidedMethod) = m.AddMethodAttrs(MethodAttributes.SpecialName); m + let adder = lazy (ProvidedMethod("add_" + eventName, [ProvidedParameter("handler", eventHandlerType)],typeof,IsStaticMethod=isStatic,DeclaringTypeImpl=declaringType,InvokeCode=adderCode.Value) |> markSpecialName) + let remover = lazy (ProvidedMethod("remove_" + eventName, [ProvidedParameter("handler", eventHandlerType)],typeof,IsStaticMethod=isStatic,DeclaringTypeImpl=declaringType,InvokeCode=removerCode.Value) |> markSpecialName) + + let customAttributesImpl = CustomAttributesImpl() + member __.AddXmlDocComputed xmlDocFunction = customAttributesImpl.AddXmlDocComputed xmlDocFunction + member __.AddXmlDocDelayed xmlDocFunction = customAttributesImpl.AddXmlDocDelayed xmlDocFunction + member __.AddXmlDoc xmlDoc = customAttributesImpl.AddXmlDoc xmlDoc + member __.AddDefinitionLocation(line,column,filePath) = customAttributesImpl.AddDefinitionLocation(line, column, filePath) + member __.GetCustomAttributesDataImpl() = customAttributesImpl.GetCustomAttributesData() +#if FX_NO_CUSTOMATTRIBUTEDATA +#else + override __.GetCustomAttributesData() = customAttributesImpl.GetCustomAttributesData() +#endif + + member __.DeclaringTypeImpl with set x = declaringType <- x // check: not set twice + member __.IsStatic + with get() = isStatic + and set x = isStatic <- x + + member __.AdderCode + with get() = adderCode.Value + and set f = + if not adder.IsValueCreated then adderCode <- Some f else failwith "ProvidedEvent: Add MethodInfo has already been created" + + member __.RemoverCode + with get() = removerCode.Value + and set f = + if not (remover.IsValueCreated) then removerCode <- Some f else failwith "ProvidedEvent: Remove MethodInfo has already been created" + + // Implement overloads + override __.EventHandlerType = eventHandlerType + override __.GetAddMethod _nonPublic = adder.Force() :> MethodInfo + override __.GetRemoveMethod _nonPublic = remover.Force() :> MethodInfo + override __.Attributes = EventAttributes.None + override __.Name = eventName + override __.DeclaringType = declaringType |> nonNull "ProvidedEvent.DeclaringType" + override __.MemberType : MemberTypes = MemberTypes.Event + + override __.GetRaiseMethod _nonPublic = notRequired "GetRaiseMethod" eventName + override __.ReflectedType = notRequired "ReflectedType" eventName + override __.GetCustomAttributes(_inherit) = notRequired "GetCustomAttributes" eventName + override __.GetCustomAttributes(_attributeType, _inherit) = notRequired "GetCustomAttributes" eventName + override __.IsDefined(_attributeType, _inherit) = notRequired "IsDefined" eventName + +type ProvidedLiteralField(fieldName:string,fieldType:Type,literalValue:obj) = + inherit System.Reflection.FieldInfo() + // State + + let mutable declaringType = null + + let customAttributesImpl = CustomAttributesImpl() + member __.AddXmlDocComputed xmlDocFunction = customAttributesImpl.AddXmlDocComputed xmlDocFunction + member __.AddXmlDocDelayed xmlDocFunction = customAttributesImpl.AddXmlDocDelayed xmlDocFunction + member __.AddXmlDoc xmlDoc = customAttributesImpl.AddXmlDoc xmlDoc + member __.AddObsoleteAttribute (message,?isError) = customAttributesImpl.AddObsolete (message,defaultArg isError false) + member __.AddDefinitionLocation(line,column,filePath) = customAttributesImpl.AddDefinitionLocation(line, column, filePath) + member __.GetCustomAttributesDataImpl() = customAttributesImpl.GetCustomAttributesData() +#if FX_NO_CUSTOMATTRIBUTEDATA +#else + override __.GetCustomAttributesData() = customAttributesImpl.GetCustomAttributesData() +#endif + + member __.DeclaringTypeImpl with set x = declaringType <- x // check: not set twice + + + // Implement overloads + override __.FieldType = fieldType + override __.GetRawConstantValue() = literalValue + override __.Attributes = FieldAttributes.Static ||| FieldAttributes.Literal ||| FieldAttributes.Public + override __.Name = fieldName + override __.DeclaringType = declaringType |> nonNull "ProvidedLiteralField.DeclaringType" + override __.MemberType : MemberTypes = MemberTypes.Field + + override __.ReflectedType = notRequired "ReflectedType" fieldName + override __.GetCustomAttributes(_inherit) = notRequired "GetCustomAttributes" fieldName + override __.GetCustomAttributes(_attributeType, _inherit) = notRequired "GetCustomAttributes" fieldName + override __.IsDefined(_attributeType, _inherit) = notRequired "IsDefined" fieldName + + override __.SetValue(_obj, _value, _invokeAttr, _binder, _culture) = notRequired "SetValue" fieldName + override __.GetValue(_obj) : obj = notRequired "GetValue" fieldName + override __.FieldHandle = notRequired "FieldHandle" fieldName + +type ProvidedField(fieldName:string,fieldType:Type) = + inherit System.Reflection.FieldInfo() + // State + + let mutable declaringType = null + + let customAttributesImpl = CustomAttributesImpl() + let mutable fieldAttrs = FieldAttributes.Private + member __.AddXmlDocComputed xmlDocFunction = customAttributesImpl.AddXmlDocComputed xmlDocFunction + member __.AddXmlDocDelayed xmlDocFunction = customAttributesImpl.AddXmlDocDelayed xmlDocFunction + member __.AddXmlDoc xmlDoc = customAttributesImpl.AddXmlDoc xmlDoc + member __.AddObsoleteAttribute (message,?isError) = customAttributesImpl.AddObsolete (message,defaultArg isError false) + member __.AddDefinitionLocation(line,column,filePath) = customAttributesImpl.AddDefinitionLocation(line, column, filePath) + member __.GetCustomAttributesDataImpl() = customAttributesImpl.GetCustomAttributesData() +#if FX_NO_CUSTOMATTRIBUTEDATA +#else + override __.GetCustomAttributesData() = customAttributesImpl.GetCustomAttributesData() +#endif + + member __.DeclaringTypeImpl with set x = declaringType <- x // check: not set twice + + member __.SetFieldAttributes attrs = fieldAttrs <- attrs + // Implement overloads + override __.FieldType = fieldType + override __.GetRawConstantValue() = null + override __.Attributes = fieldAttrs + override __.Name = fieldName + override __.DeclaringType = declaringType |> nonNull "ProvidedField.DeclaringType" + override __.MemberType : MemberTypes = MemberTypes.Field + + override __.ReflectedType = notRequired "ReflectedType" fieldName + override __.GetCustomAttributes(_inherit) = notRequired "GetCustomAttributes" fieldName + override __.GetCustomAttributes(_attributeType, _inherit) = notRequired "GetCustomAttributes" fieldName + override __.IsDefined(_attributeType, _inherit) = notRequired "IsDefined" fieldName + + override __.SetValue(_obj, _value, _invokeAttr, _binder, _culture) = notRequired "SetValue" fieldName + override __.GetValue(_obj) : obj = notRequired "GetValue" fieldName + override __.FieldHandle = notRequired "FieldHandle" fieldName + +/// Represents the type constructor in a provided symbol type. +[] +type ProvidedSymbolKind = + | SDArray + | Array of int + | Pointer + | ByRef + | Generic of System.Type + | FSharpTypeAbbreviation of (System.Reflection.Assembly * string * string[]) + + +/// Represents an array or other symbolic type involving a provided type as the argument. +/// See the type provider spec for the methods that must be implemented. +/// Note that the type provider specification does not require us to implement pointer-equality for provided types. +type ProvidedSymbolType(kind: ProvidedSymbolKind, args: Type list, convToTgt: Type -> Type) = + inherit Type() + + let rec isEquivalentTo (thisTy: Type) (otherTy: Type) = + match thisTy, otherTy with + | (:? ProvidedSymbolType as thisTy), (:? ProvidedSymbolType as thatTy) -> (thisTy.Kind,thisTy.Args) = (thatTy.Kind, thatTy.Args) + | (:? ProvidedSymbolType as thisTy), otherTy | otherTy, (:? ProvidedSymbolType as thisTy) -> + match thisTy.Kind, thisTy.Args with + | ProvidedSymbolKind.SDArray, [ty] | ProvidedSymbolKind.Array _, [ty] when otherTy.IsArray-> ty.Equals(otherTy.GetElementType()) + | ProvidedSymbolKind.ByRef, [ty] when otherTy.IsByRef -> ty.Equals(otherTy.GetElementType()) + | ProvidedSymbolKind.Pointer, [ty] when otherTy.IsPointer -> ty.Equals(otherTy.GetElementType()) + | ProvidedSymbolKind.Generic baseTy, args -> otherTy.IsGenericType && isEquivalentTo baseTy (otherTy.GetGenericTypeDefinition()) && Seq.forall2 isEquivalentTo args (otherTy.GetGenericArguments()) + | _ -> false + | a, b -> a.Equals b + + let nameText() = + match kind,args with + | ProvidedSymbolKind.SDArray,[arg] -> arg.Name + "[]" + | ProvidedSymbolKind.Array _,[arg] -> arg.Name + "[*]" + | ProvidedSymbolKind.Pointer,[arg] -> arg.Name + "*" + | ProvidedSymbolKind.ByRef,[arg] -> arg.Name + "&" + | ProvidedSymbolKind.Generic gty, args -> gty.Name + (sprintf "%A" args) + | ProvidedSymbolKind.FSharpTypeAbbreviation (_,_,path),_ -> path.[path.Length-1] + | _ -> failwith "unreachable" + + /// Substitute types for type variables. + static member convType (parameters: Type list) (ty:Type) = + if ty = null then null + elif ty.IsGenericType then + let args = Array.map (ProvidedSymbolType.convType parameters) (ty.GetGenericArguments()) + ty.GetGenericTypeDefinition().MakeGenericType(args) + elif ty.HasElementType then + let ety = ProvidedSymbolType.convType parameters (ty.GetElementType()) + if ty.IsArray then + let rank = ty.GetArrayRank() + if rank = 1 then ety.MakeArrayType() + else ety.MakeArrayType(rank) + elif ty.IsPointer then ety.MakePointerType() + elif ty.IsByRef then ety.MakeByRefType() + else ty + elif ty.IsGenericParameter then + if ty.GenericParameterPosition <= parameters.Length - 1 then + parameters.[ty.GenericParameterPosition] + else + ty + else ty + + override __.FullName = + match kind,args with + | ProvidedSymbolKind.SDArray,[arg] -> arg.FullName + "[]" + | ProvidedSymbolKind.Array _,[arg] -> arg.FullName + "[*]" + | ProvidedSymbolKind.Pointer,[arg] -> arg.FullName + "*" + | ProvidedSymbolKind.ByRef,[arg] -> arg.FullName + "&" + | ProvidedSymbolKind.Generic gty, args -> gty.FullName + "[" + (args |> List.map (fun arg -> arg.ToString()) |> String.concat ",") + "]" + | ProvidedSymbolKind.FSharpTypeAbbreviation (_,nsp,path),args -> String.concat "." (Array.append [| nsp |] path) + (match args with [] -> "" | _ -> args.ToString()) + | _ -> failwith "unreachable" + + /// Although not strictly required by the type provider specification, this is required when doing basic operations like FullName on + /// .NET symbolic types made from this type, e.g. when building Nullable.FullName + override __.DeclaringType = + match kind,args with + | ProvidedSymbolKind.SDArray,[arg] -> arg + | ProvidedSymbolKind.Array _,[arg] -> arg + | ProvidedSymbolKind.Pointer,[arg] -> arg + | ProvidedSymbolKind.ByRef,[arg] -> arg + | ProvidedSymbolKind.Generic gty,_ -> gty + | ProvidedSymbolKind.FSharpTypeAbbreviation _,_ -> null + | _ -> failwith "unreachable" + + override __.IsAssignableFrom(otherTy) = + match kind with + | Generic gtd -> + if otherTy.IsGenericType then + let otherGtd = otherTy.GetGenericTypeDefinition() + let otherArgs = otherTy.GetGenericArguments() + let yes = gtd.Equals(otherGtd) && Seq.forall2 isEquivalentTo args otherArgs + yes + else + base.IsAssignableFrom(otherTy) + | _ -> base.IsAssignableFrom(otherTy) + + override __.Name = nameText() + + override __.BaseType = + match kind with + | ProvidedSymbolKind.SDArray -> convToTgt typeof + | ProvidedSymbolKind.Array _ -> convToTgt typeof + | ProvidedSymbolKind.Pointer -> convToTgt typeof + | ProvidedSymbolKind.ByRef -> convToTgt typeof + | ProvidedSymbolKind.Generic gty -> + if gty.BaseType = null then null else + ProvidedSymbolType.convType args gty.BaseType + | ProvidedSymbolKind.FSharpTypeAbbreviation _ -> convToTgt typeof + + override __.GetArrayRank() = (match kind with ProvidedSymbolKind.Array n -> n | ProvidedSymbolKind.SDArray -> 1 | _ -> invalidOp "non-array type") + override __.IsValueTypeImpl() = (match kind with ProvidedSymbolKind.Generic gtd -> gtd.IsValueType | _ -> false) + override __.IsArrayImpl() = (match kind with ProvidedSymbolKind.Array _ | ProvidedSymbolKind.SDArray -> true | _ -> false) + override __.IsByRefImpl() = (match kind with ProvidedSymbolKind.ByRef _ -> true | _ -> false) + override __.IsPointerImpl() = (match kind with ProvidedSymbolKind.Pointer _ -> true | _ -> false) + override __.IsPrimitiveImpl() = false + override __.IsGenericType = (match kind with ProvidedSymbolKind.Generic _ -> true | _ -> false) + override __.GetGenericArguments() = (match kind with ProvidedSymbolKind.Generic _ -> args |> List.toArray | _ -> invalidOp "non-generic type") + override __.GetGenericTypeDefinition() = (match kind with ProvidedSymbolKind.Generic e -> e | _ -> invalidOp "non-generic type") + override __.IsCOMObjectImpl() = false + override __.HasElementTypeImpl() = (match kind with ProvidedSymbolKind.Generic _ -> false | _ -> true) + override __.GetElementType() = (match kind,args with (ProvidedSymbolKind.Array _ | ProvidedSymbolKind.SDArray | ProvidedSymbolKind.ByRef | ProvidedSymbolKind.Pointer),[e] -> e | _ -> invalidOp "not an array, pointer or byref type") + override this.ToString() = this.FullName + + override __.Assembly = + match kind with + | ProvidedSymbolKind.FSharpTypeAbbreviation (assembly,_nsp,_path) -> assembly + | ProvidedSymbolKind.Generic gty -> gty.Assembly + | _ -> notRequired "Assembly" (nameText()) + + override __.Namespace = + match kind with + | ProvidedSymbolKind.FSharpTypeAbbreviation (_assembly,nsp,_path) -> nsp + | _ -> notRequired "Namespace" (nameText()) + + override __.GetHashCode() = + match kind,args with + | ProvidedSymbolKind.SDArray,[arg] -> 10 + hash arg + | ProvidedSymbolKind.Array _,[arg] -> 163 + hash arg + | ProvidedSymbolKind.Pointer,[arg] -> 283 + hash arg + | ProvidedSymbolKind.ByRef,[arg] -> 43904 + hash arg + | ProvidedSymbolKind.Generic gty,_ -> 9797 + hash gty + List.sumBy hash args + | ProvidedSymbolKind.FSharpTypeAbbreviation _,_ -> 3092 + | _ -> failwith "unreachable" + + override __.Equals(other: obj) = + match other with + | :? ProvidedSymbolType as otherTy -> (kind, args) = (otherTy.Kind, otherTy.Args) + | _ -> false + + member __.Kind = kind + member __.Args = args + + member __.IsFSharpTypeAbbreviation = match kind with FSharpTypeAbbreviation _ -> true | _ -> false + // For example, int + member __.IsFSharpUnitAnnotated = match kind with ProvidedSymbolKind.Generic gtd -> not gtd.IsGenericTypeDefinition | _ -> false + + override __.Module : Module = notRequired "Module" (nameText()) + override __.GetConstructors _bindingAttr = notRequired "GetConstructors" (nameText()) + override __.GetMethodImpl(_name, _bindingAttr, _binderBinder, _callConvention, _types, _modifiers) = + match kind with + | Generic gtd -> + let ty = gtd.GetGenericTypeDefinition().MakeGenericType(Array.ofList args) + ty.GetMethod(_name, _bindingAttr) + | _ -> notRequired "GetMethodImpl" (nameText()) + override __.GetMembers _bindingAttr = notRequired "GetMembers" (nameText()) + override __.GetMethods _bindingAttr = notRequired "GetMethods" (nameText()) + override __.GetField(_name, _bindingAttr) = notRequired "GetField" (nameText()) + override __.GetFields _bindingAttr = notRequired "GetFields" (nameText()) + override __.GetInterface(_name, _ignoreCase) = notRequired "GetInterface" (nameText()) + override __.GetInterfaces() = notRequired "GetInterfaces" (nameText()) + override __.GetEvent(_name, _bindingAttr) = notRequired "GetEvent" (nameText()) + override __.GetEvents _bindingAttr = notRequired "GetEvents" (nameText()) + override __.GetProperties _bindingAttr = notRequired "GetProperties" (nameText()) + override __.GetPropertyImpl(_name, _bindingAttr, _binder, _returnType, _types, _modifiers) = notRequired "GetPropertyImpl" (nameText()) + override __.GetNestedTypes _bindingAttr = notRequired "GetNestedTypes" (nameText()) + override __.GetNestedType(_name, _bindingAttr) = notRequired "GetNestedType" (nameText()) + override __.GetAttributeFlagsImpl() = notRequired "GetAttributeFlagsImpl" (nameText()) + override this.UnderlyingSystemType = + match kind with + | ProvidedSymbolKind.SDArray + | ProvidedSymbolKind.Array _ + | ProvidedSymbolKind.Pointer + | ProvidedSymbolKind.FSharpTypeAbbreviation _ + | ProvidedSymbolKind.ByRef -> upcast this + | ProvidedSymbolKind.Generic gty -> gty.UnderlyingSystemType +#if FX_NO_CUSTOMATTRIBUTEDATA +#else + override __.GetCustomAttributesData() = ([| |] :> IList<_>) +#endif + override __.MemberType = notRequired "MemberType" (nameText()) + override __.GetMember(_name,_mt,_bindingAttr) = notRequired "GetMember" (nameText()) + override __.GUID = notRequired "GUID" (nameText()) + override __.InvokeMember(_name, _invokeAttr, _binder, _target, _args, _modifiers, _culture, _namedParameters) = notRequired "InvokeMember" (nameText()) + override __.AssemblyQualifiedName = notRequired "AssemblyQualifiedName" (nameText()) + override __.GetConstructorImpl(_bindingAttr, _binder, _callConvention, _types, _modifiers) = notRequired "GetConstructorImpl" (nameText()) + override __.GetCustomAttributes(_inherit) = [| |] + override __.GetCustomAttributes(_attributeType, _inherit) = [| |] + override __.IsDefined(_attributeType, _inherit) = false + // FSharp.Data addition: this was added to support arrays of arrays + override this.MakeArrayType() = ProvidedSymbolType(ProvidedSymbolKind.SDArray, [this], convToTgt) :> Type + override this.MakeArrayType arg = ProvidedSymbolType(ProvidedSymbolKind.Array arg, [this], convToTgt) :> Type + +type ProvidedSymbolMethod(genericMethodDefinition: MethodInfo, parameters: Type list) = + inherit System.Reflection.MethodInfo() + + let convParam (p:ParameterInfo) = + { new System.Reflection.ParameterInfo() with + override __.Name = p.Name + override __.ParameterType = ProvidedSymbolType.convType parameters p.ParameterType + override __.Attributes = p.Attributes + override __.RawDefaultValue = p.RawDefaultValue +#if FX_NO_CUSTOMATTRIBUTEDATA +#else + override __.GetCustomAttributesData() = p.GetCustomAttributesData() +#endif + } + + override this.IsGenericMethod = + (if this.DeclaringType.IsGenericType then this.DeclaringType.GetGenericArguments().Length else 0) < parameters.Length + + override this.GetGenericArguments() = + Seq.skip (if this.DeclaringType.IsGenericType then this.DeclaringType.GetGenericArguments().Length else 0) parameters |> Seq.toArray + + override __.GetGenericMethodDefinition() = genericMethodDefinition + + override __.DeclaringType = ProvidedSymbolType.convType parameters genericMethodDefinition.DeclaringType + override __.ToString() = "Method " + genericMethodDefinition.Name + override __.Name = genericMethodDefinition.Name + override __.MetadataToken = genericMethodDefinition.MetadataToken + override __.Attributes = genericMethodDefinition.Attributes + override __.CallingConvention = genericMethodDefinition.CallingConvention + override __.MemberType = genericMethodDefinition.MemberType + + override __.IsDefined(_attributeType, _inherit) : bool = notRequired "IsDefined" genericMethodDefinition.Name + override __.ReturnType = ProvidedSymbolType.convType parameters genericMethodDefinition.ReturnType + override __.GetParameters() = genericMethodDefinition.GetParameters() |> Array.map convParam + override __.ReturnParameter = genericMethodDefinition.ReturnParameter |> convParam + override __.ReturnTypeCustomAttributes = notRequired "ReturnTypeCustomAttributes" genericMethodDefinition.Name + override __.GetBaseDefinition() = notRequired "GetBaseDefinition" genericMethodDefinition.Name + override __.GetMethodImplementationFlags() = notRequired "GetMethodImplementationFlags" genericMethodDefinition.Name + override __.MethodHandle = notRequired "MethodHandle" genericMethodDefinition.Name + override __.Invoke(_obj, _invokeAttr, _binder, _parameters, _culture) = notRequired "Invoke" genericMethodDefinition.Name + override __.ReflectedType = notRequired "ReflectedType" genericMethodDefinition.Name + override __.GetCustomAttributes(_inherit) = notRequired "GetCustomAttributes" genericMethodDefinition.Name + override __.GetCustomAttributes(_attributeType, _inherit) = notRequired "GetCustomAttributes" genericMethodDefinition.Name + + + +type ProvidedTypeBuilder() = + static member MakeGenericType(genericTypeDefinition, genericArguments) = ProvidedSymbolType(Generic genericTypeDefinition, genericArguments, id) :> Type + static member MakeGenericMethod(genericMethodDefinition, genericArguments) = ProvidedSymbolMethod(genericMethodDefinition, genericArguments) :> MethodInfo + +type ZProvidedTypeBuilder(convToTgt: Type -> Type) = + member __.MakeGenericType(genericTypeDefinition, genericArguments) = ProvidedSymbolType(Generic genericTypeDefinition, genericArguments, convToTgt) :> Type + member __.MakeGenericMethod(genericMethodDefinition, genericArguments) = ProvidedSymbolMethod(genericMethodDefinition, genericArguments) :> MethodInfo + +[] +type ProvidedMeasureBuilder() = + + // TODO: this shouldn't be hardcoded, but without creating a dependency on FSharp.Compiler.Service + // there seems to be no way to check if a type abbreviation exists + let unitNamesTypeAbbreviations = + [ "meter"; "hertz"; "newton"; "pascal"; "joule"; "watt"; "coulomb"; + "volt"; "farad"; "ohm"; "siemens"; "weber"; "tesla"; "henry" + "lumen"; "lux"; "becquerel"; "gray"; "sievert"; "katal" ] + |> Set.ofList + + let unitSymbolsTypeAbbreviations = + [ "m"; "kg"; "s"; "A"; "K"; "mol"; "cd"; "Hz"; "N"; "Pa"; "J"; "W"; "C" + "V"; "F"; "S"; "Wb"; "T"; "lm"; "lx"; "Bq"; "Gy"; "Sv"; "kat"; "H" ] + |> Set.ofList + + static let theBuilder = ProvidedMeasureBuilder() + static member Default = theBuilder + member __.One = typeof + member __.Product (m1,m2) = typedefof>.MakeGenericType [| m1;m2 |] + member __.Inverse m = typedefof>.MakeGenericType [| m |] + member b.Ratio (m1, m2) = b.Product(m1, b.Inverse m2) + member b.Square m = b.Product(m, m) + + // FSharp.Data change: if the unit is not a valid type, instead + // of assuming it's a type abbreviation, which may not be the case and cause a + // problem later on, check the list of valid abbreviations + member __.SI (m:string) = + let mLowerCase = m.ToLowerInvariant() + let abbreviation = + if unitNamesTypeAbbreviations.Contains mLowerCase then + Some ("Microsoft.FSharp.Data.UnitSystems.SI.UnitNames", mLowerCase) + elif unitSymbolsTypeAbbreviations.Contains m then + Some ("Microsoft.FSharp.Data.UnitSystems.SI.UnitSymbols", m) + else + None + match abbreviation with + | Some (ns, unitName) -> + ProvidedSymbolType(ProvidedSymbolKind.FSharpTypeAbbreviation(typeof.Assembly,ns,[| unitName |]), [], id) :> Type + | None -> + typedefof>.Assembly.GetType("Microsoft.FSharp.Data.UnitSystems.SI.UnitNames." + mLowerCase) + + member __.AnnotateType (basicType, annotation) = ProvidedSymbolType(Generic basicType, annotation, id) :> Type + + + +[] +type TypeContainer = + | Namespace of Assembly * string // namespace + | Type of System.Type + | TypeToBeDecided + +#if NO_GENERATIVE +#else +module GlobalProvidedAssemblyElementsTable = + let theTable = Dictionary>() +#endif + +type ProvidedTypeDefinition(container:TypeContainer, className : string, baseType : Type option, convToTgt) as this = + inherit Type() + + do match container, !ProvidedTypeDefinition.Logger with + | TypeContainer.Namespace _, Some logger -> logger (sprintf "Creating ProvidedTypeDefinition %s [%d]" className (System.Runtime.CompilerServices.RuntimeHelpers.GetHashCode this)) + | _ -> () + + // state + let mutable attributes = + TypeAttributes.Public ||| + TypeAttributes.Class ||| + TypeAttributes.Sealed ||| + enum (int32 TypeProviderTypeAttributes.IsErased) + + + let mutable enumUnderlyingType = None + let mutable baseType = lazy baseType + let mutable membersKnown = ResizeArray() + let mutable membersQueue = ResizeArray<(unit -> list)>() + let mutable staticParams = [ ] + let mutable staticParamsApply = None + let mutable container = container + let mutable interfaceImpls = ResizeArray() + let mutable interfaceImplsDelayed = ResizeArray list>() + let mutable methodOverrides = ResizeArray() + + // members API + let getMembers() = + if membersQueue.Count > 0 then + let elems = membersQueue |> Seq.toArray // take a copy in case more elements get added + membersQueue.Clear() + for f in elems do + for i in f() do + membersKnown.Add i + match i with + | :? ProvidedProperty as p -> + if p.CanRead then membersKnown.Add (p.GetGetMethod true) + if p.CanWrite then membersKnown.Add (p.GetSetMethod true) + | :? ProvidedEvent as e -> + membersKnown.Add (e.GetAddMethod true) + membersKnown.Add (e.GetRemoveMethod true) + | _ -> () + + membersKnown.ToArray() + + // members API + let getInterfaces() = + if interfaceImplsDelayed.Count > 0 then + let elems = interfaceImplsDelayed |> Seq.toArray // take a copy in case more elements get added + interfaceImplsDelayed.Clear() + for f in elems do + for i in f() do + interfaceImpls.Add i + + interfaceImpls.ToArray() + + let mutable theAssembly = + lazy + match container with + | TypeContainer.Namespace (theAssembly, rootNamespace) -> + if theAssembly = null then failwith "Null assemblies not allowed" + if rootNamespace<>null && rootNamespace.Length=0 then failwith "Use 'null' for global namespace" + theAssembly + | TypeContainer.Type superTy -> superTy.Assembly + | TypeContainer.TypeToBeDecided -> failwith (sprintf "type '%s' was not added as a member to a declaring type" className) + + let rootNamespace = + lazy + match container with + | TypeContainer.Namespace (_,rootNamespace) -> rootNamespace + | TypeContainer.Type enclosingTyp -> enclosingTyp.Namespace + | TypeContainer.TypeToBeDecided -> failwith (sprintf "type '%s' was not added as a member to a declaring type" className) + + let declaringType = + lazy + match container with + | TypeContainer.Namespace _ -> null + | TypeContainer.Type enclosingTyp -> enclosingTyp + | TypeContainer.TypeToBeDecided -> failwith (sprintf "type '%s' was not added as a member to a declaring type" className) + + let fullName = + lazy + match container with + | TypeContainer.Type declaringType -> declaringType.FullName + "+" + className + | TypeContainer.Namespace (_,namespaceName) -> + if namespaceName="" then failwith "use null for global namespace" + match namespaceName with + | null -> className + | _ -> namespaceName + "." + className + | TypeContainer.TypeToBeDecided -> failwith (sprintf "type '%s' was not added as a member to a declaring type" className) + + let patchUpAddedMemberInfo (this:Type) (m:MemberInfo) = + match m with + | :? ProvidedConstructor as c -> c.DeclaringTypeImpl <- this // patch up "declaring type" on provided MethodInfo + | :? ProvidedMethod as m -> m.DeclaringTypeImpl <- this // patch up "declaring type" on provided MethodInfo + | :? ProvidedProperty as p -> p.DeclaringTypeImpl <- this // patch up "declaring type" on provided MethodInfo + | :? ProvidedEvent as e -> e.DeclaringTypeImpl <- this // patch up "declaring type" on provided MethodInfo + | :? ProvidedTypeDefinition as t -> t.DeclaringTypeImpl <- this + | :? ProvidedLiteralField as l -> l.DeclaringTypeImpl <- this + | :? ProvidedField as l -> l.DeclaringTypeImpl <- this + | _ -> () + + let customAttributesImpl = CustomAttributesImpl() + + member __.AddXmlDocComputed xmlDocFunction = customAttributesImpl.AddXmlDocComputed xmlDocFunction + member __.AddXmlDocDelayed xmlDocFunction = customAttributesImpl.AddXmlDocDelayed xmlDocFunction + member __.AddXmlDoc xmlDoc = customAttributesImpl.AddXmlDoc xmlDoc + member __.AddObsoleteAttribute (message,?isError) = customAttributesImpl.AddObsolete (message,defaultArg isError false) + member __.AddDefinitionLocation(line,column,filePath) = customAttributesImpl.AddDefinitionLocation(line, column, filePath) + member __.HideObjectMethods with set v = customAttributesImpl.HideObjectMethods <- v + member __.NonNullable with set v = customAttributesImpl.NonNullable <- v + member __.GetCustomAttributesDataImpl() = customAttributesImpl.GetCustomAttributesData() + member __.AddCustomAttribute attribute = customAttributesImpl.AddCustomAttribute attribute +#if FX_NO_CUSTOMATTRIBUTEDATA +#else + override __.GetCustomAttributesData() = customAttributesImpl.GetCustomAttributesData() +#endif + + member __.ResetEnclosingType (ty) = + container <- TypeContainer.Type ty + new (assembly:Assembly,namespaceName,className,baseType) = new ProvidedTypeDefinition(TypeContainer.Namespace (assembly,namespaceName), className, baseType, id) + new (className:string,baseType) = new ProvidedTypeDefinition(TypeContainer.TypeToBeDecided, className, baseType, id) + + new (assembly:Assembly,namespaceName,className,baseType,convToTgt) = new ProvidedTypeDefinition(TypeContainer.Namespace (assembly,namespaceName), className, baseType, convToTgt) + new (className,baseType, convToTgt) = new ProvidedTypeDefinition(TypeContainer.TypeToBeDecided, className, baseType, convToTgt) + // state ops + + override __.UnderlyingSystemType = typeof + + member __.SetEnumUnderlyingType(ty) = enumUnderlyingType <- Some ty + + override __.GetEnumUnderlyingType() = + if this.IsEnum then + match enumUnderlyingType with + | None -> convToTgt typeof + | Some ty -> ty + else invalidOp "not enum type" + + member __.SetBaseType t = baseType <- lazy Some t + + member __.SetBaseTypeDelayed baseTypeFunction = baseType <- lazy (Some (baseTypeFunction())) + + member __.SetAttributes x = attributes <- x + + // Add MemberInfos + member __.AddMembersDelayed(membersFunction : unit -> list<#MemberInfo>) = + membersQueue.Add (fun () -> membersFunction() |> List.map (fun x -> patchUpAddedMemberInfo this x; x :> MemberInfo )) + + member __.AddMembers(memberInfos:list<#MemberInfo>) = (* strict *) + memberInfos |> List.iter (patchUpAddedMemberInfo this) // strict: patch up now + membersQueue.Add (fun () -> memberInfos |> List.map (fun x -> x :> MemberInfo)) + + member __.AddMember(memberInfo:MemberInfo) = + this.AddMembers [memberInfo] + + member __.AddMemberDelayed(memberFunction : unit -> #MemberInfo) = + this.AddMembersDelayed(fun () -> [memberFunction()]) + +#if NO_GENERATIVE +#else + member __.AddAssemblyTypesAsNestedTypesDelayed (assemblyf : unit -> System.Reflection.Assembly) = + let bucketByPath nodef tipf (items: (string list * 'Value) list) = + // Find all the items with an empty key list and call 'tipf' + let tips = + [ for (keylist,v) in items do + match keylist with + | [] -> yield tipf v + | _ -> () ] + + // Find all the items with a non-empty key list. Bucket them together by + // the first key. For each bucket, call 'nodef' on that head key and the bucket. + let nodes = + let buckets = new Dictionary<_,_>(10) + for (keylist,v) in items do + match keylist with + | [] -> () + | key::rest -> + buckets.[key] <- (rest,v) :: (if buckets.ContainsKey key then buckets.[key] else []); + + [ for (KeyValue(key,items)) in buckets -> nodef key items ] + + tips @ nodes + this.AddMembersDelayed (fun _ -> + let topTypes = [ for ty in assemblyf().GetTypes() do + if not ty.IsNested then + let namespaceParts = match ty.Namespace with null -> [] | s -> s.Split '.' |> Array.toList + yield namespaceParts, ty ] + let rec loop types = + types + |> bucketByPath + (fun namespaceComponent typesUnderNamespaceComponent -> + let t = ProvidedTypeDefinition(namespaceComponent, baseType = Some typeof) + t.AddMembers (loop typesUnderNamespaceComponent) + (t :> Type)) + (fun ty -> ty) + loop topTypes) +#endif + + /// Abstract a type to a parametric-type. Requires "formal parameters" and "instantiation function". + member __.DefineStaticParameters(staticParameters : list, apply : (string -> obj[] -> ProvidedTypeDefinition)) = + staticParams <- staticParameters + staticParamsApply <- Some apply + + /// Get ParameterInfo[] for the parametric type parameters (//s GetGenericParameters) + member __.GetStaticParameters() = [| for p in staticParams -> p :> ParameterInfo |] + + /// Instantiate parametrics type + member __.MakeParametricType(name:string,args:obj[]) = + if staticParams.Length>0 then + if staticParams.Length <> args.Length then + failwith (sprintf "ProvidedTypeDefinition: expecting %d static parameters but given %d for type %s" staticParams.Length args.Length (fullName.Force())) + match staticParamsApply with + | None -> failwith "ProvidedTypeDefinition: DefineStaticParameters was not called" + | Some f -> f name args + + else + failwith (sprintf "ProvidedTypeDefinition: static parameters supplied but not expected for %s" (fullName.Force())) + + member __.DeclaringTypeImpl + with set x = + match container with TypeContainer.TypeToBeDecided -> () | _ -> failwith (sprintf "container type for '%s' was already set to '%s'" this.FullName x.FullName); + container <- TypeContainer.Type x + + // Implement overloads + override __.Assembly = theAssembly.Force() + + member __.SetAssembly assembly = theAssembly <- lazy assembly + + member __.SetAssemblyLazy assembly = theAssembly <- assembly + + override __.FullName = fullName.Force() + + override __.Namespace = rootNamespace.Force() + + override __.BaseType = match baseType.Value with Some ty -> ty | None -> null + + // Constructors + override __.GetConstructors bindingAttr = + [| for m in this.GetMembers bindingAttr do + if m.MemberType = MemberTypes.Constructor then + yield (m :?> ConstructorInfo) |] + // Methods + override __.GetMethodImpl(name, bindingAttr, _binderBinder, _callConvention, _types, _modifiers) : MethodInfo = + let membersWithName = + [ for m in this.GetMembers(bindingAttr) do + if m.MemberType.HasFlag(MemberTypes.Method) && m.Name = name then + yield m ] + match membersWithName with + | [] -> null + | [meth] -> meth :?> MethodInfo + | _several -> failwith "GetMethodImpl. not support overloads" + + override __.GetMethods bindingAttr = + this.GetMembers bindingAttr + |> Array.filter (fun m -> m.MemberType.HasFlag(MemberTypes.Method)) + |> Array.map (fun m -> m :?> MethodInfo) + + // Fields + override __.GetField(name, bindingAttr) = + let fields = [| for m in this.GetMembers bindingAttr do + if m.MemberType.HasFlag(MemberTypes.Field) && (name = null || m.Name = name) then // REVIEW: name = null. Is that a valid query?! + yield m |] + if fields.Length > 0 then fields.[0] :?> FieldInfo else null + + override __.GetFields bindingAttr = + [| for m in this.GetMembers bindingAttr do if m.MemberType.HasFlag(MemberTypes.Field) then yield m :?> FieldInfo |] + + override __.GetInterface(_name, _ignoreCase) = notRequired "GetInterface" this.Name + + override __.GetInterfaces() = + [| yield! getInterfaces() |] + + member __.GetInterfaceImplementations() = + [| yield! getInterfaces() |] + + member __.AddInterfaceImplementation ityp = interfaceImpls.Add ityp + + member __.AddInterfaceImplementationsDelayed itypf = interfaceImplsDelayed.Add itypf + + member __.GetMethodOverrides() = + [| yield! methodOverrides |] + + member __.DefineMethodOverride (bodyMethInfo,declMethInfo) = methodOverrides.Add (bodyMethInfo, declMethInfo) + + // Events + override __.GetEvent(name, bindingAttr) = + let events = this.GetMembers bindingAttr + |> Array.filter(fun m -> m.MemberType.HasFlag(MemberTypes.Event) && (name = null || m.Name = name)) + if events.Length > 0 then events.[0] :?> EventInfo else null + + override __.GetEvents bindingAttr = + [| for m in this.GetMembers bindingAttr do if m.MemberType.HasFlag(MemberTypes.Event) then yield downcast m |] + + // Properties + override __.GetProperties bindingAttr = + [| for m in this.GetMembers bindingAttr do if m.MemberType.HasFlag(MemberTypes.Property) then yield downcast m |] + + override __.GetPropertyImpl(name, bindingAttr, binder, returnType, types, modifiers) = + if returnType <> null then failwith "Need to handle specified return type in GetPropertyImpl" + if types <> null then failwith "Need to handle specified parameter types in GetPropertyImpl" + if modifiers <> null then failwith "Need to handle specified modifiers in GetPropertyImpl" + if binder <> null then failwith "Need to handle binder in GetPropertyImpl" + let props = this.GetMembers bindingAttr |> Array.filter(fun m -> m.MemberType.HasFlag(MemberTypes.Property) && (name = null || m.Name = name)) // Review: nam = null, valid query!? + if props.Length > 0 then + props.[0] :?> PropertyInfo + else + null + // Nested Types + override __.MakeArrayType() = ProvidedSymbolType(ProvidedSymbolKind.SDArray, [this], convToTgt) :> Type + override __.MakeArrayType arg = ProvidedSymbolType(ProvidedSymbolKind.Array arg, [this], convToTgt) :> Type + override __.MakePointerType() = ProvidedSymbolType(ProvidedSymbolKind.Pointer, [this], convToTgt) :> Type + override __.MakeByRefType() = ProvidedSymbolType(ProvidedSymbolKind.ByRef, [this], convToTgt) :> Type + + // FSharp.Data addition: this method is used by Debug.fs and QuotationBuilder.fs + // Emulate the F# type provider type erasure mechanism to get the + // actual (erased) type. We erase ProvidedTypes to their base type + // and we erase array of provided type to array of base type. In the + // case of generics all the generic type arguments are also recursively + // replaced with the erased-to types + static member EraseType(t:Type) : Type = + match t with + | :? ProvidedTypeDefinition as ptd when ptd.IsErased -> ProvidedTypeDefinition.EraseType t.BaseType + | t when t.IsArray -> + let rank = t.GetArrayRank() + let et = ProvidedTypeDefinition.EraseType (t.GetElementType()) + if rank = 0 then et.MakeArrayType() else et.MakeArrayType(rank) + | :? ProvidedSymbolType as sym when sym.IsFSharpUnitAnnotated -> + t.UnderlyingSystemType + | t when t.IsGenericType && not t.IsGenericTypeDefinition -> + let genericTypeDefinition = t.GetGenericTypeDefinition() + let genericArguments = t.GetGenericArguments() |> Array.map ProvidedTypeDefinition.EraseType + genericTypeDefinition.MakeGenericType(genericArguments) + | t -> t + + static member Logger : (string -> unit) option ref = ref None + + // The binding attributes are always set to DeclaredOnly ||| Static ||| Instance ||| Public when GetMembers is called directly by the F# compiler + // However, it's possible for the framework to generate other sets of flags in some corner cases (e.g. via use of `enum` with a provided type as the target) + override __.GetMembers bindingAttr = + let mems = + getMembers() + |> Array.filter (fun mem -> + let isStatic, isPublic = + match mem with + | :? FieldInfo as f -> f.IsStatic, f.IsPublic + | :? MethodInfo as m -> m.IsStatic, m.IsPublic + | :? ConstructorInfo as c -> c.IsStatic, c.IsPublic + | :? PropertyInfo as p -> + let m = if p.CanRead then p.GetGetMethod() else p.GetSetMethod() + m.IsStatic, m.IsPublic + | :? EventInfo as e -> + let m = e.GetAddMethod() + m.IsStatic, m.IsPublic + | :? Type as ty -> + true, ty.IsNestedPublic + | _ -> failwith (sprintf "Member %O is of unexpected type" mem) + bindingAttr.HasFlag(if isStatic then BindingFlags.Static else BindingFlags.Instance) && + ( + (bindingAttr.HasFlag(BindingFlags.Public) && isPublic) || (bindingAttr.HasFlag(BindingFlags.NonPublic) && not isPublic) + )) + + if bindingAttr.HasFlag(BindingFlags.DeclaredOnly) || this.BaseType = null then mems + else + // FSharp.Data change: just using this.BaseType is not enough in the case of CsvProvider, + // because the base type is CsvRow, so we have to erase recursively to CsvRow + let baseMems = (ProvidedTypeDefinition.EraseType this.BaseType).GetMembers bindingAttr + Array.append mems baseMems + + override __.GetNestedTypes bindingAttr = + this.GetMembers bindingAttr + |> Array.filter(fun m -> + m.MemberType.HasFlag(MemberTypes.NestedType) || + // Allow 'fake' nested types that are actually real .NET types + m.MemberType.HasFlag(MemberTypes.TypeInfo)) |> Array.map(fun m -> m :?> Type) + + override __.GetMember(name,mt,_bindingAttr) = + let mt = + if mt &&& MemberTypes.NestedType = MemberTypes.NestedType then + mt ||| MemberTypes.TypeInfo + else + mt + getMembers() |> Array.filter(fun m->0<>(int(m.MemberType &&& mt)) && m.Name = name) + + override __.GetNestedType(name, bindingAttr) = + let nt = this.GetMember(name, MemberTypes.NestedType ||| MemberTypes.TypeInfo, bindingAttr) + match nt.Length with + | 0 -> null + | 1 -> downcast nt.[0] + | _ -> failwith (sprintf "There is more than one nested type called '%s' in type '%s'" name this.FullName) + + // Attributes, etc.. + override __.GetAttributeFlagsImpl() = adjustTypeAttributes attributes this.IsNested + override this.IsValueTypeImpl() = if this.BaseType <> null then this.BaseType = typeof || this.BaseType.IsValueType else false + override __.IsArrayImpl() = false + override __.IsByRefImpl() = false + override __.IsPointerImpl() = false + override __.IsPrimitiveImpl() = false + override __.IsCOMObjectImpl() = false + override __.HasElementTypeImpl() = false + override __.Name = className + override __.DeclaringType = declaringType.Force() + override __.MemberType = if this.IsNested then MemberTypes.NestedType else MemberTypes.TypeInfo + override __.GetHashCode() = rootNamespace.GetHashCode() ^^^ className.GetHashCode() + override __.Equals(that:obj) = + match that with + | null -> false + | :? ProvidedTypeDefinition as ti -> System.Object.ReferenceEquals(this,ti) + | _ -> false + + override __.GetGenericArguments() = [||] + override __.ToString() = this.Name + + + override __.Module : Module = notRequired "Module" this.Name + override __.GUID = Guid.Empty + override __.GetConstructorImpl(_bindingAttr, _binder, _callConvention, _types, _modifiers) = null + override __.GetCustomAttributes(_inherit) = [| |] + override __.GetCustomAttributes(_attributeType, _inherit) = [| |] + override __.IsDefined(_attributeType: Type, _inherit) = false + + override __.GetElementType() = notRequired "Module" this.Name + override __.InvokeMember(_name, _invokeAttr, _binder, _target, _args, _modifiers, _culture, _namedParameters) = notRequired "Module" this.Name + override __.AssemblyQualifiedName = notRequired "Module" this.Name + member __.IsErased + with get() = (attributes &&& enum (int32 TypeProviderTypeAttributes.IsErased)) <> enum 0 + and set v = + if v then attributes <- attributes ||| enum (int32 TypeProviderTypeAttributes.IsErased) + else attributes <- attributes &&& ~~~(enum (int32 TypeProviderTypeAttributes.IsErased)) + + member __.SuppressRelocation + with get() = (attributes &&& enum (int32 TypeProviderTypeAttributes.SuppressRelocate)) <> enum 0 + and set v = + if v then attributes <- attributes ||| enum (int32 TypeProviderTypeAttributes.SuppressRelocate) + else attributes <- attributes &&& ~~~(enum (int32 TypeProviderTypeAttributes.SuppressRelocate)) + + +#if NO_GENERATIVE +#else +//------------------------------------------------------------------------------------------------- +// The assembly compiler for generative type providers. + +type AssemblyGenerator(assemblyFileName) = + let assemblyShortName = Path.GetFileNameWithoutExtension assemblyFileName + let assemblyName = AssemblyName assemblyShortName +#if FX_NO_LOCAL_FILESYSTEM + let assembly = + System.AppDomain.CurrentDomain.DefineDynamicAssembly(name=assemblyName,access=AssemblyBuilderAccess.Run) + let assemblyMainModule = + assembly.DefineDynamicModule("MainModule") +#else + let assembly = + System.AppDomain.CurrentDomain.DefineDynamicAssembly(name=assemblyName,access=(AssemblyBuilderAccess.Save ||| AssemblyBuilderAccess.Run),dir=Path.GetDirectoryName assemblyFileName) + let assemblyMainModule = + assembly.DefineDynamicModule("MainModule", Path.GetFileName assemblyFileName) +#endif + let typeMap = Dictionary(HashIdentity.Reference) + let typeMapExtra = Dictionary(HashIdentity.Structural) + let uniqueLambdaTypeName() = + // lambda name should be unique across all types that all type provider might contribute in result assembly + sprintf "Lambda%O" (Guid.NewGuid()) + + member __.Assembly = assembly :> Assembly + + /// Emit the given provided type definitions into an assembly and adjust 'Assembly' property of all type definitions to return that + /// assembly. + member __.Generate(providedTypeDefinitions:(ProvidedTypeDefinition * string list option) list) = + let ALL = BindingFlags.Public ||| BindingFlags.NonPublic ||| BindingFlags.Static ||| BindingFlags.Instance + // phase 1 - set assembly fields and emit type definitions + begin + let rec typeMembers (tb:TypeBuilder) (td : ProvidedTypeDefinition) = + for ntd in td.GetNestedTypes(ALL) do + nestedType tb ntd + + and nestedType (tb:TypeBuilder) (ntd : Type) = + match ntd with + | :? ProvidedTypeDefinition as pntd -> + if pntd.IsErased then invalidOp ("The nested provided type "+pntd.Name+" is marked as erased and cannot be converted to a generated type. Set 'IsErased' to false on the ProvidedTypeDefinition") + // Adjust the attributes - we're codegen'ing this type as nested + let attributes = adjustTypeAttributes ntd.Attributes true + let ntb = tb.DefineNestedType(pntd.Name,attr=attributes) + pntd.SetAssembly null + typeMap.[pntd] <- ntb + typeMembers ntb pntd + | _ -> () + + for (pt,enclosingGeneratedTypeNames) in providedTypeDefinitions do + match enclosingGeneratedTypeNames with + | None -> + // Filter out the additional TypeProviderTypeAttributes flags + let attributes = pt.Attributes &&& ~~~(enum (int32 TypeProviderTypeAttributes.SuppressRelocate)) + &&& ~~~(enum (int32 TypeProviderTypeAttributes.IsErased)) + // Adjust the attributes - we're codegen'ing as non-nested + let attributes = adjustTypeAttributes attributes false + let tb = assemblyMainModule.DefineType(name=pt.FullName,attr=attributes) + pt.SetAssembly null + typeMap.[pt] <- tb + typeMembers tb pt + | Some ns -> + let otb,_ = + ((None,""),ns) ||> List.fold (fun (otb:TypeBuilder option,fullName) n -> + let fullName = if fullName = "" then n else fullName + "." + n + let priorType = if typeMapExtra.ContainsKey(fullName) then Some typeMapExtra.[fullName] else None + let tb = + match priorType with + | Some tbb -> tbb + | None -> + // OK, the implied nested type is not defined, define it now + let attributes = TypeAttributes.Public ||| TypeAttributes.Class ||| TypeAttributes.Sealed + // Filter out the additional TypeProviderTypeAttributes flags + let attributes = adjustTypeAttributes attributes otb.IsSome + let tb = + match otb with + | None -> assemblyMainModule.DefineType(name=n,attr=attributes) + | Some (otb:TypeBuilder) -> otb.DefineNestedType(name=n,attr=attributes) + typeMapExtra.[fullName] <- tb + tb + (Some tb, fullName)) + nestedType otb.Value pt + end + + let rec transType (ty:Type) = + match ty with + | :? ProvidedTypeDefinition as ptd -> + if typeMap.ContainsKey ptd then typeMap.[ptd] :> Type else ty + | _ -> + if ty.IsGenericType then ty.GetGenericTypeDefinition().MakeGenericType (Array.map transType (ty.GetGenericArguments())) + elif ty.HasElementType then + let ety = transType (ty.GetElementType()) + if ty.IsArray then + let rank = ty.GetArrayRank() + if rank = 1 then ety.MakeArrayType() + else ety.MakeArrayType rank + elif ty.IsPointer then ety.MakePointerType() + elif ty.IsByRef then ety.MakeByRefType() + else ty + else ty + + let ctorMap = Dictionary(HashIdentity.Reference) + let methMap = Dictionary(HashIdentity.Reference) + let fieldMap = Dictionary(HashIdentity.Reference) + let transCtor (f:ConstructorInfo) = match f with :? ProvidedConstructor as pc when ctorMap.ContainsKey pc -> ctorMap.[pc] :> ConstructorInfo | c -> c + let transField (f:FieldInfo) = match f with :? ProvidedField as pf when fieldMap.ContainsKey pf -> fieldMap.[pf] :> FieldInfo | f -> f + let transMeth (m:MethodInfo) = match m with :? ProvidedMethod as pm when methMap.ContainsKey pm -> methMap.[pm] :> MethodInfo | m -> m + let isLiteralEnumField (f:FieldInfo) = match f with :? ProvidedLiteralField as plf -> plf.DeclaringType.IsEnum | _ -> false + + let iterateTypes f = + let rec typeMembers (ptd : ProvidedTypeDefinition) = + let tb = typeMap.[ptd] + f tb (Some ptd) + for ntd in ptd.GetNestedTypes(ALL) do + nestedType ntd + + and nestedType (ntd : Type) = + match ntd with + | :? ProvidedTypeDefinition as pntd -> typeMembers pntd + | _ -> () + + for (pt,enclosingGeneratedTypeNames) in providedTypeDefinitions do + match enclosingGeneratedTypeNames with + | None -> + typeMembers pt + | Some ns -> + let _fullName = + ("",ns) ||> List.fold (fun fullName n -> + let fullName = if fullName = "" then n else fullName + "." + n + f typeMapExtra.[fullName] None + fullName) + nestedType pt + + + // phase 1b - emit base types + iterateTypes (fun tb ptd -> + match ptd with + | None -> () + | Some ptd -> + match ptd.BaseType with null -> () | bt -> tb.SetParent(transType bt)) + + let defineCustomAttrs f (cattrs: IList) = + for attr in cattrs do + let constructorArgs = [ for x in attr.ConstructorArguments -> x.Value ] + let namedProps,namedPropVals = [ for x in attr.NamedArguments do match x.MemberInfo with :? PropertyInfo as pi -> yield (pi, x.TypedValue.Value) | _ -> () ] |> List.unzip + let namedFields,namedFieldVals = [ for x in attr.NamedArguments do match x.MemberInfo with :? FieldInfo as pi -> yield (pi, x.TypedValue.Value) | _ -> () ] |> List.unzip + let cab = CustomAttributeBuilder(attr.Constructor, Array.ofList constructorArgs, Array.ofList namedProps, Array.ofList namedPropVals, Array.ofList namedFields, Array.ofList namedFieldVals) + f cab + + // phase 2 - emit member definitions + iterateTypes (fun tb ptd -> + match ptd with + | None -> () + | Some ptd -> + for cinfo in ptd.GetConstructors(ALL) do + match cinfo with + | :? ProvidedConstructor as pcinfo when not (ctorMap.ContainsKey pcinfo) -> + let cb = + if pcinfo.IsTypeInitializer then + if (cinfo.GetParameters()).Length <> 0 then failwith "Type initializer should not have parameters" + tb.DefineTypeInitializer() + else + let cb = tb.DefineConstructor(cinfo.Attributes, CallingConventions.Standard, [| for p in cinfo.GetParameters() -> transType p.ParameterType |]) + for (i,p) in cinfo.GetParameters() |> Seq.mapi (fun i x -> (i,x)) do + cb.DefineParameter(i+1, ParameterAttributes.None, p.Name) |> ignore + cb + ctorMap.[pcinfo] <- cb + | _ -> () + + if ptd.IsEnum then + tb.DefineField("value__", ptd.GetEnumUnderlyingType(), FieldAttributes.Public ||| FieldAttributes.SpecialName ||| FieldAttributes.RTSpecialName) + |> ignore + + for finfo in ptd.GetFields(ALL) do + let fieldInfo = + match finfo with + | :? ProvidedField as pinfo -> + Some (pinfo.Name, transType finfo.FieldType, finfo.Attributes, pinfo.GetCustomAttributesDataImpl(), None) + | :? ProvidedLiteralField as pinfo -> + Some (pinfo.Name, transType finfo.FieldType, finfo.Attributes, pinfo.GetCustomAttributesDataImpl(), Some (pinfo.GetRawConstantValue())) + | _ -> None + match fieldInfo with + | Some (name, ty, attr, cattr, constantVal) when not (fieldMap.ContainsKey finfo) -> + let fb = tb.DefineField(name, ty, attr) + if constantVal.IsSome then + fb.SetConstant constantVal.Value + defineCustomAttrs fb.SetCustomAttribute cattr + fieldMap.[finfo] <- fb + | _ -> () + for minfo in ptd.GetMethods(ALL) do + match minfo with + | :? ProvidedMethod as pminfo when not (methMap.ContainsKey pminfo) -> + let mb = tb.DefineMethod(minfo.Name, minfo.Attributes, transType minfo.ReturnType, [| for p in minfo.GetParameters() -> transType p.ParameterType |]) + for (i, p) in minfo.GetParameters() |> Seq.mapi (fun i x -> (i,x :?> ProvidedParameter)) do + // TODO: check why F# compiler doesn't emit default value when just p.Attributes is used (thus bad metadata is emitted) +// let mutable attrs = ParameterAttributes.None +// +// if p.IsOut then attrs <- attrs ||| ParameterAttributes.Out +// if p.HasDefaultParameterValue then attrs <- attrs ||| ParameterAttributes.Optional + + let pb = mb.DefineParameter(i+1, p.Attributes, p.Name) + if p.HasDefaultParameterValue then + do + let ctor = typeof.GetConstructor([|typeof|]) + let builder = new CustomAttributeBuilder(ctor, [|p.RawDefaultValue|]) + pb.SetCustomAttribute builder + do + let ctor = typeof.GetConstructor([||]) + let builder = new CustomAttributeBuilder(ctor, [||]) + pb.SetCustomAttribute builder + pb.SetConstant p.RawDefaultValue + methMap.[pminfo] <- mb + | _ -> () + + for ityp in ptd.GetInterfaceImplementations() do + tb.AddInterfaceImplementation ityp) + + // phase 3 - emit member code + iterateTypes (fun tb ptd -> + match ptd with + | None -> () + | Some ptd -> + let cattr = ptd.GetCustomAttributesDataImpl() + defineCustomAttrs tb.SetCustomAttribute cattr + // Allow at most one constructor, and use its arguments as the fields of the type + let ctors = + ptd.GetConstructors(ALL) // exclude type initializer + |> Seq.choose (function :? ProvidedConstructor as pcinfo when not pcinfo.IsTypeInitializer -> Some pcinfo | _ -> None) + |> Seq.toList + let implictCtorArgs = + match ctors |> List.filter (fun x -> x.IsImplicitCtor) with + | [] -> [] + | [ pcinfo ] -> [ for p in pcinfo.GetParameters() -> p ] + | _ -> failwith "at most one implicit constructor allowed" + + let implicitCtorArgsAsFields = + [ for ctorArg in implictCtorArgs -> + tb.DefineField(ctorArg.Name, transType ctorArg.ParameterType, FieldAttributes.Private) ] + + + + // Emit the constructor (if any) + for pcinfo in ctors do + assert ctorMap.ContainsKey pcinfo + let cb = ctorMap.[pcinfo] + let cattr = pcinfo.GetCustomAttributesDataImpl() + defineCustomAttrs cb.SetCustomAttribute cattr + let ilg = cb.GetILGenerator() + let locals = Dictionary() + let parameterVars = + [| yield Var("this", pcinfo.DeclaringType) + for p in pcinfo.GetParameters() do + yield Var(p.Name, p.ParameterType) |] + + let codeGen = CodeGenerator(assemblyMainModule, uniqueLambdaTypeName, implicitCtorArgsAsFields, transType, transField, transMeth, transCtor, isLiteralEnumField, ilg, locals, parameterVars) + let parameters = + [| for v in parameterVars -> Expr.Var v |] + match pcinfo.GetBaseConstructorCallInternal true with + | None -> + ilg.Emit(OpCodes.Ldarg_0) + let cinfo = ptd.BaseType.GetConstructor(BindingFlags.Public ||| BindingFlags.NonPublic ||| BindingFlags.Instance, null, [| |], null) + ilg.Emit(OpCodes.Call,cinfo) + | Some f -> + // argExprs should always include 'this' + let (cinfo,argExprs) = f (Array.toList parameters) + for argExpr in argExprs do + codeGen.EmitExpr (ExpectedStackState.Value, argExpr) + ilg.Emit(OpCodes.Call,cinfo) + + if pcinfo.IsImplicitCtor then + for ctorArgsAsFieldIdx,ctorArgsAsField in List.mapi (fun i x -> (i,x)) implicitCtorArgsAsFields do + ilg.Emit(OpCodes.Ldarg_0) + ilg.Emit(OpCodes.Ldarg, ctorArgsAsFieldIdx+1) + ilg.Emit(OpCodes.Stfld, ctorArgsAsField) + else + let code = pcinfo.GetInvokeCodeInternal true parameters + codeGen.EmitExpr (ExpectedStackState.Empty, code) + ilg.Emit(OpCodes.Ret) + + match ptd.GetConstructors(ALL) |> Seq.tryPick (function :? ProvidedConstructor as pc when pc.IsTypeInitializer -> Some pc | _ -> None) with + | None -> () + | Some pc -> + let cb = ctorMap.[pc] + let ilg = cb.GetILGenerator() + let cattr = pc.GetCustomAttributesDataImpl() + defineCustomAttrs cb.SetCustomAttribute cattr + let expr = pc.GetInvokeCodeInternal true [||] + let codeGen = CodeGenerator(assemblyMainModule, uniqueLambdaTypeName, implicitCtorArgsAsFields, transType, transField, transMeth, transCtor, isLiteralEnumField, ilg, new Dictionary<_, _>(), [| |]) + codeGen.EmitExpr (ExpectedStackState.Empty, expr) + ilg.Emit OpCodes.Ret + + // Emit the methods + for minfo in ptd.GetMethods(ALL) do + match minfo with + | :? ProvidedMethod as pminfo -> + let mb = methMap.[pminfo] + let ilg = mb.GetILGenerator() + let cattr = pminfo.GetCustomAttributesDataImpl() + defineCustomAttrs mb.SetCustomAttribute cattr + + let parameterVars = + [| if not pminfo.IsStatic then + yield Var("this", pminfo.DeclaringType) + for p in pminfo.GetParameters() do + yield Var(p.Name, p.ParameterType) |] + let parameters = + [| for v in parameterVars -> Expr.Var v |] + + let expr = pminfo.GetInvokeCodeInternal true parameters + + let locals = Dictionary() + //printfn "Emitting linqCode for %s::%s, code = %s" pminfo.DeclaringType.FullName pminfo.Name (try linqCode.ToString() with _ -> "") + + + let expectedState = if (minfo.ReturnType = typeof) then ExpectedStackState.Empty else ExpectedStackState.Value + let codeGen = CodeGenerator(assemblyMainModule, uniqueLambdaTypeName, implicitCtorArgsAsFields, transType, transField, transMeth, transCtor, isLiteralEnumField, ilg, locals, parameterVars) + codeGen.EmitExpr (expectedState, expr) + ilg.Emit OpCodes.Ret + | _ -> () + + for (bodyMethInfo,declMethInfo) in ptd.GetMethodOverrides() do + let bodyMethBuilder = methMap.[bodyMethInfo] + tb.DefineMethodOverride(bodyMethBuilder,declMethInfo) + + for evt in ptd.GetEvents(ALL) |> Seq.choose (function :? ProvidedEvent as pe -> Some pe | _ -> None) do + let eb = tb.DefineEvent(evt.Name, evt.Attributes, evt.EventHandlerType) + defineCustomAttrs eb.SetCustomAttribute (evt.GetCustomAttributesDataImpl()) + eb.SetAddOnMethod(methMap.[evt.GetAddMethod(true) :?> _]) + eb.SetRemoveOnMethod(methMap.[evt.GetRemoveMethod(true) :?> _]) + // TODO: add raiser + + for pinfo in ptd.GetProperties(ALL) |> Seq.choose (function :? ProvidedProperty as pe -> Some pe | _ -> None) do + let pb = tb.DefineProperty(pinfo.Name, pinfo.Attributes, transType pinfo.PropertyType, [| for p in pinfo.GetIndexParameters() -> transType p.ParameterType |]) + let cattr = pinfo.GetCustomAttributesDataImpl() + defineCustomAttrs pb.SetCustomAttribute cattr + if pinfo.CanRead then + let minfo = pinfo.GetGetMethod(true) + pb.SetGetMethod (methMap.[minfo :?> ProvidedMethod ]) + if pinfo.CanWrite then + let minfo = pinfo.GetSetMethod(true) + pb.SetSetMethod (methMap.[minfo :?> ProvidedMethod ])) + + // phase 4 - complete types + + let resolveHandler = ResolveEventHandler(fun _ args -> + // On Mono args.Name is full name of the type, on .NET - just name (no namespace) + typeMap.Values + |> Seq.filter (fun tb -> tb.FullName = args.Name || tb.Name = args.Name) + |> Seq.iter (fun tb -> tb.CreateType() |> ignore) + + assemblyMainModule.Assembly) + + try + AppDomain.CurrentDomain.add_TypeResolve resolveHandler + iterateTypes (fun tb _ -> tb.CreateType() |> ignore) + finally + AppDomain.CurrentDomain.remove_TypeResolve resolveHandler + +#if FX_NO_LOCAL_FILESYSTEM +#else + assembly.Save (Path.GetFileName assemblyFileName) +#endif + + let assemblyLoadedInMemory = assemblyMainModule.Assembly + + iterateTypes (fun _tb ptd -> + match ptd with + | None -> () + | Some ptd -> ptd.SetAssembly assemblyLoadedInMemory) + +#if FX_NO_LOCAL_FILESYSTEM +#else + member __.GetFinalBytes() = + let assemblyBytes = File.ReadAllBytes assemblyFileName + let _assemblyLoadedInMemory = System.Reflection.Assembly.Load(assemblyBytes,null,System.Security.SecurityContextSource.CurrentAppDomain) + //printfn "final bytes in '%s'" assemblyFileName + File.Delete assemblyFileName + assemblyBytes +#endif + +type ProvidedAssembly(assemblyFileName: string) = + let theTypes = ResizeArray<_>() + let assemblyGenerator = AssemblyGenerator(assemblyFileName) + let assemblyLazy = + lazy + assemblyGenerator.Generate(theTypes |> Seq.toList) + assemblyGenerator.Assembly +#if FX_NO_LOCAL_FILESYSTEM +#else + let theAssemblyBytesLazy = + lazy + assemblyGenerator.GetFinalBytes() + + do + GlobalProvidedAssemblyElementsTable.theTable.Add(assemblyGenerator.Assembly, theAssemblyBytesLazy) + +#endif + + let add (providedTypeDefinitions:ProvidedTypeDefinition list, enclosingTypeNames: string list option) = + for pt in providedTypeDefinitions do + if pt.IsErased then invalidOp ("The provided type "+pt.Name+"is marked as erased and cannot be converted to a generated type. Set 'IsErased' to false on the ProvidedTypeDefinition") + theTypes.Add(pt,enclosingTypeNames) + pt.SetAssemblyLazy assemblyLazy + + member x.AddNestedTypes (providedTypeDefinitions, enclosingTypeNames) = add (providedTypeDefinitions, Some enclosingTypeNames) + member x.AddTypes (providedTypeDefinitions) = add (providedTypeDefinitions, None) +#if FX_NO_LOCAL_FILESYSTEM +#else + static member RegisterGenerated (fileName:string) = + //printfn "registered assembly in '%s'" fileName + let assemblyBytes = System.IO.File.ReadAllBytes fileName + let assembly = Assembly.Load(assemblyBytes,null,System.Security.SecurityContextSource.CurrentAppDomain) + GlobalProvidedAssemblyElementsTable.theTable.Add(assembly, Lazy<_>.CreateFromValue assemblyBytes) + assembly +#endif + +#endif // NO_GENERATIVE + +module Local = + + let makeProvidedNamespace (namespaceName:string) (types:ProvidedTypeDefinition list) = + let types = [| for ty in types -> ty :> Type |] + {new IProvidedNamespace with + member __.GetNestedNamespaces() = [| |] + member __.NamespaceName = namespaceName + member __.GetTypes() = types |> Array.copy + member __.ResolveTypeName typeName : System.Type = + match types |> Array.tryFind (fun ty -> ty.Name = typeName) with + | Some ty -> ty + | None -> null + } + + +#if FX_NO_LOCAL_FILESYSTEM +type TypeProviderForNamespaces(namespacesAndTypes : list<(string * list)>) = +#else +type TypeProviderForNamespaces(namespacesAndTypes : list<(string * list)>) as this = +#endif + let otherNamespaces = ResizeArray>() + + let providedNamespaces = + lazy [| for (namespaceName,types) in namespacesAndTypes do + yield Local.makeProvidedNamespace namespaceName types + for (namespaceName,types) in otherNamespaces do + yield Local.makeProvidedNamespace namespaceName types |] + + let invalidateE = new Event() + + let disposing = Event() + +#if FX_NO_LOCAL_FILESYSTEM +#else + let probingFolders = ResizeArray() + let handler = ResolveEventHandler(fun _ args -> this.ResolveAssembly(args)) + do AppDomain.CurrentDomain.add_AssemblyResolve handler +#endif + + new (namespaceName:string,types:list) = new TypeProviderForNamespaces([(namespaceName,types)]) + new () = new TypeProviderForNamespaces([]) + + [] + member __.Disposing = disposing.Publish + +#if FX_NO_LOCAL_FILESYSTEM + interface System.IDisposable with + member x.Dispose() = + disposing.Trigger(x, EventArgs.Empty) +#else + abstract member ResolveAssembly : args : System.ResolveEventArgs -> Assembly + + default __.ResolveAssembly(args) = + let expectedName = (AssemblyName(args.Name)).Name + ".dll" + let expectedLocationOpt = + probingFolders + |> Seq.map (fun f -> IO.Path.Combine(f, expectedName)) + |> Seq.tryFind IO.File.Exists + match expectedLocationOpt with + | Some f -> Assembly.LoadFrom f + | None -> null + + member __.RegisterProbingFolder (folder) = + // use GetFullPath to ensure that folder is valid + ignore(IO.Path.GetFullPath folder) + probingFolders.Add folder + + member __.RegisterRuntimeAssemblyLocationAsProbingFolder (config : TypeProviderConfig) = + config.RuntimeAssembly + |> IO.Path.GetDirectoryName + |> this.RegisterProbingFolder + + interface System.IDisposable with + member x.Dispose() = + disposing.Trigger(x, EventArgs.Empty) + AppDomain.CurrentDomain.remove_AssemblyResolve handler +#endif + + member __.AddNamespace (namespaceName,types:list<_>) = otherNamespaces.Add (namespaceName,types) + + // FSharp.Data addition: this method is used by Debug.fs + member __.Namespaces = Seq.readonly otherNamespaces + + member this.Invalidate() = invalidateE.Trigger(this,EventArgs()) + + member __.GetStaticParametersForMethod(mb: MethodBase) = + // printfn "In GetStaticParametersForMethod" + match mb with + | :? ProvidedMethod as t -> t.GetStaticParameters() + | _ -> [| |] + + member __.ApplyStaticArgumentsForMethod(mb: MethodBase, mangledName, objs) = + // printfn "In ApplyStaticArgumentsForMethod" + match mb with + | :? ProvidedMethod as t -> t.ApplyStaticArguments(mangledName, objs) :> MethodBase + | _ -> failwith (sprintf "ApplyStaticArguments: static parameters for method %s are unexpected" mb.Name) + + interface ITypeProvider with + + [] + override __.Invalidate = invalidateE.Publish + + override __.GetNamespaces() = Array.copy providedNamespaces.Value + + member __.GetInvokerExpression(methodBase, parameters) = + let rec getInvokerExpression (methodBase : MethodBase) parameters = + match methodBase with + | :? ProvidedMethod as m when (match methodBase.DeclaringType with :? ProvidedTypeDefinition as pt -> pt.IsErased | _ -> true) -> + m.GetInvokeCodeInternal false parameters + |> expand + | :? ProvidedConstructor as m when (match methodBase.DeclaringType with :? ProvidedTypeDefinition as pt -> pt.IsErased | _ -> true) -> + m.GetInvokeCodeInternal false parameters + |> expand + // Otherwise, assume this is a generative assembly and just emit a call to the constructor or method + | :? ConstructorInfo as cinfo -> + Expr.NewObjectUnchecked(cinfo, Array.toList parameters) + | :? System.Reflection.MethodInfo as minfo -> + if minfo.IsStatic then + Expr.CallUnchecked(minfo, Array.toList parameters) + else + Expr.CallUnchecked(parameters.[0], minfo, Array.toList parameters.[1..]) + | _ -> failwith ("TypeProviderForNamespaces.GetInvokerExpression: not a ProvidedMethod/ProvidedConstructor/ConstructorInfo/MethodInfo, name=" + methodBase.Name + " class=" + methodBase.GetType().FullName) + and expand expr = + match expr with + | NewObject(ctor, args) -> getInvokerExpression ctor [| for arg in args -> expand arg|] + | Call(inst, mi, args) -> + let args = + [| + match inst with + | Some inst -> yield expand inst + | _ -> () + yield! List.map expand args + |] + getInvokerExpression mi args + | ShapeCombinationUnchecked(shape, args) -> RebuildShapeCombinationUnchecked(shape, List.map expand args) + | ShapeVarUnchecked v -> Expr.Var v + | ShapeLambdaUnchecked(v, body) -> Expr.Lambda(v, expand body) + getInvokerExpression methodBase parameters +#if FX_NO_CUSTOMATTRIBUTEDATA + + member __.GetMemberCustomAttributesData(methodBase) = + match methodBase with + | :? ProvidedTypeDefinition as m -> m.GetCustomAttributesDataImpl() + | :? ProvidedMethod as m -> m.GetCustomAttributesDataImpl() + | :? ProvidedProperty as m -> m.GetCustomAttributesDataImpl() + | :? ProvidedConstructor as m -> m.GetCustomAttributesDataImpl() + | :? ProvidedEvent as m -> m.GetCustomAttributesDataImpl() + | :? ProvidedLiteralField as m -> m.GetCustomAttributesDataImpl() + | :? ProvidedField as m -> m.GetCustomAttributesDataImpl() + | _ -> [| |] :> IList<_> + + member __.GetParameterCustomAttributesData(methodBase) = + match methodBase with + | :? ProvidedParameter as m -> m.GetCustomAttributesDataImpl() + | _ -> [| |] :> IList<_> + + +#endif + override __.GetStaticParameters(ty) = + match ty with + | :? ProvidedTypeDefinition as t -> + if ty.Name = t.Name (* REVIEW: use equality? *) then + t.GetStaticParameters() + else + [| |] + | _ -> [| |] + + override __.ApplyStaticArguments(ty,typePathAfterArguments:string[],objs) = + let typePathAfterArguments = typePathAfterArguments.[typePathAfterArguments.Length-1] + match ty with + | :? ProvidedTypeDefinition as t -> (t.MakeParametricType(typePathAfterArguments,objs) :> Type) + | _ -> failwith (sprintf "ApplyStaticArguments: static params for type %s are unexpected" ty.FullName) + +#if NO_GENERATIVE + override __.GetGeneratedAssemblyContents(_assembly) = + failwith "no generative assemblies" +#else +#if FX_NO_LOCAL_FILESYSTEM + override __.GetGeneratedAssemblyContents(_assembly) = + // TODO: this is very fake, we rely on the fact it is never needed + match System.Windows.Application.GetResourceStream(System.Uri("FSharp.Core.dll",System.UriKind.Relative)) with + | null -> failwith "FSharp.Core.dll not found as Manifest Resource, we're just trying to read some random .NET assembly, ok?" + | resStream -> + use stream = resStream.Stream + let len = stream.Length + let buf = Array.zeroCreate (int len) + let rec loop where rem = + let n = stream.Read(buf, 0, int rem) + if n < rem then loop (where + n) (rem - n) + loop 0 (int len) + buf + + //failwith "no file system" +#else + override __.GetGeneratedAssemblyContents(assembly:Assembly) = + //printfn "looking up assembly '%s'" assembly.FullName + match GlobalProvidedAssemblyElementsTable.theTable.TryGetValue assembly with + | true,bytes -> bytes.Force() + | _ -> + let bytes = System.IO.File.ReadAllBytes assembly.ManifestModule.FullyQualifiedName + GlobalProvidedAssemblyElementsTable.theTable.[assembly] <- Lazy<_>.CreateFromValue bytes + bytes +#endif +#endif diff --git a/Rezoom.SQL.Provider/ProvidedTypes.fsi b/Rezoom.SQL.Provider/ProvidedTypes.fsi new file mode 100644 index 0000000..145bb9a --- /dev/null +++ b/Rezoom.SQL.Provider/ProvidedTypes.fsi @@ -0,0 +1,526 @@ +// Copyright (c) Microsoft Corporation 2005-2014 and other contributors. +// This sample code is provided "as is" without warranty of any kind. +// We disclaim all warranties, either express or implied, including the +// warranties of merchantability and fitness for a particular purpose. +// +// This file contains a set of helper types and methods for providing types in an implementation +// of ITypeProvider. +// +// This code has been modified and is appropriate for use in conjunction with the F# 3.0-4.0 releases + + +namespace ProviderImplementation.ProvidedTypes + +open System +open System.Reflection +open System.Linq.Expressions +open Microsoft.FSharp.Quotations +open Microsoft.FSharp.Core.CompilerServices + +/// Represents an erased provided parameter +type ProvidedParameter = + inherit ParameterInfo + // [] + new : parameterName: string * parameterType: Type * ?isOut:bool * ?optionalValue:obj -> ProvidedParameter + member IsParamArray : bool with get,set + member IsReflectedDefinition : bool with get,set + +/// Represents a provided static parameter. +type ProvidedStaticParameter = + inherit ParameterInfo + // [] + new : parameterName: string * parameterType:Type * ?parameterDefaultValue:obj -> ProvidedStaticParameter + + /// Add XML documentation information to this provided constructor + member AddXmlDoc : xmlDoc: string -> unit + + /// Add XML documentation information to this provided constructor, where the computation of the documentation is delayed until necessary + member AddXmlDocDelayed : xmlDocFunction: (unit -> string) -> unit + +/// Represents an erased provided constructor. +type ProvidedConstructor = + inherit ConstructorInfo + + /// Create a new provided constructor. It is not initially associated with any specific provided type definition. + // [] + new : parameters: ProvidedParameter list -> ProvidedConstructor + + /// Add a 'Obsolete' attribute to this provided constructor + member AddObsoleteAttribute : message: string * ?isError: bool -> unit + + /// Add XML documentation information to this provided constructor + member AddXmlDoc : xmlDoc: string -> unit + + /// Add XML documentation information to this provided constructor, where the computation of the documentation is delayed until necessary + member AddXmlDocDelayed : xmlDocFunction: (unit -> string) -> unit + + /// Add XML documentation information to this provided constructor, where the documentation is re-computed every time it is required. + member AddXmlDocComputed : xmlDocFunction: (unit -> string) -> unit + + /// Set the quotation used to compute the implementation of invocations of this constructor. + member InvokeCode : (Expr list -> Expr) with set + + /// This method is used by Debug.fs + member internal GetInvokeCodeInternal : bool -> (Expr [] -> Expr) + + /// Set the target and arguments of the base constructor call. Only used for generated types. + member BaseConstructorCall : (Expr list -> ConstructorInfo * Expr list) with set + + /// Set a flag indicating that the constructor acts like an F# implicit constructor, so the + /// parameters of the constructor become fields and can be accessed using Expr.GlobalVar with the + /// same name. + member IsImplicitCtor : bool with get,set + + /// Add definition location information to the provided constructor. + member AddDefinitionLocation : line:int * column:int * filePath:string -> unit + + member IsTypeInitializer : bool with get,set + +type ProvidedMethod = + inherit MethodInfo + + /// Create a new provided method. It is not initially associated with any specific provided type definition. + // [] + new : methodName:string * parameters: ProvidedParameter list * returnType: Type -> ProvidedMethod + + /// Add XML documentation information to this provided method + member AddObsoleteAttribute : message: string * ?isError: bool -> unit + + /// Add XML documentation information to this provided constructor + member AddXmlDoc : xmlDoc: string -> unit + + /// Add XML documentation information to this provided constructor, where the computation of the documentation is delayed until necessary + member AddXmlDocDelayed : xmlDocFunction: (unit -> string) -> unit + + /// Add XML documentation information to this provided constructor, where the computation of the documentation is delayed until necessary + /// The documentation is re-computed every time it is required. + member AddXmlDocComputed : xmlDocFunction: (unit -> string) -> unit + + member AddMethodAttrs : attributes:MethodAttributes -> unit + + /// Set the method attributes of the method. By default these are simple 'MethodAttributes.Public' + member SetMethodAttrs : attributes:MethodAttributes -> unit + + /// Get or set a flag indicating if the property is static. + member IsStaticMethod : bool with get, set + + /// Set the quotation used to compute the implementation of invocations of this method. + member InvokeCode : (Expr list -> Expr) with set + + // this method is used by Debug.fs + member internal GetInvokeCodeInternal : bool -> (Expr [] -> Expr) + + /// Add definition location information to the provided type definition. + member AddDefinitionLocation : line:int * column:int * filePath:string -> unit + + /// Add a custom attribute to the provided method definition. + member AddCustomAttribute : CustomAttributeData -> unit + + /// Define the static parameters available on a statically parameterized method + member DefineStaticParameters : parameters: ProvidedStaticParameter list * instantiationFunction: (string -> obj[] -> ProvidedMethod) -> unit + +/// Represents an erased provided property. +type ProvidedProperty = + inherit PropertyInfo + + /// Create a new provided property. It is not initially associated with any specific provided type definition. + // [] + new : propertyName: string * propertyType: Type * ?parameters:ProvidedParameter list -> ProvidedProperty + + /// Add a 'Obsolete' attribute to this provided property + member AddObsoleteAttribute : message: string * ?isError: bool -> unit + + /// Add XML documentation information to this provided constructor + member AddXmlDoc : xmlDoc: string -> unit + + /// Add XML documentation information to this provided constructor, where the computation of the documentation is delayed until necessary + member AddXmlDocDelayed : xmlDocFunction: (unit -> string) -> unit + + /// Add XML documentation information to this provided constructor, where the computation of the documentation is delayed until necessary + /// The documentation is re-computed every time it is required. + member AddXmlDocComputed : xmlDocFunction: (unit -> string) -> unit + + /// Get or set a flag indicating if the property is static. + member IsStatic : bool with get,set + + /// Set the quotation used to compute the implementation of gets of this property. + member GetterCode : (Expr list -> Expr) with set + + /// Set the function used to compute the implementation of sets of this property. + member SetterCode : (Expr list -> Expr) with set + + /// Add definition location information to the provided type definition. + member AddDefinitionLocation : line:int * column:int * filePath:string -> unit + + /// Add a custom attribute to the provided property definition. + member AddCustomAttribute : CustomAttributeData -> unit + +/// Represents an erased provided property. +type ProvidedEvent = + inherit EventInfo + + /// Create a new provided type. It is not initially associated with any specific provided type definition. + new : propertyName: string * eventHandlerType: Type -> ProvidedEvent + + /// Add XML documentation information to this provided constructor + member AddXmlDoc : xmlDoc: string -> unit + + /// Add XML documentation information to this provided constructor, where the computation of the documentation is delayed until necessary + member AddXmlDocDelayed : xmlDocFunction: (unit -> string) -> unit + + /// Add XML documentation information to this provided constructor, where the computation of the documentation is delayed until necessary + /// The documentation is re-computed every time it is required. + member AddXmlDocComputed : xmlDocFunction: (unit -> string) -> unit + + /// Get or set a flag indicating if the property is static. + member IsStatic : bool with set + + /// Set the quotation used to compute the implementation of gets of this property. + member AdderCode : (Expr list -> Expr) with set + + /// Set the function used to compute the implementation of sets of this property. + member RemoverCode : (Expr list -> Expr) with set + + /// Add definition location information to the provided type definition. + member AddDefinitionLocation : line:int * column:int * filePath:string -> unit + +/// Represents an erased provided field. +type ProvidedLiteralField = + inherit FieldInfo + + /// Create a new provided field. It is not initially associated with any specific provided type definition. + // [] + new : fieldName: string * fieldType: Type * literalValue: obj -> ProvidedLiteralField + + /// Add a 'Obsolete' attribute to this provided field + member AddObsoleteAttribute : message: string * ?isError: bool -> unit + + /// Add XML documentation information to this provided field + member AddXmlDoc : xmlDoc: string -> unit + + /// Add XML documentation information to this provided field, where the computation of the documentation is delayed until necessary + member AddXmlDocDelayed : xmlDocFunction: (unit -> string) -> unit + + /// Add XML documentation information to this provided field, where the computation of the documentation is delayed until necessary + /// The documentation is re-computed every time it is required. + member AddXmlDocComputed : xmlDocFunction: (unit -> string) -> unit + + /// Add definition location information to the provided field. + member AddDefinitionLocation : line:int * column:int * filePath:string -> unit + +/// Represents an erased provided field. +type ProvidedField = + inherit FieldInfo + + /// Create a new provided field. It is not initially associated with any specific provided type definition. + // [] + new : fieldName: string * fieldType: Type -> ProvidedField + + /// Add a 'Obsolete' attribute to this provided field + member AddObsoleteAttribute : message: string * ?isError: bool -> unit + + /// Add XML documentation information to this provided field + member AddXmlDoc : xmlDoc: string -> unit + + /// Add XML documentation information to this provided field, where the computation of the documentation is delayed until necessary + member AddXmlDocDelayed : xmlDocFunction: (unit -> string) -> unit + + /// Add XML documentation information to this provided field, where the computation of the documentation is delayed until necessary + /// The documentation is re-computed every time it is required. + member AddXmlDocComputed : xmlDocFunction: (unit -> string) -> unit + + /// Add definition location information to the provided field definition. + member AddDefinitionLocation : line:int * column:int * filePath:string -> unit + + member SetFieldAttributes : attributes : FieldAttributes -> unit + +/// Represents the type constructor in a provided symbol type. +[] +type ProvidedSymbolKind = + /// Indicates that the type constructor is for a single-dimensional array + | SDArray + /// Indicates that the type constructor is for a multi-dimensional array + | Array of int + /// Indicates that the type constructor is for pointer types + | Pointer + /// Indicates that the type constructor is for byref types + | ByRef + /// Indicates that the type constructor is for named generic types + | Generic of Type + /// Indicates that the type constructor is for abbreviated types + | FSharpTypeAbbreviation of (Assembly * string * string[]) + +/// Represents an array or other symbolic type involving a provided type as the argument. +/// See the type provider spec for the methods that must be implemented. +/// Note that the type provider specification does not require us to implement pointer-equality for provided types. +[] +type ProvidedSymbolType = + inherit Type + + /// Returns the kind of this symbolic type + member Kind : ProvidedSymbolKind + + /// Return the provided types used as arguments of this symbolic type + member Args : list + + /// For example, kg + member IsFSharpTypeAbbreviation: bool + + /// For example, int or int + member IsFSharpUnitAnnotated : bool + +/// Helpers to build symbolic provided types +[] +type ProvidedTypeBuilder = + + /// Like typ.MakeGenericType, but will also work with unit-annotated types + static member MakeGenericType: genericTypeDefinition: Type * genericArguments: Type list -> Type + + /// Like methodInfo.MakeGenericMethod, but will also work with unit-annotated types and provided types + static member MakeGenericMethod: genericMethodDefinition: MethodInfo * genericArguments: Type list -> MethodInfo + +[] +/// Used internally for ProvidedTypesContext +type internal ZProvidedTypeBuilder = + new : convToTgt: (Type -> Type) -> ZProvidedTypeBuilder + member MakeGenericType: genericTypeDefinition: Type * genericArguments: Type list -> Type + member MakeGenericMethod: genericMethodDefinition: MethodInfo * genericArguments: Type list -> MethodInfo + +/// Helps create erased provided unit-of-measure annotations. +[] +type ProvidedMeasureBuilder = + + /// The ProvidedMeasureBuilder for building measures. + static member Default : ProvidedMeasureBuilder + + /// Gets the measure indicating the "1" unit of measure, that is the unitless measure. + member One : Type + + /// Returns the measure indicating the product of two units of measure, e.g. kg * m + member Product : measure1: Type * measure1: Type -> Type + + /// Returns the measure indicating the inverse of two units of measure, e.g. 1 / s + member Inverse : denominator: Type -> Type + + /// Returns the measure indicating the ratio of two units of measure, e.g. kg / m + member Ratio : numerator: Type * denominator: Type -> Type + + /// Returns the measure indicating the square of a unit of measure, e.g. m * m + member Square : ``measure``: Type -> Type + + /// Returns the measure for an SI unit from the F# core library, where the string is in capitals and US spelling, e.g. Meter + member SI : unitName:string -> Type + + /// Returns a type where the type has been annotated with the given types and/or units-of-measure. + /// e.g. float, Vector + member AnnotateType : basic: Type * argument: Type list -> Type + + +/// Represents a provided type definition. +type ProvidedTypeDefinition = + inherit Type + + /// Create a new provided type definition in a namespace. + // [] + new : assembly: Assembly * namespaceName: string * className: string * baseType: Type option -> ProvidedTypeDefinition + + /// Create a new provided type definition, to be located as a nested type in some type definition. + // [] + new : className : string * baseType: Type option -> ProvidedTypeDefinition + + + internal new : assembly: Assembly * namespaceName: string * className: string * baseType: Type option * convToTgt: (Type -> Type) -> ProvidedTypeDefinition + internal new : className : string * baseType: Type option * convToTgt: (Type -> Type) -> ProvidedTypeDefinition + + + /// Add the given type as an implemented interface. + member AddInterfaceImplementation : interfaceType: Type -> unit + + /// Add the given function as a set of on-demand computed interfaces. + member AddInterfaceImplementationsDelayed : interfacesFunction:(unit -> Type list)-> unit + + /// Specifies that the given method body implements the given method declaration. + member DefineMethodOverride : methodInfoBody: ProvidedMethod * methodInfoDeclaration: MethodInfo -> unit + + /// Add a 'Obsolete' attribute to this provided type definition + member AddObsoleteAttribute : message: string * ?isError: bool -> unit + + /// Add XML documentation information to this provided constructor + member AddXmlDoc : xmlDoc: string -> unit + + /// Set the base type + member SetBaseType : Type -> unit + + /// Set the base type to a lazily evaluated value. Use this to delay realization of the base type as late as possible. + member SetBaseTypeDelayed : baseTypeFunction:(unit -> Type) -> unit + + /// Set underlying type for generated enums + member SetEnumUnderlyingType : Type -> unit + + /// Add XML documentation information to this provided constructor, where the computation of the documentation is delayed until necessary. + /// The documentation is only computed once. + member AddXmlDocDelayed : xmlDocFunction: (unit -> string) -> unit + + /// Add XML documentation information to this provided constructor, where the computation of the documentation is delayed until necessary + /// The documentation is re-computed every time it is required. + member AddXmlDocComputed : xmlDocFunction: (unit -> string) -> unit + + /// Set the attributes on the provided type. This fully replaces the default TypeAttributes. + member SetAttributes : TypeAttributes -> unit + + /// Reset the enclosing type (for generated nested types) + member ResetEnclosingType: enclosingType:Type -> unit + + /// Add a method, property, nested type or other member to a ProvidedTypeDefinition + member AddMember : memberInfo:MemberInfo -> unit + + /// Add a set of members to a ProvidedTypeDefinition + member AddMembers : memberInfos:list<#MemberInfo> -> unit + + /// Add a member to a ProvidedTypeDefinition, delaying computation of the members until required by the compilation context. + member AddMemberDelayed : memberFunction:(unit -> #MemberInfo) -> unit + + /// Add a set of members to a ProvidedTypeDefinition, delaying computation of the members until required by the compilation context. + member AddMembersDelayed : membersFunction:(unit -> list<#MemberInfo>) -> unit + +#if NO_GENERATIVE +#else + /// Add the types of the generated assembly as generative types, where types in namespaces get hierarchically positioned as nested types. + member AddAssemblyTypesAsNestedTypesDelayed : assemblyFunction:(unit -> Assembly) -> unit +#endif + + /// Define the static parameters available on a statically parameterized type + member DefineStaticParameters : parameters: ProvidedStaticParameter list * instantiationFunction: (string -> obj[] -> ProvidedTypeDefinition) -> unit + + /// Add definition location information to the provided type definition. + member AddDefinitionLocation : line:int * column:int * filePath:string -> unit + + /// Suppress Object entries in intellisense menus in instances of this provided type + member HideObjectMethods : bool with set + + /// Disallows the use of the null literal. + member NonNullable : bool with set + + /// Get or set a flag indicating if the ProvidedTypeDefinition is erased + member IsErased : bool with get,set + + /// Get or set a flag indicating if the ProvidedTypeDefinition has type-relocation suppressed + [] + member SuppressRelocation : bool with get,set + + // This method is used by Debug.fs + member MakeParametricType : name:string * args:obj[] -> ProvidedTypeDefinition + + /// Add a custom attribute to the provided type definition. + member AddCustomAttribute : CustomAttributeData -> unit + + /// Emulate the F# type provider type erasure mechanism to get the + /// actual (erased) type. We erase ProvidedTypes to their base type + /// and we erase array of provided type to array of base type. In the + /// case of generics all the generic type arguments are also recursively + /// replaced with the erased-to types + static member EraseType : typ:Type -> Type + + /// Get or set a utility function to log the creation of root Provided Type. Used to debug caching/invalidation. + static member Logger : (string -> unit) option ref + +#if NO_GENERATIVE +#else +/// A provided generated assembly +type ProvidedAssembly = + /// Create a provided generated assembly + new : assemblyFileName:string -> ProvidedAssembly + + /// Emit the given provided type definitions as part of the assembly + /// and adjust the 'Assembly' property of all provided type definitions to return that + /// assembly. + /// + /// The assembly is only emitted when the Assembly property on the root type is accessed for the first time. + /// The host F# compiler does this when processing a generative type declaration for the type. + member AddTypes : types : ProvidedTypeDefinition list -> unit + + /// + /// Emit the given nested provided type definitions as part of the assembly. + /// and adjust the 'Assembly' property of all provided type definitions to return that + /// assembly. + /// + /// A path of type names to wrap the generated types. The generated types are then generated as nested types. + member AddNestedTypes : types : ProvidedTypeDefinition list * enclosingGeneratedTypeNames: string list -> unit + +#if FX_NO_LOCAL_FILESYSTEM +#else + /// Register that a given file is a provided generated assembly + static member RegisterGenerated : fileName:string -> Assembly +#endif + +#endif + + +/// A base type providing default implementations of type provider functionality when all provided +/// types are of type ProvidedTypeDefinition. +type TypeProviderForNamespaces = + + /// Initializes a type provider to provide the types in the given namespace. + new : namespaceName:string * types: ProvidedTypeDefinition list -> TypeProviderForNamespaces + + /// Initializes a type provider + new : unit -> TypeProviderForNamespaces + + /// Invoked by the type provider to add a namespace of provided types in the specification of the type provider. + member AddNamespace : namespaceName:string * types: ProvidedTypeDefinition list -> unit + + /// Invoked by the type provider to get all provided namespaces with their provided types. + member Namespaces : seq + + /// Invoked by the type provider to invalidate the information provided by the provider + member Invalidate : unit -> unit + + /// Invoked by the host of the type provider to get the static parameters for a method. + member GetStaticParametersForMethod : MethodBase -> ParameterInfo[] + + /// Invoked by the host of the type provider to apply the static argumetns for a method. + member ApplyStaticArgumentsForMethod : MethodBase * string * obj[] -> MethodBase + +#if FX_NO_LOCAL_FILESYSTEM +#else + /// AssemblyResolve handler. Default implementation searches .dll file in registered folders + abstract ResolveAssembly : ResolveEventArgs -> Assembly + default ResolveAssembly : ResolveEventArgs -> Assembly + + /// Registers custom probing path that can be used for probing assemblies + member RegisterProbingFolder : folder: string -> unit + + /// Registers location of RuntimeAssembly (from TypeProviderConfig) as probing folder + member RegisterRuntimeAssemblyLocationAsProbingFolder : config: TypeProviderConfig -> unit + +#endif + + [] + member Disposing : IEvent + + interface ITypeProvider + + +module internal UncheckedQuotations = + + type Expr with + static member NewDelegateUnchecked : ty:Type * vs:Var list * body:Expr -> Expr + static member NewObjectUnchecked : cinfo:ConstructorInfo * args:Expr list -> Expr + static member NewArrayUnchecked : elementType:Type * elements:Expr list -> Expr + static member CallUnchecked : minfo:MethodInfo * args:Expr list -> Expr + static member CallUnchecked : obj:Expr * minfo:MethodInfo * args:Expr list -> Expr + static member ApplicationUnchecked : f:Expr * x:Expr -> Expr + static member PropertyGetUnchecked : pinfo:PropertyInfo * args:Expr list -> Expr + static member PropertyGetUnchecked : obj:Expr * pinfo:PropertyInfo * ?args:Expr list -> Expr + static member PropertySetUnchecked : pinfo:PropertyInfo * value:Expr * ?args:Expr list -> Expr + static member PropertySetUnchecked : obj:Expr * pinfo:PropertyInfo * value:Expr * args:Expr list -> Expr + static member FieldGetUnchecked : pinfo:FieldInfo -> Expr + static member FieldGetUnchecked : obj:Expr * pinfo:FieldInfo -> Expr + static member FieldSetUnchecked : pinfo:FieldInfo * value:Expr -> Expr + static member FieldSetUnchecked : obj:Expr * pinfo:FieldInfo * value:Expr -> Expr + static member TupleGetUnchecked : e:Expr * n:int -> Expr + static member LetUnchecked : v:Var * e:Expr * body:Expr -> Expr + + type Shape + val ( |ShapeCombinationUnchecked|ShapeVarUnchecked|ShapeLambdaUnchecked| ) : e:Expr -> Choice<(Shape * Expr list),Var, (Var * Expr)> + val RebuildShapeCombinationUnchecked : Shape * args:Expr list -> Expr diff --git a/Rezoom.SQL.Provider/ProvidedTypesTrickery.fs b/Rezoom.SQL.Provider/ProvidedTypesTrickery.fs new file mode 100644 index 0000000..9791c07 --- /dev/null +++ b/Rezoom.SQL.Provider/ProvidedTypesTrickery.fs @@ -0,0 +1,35 @@ +module internal Microsoft.FSharp.Quotations.DerivedPatterns +open System +open System.Text +open System.IO +open System.Reflection +open System.Linq.Expressions +open System.Collections.Generic +open Microsoft.FSharp.Quotations +open Microsoft.FSharp.Quotations.Patterns +open Microsoft.FSharp.Quotations.DerivedPatterns +open Microsoft.FSharp.Core.CompilerServices + +let private metadataToken (minfo : MethodInfo) = + try minfo.MetadataToken with + | _ -> -1 + +let private methodsMatch isg1 gmd minfo1 minfo2 = + metadataToken minfo1 = metadataToken minfo2 + && if isg1 then minfo2.IsGenericMethod && gmd = minfo2.GetGenericMethodDefinition() + else minfo1 = minfo2 + +[] +let (|SpecificCall|_|) templateParameter = + match templateParameter with + | (Lambdas(_, Call(_, minfo1, _)) | Call(_, minfo1, _)) -> + // precompute these two + let isg1 = minfo1.IsGenericMethod + let gmd = if isg1 then minfo1.GetGenericMethodDefinition() else null + (fun tm -> + match tm with + | Call(obj, minfo2, args) when methodsMatch isg1 gmd minfo1 minfo2 -> + Some(obj, (minfo2.GetGenericArguments() |> Array.toList), args) + | _ -> None) + | _ -> + invalidArg "templateParameter" "unrecognized method call" \ No newline at end of file diff --git a/Rezoom.SQL.Provider/Provider.fs b/Rezoom.SQL.Provider/Provider.fs new file mode 100644 index 0000000..f349800 --- /dev/null +++ b/Rezoom.SQL.Provider/Provider.fs @@ -0,0 +1,68 @@ +namespace Rezoom.SQL.Provider +open System +open System.Collections.Generic +open System.IO +open System.Reflection +open Microsoft.FSharp.Core.CompilerServices +open Microsoft.FSharp.Quotations +open ProviderImplementation.ProvidedTypes +open Rezoom.SQL +open Rezoom.SQL.Provider.TypeGeneration + +[] +type public Provider(cfg : TypeProviderConfig) as this = + inherit TypeProviderForNamespaces() + + // Get the assembly and namespace used to house the provided types. + let thisAssembly = Assembly.LoadFrom(cfg.RuntimeAssembly) + let tmpAssembly = ProvidedAssembly(Path.GetTempFileName()) + let rootNamespace = "Rezoom.SQL.Provider" + + let modelCache = new UserModelCache() + let generateType typeName model case = + let tmpAssembly = ProvidedAssembly(Path.GetTempFileName()) + let model = modelCache.Load(cfg.ResolutionFolder, model) + let ty = + { Assembly = thisAssembly + Namespace = rootNamespace + TypeName = typeName + UserModel = model + Case = case + } |> generateType + tmpAssembly.AddTypes([ ty ]) + ty + + let sqlTy = + let sqlTy = + ProvidedTypeDefinition(thisAssembly, rootNamespace, "SQL", Some typeof, IsErased = false) + let staticParams = + [ ProvidedStaticParameter("sql", typeof) + ProvidedStaticParameter("model", typeof, "") + ] + let buildSQLFromStaticParams typeName (parameterValues : obj array) = + match parameterValues with + | [| :? string as sql; :? string as model |] -> generateType typeName model (GenerateSQL sql) + | _ -> failwith "Invalid parameters (expected 2 strings: sql, model)" + sqlTy.DefineStaticParameters(staticParams, buildSQLFromStaticParams) + sqlTy + + let modelTy = + let modelTy = + ProvidedTypeDefinition(thisAssembly, rootNamespace, "SQLModel", Some typeof, IsErased = false) + let staticParams = [ ProvidedStaticParameter("model", typeof, "") ] + let buildModelFromStaticParams typeName (parameterValues : obj array) = + match parameterValues with + | [| :? string as model |] -> generateType typeName model GenerateModel + | _ -> failwith "Invalid parameters (expected 1 string: model)" + modelTy.DefineStaticParameters(staticParams, buildModelFromStaticParams) + modelTy + + do + let tys = [ sqlTy; modelTy ] + tmpAssembly.AddTypes(tys) + this.AddNamespace(rootNamespace, tys) + modelCache.Invalidated.Add(fun _ -> this.Invalidate()) + this.Disposing.Add(fun _ -> modelCache.Dispose()) + +[] +do () \ No newline at end of file diff --git a/Rezoom.SQL.Provider/Rezoom.SQL.Provider.fsproj b/Rezoom.SQL.Provider/Rezoom.SQL.Provider.fsproj new file mode 100644 index 0000000..ad03f35 --- /dev/null +++ b/Rezoom.SQL.Provider/Rezoom.SQL.Provider.fsproj @@ -0,0 +1,101 @@ + + + + + Debug + AnyCPU + 2.0 + 7b1765cb-23f8-419a-9cc6-3da319ed066f + Library + Rezoom.SQL.Provider + Rezoom.SQL.Provider + v4.6 + 4.4.0.0 + true + Rezoom.SQL.Provider + + + true + full + false + false + bin\Debug\ + DEBUG;TRACE + 3 + bin\Debug\Rezoom.SQL.Provider.XML + Program + C:\Program Files (x86)\Microsoft SDKs\F#\4.0\Framework\v4.0\Fsi.exe + user.fsx + D:\src\Rezoom\Rezoom.SQL.Provider\ + + + pdbonly + true + true + bin\Release\ + TRACE + 3 + bin\Release\Rezoom.SQL.Provider.XML + Program + + + 11 + + + + + $(MSBuildExtensionsPath32)\..\Microsoft SDKs\F#\3.0\Framework\v4.0\Microsoft.FSharp.Targets + + + + + $(MSBuildExtensionsPath32)\Microsoft\VisualStudio\v$(VisualStudioVersion)\FSharp\Microsoft.FSharp.Targets + + + + + + + + + + + + + + + + + + + Rezoom.SQL.Compiler + {87fcd04a-1f90-4d53-a428-cf5f5c532a22} + True + + + Rezoom + {d98acbeb-a039-4340-a7c5-6ed2b677268b} + True + + + + True + + + + + + + Rezoom.SQL.Mapping + {6b6a06c5-157a-4fe3-8b4c-2a1ae6a15333} + True + + + + diff --git a/Rezoom.SQL.Provider/TypeGeneration.fs b/Rezoom.SQL.Provider/TypeGeneration.fs new file mode 100644 index 0000000..af975c9 --- /dev/null +++ b/Rezoom.SQL.Provider/TypeGeneration.fs @@ -0,0 +1,303 @@ +module Rezoom.SQL.Provider.TypeGeneration +open System +open System.Data +open System.Data.Common +open System.Collections.Generic +open System.IO +open System.Text.RegularExpressions +open System.Reflection +open FSharp.Core.CompilerServices +open FSharp.Quotations +open FSharp.Reflection +open ProviderImplementation.ProvidedTypes +open ProviderImplementation.ProvidedTypes.UncheckedQuotations +open Rezoom +open Rezoom.SQL.Mapping +open Rezoom.SQL.Compiler + +type GenerateTypeCase = + | GenerateSQL of string + | GenerateModel + +type GenerateType = + { UserModel : UserModel + Assembly : Assembly + Namespace : string + TypeName : string + Case : GenerateTypeCase + } + +let private parameterIndexer (pars : BindParameter seq) = + let dict = + pars |> Seq.indexed |> Seq.map (fun (a, b) -> (b, a)) |> dict + { new IParameterIndexer with + member __.ParameterIndex(par) = dict.[par] + } + +let private toFragmentExpr (fragment : CommandFragment) = + match fragment with + | LocalName n -> <@@ LocalName (%%Quotations.Expr.Value(n)) @@> + | CommandText t -> <@@ CommandText (%%Quotations.Expr.Value(t)) @@> + | Parameter i -> <@@ Parameter (%%Quotations.Expr.Value(i)) @@> + | Whitespace -> <@@ Whitespace @@> + +let private toFragmentArrayExpr (fragments : CommandFragment IReadOnlyList) = + Expr.NewArray(typeof, fragments |> Seq.map toFragmentExpr |> Seq.toList) + +/// Lowercase initial uppercase characters. +let private toCamelCase (str : string) = + Regex.Replace(str, @"^\p{Lu}+", fun m -> m.Value.ToLowerInvariant()) + +let private toRowTypeName (name : string) = + // Must sanitize to remove things like * from the name. + Regex.Replace(name, @"[^_a-zA-Z0-9]", fun m -> string (char (int m.Value.[0] % 26 + int 'A'))) + "Row" + +type private BlueprintNoKeyAttributeData() = + inherit CustomAttributeData() + override __.Constructor = typeof.GetConstructor(Type.EmptyTypes) + override __.ConstructorArguments = [||] :> IList<_> + override __.NamedArguments = [||] :> IList<_> + +type private BlueprintKeyAttributeData() = + inherit CustomAttributeData() + override __.Constructor = typeof.GetConstructor(Type.EmptyTypes) + override __.ConstructorArguments = [||] :> IList<_> + override __.NamedArguments = [||] :> IList<_> + +type private BlueprintColumnNameAttributeData(name : string) = + inherit CustomAttributeData() + override __.Constructor = typeof.GetConstructor([| typeof |]) + override __.ConstructorArguments = + [| CustomAttributeTypedArgument(typeof, name) + |] :> IList<_> + override __.NamedArguments = [||] :> IList<_> + +let rec private generateRowTypeFromColumns (model : UserModel) name (columnMap : CompileTimeColumnMap) = + let ty = + ProvidedTypeDefinition + ( name + , Some typeof + , IsErased = false + , HideObjectMethods = true + ) + if not columnMap.HasSubMaps then + ty.AddCustomAttribute(BlueprintNoKeyAttributeData()) + let fields = ResizeArray() + let addField pk (name : string) (fieldTy : Type) = + let fieldTy, propName = + if name.EndsWith("*") then + typedefof<_ IReadOnlyList>.MakeGenericType(fieldTy), name.Substring(0, name.Length - 1) + elif name.EndsWith("?") then + typedefof<_ option>.MakeGenericType(fieldTy), name.Substring(0, name.Length - 1) + else fieldTy, name + let camel = toCamelCase propName + let field = ProvidedField("_" + camel, fieldTy) + field.SetFieldAttributes(FieldAttributes.Private) + let getter = ProvidedProperty(propName, fieldTy) + if pk then + getter.AddCustomAttribute(BlueprintKeyAttributeData()) + if name <> propName then + getter.AddCustomAttribute(BlueprintColumnNameAttributeData(name)) + getter.GetterCode <- + function + | [ this ] -> Expr.FieldGet(this, field) + | _ -> failwith "Invalid getter argument list" + ty.AddMembers [ field :> MemberInfo; getter :> _ ] + fields.Add(camel, field) + for KeyValue(name, (_, column)) in columnMap.Columns do + let info = column.Expr.Info + addField info.PrimaryKey name <| info.Type.CLRType(useOptional = (model.Config.Optionals = Config.FsStyle)) + for KeyValue(name, subMap) in columnMap.SubMaps do + let subTy = generateRowTypeFromColumns model (toRowTypeName name) subMap + ty.AddMember(subTy) + addField false name subTy + let ctorParams = [ for camel, field in fields -> ProvidedParameter(camel, field.FieldType) ] + let ctor = ProvidedConstructor(ctorParams) + ctor.InvokeCode <- + function + | this :: pars -> + Seq.zip fields pars + |> Seq.fold + (fun exp ((_, field), par) -> Expr.Sequential(exp, Expr.FieldSet(this, field, par))) + (Quotations.Expr.Value(())) + | _ -> failwith "Invalid ctor argument list" + ty.AddMember(ctor) + ty + +let private generateRowType (model : UserModel) (name : string) (query : ColumnType QueryExprInfo) = + CompileTimeColumnMap.Parse(query.Columns) + |> generateRowTypeFromColumns model name + +let private maskOfTables (model : UserModel) (tables : (Name * Name) seq) = + let mutable mask = BitMask.Zero + for table in tables do + match model.TableIds.Value |> Map.tryFind table with + | None -> () + | Some id -> + mask <- mask.WithBit(id % 128, true) + mask + +let private generateCommandMethod + (generate : GenerateType) (command : CommandEffect) (retTy : Type) (callMeth : MethodInfo) = + let backend = generate.UserModel.Backend + let parameters = command.Parameters |> Seq.sortBy fst |> Seq.toList + let indexer = parameterIndexer (parameters |> Seq.map fst) + let commandData = + let fragments = backend.ToCommandFragments(indexer, command.Statements) |> toFragmentArrayExpr + let identity = generate.Namespace + generate.TypeName + let resultSetCount = command.ResultSets() |> Seq.length + let cacheable, dependencies, invalidations = + match command.CacheInfo.Value with + | Some info -> + ( info.Idempotent + , maskOfTables generate.UserModel info.ReadTables + , maskOfTables generate.UserModel info.WriteTables + ) + | None -> false, BitMask.Full, BitMask.Full // assume the worst + <@@ { ConnectionName = %%Quotations.Expr.Value(generate.UserModel.ConnectionName) + Identity = %%Quotations.Expr.Value(identity) + Fragments = (%%fragments : _ array) :> _ IReadOnlyList + Cacheable = %%Quotations.Expr.Value(cacheable) + DependencyMask = + BitMask + ( %%Quotations.Expr.Value(dependencies.HighBits) + , %%Quotations.Expr.Value(dependencies.LowBits)) + InvalidationMask = + BitMask + ( %%Quotations.Expr.Value(invalidations.HighBits) + , %%Quotations.Expr.Value(invalidations.LowBits)) + ResultSetCount = Some (%%Quotations.Expr.Value(resultSetCount)) + } @@> + let useOptional = generate.UserModel.Config.Optionals = Config.FsStyle + let methodParameters = + [ for NamedParameter name, ty in parameters -> + ProvidedParameter(name.Value, ty.CLRType(useOptional)) + ] + let meth = ProvidedMethod("Command", methodParameters, retTy) + meth.SetMethodAttrs(MethodAttributes.Static ||| MethodAttributes.Public) + meth.InvokeCode <- + fun args -> + let arr = + Expr.NewArray + ( typeof + , (args, parameters) ||> List.map2 (fun ex (_, ty) -> + let tx = backend.ParameterTransform(ty) + Expr.NewTuple([ tx.ValueTransform ex; Quotations.Expr.Value(tx.ParameterType) ])) + ) + Expr.CallUnchecked(callMeth, [ commandData; arr ]) + meth + +let generateSQLType (generate : GenerateType) (sql : string) = + let commandEffect = CommandEffect.OfSQL(generate.UserModel.Model, generate.TypeName, sql) + let commandCtor = typeof + let cmd (r : Type) = typedefof<_ Command>.MakeGenericType(r) + let lst (r : Type) = typedefof<_ IReadOnlyList>.MakeGenericType(r) + let rowTypes, commandCtorMethod, commandType = + let genRowType = generateRowType generate.UserModel + match commandEffect.ResultSets() |> Seq.toList with + | [] -> + [] + , commandCtor.GetMethod("Command0") + , cmd typeof + | [ resultSet ] -> + let rowType = genRowType "Row" resultSet + [ rowType ] + , commandCtor.GetMethod("Command1").MakeGenericMethod(lst rowType) + , cmd (lst rowType) + | [ resultSet1; resultSet2 ] -> + let rowType1 = genRowType "Row1" resultSet1 + let rowType2 = genRowType "Row2" resultSet2 + [ rowType1; rowType2 ] + , commandCtor.GetMethod("Command2").MakeGenericMethod(lst rowType1, lst rowType2) + , cmd <| typedefof>.MakeGenericType(lst rowType1, lst rowType2) + | [ resultSet1; resultSet2; resultSet3 ] -> + let rowType1 = genRowType "Row1" resultSet1 + let rowType2 = genRowType "Row2" resultSet2 + let rowType3 = genRowType "Row3" resultSet3 + [ rowType1; rowType2; rowType3 ] + , commandCtor.GetMethod("Command3").MakeGenericMethod(lst rowType1, lst rowType2, lst rowType3) + , cmd <| typedefof>.MakeGenericType(lst rowType1, lst rowType2, lst rowType3) + | [ resultSet1; resultSet2; resultSet3; resultSet4 ] -> + let rowType1 = genRowType "Row1" resultSet1 + let rowType2 = genRowType "Row2" resultSet2 + let rowType3 = genRowType "Row3" resultSet3 + let rowType4 = genRowType "Row4" resultSet4 + [ rowType1; rowType2; rowType3; rowType4 ] + , commandCtor.GetMethod("Command4").MakeGenericMethod + (lst rowType1, lst rowType2, lst rowType3, lst rowType4) + , cmd <| + typedefof>.MakeGenericType + (lst rowType1, lst rowType2, lst rowType3, lst rowType4) + | sets -> + failwithf "Too many (%d) result sets from command." (List.length sets) + let provided = + ProvidedTypeDefinition + ( generate.Assembly + , generate.Namespace + , generate.TypeName + , Some typeof + , IsErased = false + , HideObjectMethods = true + ) + provided.AddMembers rowTypes + provided.AddMember <| generateCommandMethod generate commandEffect commandType commandCtorMethod + provided + +let generateModelType (generate : GenerateType) = + let backend = generate.UserModel.Backend + let provided = + ProvidedTypeDefinition + ( generate.Assembly + , generate.Namespace + , generate.TypeName + , Some typeof + , IsErased = false + , HideObjectMethods = true + ) + let migrationsField = + ProvidedField + ( "_migrations" + , typeof + ) + migrationsField.SetFieldAttributes(FieldAttributes.Static ||| FieldAttributes.Private) + provided.AddMember <| migrationsField + let staticCtor = + ProvidedConstructor([], IsTypeInitializer = true) + staticCtor.InvokeCode <- fun _ -> + Expr.FieldSet + ( migrationsField + , Expr.NewArray + ( typeof + , generate.UserModel.Migrations |> Seq.map Migrations.quotationizeMigrationTree |> Seq.toList + )) + provided.AddMember <| staticCtor + provided.AddMember <| + ProvidedProperty + ( "Migrations" + , typeof + , GetterCode = fun _ -> Expr.FieldGet(migrationsField) + , IsStatic = true + ) + do + let pars = + [ ProvidedParameter("config", typeof) + ProvidedParameter("conn", typeof) + ] + let meth = ProvidedMethod("Migrate", pars, typeof) + meth.IsStaticMethod <- true + meth.InvokeCode <- function + | [config; conn] -> + <@@ + let migrationBackend : Migrations.IMigrationBackend = + (%%(upcast backend.MigrationBackend)) (%%conn : DbConnection) + let migrations : string Migrations.MigrationTree array = %%Expr.FieldGet(migrationsField) + Migrations.runMigrations %%config migrationBackend migrations + @@> + | _ -> failwith "Invalid migrate argument list" + provided.AddMember meth + provided + +let generateType (generate : GenerateType) = + match generate.Case with + | GenerateSQL sql -> generateSQLType generate sql + | GenerateModel -> generateModelType generate \ No newline at end of file diff --git a/Rezoom.SQL.Provider/UserModelCache.fs b/Rezoom.SQL.Provider/UserModelCache.fs new file mode 100644 index 0000000..bebd9dd --- /dev/null +++ b/Rezoom.SQL.Provider/UserModelCache.fs @@ -0,0 +1,40 @@ +namespace Rezoom.SQL.Provider +open System +open System.Collections.Generic +open System.IO +open Rezoom.SQL.Compiler + +type UserModelCache() as this = + let watchers = Dictionary() + let cache = Dictionary() + let invalidated = Event() + + let addWatcher path invalidateKey = + let succ, watcher = watchers.TryGetValue(path) + let watcher = + if succ then watcher else + let watcher = new Watcher(path) + watcher.Invalidated.Add(fun _ -> invalidated.Trigger(this, EventArgs.Empty)) + watchers.Add(path, watcher) + watcher + watcher.Invalidating.Add(fun _ -> ignore <| cache.Remove(invalidateKey)) // remove from cache on changes + + [] + member __.Invalidated = invalidated.Publish + + member this.Load(resolutionFolder, modelPath) = + let key = (resolutionFolder, modelPath) + let succ, cachedModel = cache.TryGetValue(key) + if succ then cachedModel else + let model = UserModel.Load(resolutionFolder, modelPath) + cache.[key] <- model + addWatcher model.ConfigDirectory key + addWatcher model.MigrationsDirectory key + model + + member this.Dispose() = + for KeyValue(_, w) in watchers do + w.Dispose() + watchers.Clear() + interface IDisposable with + member this.Dispose() = this.Dispose() diff --git a/Rezoom.SQL.Provider/Watcher.fs b/Rezoom.SQL.Provider/Watcher.fs new file mode 100644 index 0000000..7049396 --- /dev/null +++ b/Rezoom.SQL.Provider/Watcher.fs @@ -0,0 +1,39 @@ +namespace Rezoom.SQL.Provider +open System +open System.Threading +open System.IO +open Rezoom.SQL.Compiler + +type Watcher(path : string) as this = + let fs = new FileSystemWatcher(path, IncludeSubdirectories = true) + let invalidating = Event() + let invalidated = Event() + let timer = new Timer(fun _ -> // we use a timer to buffer changes so we don't invalidate many times in a couple ms + invalidating.Trigger(this, EventArgs.Empty) + invalidated.Trigger(this, EventArgs.Empty)) + static let isRelevant (path : string) = + path.EndsWith(".SQL", StringComparison.OrdinalIgnoreCase) + || path.EndsWith(UserModel.ConfigFileName, StringComparison.OrdinalIgnoreCase) + let handler (ev : FileSystemEventArgs) = + if isRelevant ev.FullPath then + ignore <| timer.Change(TimeSpan.FromMilliseconds(100.0), Timeout.InfiniteTimeSpan) + + do + fs.Created.Add(handler) + fs.Deleted.Add(handler) + fs.Changed.Add(handler) + fs.Renamed.Add(handler) + fs.EnableRaisingEvents <- true + + member __.Path = path + + [] + member __.Invalidating = invalidating.Publish + [] + member __.Invalidated = invalidated.Publish + member __.Dispose() = + fs.Dispose() + timer.Dispose() + + interface IDisposable with + member this.Dispose() = this.Dispose() \ No newline at end of file diff --git a/Rezoom.SQL.Provider/user-migrations/V1.initial.sql b/Rezoom.SQL.Provider/user-migrations/V1.initial.sql new file mode 100644 index 0000000..e2c89b6 --- /dev/null +++ b/Rezoom.SQL.Provider/user-migrations/V1.initial.sql @@ -0,0 +1,19 @@ +create table Users + ( Id int primary key + , Name string(128) null + , Email string(128) + , Password binary(64) + , Salt binary(64) + ); + +create Table Groups + ( Id int primary key + , Name string(128) + ); + +create table UserGroupMaps + ( UserId int references Users(Id) + , GroupId int references Groups(Id) + , primary key(UserId, GroupId) + ); + diff --git a/Rezoom.SQL.Provider/user.fsx b/Rezoom.SQL.Provider/user.fsx new file mode 100644 index 0000000..a0351b1 --- /dev/null +++ b/Rezoom.SQL.Provider/user.fsx @@ -0,0 +1,41 @@ +#I "bin/Debug" +#r "FSharp.Core.dll" +#r "FParsec.dll" +#r "FParsecCS.dll" +#r "FParsec-Pipes.dll" +#r "LicenseToCIL.dll" +#r "Rezoom.SQL.dll" +#r "Rezoom.SQL.Mapping.dll" +#r "Rezoom.SQL.Provider.dll" + +open Rezoom.SQL.Provider +open Rezoom.SQL.Mapping + +type M = SQLModel + +type Query = SQL<""" + select * from Users u where u.id = @id +""", "user-migrations"> + +type QueryInPar = SQL<""" + select * from Users u where u.id in @id +""", "user-migrations"> + +type QueryWithNullablePar = SQL<""" + select * from Users u + where u.Name is @name +""", "user-migrations"> + + +let q : Command<_> = Query.Command(id = 1) +printfn "%O" <| q.GetType() +printfn "%A" <| q.Fragments + +let qIn : Command<_> = QueryInPar.Command(id = [|1|]) +printfn "%O" <| qIn.GetType() +printfn "%A" <| qIn.Fragments + +let qNull : Command<_> = QueryWithNullablePar.Command(name = Some "test") +printfn "%O" <| qNull.GetType() +printfn "%A" <| qNull.Fragments + diff --git a/Rezoom.SQL.Test/AssemblyInfo.fs b/Rezoom.SQL.Test/AssemblyInfo.fs new file mode 100644 index 0000000..89a5f63 --- /dev/null +++ b/Rezoom.SQL.Test/AssemblyInfo.fs @@ -0,0 +1,41 @@ +namespace Rezoom.SQL.Test.AssemblyInfo + +open System.Reflection +open System.Runtime.CompilerServices +open System.Runtime.InteropServices + +// General Information about an assembly is controlled through the following +// set of attributes. Change these attribute values to modify the information +// associated with an assembly. +[] +[] +[] +[] +[] +[] +[] +[] + +// Setting ComVisible to false makes the types in this assembly not visible +// to COM components. If you need to access a type in this assembly from +// COM, set the ComVisible attribute to true on that type. +[] + +// The following GUID is for the ID of the typelib if this project is exposed to COM +[] + +// Version information for an assembly consists of the following four values: +// +// Major Version +// Minor Version +// Build Number +// Revision +// +// You can specify all the values or you can default the Build and Revision Numbers +// by using the '*' as shown below: +// [] +[] +[] + +do + () \ No newline at end of file diff --git a/Rezoom.SQL.Test/Environment.fs b/Rezoom.SQL.Test/Environment.fs new file mode 100644 index 0000000..8f06efe --- /dev/null +++ b/Rezoom.SQL.Test/Environment.fs @@ -0,0 +1,42 @@ +[] +module Rezoom.SQL.Test.Environment +open NUnit.Framework +open FsUnit +open System +open System.Reflection +open System.IO +open System.Collections.Generic +open Rezoom.SQL.Compiler + +let userModelByName name = + let assemblyFolder = Path.GetDirectoryName(Uri(Assembly.GetExecutingAssembly().CodeBase).LocalPath) + let resolutionFolder = Path.Combine(assemblyFolder, "../../" + name) + UserModel.Load(resolutionFolder, ".") + +let userModel1() = userModelByName "user-model-1" + +let userModel2() = userModelByName "user-model-2" + +let expectError (msg : string) (sql : string) = + let userModel = userModel1() + try + ignore <| CommandEffect.OfSQL(userModel.Model, "anonymous", sql) + failwith "Should've thrown an exception!" + with + | :? SourceException as exn -> + printfn "\"%s\"" exn.Message + Assert.AreEqual(msg, exn.Reason.Trim()) + +let dispenserParameterIndexer() = + let dict = Dictionary() + let mutable last = -1 + { new IParameterIndexer with + member __.ParameterIndex(par) = + let succ, value = dict.TryGetValue(par) + if succ then value + else + last <- last + 1 + dict.[par] <- last + last + } + diff --git a/Rezoom.SQL.Test/Rezoom.SQL.Test.fsproj b/Rezoom.SQL.Test/Rezoom.SQL.Test.fsproj new file mode 100644 index 0000000..055bada --- /dev/null +++ b/Rezoom.SQL.Test/Rezoom.SQL.Test.fsproj @@ -0,0 +1,123 @@ + + + + + Debug + AnyCPU + 2.0 + aa699897-f692-4ed0-9865-98b6b4c713db + Library + Rezoom.SQL.Test + Rezoom.SQL.Test + v4.6 + 4.4.0.0 + true + Rezoom.SQL.Test + + + true + full + false + false + bin\Debug\ + DEBUG;TRACE + 3 + bin\Debug\Rezoom.SQL.Test.XML + + + pdbonly + true + true + bin\Release\ + TRACE + 3 + bin\Release\Rezoom.SQL.Test.XML + + + 11 + + + + + $(MSBuildExtensionsPath32)\..\Microsoft SDKs\F#\3.0\Framework\v4.0\Microsoft.FSharp.Targets + + + + + $(MSBuildExtensionsPath32)\Microsoft\VisualStudio\v$(VisualStudioVersion)\FSharp\Microsoft.FSharp.Targets + + + + + + + + + + + + + + + + + + + + + + + + + + + + Rezoom.SQL.Compiler + {87fcd04a-1f90-4d53-a428-cf5f5c532a22} + True + + + Rezoom + {d98acbeb-a039-4340-a7c5-6ed2b677268b} + True + + + ..\packages\FSharp.Core.4.0.0.1\lib\net40\FSharp.Core.dll + True + + + ..\packages\FsUnit.2.3.2\lib\net45\FsUnit.NUnit.dll + True + + + ..\packages\LicenseToCIL.0.2.2\lib\net46\LicenseToCIL.dll + True + + + + ..\packages\FsUnit.2.3.2\lib\net45\NHamcrest.dll + True + + + ..\packages\NUnit.3.5.0\lib\net45\nunit.framework.dll + True + + + + + + + + Rezoom.SQL.Mapping + {6b6a06c5-157a-4fe3-8b4c-2a1ae6a15333} + True + + + + \ No newline at end of file diff --git a/Rezoom.SQL.Test/TestAggregateErrors.fs b/Rezoom.SQL.Test/TestAggregateErrors.fs new file mode 100644 index 0000000..0616fd0 --- /dev/null +++ b/Rezoom.SQL.Test/TestAggregateErrors.fs @@ -0,0 +1,30 @@ +module Rezoom.SQL.Test.TestAggregateErrors +open System +open NUnit.Framework +open FsUnit +open Rezoom.SQL.Compiler +open Rezoom.SQL.Mapping + +[] +let ``aggregates without group must not be found with non-aggregates`` () = + expectError Error.columnNotAggregated + """ + select sum(Id) as Sum, Id from Users + """ + +[] +let ``aggregates with group by must not contain non-grouped column references`` () = + expectError Error.columnNotGroupedBy + """ + select Id, Name + from Users + group by Id + """ + +[] +let ``aggregates may not appear in where clause`` () = + expectError Error.aggregateInWhereClause + """ + select count(*) as x from Users + where count(*) > 0 + """ \ No newline at end of file diff --git a/Rezoom.SQL.Test/TestBlueprints.fs b/Rezoom.SQL.Test/TestBlueprints.fs new file mode 100644 index 0000000..cfde573 --- /dev/null +++ b/Rezoom.SQL.Test/TestBlueprints.fs @@ -0,0 +1,105 @@ +module Rezoom.SQL.Test.Blueprints +open NUnit.Framework +open FsUnit +open Rezoom.SQL.Mapping +open System + +type Folder = + { + Id : int + ParentFolder : Folder + ChildFolders : Folder list + } + +[] +let ``folder blueprint makes sense`` () = + let blue = Blueprint.ofType typeof + match blue.Cardinality with + | One { Shape = Composite folder } -> + match folder.Columns.["ParentFolder"].ReverseRelationship.Value with + | None -> failwith "No reverse relationship for parent folder" + | Some childFolders -> + Assert.IsTrue("ChildFolders".Equals(childFolders.Name, StringComparison.OrdinalIgnoreCase)) + Assert.IsTrue(obj.ReferenceEquals(childFolders, folder.Columns.["ChildFolders"])) + + match folder.Columns.["ChildFolders"].ReverseRelationship.Value with + | None -> failwith "No reverse relationship for child folders" + | Some parent -> + Assert.IsTrue("ParentFolder".Equals(parent.Name, StringComparison.OrdinalIgnoreCase)) + | _ -> failwith "Wrong cardinality/shape" + +type UserFriendMap = + { + Friend1 : User + Friend2 : User + } + +and User = + { + Id : int + Friend1Maps : UserFriendMap list + Friend2Maps : UserFriendMap list + } + +[] +let ``user blueprint makes sense`` () = + let blue = Blueprint.ofType typeof + match blue.Cardinality with + | One { Shape = Composite user } -> + match user.Columns.["Friend1Maps"].ReverseRelationship.Value with + | None -> failwith "No reverse relationship for friend1maps" + | Some friend1Maps -> + Assert.IsTrue("Friend1".Equals(friend1Maps.Name, StringComparison.OrdinalIgnoreCase)) + match user.Columns.["Friend2Maps"].ReverseRelationship.Value with + | None -> failwith "No reverse relationship for friend2maps" + | Some friend2Maps -> + Assert.IsTrue("Friend2".Equals(friend2Maps.Name, StringComparison.OrdinalIgnoreCase)) + | _ -> failwith "Wrong cardinality/shape" + +[] +let ``friend map blueprint makes sense`` () = + let blue = Blueprint.ofType typeof + match blue.Cardinality with + | One { Shape = Composite friendMap } -> + match friendMap.Columns.["Friend1"].ReverseRelationship.Value with + | None -> failwith "No reverse relationship for friend1" + | Some friend1Maps -> + Assert.IsTrue("Friend1Maps".Equals(friend1Maps.Name, StringComparison.OrdinalIgnoreCase)) + match friendMap.Columns.["Friend2"].ReverseRelationship.Value with + | None -> failwith "No reverse relationship for friend1" + | Some friend2Maps -> + Assert.IsTrue("Friend2Maps".Equals(friend2Maps.Name, StringComparison.OrdinalIgnoreCase)) + | _ -> failwith "Wrong cardinality/shape" + +type Foo = + { + FooId : int + ChildBars : Bar array + } +and Bar = + { + BarId : int + ParentFoo : Foo + } + +[] +let ``foo blueprint makes sense`` () = + let blue = Blueprint.ofType typeof + match blue.Cardinality with + | One { Shape = Composite fooMap } -> + match fooMap.Columns.["ChildBars"].ReverseRelationship.Value with + | None -> failwith "No reverse relationship for ChildBars" + | Some parentFoo -> + Assert.IsTrue("ParentFoo".Equals(parentFoo.Name, StringComparison.OrdinalIgnoreCase)) + | _ -> failwith "Wrong cardinality/shape" + +[] +let ``bar blueprint makes sense`` () = + let blue = Blueprint.ofType typeof + match blue.Cardinality with + | One { Shape = Composite barMap } -> + match barMap.Columns.["ParentFoo"].ReverseRelationship.Value with + | None -> failwith "No reverse relationship for ParentFoo" + | Some childBars -> + Assert.IsTrue("ChildBars".Equals(childBars.Name, StringComparison.OrdinalIgnoreCase)) + | _ -> failwith "Wrong cardinality/shape" \ No newline at end of file diff --git a/Rezoom.SQL.Test/TestCompositeReaders.fs b/Rezoom.SQL.Test/TestCompositeReaders.fs new file mode 100644 index 0000000..fef616c --- /dev/null +++ b/Rezoom.SQL.Test/TestCompositeReaders.fs @@ -0,0 +1,208 @@ +module Rezoom.SQL.Test.CompositeReaders +open NUnit.Framework +open FsUnit +open Rezoom.SQL.Mapping +open Rezoom.SQL.Mapping.CodeGeneration +open System +open System.ComponentModel.DataAnnotations + +type User = + { + UserId : int + Name : string + } + +type Folder = + { + FolderId : int + Children : Folder array + } + +type Person = + { + PersonId : int + Name : string + Parent : Person + } + +type CompositeKeyType = + { + [] + FooId : int + [] + BarId : int + MapName : string + } + +type Employee = + { Employer : Person option + EmployeeId : int + Name : string + } + +[] +let ``read nothing`` () = + let colMap = + [| + "UserId", ColumnType.Int32 + "Name", ColumnType.String + |] |> ColumnMap.Parse + let reader = ReaderTemplate.Template().CreateReader() + reader.ProcessColumns(colMap) + let users = reader.ToEntity() + Assert.AreEqual([||], users) + +[] +let ``read user`` () = + let colMap = + [| + "UserId", ColumnType.Int32 + "Name", ColumnType.String + |] |> ColumnMap.Parse + let row = ObjectRow(1, "jim") + let reader = ReaderTemplate.Template().CreateReader() + reader.ProcessColumns(colMap) + reader.Read(row) + let user = reader.ToEntity() + Assert.IsNotNull(user) + Assert.AreEqual(1, user.UserId) + Assert.AreEqual("jim", user.Name) + +[] +let ``read many users`` () = + let colMap = + [| + "UserId", ColumnType.Int32 + "Name", ColumnType.String + |] |> ColumnMap.Parse + let reader = ReaderTemplate.Template().CreateReader() + reader.ProcessColumns(colMap) + reader.Read(ObjectRow(1, "jim")) + reader.Read(ObjectRow(1, "jim")) + reader.Read(ObjectRow(2, "jerry")) + let users = reader.ToEntity() + Assert.AreEqual( + [ + { UserId = 1; Name = "jim" } + { UserId = 2; Name = "jerry" } + ], + users) + +[] +let ``read employee (optional nav)`` () = + let colMap = + [| + "EmployeeId", ColumnType.Int32 + "Name", ColumnType.String + "Employer$PersonId", ColumnType.Int32 + "Employer$Name", ColumnType.String + |] |> ColumnMap.Parse + let row = ObjectRow(1, "jim", 2, "michael") + let reader = ReaderTemplate.Template().CreateReader() + reader.ProcessColumns(colMap) + reader.Read(row) + let jim = reader.ToEntity() + Assert.IsNotNull(jim) + Assert.AreEqual(1, jim.EmployeeId) + Assert.AreEqual("jim", jim.Name) + match jim.Employer with + | None -> failwith "shouldn't be None" + | Some michael -> + Assert.AreEqual(2, michael.PersonId) + Assert.AreEqual("michael", michael.Name) + +[] +let ``read folder 1 level deep`` () = + let colMap = + [| + "FolderId", ColumnType.Int32 + |] |> ColumnMap.Parse + let reader = ReaderTemplate.Template().CreateReader() + reader.ProcessColumns(colMap) + reader.Read(ObjectRow(1)) + let folder = reader.ToEntity() + Assert.IsNotNull(folder) + Assert.AreEqual(1, folder.FolderId) + Assert.AreEqual(0, folder.Children.Length) + +[] +let ``read folder 2 levels deep`` () = + let colMap = + [| + "FolderId", ColumnType.Int32 + "Children.FolderId", ColumnType.Int32 + |] |> ColumnMap.Parse + let reader = ReaderTemplate.Template().CreateReader() + reader.ProcessColumns(colMap) + reader.Read(ObjectRow(1, 2)) + reader.Read(ObjectRow(1, 3)) + let folder = reader.ToEntity() + Assert.IsNotNull(folder) + Assert.AreEqual(1, folder.FolderId) + Assert.AreEqual(2, folder.Children.Length) + Assert.AreEqual(2, folder.Children.[0].FolderId) + Assert.AreEqual(3, folder.Children.[1].FolderId) + Assert.AreEqual(0, folder.Children.[0].Children.Length) + Assert.AreEqual(0, folder.Children.[1].Children.Length) + +[] +let ``read person 1 level deep`` () = + let colMap = + [| + "PersonId", ColumnType.Int32 + "Name", ColumnType.String + |] |> ColumnMap.Parse + let reader = ReaderTemplate.Template().CreateReader() + reader.ProcessColumns(colMap) + reader.Read(ObjectRow(1, "ben")) + let person = reader.ToEntity() + Assert.IsNotNull(person) + Assert.AreEqual(1, person.PersonId) + Assert.AreEqual("ben", person.Name) + Assert.IsNull(person.Parent) + +[] +let ``read person 2 levels deep`` () = + let colMap = + [| + "PersonId", ColumnType.Int32 + "Name", ColumnType.String + "Parent.PersonId", ColumnType.Int32 + "Parent.Name", ColumnType.String + |] |> ColumnMap.Parse + let reader = ReaderTemplate.Template().CreateReader() + reader.ProcessColumns(colMap) + reader.Read(ObjectRow(1, "ben", 2, "pat")) + let person = reader.ToEntity() + Assert.IsNotNull(person) + Assert.AreEqual(1, person.PersonId) + Assert.AreEqual("ben", person.Name) + Assert.IsNotNull(person.Parent) + Assert.AreEqual(2, person.Parent.PersonId) + Assert.AreEqual("pat", person.Parent.Name) + +[] +let ``read objects with composite keys`` () = + let colMap = + [| + "FooId", ColumnType.Int32 + "BarId", ColumnType.Int32 + "MapName", ColumnType.String + |] |> ColumnMap.Parse + let reader = ReaderTemplate.Template().CreateReader() + reader.ProcessColumns(colMap) + reader.Read(ObjectRow(1, 1, "a")) + reader.Read(ObjectRow(1, 1, "b")) // should be ignored + reader.Read(ObjectRow(1, 2, "c")) + reader.Read(ObjectRow(2, 1, "d")) + reader.Read(ObjectRow(2, 2, "e")) + let composites = reader.ToEntity() + Assert.AreEqual + ([ + { FooId = 1; BarId = 1; MapName = "a" } + { FooId = 1; BarId = 2; MapName = "c" } + { FooId = 2; BarId = 1; MapName = "d" } + { FooId = 2; BarId = 2; MapName = "e" } + ], composites) + + \ No newline at end of file diff --git a/Rezoom.SQL.Test/TestManyPrimitives.fs b/Rezoom.SQL.Test/TestManyPrimitives.fs new file mode 100644 index 0000000..595f7c5 --- /dev/null +++ b/Rezoom.SQL.Test/TestManyPrimitives.fs @@ -0,0 +1,102 @@ +module Rezoom.SQL.Test.ManyPrimitives +open NUnit.Framework +open FsUnit +open Rezoom.SQL.Mapping +open Rezoom.SQL.Mapping.CodeGeneration +open System +open System.Collections.Generic + +type Friend = + { Id : int + Name : string + Aliases : string array + } + +[] +let ``read friend`` () = + let colMap = + [| "Id", ColumnType.Int32 + "Name", ColumnType.String + "Aliases", ColumnType.String + |] |> ColumnMap.Parse + let rows = + [ ObjectRow(3, "Robert", "Bob") + ObjectRow(3, "Robert", "Bobby") + ObjectRow(3, "Robert", "Rob") + ObjectRow(3, "Robert", "Robby") + ] + let reader = ReaderTemplate.Template().CreateReader() + reader.ProcessColumns(colMap) + for row in rows do + reader.Read(row) + let friend = reader.ToEntity() + Assert.IsNotNull(friend) + Assert.AreEqual(3, friend.Id) + Assert.AreEqual("Robert", friend.Name) + Assert.AreEqual(4, friend.Aliases.Length) + Assert.IsTrue([| "Bob"; "Bobby"; "Rob"; "Robby" |] = friend.Aliases) + + +type StringPair = // notice no key properties + { Left : string + Right : string + } + +[] +let ``read string pairs`` () = + let colMap = + [| "Left", ColumnType.String + "Right", ColumnType.String + |] |> ColumnMap.Parse + let rows = + [ ObjectRow("a", "1") + ObjectRow("b", "2") + ObjectRow("b", "2") // duplicate should appear in results + ObjectRow("a", "1") + ] + let reader = ReaderTemplate.Template().CreateReader() + reader.ProcessColumns(colMap) + for row in rows do + reader.Read(row) + let pairs = reader.ToEntity() + Assert.AreEqual + ( [| { Left = "a"; Right = "1" } + { Left = "b"; Right = "2" } + { Left = "b"; Right = "2" } + { Left = "a"; Right = "1" } + |] + , pairs + ) + +[] +type IgnoredIds = + { [] + Le : string + Ri : string + } + +[] +let ``ignored ids`` () = + let colMap = + [| "Le", ColumnType.String + "Ri", ColumnType.String + |] |> ColumnMap.Parse + let rows = + [ ObjectRow("a", "1") + ObjectRow("b", "2") + ObjectRow("b", "2") // duplicate should appear in results + ObjectRow("a", "1") + ] + let reader = ReaderTemplate.Template().CreateReader() + reader.ProcessColumns(colMap) + for row in rows do + reader.Read(row) + let pairs = reader.ToEntity() + Assert.AreEqual + ( [| { Le = "a"; Ri = "1" } + { Le = "b"; Ri = "2" } + { Le = "b"; Ri = "2" } + { Le = "a"; Ri = "1" } + |] + , pairs |> Array.ofSeq + ) \ No newline at end of file diff --git a/Rezoom.SQL.Test/TestModel.fs b/Rezoom.SQL.Test/TestModel.fs new file mode 100644 index 0000000..50f8a4b --- /dev/null +++ b/Rezoom.SQL.Test/TestModel.fs @@ -0,0 +1,10 @@ +module Rezoom.SQL.Test.TestModel +open NUnit.Framework +open FsUnit +open Rezoom.SQL + +[] +let ``model 2 loads`` () = + let model = userModel2() + let schema = model.Model.Schemas.[model.Model.DefaultSchema] + Assert.AreEqual(4, schema.Objects.Count) \ No newline at end of file diff --git a/Rezoom.SQL.Test/TestNavProperties.fs b/Rezoom.SQL.Test/TestNavProperties.fs new file mode 100644 index 0000000..591c368 --- /dev/null +++ b/Rezoom.SQL.Test/TestNavProperties.fs @@ -0,0 +1,109 @@ +module Rezoom.SQL.Test.TestNavProperties +open System +open NUnit.Framework +open FsUnit +open Rezoom.SQL.Compiler +open Rezoom.SQL.Mapping + +let columns (sql : string) expected = + let userModel = userModel1() + let parsed = CommandEffect.OfSQL(userModel.Model, "anonymous", sql) + let sets = parsed.ResultSets() |> Seq.toArray + if sets.Length <> 1 then failwith "expected 1 result set" + let cols = sets.[0].Columns |> Seq.map (fun c -> c.ColumnName.Value, c.Expr.Info.Type.ToString()) |> Seq.toList + printfn "%A" cols + Assert.AreEqual(expected, cols) + +[] +let ``1 user many groups`` () = + columns + """ + select u.Id, u.Name, many Groups(g.Id, g.Name) + from Users u + join UserGroupMaps ugm on ugm.UserId = u.Id + join Groups g on g.Id = ugm.GroupId + """ + [ "Id", "INT" + "Name", "STRING?" + "Groups*$Id", "INT" + "Groups*$Name", "STRING?" + ] + +[] +let ``1 user 1 group`` () = + columns + """ + select u.Id, u.Name, one Group(g.Id, g.Name) + from Users u + join UserGroupMaps ugm on ugm.UserId = u.Id + join Groups g on g.Id = ugm.GroupId + """ + [ "Id", "INT" + "Name", "STRING?" + "Group$Id", "INT" + "Group$Name", "STRING?" + ] + +[] +let ``1 user many groups left join no nav`` () = + columns + """ + select u.Id, u.Name, g.Id as GroupId, g.Name as GroupName + from Users u + left join UserGroupMaps ugm on ugm.UserId = u.Id + left join Groups g on g.Id = ugm.GroupId + """ + [ "Id", "INT" + "Name", "STRING?" + "GroupId", "INT?" + "GroupName", "STRING?" + ] + +[] +let ``1 user many groups left join nav`` () = + columns + """ + select u.Id, u.Name, many Groups(g.Id, g.Name) + from Users u + left join UserGroupMaps ugm on ugm.UserId = u.Id + left join Groups g on g.Id = ugm.GroupId + """ + [ "Id", "INT" + "Name", "STRING?" + "Groups*$Id", "INT" + "Groups*$Name", "STRING?" + ] + +[] +let ``1 user many maps many groups left join nav`` () = + columns + """ + select u.Id, u.Name, many Maps(ugm.UserId, ugm.GroupId, one Group(g.Id, g.Name)) + from Users u + left join UserGroupMaps ugm on ugm.UserId = u.Id + left join Groups g on g.Id = ugm.GroupId + """ + [ "Id", "INT" + "Name", "STRING?" + "Maps*$UserId", "INT" + "Maps*$GroupId", "INT" + "Maps*$Group$Id", "INT" + "Maps*$Group$Name", "STRING?" + ] + +[] +let ``1 user many maps many groups left join nav/nonav`` () = + columns + """ + select u.Id, u.Name, many Maps(ugm.UserId, ugm.GroupId, g.Id as GroupGroupId, g.Name as GroupGroupName) + from Users u + left join UserGroupMaps ugm on ugm.UserId = u.Id + left join Groups g on g.Id = ugm.GroupId + """ + [ "Id", "INT" + "Name", "STRING?" + "Maps*$UserId", "INT" + "Maps*$GroupId", "INT" + "Maps*$GroupGroupId", "INT?" // note that this is back to nullable now + "Maps*$GroupGroupName", "STRING?" + ] \ No newline at end of file diff --git a/Rezoom.SQL.Test/TestNullInference.fs b/Rezoom.SQL.Test/TestNullInference.fs new file mode 100644 index 0000000..62e831e --- /dev/null +++ b/Rezoom.SQL.Test/TestNullInference.fs @@ -0,0 +1,153 @@ +module Rezoom.SQL.Test.TestNullInference +open System +open NUnit.Framework +open FsUnit +open Rezoom.SQL.Compiler +open Rezoom.SQL.Mapping + +let expect (sql : string) expectedColumns expectedParams = + let userModel = userModel1() + let parsed = CommandEffect.OfSQL(userModel.Model, "anonymous", sql) + let sets = parsed.ResultSets() |> Seq.toArray + if sets.Length <> 1 then failwith "expected 1 result set" + let cols = sets.[0].Columns |> Seq.map (fun c -> c.ColumnName.Value, c.Expr.Info.Type.ToString()) |> Seq.toList + printfn "%A" cols + Assert.AreEqual(expectedColumns, cols) + let pars = + parsed.Parameters + |> Seq.map (fun (NamedParameter name, ty) -> name.Value, ty.ToString()) + |> Seq.toList + printfn "%A" pars + Assert.AreEqual(expectedParams, pars) + +[] +let ``coalesce forces all but last arg nullable`` () = + expect + """ + select coalesce(@a, @b, @c, @d) as c + """ + [ "c", "" + ] + [ for p in "abc" -> (string p, "?") + yield "d", "" + ] + +[] +let ``coalesce(a + b, 1)`` () = + expect + """ + select coalesce(@a + @b, 1) as c + """ + [ "c", "" + ] + [ "a", "?" + "b", "?" + ] + +[] +let ``coalesce(a + b, null)`` () = + expect + """ + select coalesce(@a + @b, null) as c + """ + [ "c", "?" + ] + [ "a", "?" + "b", "?" + ] + +[] +let ``coalesce(nullable(a) + b, 1)`` () = + expect + """ + select coalesce(nullable(@a) + @b, 1) as c + """ + [ "c", "" + ] + [ "a", "?" + "b", "" + ] + +[] +let ``coalesce(a + nullable(b), 1)`` () = + expect + """ + select coalesce(@a + nullable(@b), 1) as c + """ + [ "c", "" + ] + [ "a", "" + "b", "?" + ] + +[] +let ``case nullable`` () = + expect + """ + select case when 1=1 then 1 else null end as c + """ + [ "c", "?" + ] [] + +[] +let ``case not nullable`` () = + expect + """ + select case when null then 1 else 0 end as c + """ + [ "c", "" + ] [] + +[] +let ``case not handled means null`` () = + expect + """ + select case when 1=0 then 1 end as c + """ + [ + "c", "?" + ] [] + +[] +let ``insert into nullable column with parameter should be nullable`` () = + expect + """ + insert into Users(Id, Name) + values (@x, @y); + select 0 as ignore; + """ + [ "ignore", "" + ] + [ "x", "INT" + "y", "STRING?" + ] + +[] +let ``insert into nullable column with parameter from select should be nullable`` () = + expect + """ + insert into Users(Id, Name) + select @x, @y; + select 0 as ignore; + """ + [ "ignore", "" + ] + [ "x", "INT" + "y", "STRING?" + ] + +[] +let ``update into nullable column with parameter should be nullable`` () = + expect + """ + update Users + set Id = @x + , Name = @y + where true; + select 0 as ignore; + """ + [ "ignore", "" + ] + [ "x", "INT" + "y", "STRING?" + ] \ No newline at end of file diff --git a/Rezoom.SQL.Test/TestParserErrors.fs b/Rezoom.SQL.Test/TestParserErrors.fs new file mode 100644 index 0000000..b9c0bea --- /dev/null +++ b/Rezoom.SQL.Test/TestParserErrors.fs @@ -0,0 +1,16 @@ +module Rezoom.SQL.Test.TestParserErrors +open System +open NUnit.Framework +open FsUnit +open Rezoom.SQL +open Rezoom.SQL.Mapping + +[] +let ``invalid CTE`` () = + expectError "SQ000: Expecting: whitespace, ')', '--' or '/*'" + """ + with cte ( as + select * from Users u + ) + select * from cte + """ \ No newline at end of file diff --git a/Rezoom.SQL.Test/TestPrimitiveReaders.fs b/Rezoom.SQL.Test/TestPrimitiveReaders.fs new file mode 100644 index 0000000..091fa97 --- /dev/null +++ b/Rezoom.SQL.Test/TestPrimitiveReaders.fs @@ -0,0 +1,191 @@ +module Rezoom.SQL.Test.PrimitiveReaders +open NUnit.Framework +open FsUnit +open Rezoom.SQL.Mapping +open Rezoom.SQL.Mapping.CodeGeneration +open System + +type Enum16 = + | One16 = 1s + | Two16 = 32767s + +type Enum32U = + | One32U = 1u + | Two32U = 4294967295u + +type Enum64 = + | One64 = 1L + | Two64 = 9223372036854775807L + +type Enum64U = + | One64U = 1UL + | Two64U = 18446744073709551615UL + +let testXCore inRow (expected : 'a) ctype = + let colMap = + [| + Guid.NewGuid().ToString("N"), ctype + |] |> ColumnMap.Parse + let row = ObjectRow([| box <| inRow expected |]) + let reader = ReaderTemplate<'a>.Template().CreateReader() + reader.ProcessColumns(colMap) + reader.Read(row) + let mat = reader.ToEntity() + Assert.AreEqual(expected, mat) +let testCore expected ctype = testXCore id expected ctype +let testRef (expected : 'a) ctype = + testCore expected ctype + testCore expected ColumnType.Object + testCore (null : 'a) ctype + testCore (null : 'a) ColumnType.Object + testCore (None : 'a option) ColumnType.Object + testXCore Option.get (Some expected) ColumnType.Object +let test (expected : 'a) ctype = + testCore expected ctype + testCore expected ColumnType.Object + testCore (Nullable<'a>(expected)) ctype + testCore (Nullable<'a>(expected)) ColumnType.Object + testCore (Nullable<'a>()) ctype + testCore (Nullable<'a>()) ColumnType.Object + testCore (None : 'a option) ColumnType.Object + testXCore Option.get (Some expected) ColumnType.Object + +[] +let ``read string`` () = + testRef "thirteen" ColumnType.String + +[] +let ``read byte array`` () = + testRef [|0uy;1uy;2uy;3uy|] ColumnType.Object + +[] +let ``read int32`` () = + test 13 ColumnType.Int32 +[] +let ``read int16`` () = + test 13s ColumnType.Int16 +[] +let ``read int64`` () = + test 13L ColumnType.Int64 + +[] +let ``read uint32`` () = + test 13u ColumnType.UInt32 +[] +let ``read uint16`` () = + test 13us ColumnType.UInt16 +[] +let ``read uint64`` () = + test 13UL ColumnType.UInt64 + +[] +let ``read byte`` () = + test 13uy ColumnType.Byte +[] +let ``read sbyte`` () = + test 13y ColumnType.SByte + +[] +let ``read single`` () = + test 13.5f ColumnType.Single + +[] +let ``read double`` () = + test 13.5 ColumnType.Double + +[] +let ``read decimal`` () = + test 13.5m ColumnType.Decimal + +[] +let ``read DateTime`` () = + test DateTime.UtcNow ColumnType.DateTime + +[] +let ``read DateTimeKind enum (via TryParser)`` () = + let mutable e = DateTimeKind.Unspecified + let succ = PrimitiveConverters.EnumTryParser.TryParse("Local", &e) + Assert.IsTrue(succ) + Assert.AreEqual(DateTimeKind.Local, e) + let succ = PrimitiveConverters.EnumTryParser.TryParse("Test", &e) + Assert.IsFalse(succ) + +[] +let ``read Enum16 (via TryParser)`` () = + let mutable e = Enum16.One16 + let succ = PrimitiveConverters.EnumTryParser.TryParse("Two16", &e) + Assert.IsTrue(succ) + Assert.AreEqual(Enum16.Two16, e) + +[] +let ``read Enum32U (via TryParser)`` () = + let mutable e = Enum32U.One32U + let succ = PrimitiveConverters.EnumTryParser.TryParse("Two32U", &e) + Assert.IsTrue(succ) + Assert.AreEqual(Enum32U.Two32U, e) + +[] +let ``read Enum64 (via TryParser)`` () = + let mutable e = Enum64.One64 + let succ = PrimitiveConverters.EnumTryParser.TryParse("Two64", &e) + Assert.IsTrue(succ) + Assert.AreEqual(Enum64.Two64, e) + +[] +let ``read Enum64U (via TryParser)`` () = + let mutable e = Enum64U.One64U + let succ = PrimitiveConverters.EnumTryParser.TryParse("Two64U", &e) + Assert.IsTrue(succ) + Assert.AreEqual(Enum64U.Two64U, e) + +[] +let ``read DateTimeKind`` () = + test DateTimeKind.Local ColumnType.Int32 + +[] +let ``read enums from string``() = + let happy (expected : 'a) (str : string) = + let colMap = + [| + Guid.NewGuid().ToString("N"), ColumnType.String + |] |> ColumnMap.Parse + let row = ObjectRow(str :> obj) + let reader = ReaderTemplate<'a>.Template().CreateReader() + reader.ProcessColumns(colMap) + reader.Read(row) + let mat = reader.ToEntity() + Assert.AreEqual(expected :> obj, mat :> obj) + + let reader = ReaderTemplate<'a Nullable>.Template().CreateReader() + reader.ProcessColumns(colMap) + reader.Read(row) + let mat = reader.ToEntity() + Assert.AreEqual(expected :> obj, mat :> obj) + + let reader = ReaderTemplate<'a Nullable>.Template().CreateReader() + reader.ProcessColumns(colMap) + reader.Read(ObjectRow(null : obj)) + let mat = reader.ToEntity() + Assert.IsNull(mat) + let sad (example : 'a) (str : string) = + let colMap = + [| + Guid.NewGuid().ToString("N"), ColumnType.String + |] |> ColumnMap.Parse + let row = ObjectRow(str :> obj) + let reader = ReaderTemplate<'a>.Template().CreateReader() + reader.ProcessColumns(colMap) + Assert.IsTrue( + try + reader.Read(row) + ignore <| reader.ToEntity() + false + with + | exn -> true) + + happy DateTimeKind.Local "Local" + happy DateTimeKind.Utc "Utc" + happy StringComparison.InvariantCultureIgnoreCase "InvariantCultureIgnoreCase" + happy Enum64U.Two64U "Two64U" + sad DateTimeKind.Unspecified "Something" + sad StringComparison.CurrentCulture "" \ No newline at end of file diff --git a/Rezoom.SQL.Test/TestQueryParents.fs b/Rezoom.SQL.Test/TestQueryParents.fs new file mode 100644 index 0000000..08a7c7e --- /dev/null +++ b/Rezoom.SQL.Test/TestQueryParents.fs @@ -0,0 +1,231 @@ +module Rezoom.SQL.Test.QueryParents +open NUnit.Framework +open FsUnit +open Rezoom.SQL.Mapping +open Rezoom.SQL.Mapping.CodeGeneration + +type RecordFolder = + { + Id : int + Name : string + ChildFolders : RecordFolder array + ParentFolder : RecordFolder + } + +[] +type ClassFolder() = + member val Id = 0 with get, set + member val Name = "" with get, set + member val ChildFolders = null : ClassFolder array with get, set + member val ParentFolder = null : ClassFolder with get, set + +[] +[] +let ``self-referential record equality stack overflows`` () = + let colMap = + [| + "Id", ColumnType.Int32 + "Name", ColumnType.String + "ChildFolders$Id", ColumnType.Int32 + "ChildFolders$Name", ColumnType.String + "ChildFolders$ChildFolders$Id", ColumnType.Int32 + "ChildFolders$ChildFolders$Name", ColumnType.String + |] |> ColumnMap.Parse + let reader = ReaderTemplate.Template().CreateReader() + reader.ProcessColumns(colMap) + for objectRow in + [| + ObjectRow(1, "A", 2, "A.1", 7, "A.1.1") + ObjectRow(1, "A", 3, "A.2", 8, "A.2.1") + ObjectRow(4, "B", 5, "B.1", 9, "B.1.1") + ObjectRow(4, "B", 6, "B.2", 10, "B.2.1") + |] do reader.Read(objectRow) + let folders1 = reader.ToEntity() + + let reader = ReaderTemplate.Template().CreateReader() + reader.ProcessColumns(colMap) + for objectRow in + [| + ObjectRow(1, "A", 2, "A.1", 7, "A.1.1") + ObjectRow(1, "A", 3, "A.2", 8, "A.2.1") + ObjectRow(4, "B", 5, "B.1", 9, "B.1.1") + ObjectRow(4, "B", 6, "B.2", 10, "B.2.1") + |] do reader.Read(objectRow) + let folders2 = reader.ToEntity() + let bottom1 = folders1.Head.ChildFolders.[0].ChildFolders.[0] + let bottom2 = folders2.Head.ChildFolders.[0].ChildFolders.[0] + Assert.AreEqual(bottom1, bottom2) + // This will stack overflow, because the equality comparison goes: + // are we equal? --> + // ^ are our Ids equal? yes. + // | are our Names equal? yes. + // ^ are our Children equal? yes. + // | are our Parents equal? --> + // ^ are their Ids equal? yes. + // | are their Names equal? yes. + // ^ are their Children equal? -->+ + // | | + // +-<--<--<--<--<--<--<--<--<--<--<--<--+ + + // There is nothing we can do about this other than advise against using records this way. + // This test serves as a handy way of demonstrating the problem. + +[] +let ``read record folders with parent backreferences`` () = + let colMap = + [| + "Id", ColumnType.Int32 + "Name", ColumnType.String + "ChildFolders$Id", ColumnType.Int32 + "ChildFolders$Name", ColumnType.String + |] |> ColumnMap.Parse + let reader = ReaderTemplate.Template().CreateReader() + reader.ProcessColumns(colMap) + for objectRow in + [| + ObjectRow(1, "A", 2, "A.1") + ObjectRow(1, "A", 3, "A.2") + ObjectRow(4, "B", 5, "B.1") + ObjectRow(4, "B", 6, "B.2") + |] do reader.Read(objectRow) + let folders = reader.ToEntity() + Assert.IsNotNull(folders) + Assert.AreEqual(2, folders.Length) + + Assert.IsNotNull(folders.[0]) + Assert.IsNull(folders.[0].ParentFolder) + Assert.AreEqual(1, folders.[0].Id) + Assert.AreEqual("A", folders.[0].Name) + Assert.IsNotNull(folders.[0].ChildFolders) + Assert.AreEqual(2, folders.[0].ChildFolders.Length) + + Assert.IsNotNull(folders.[0].ChildFolders.[0]) + Assert.AreEqual(2, folders.[0].ChildFolders.[0].Id) + Assert.AreEqual("A.1", folders.[0].ChildFolders.[0].Name) + Assert.AreEqual(0, folders.[0].ChildFolders.[0].ChildFolders.Length) + Assert.IsNotNull(folders.[0].ChildFolders.[0].ParentFolder) + Assert.IsTrue(obj.ReferenceEquals(folders.[0], folders.[0].ChildFolders.[0].ParentFolder)) + + Assert.IsNotNull(folders.[0].ChildFolders.[1]) + Assert.AreEqual(3, folders.[0].ChildFolders.[1].Id) + Assert.AreEqual("A.2", folders.[0].ChildFolders.[1].Name) + Assert.AreEqual(0, folders.[0].ChildFolders.[1].ChildFolders.Length) + Assert.IsNotNull(folders.[0].ChildFolders.[1].ParentFolder) + Assert.IsTrue(obj.ReferenceEquals(folders.[0], folders.[0].ChildFolders.[1].ParentFolder)) + + Assert.IsNotNull(folders.[1]) + Assert.IsNull(folders.[1].ParentFolder) + Assert.AreEqual(4, folders.[1].Id) + Assert.AreEqual("B", folders.[1].Name) + Assert.IsNotNull(folders.[1].ChildFolders) + Assert.AreEqual(2, folders.[1].ChildFolders.Length) + + Assert.IsNotNull(folders.[1].ChildFolders.[0]) + Assert.AreEqual(5, folders.[1].ChildFolders.[0].Id) + Assert.AreEqual("B.1", folders.[1].ChildFolders.[0].Name) + Assert.AreEqual(0, folders.[1].ChildFolders.[0].ChildFolders.Length) + Assert.IsNotNull(folders.[1].ChildFolders.[0].ParentFolder) + Assert.IsTrue(obj.ReferenceEquals(folders.[1], folders.[1].ChildFolders.[0].ParentFolder)) + + Assert.IsNotNull(folders.[1].ChildFolders.[1]) + Assert.AreEqual(6, folders.[1].ChildFolders.[1].Id) + Assert.AreEqual("B.2", folders.[1].ChildFolders.[1].Name) + Assert.AreEqual(0, folders.[1].ChildFolders.[1].ChildFolders.Length) + Assert.IsNotNull(folders.[1].ChildFolders.[1].ParentFolder) + Assert.IsTrue(obj.ReferenceEquals(folders.[1], folders.[1].ChildFolders.[1].ParentFolder)) + +[] +let ``read class folders with parent backreferences`` () = + let colMap = + [| + "Id", ColumnType.Int32 + "Name", ColumnType.String + "ChildFolders$Id", ColumnType.Int32 + "ChildFolders$Name", ColumnType.String + |] |> ColumnMap.Parse + let reader = ReaderTemplate.Template().CreateReader() + reader.ProcessColumns(colMap) + let next = ReaderTemplate.Template().CreateReader() + reader.ImpartKnowledgeToNext(next) + for objectRow in + [| + ObjectRow(1, "A", 2, "A.1") + ObjectRow(1, "A", 3, "A.2") + ObjectRow(4, "B", 5, "B.1") + ObjectRow(4, "B", 6, "B.2") + |] do reader.Read(objectRow) + let folders = reader.ToEntity() + Assert.IsNotNull(folders) + Assert.AreEqual(2, folders.Length) + + Assert.IsNotNull(folders.[0]) + Assert.IsNull(folders.[0].ParentFolder) + Assert.AreEqual(1, folders.[0].Id) + Assert.AreEqual("A", folders.[0].Name) + Assert.IsNotNull(folders.[0].ChildFolders) + Assert.AreEqual(2, folders.[0].ChildFolders.Length) + + Assert.IsNotNull(folders.[0].ChildFolders.[0]) + Assert.AreEqual(2, folders.[0].ChildFolders.[0].Id) + Assert.AreEqual("A.1", folders.[0].ChildFolders.[0].Name) + Assert.AreEqual(0, folders.[0].ChildFolders.[0].ChildFolders.Length) + Assert.IsNotNull(folders.[0].ChildFolders.[0].ParentFolder) + Assert.IsTrue(obj.ReferenceEquals(folders.[0], folders.[0].ChildFolders.[0].ParentFolder)) + + Assert.IsNotNull(folders.[0].ChildFolders.[1]) + Assert.AreEqual(3, folders.[0].ChildFolders.[1].Id) + Assert.AreEqual("A.2", folders.[0].ChildFolders.[1].Name) + Assert.AreEqual(0, folders.[0].ChildFolders.[1].ChildFolders.Length) + Assert.IsNotNull(folders.[0].ChildFolders.[1].ParentFolder) + Assert.IsTrue(obj.ReferenceEquals(folders.[0], folders.[0].ChildFolders.[1].ParentFolder)) + + Assert.IsNotNull(folders.[1]) + Assert.IsNull(folders.[1].ParentFolder) + Assert.AreEqual(4, folders.[1].Id) + Assert.AreEqual("B", folders.[1].Name) + Assert.IsNotNull(folders.[1].ChildFolders) + Assert.AreEqual(2, folders.[1].ChildFolders.Length) + + Assert.IsNotNull(folders.[1].ChildFolders.[0]) + Assert.AreEqual(5, folders.[1].ChildFolders.[0].Id) + Assert.AreEqual("B.1", folders.[1].ChildFolders.[0].Name) + Assert.AreEqual(0, folders.[1].ChildFolders.[0].ChildFolders.Length) + Assert.IsNotNull(folders.[1].ChildFolders.[0].ParentFolder) + Assert.IsTrue(obj.ReferenceEquals(folders.[1], folders.[1].ChildFolders.[0].ParentFolder)) + + Assert.IsNotNull(folders.[1].ChildFolders.[1]) + Assert.AreEqual(6, folders.[1].ChildFolders.[1].Id) + Assert.AreEqual("B.2", folders.[1].ChildFolders.[1].Name) + Assert.AreEqual(0, folders.[1].ChildFolders.[1].ChildFolders.Length) + Assert.IsNotNull(folders.[1].ChildFolders.[1].ParentFolder) + Assert.IsTrue(obj.ReferenceEquals(folders.[1], folders.[1].ChildFolders.[1].ParentFolder)) + +[] +let ``record folder joined parents shouldn't have children populated`` () = + let colMap = + [| + "Id", ColumnType.Int32 + "Name", ColumnType.String + "ParentFolder$Id", ColumnType.Int32 + "ParentFolder$Name", ColumnType.String + |] |> ColumnMap.Parse + let reader = ReaderTemplate.Template().CreateReader() + reader.ProcessColumns(colMap) + for objectRow in + [| + ObjectRow(2, "A.1", 1, "A") + ObjectRow(3, "A.2", 1, "A") + ObjectRow(5, "B.1", 4, "B") + ObjectRow(6, "B.2", 4, "B") + |] do reader.Read(objectRow) + let folders = reader.ToEntity() + Assert.IsNotNull(folders) + Assert.AreEqual(4, folders.Length) + Assert.AreEqual(2, folders.[0].Id) + Assert.AreEqual("A.1", folders.[0].Name) + Assert.AreEqual(0, folders.[0].ChildFolders.Length) + Assert.IsNotNull(folders.[0].ParentFolder) + Assert.AreEqual(1, folders.[0].ParentFolder.Id) + Assert.AreEqual("A", folders.[0].ParentFolder.Name) + Assert.IsNull(folders.[0].ParentFolder.ParentFolder) + Assert.AreEqual(0, folders.[0].ParentFolder.ChildFolders.Length) \ No newline at end of file diff --git a/Rezoom.SQL.Test/TestRoundTrip.fs b/Rezoom.SQL.Test/TestRoundTrip.fs new file mode 100644 index 0000000..2504281 --- /dev/null +++ b/Rezoom.SQL.Test/TestRoundTrip.fs @@ -0,0 +1,201 @@ +module Rezoom.SQL.Test.TestRoundTrip +open System +open NUnit.Framework +open FsUnit +open Rezoom.SQL.Compiler +open Rezoom.SQL.Mapping + +let roundtrip (sql : string) = + let userModel = userModel1() + let parsed = CommandEffect.OfSQL(userModel.Model, "anonymous", sql) + let indexer = { new IParameterIndexer with member __.ParameterIndex(_) = 0 } + let backend = DefaultBackend() :> IBackend + let fragments = backend.ToCommandFragments(indexer, parsed.Statements) + let str = CommandFragment.Stringize(fragments) + Console.WriteLine(str) + let parsedBack = CommandEffect.OfSQL(userModel.Model, "readback", str) + let fragmentsBack = backend.ToCommandFragments(indexer, parsedBack.Statements) + let strBack = CommandFragment.Stringize(fragmentsBack) + Console.WriteLine(String('-', 80)) + Console.WriteLine(strBack) + Assert.AreEqual(str, strBack) + +[] +let ``select`` () = + roundtrip """ + select * from Users u where u.Id = 1 + """ +[] +let ``fancy select`` () = + roundtrip """ + select g.*, u.* + from Users u + left join UserGroupMaps gm on gm.UserId = u.Id + left join Groups g on g.Id = gm.GroupId + where g.Name like '%grp%' escape '%' + """ + +[] +let ``insert`` () = + roundtrip """ + insert into Users(id, name) + values (0, 'ben') + """ + +[] +let ``insert from select`` () = + roundtrip """ + insert into Groups + select * from Groups + """ + +[] +let ``delete`` () = + roundtrip """ + delete from Users where Email like '%earthlink.net' + """ + +[] +let ``drop`` () = + roundtrip """ + drop table main.Users + """ + +[] +let ``create table with column list and fk`` () = + roundtrip """ + create table Foo + ( bar int primary key + , baz float32 + , foreign key (bar, baz) references Users(Email, Name) + ); + """ + +[] +let ``alter table add column`` () = + roundtrip """ + alter table UserGroupMaps + add Tag int null + """ + +[] +let ``alter table rename to`` () = + roundtrip """ + alter table UserGroupMaps rename to UserGroupAssociations + """ + +[] +let ``create temp view`` () = + roundtrip """ + create temp view CoolUsers as select * from Users where name not like '%szany%' + """ + +[] +let ``create temp view with column names`` () = + roundtrip """ + create temp view CoolUsers(id, name) as select 1, '' from users where name not like '%szany%' + """ + +[] +let ``create temp view and select from it`` () = + roundtrip """ + create temp view CoolUsers(id, name) as select 1, '' from users where name not like '%szany%'; + select * from CoolUsers; + """ + +[] +let ``create table with composite PK`` () = + roundtrip """ + create table Maps(UserId int, GroupId int, primary key(UserId, GroupId)) + """ + +[] +let ``many nav property`` () = + roundtrip """ + select u.*, many Groups(g.*) + from Users u + left join UserGroupMaps gm on gm.UserId = u.Id + left join Groups g on g.Id = gm.GroupId + """ + +[] +let ``one nav property`` () = + roundtrip """ + select u.*, one Group(g.*) + from Users u + left join UserGroupMaps gm on gm.UserId = u.Id + left join Groups g on g.Id = gm.GroupId + """ + +[] +let ``date literals`` () = + roundtrip """ + select * + from Users u + where 2016-10-16 > 2015-01-01 + """ + roundtrip """ + select * + from Users u + where 2016-10-16T04:30:31 > 2016-10-16T18:14:19.123 + """ + roundtrip """ + select * + from Users u + where 2016-10-16T04:30:31+01:30 > 2016-10-16T18:14:19.123-04:00 + """ + +[] +let ``join subqueries`` () = + roundtrip """ + select * from + (select u.Id from Users u) us + join + (select g.Id from Groups g) gs + on us.Id = gs.Id + """ + +[] +let ``simple CTE`` () = + roundtrip """ + with + a(x, y) as + ( select Id, 1 from Users ) + select * from a; + """ + +[] +let ``recursive CTE`` () = + roundtrip """ + with recursive + nums(x) as ( + select 1 + union all + select x+1 from nums + limit 1000000 + ) + select x from nums; + """ + +[] +let ``recursive CTE with implicit column names`` () = + roundtrip """ + with recursive + nums as ( + select 1 as myname + union all + select myname+1 from nums + limit 1000000 + ) + select myname from nums; + """ + +[] +let ``table with self-referential constraints`` () = + roundtrip """ + create table Folders + ( Id int primary key autoincrement + , Name string(80) + , ParentId int references Folders(Id) + ); + """ \ No newline at end of file diff --git a/Rezoom.SQL.Test/TestTSQL.fs b/Rezoom.SQL.Test/TestTSQL.fs new file mode 100644 index 0000000..5824547 --- /dev/null +++ b/Rezoom.SQL.Test/TestTSQL.fs @@ -0,0 +1,57 @@ +module Rezoom.SQL.Test.TestTSQL +open System +open NUnit.Framework +open FsUnit +open Rezoom.SQL.Compiler +open Rezoom.SQL.Mapping + +let translate (sql : string) (expectedTSQL : string) = + let userModel = + let userModel = userModel1() + let backend = TSQL.TSQLBackend() :> IBackend + { userModel with + Backend = backend + Model = { userModel.Model with Builtin = backend.InitialModel.Builtin } + } + let parsed = CommandEffect.OfSQL(userModel.Model, "anonymous", sql) + let indexer = { new IParameterIndexer with member __.ParameterIndex(_) = 0 } + let fragments = userModel.Backend.ToCommandFragments(indexer, parsed.Statements) + let str = CommandFragment.Stringize(fragments) + Console.WriteLine(str) + Assert.AreEqual(expectedTSQL, str) + +[] +let ``at at proc translation`` () = + translate + """select datefirst() as d""" + """SELECT @@DATEFIRST AS [d];""" + +[] +let ``datepart translation`` () = + translate + """select dateadd('day', 1, sysutcdatetime()) d""" + """SELECT dateadd(day,1,sysutcdatetime()) AS [d];""" + +[] +let ``bool to first class`` ()= + translate + """select 1 < 0 as b""" + """SELECT CAST((CASE WHEN (1 < 0) THEN 1 ELSE 0 END) AS BIT) AS [b];""" + +[] +let ``first class to bool`` ()= + translate + """select 1 as col from Users where true""" + """SELECT 1 AS [col] FROM [Users] WHERE ((1)<>0);""" + +[] +let ``iif with predicate`` ()= + translate + """select IIF(1 > 0, 'a', 'b') as choice""" + """SELECT IIF((1 > 0),'a','b') AS [choice];""" + +[] +let ``iif with first class value`` () = + translate + """select IIF(false, 'a', 'b') as choice""" + """SELECT IIF(((0)<>0),'a','b') AS [choice];""" diff --git a/Rezoom.SQL.Test/TestTypeErrors.fs b/Rezoom.SQL.Test/TestTypeErrors.fs new file mode 100644 index 0000000..6388cfd --- /dev/null +++ b/Rezoom.SQL.Test/TestTypeErrors.fs @@ -0,0 +1,73 @@ +module Rezoom.SQL.Test.TestTypeErrors +open System +open NUnit.Framework +open FsUnit +open Rezoom.SQL.Compiler +open Rezoom.SQL.Mapping + +[] +let ``incompatible types can't be compared for equality`` () = + expectError (Error.cannotUnify "INT" "STRING") + """ + select g.*, u.* + from Users u + left join UserGroupMaps gm on gm.UserId = u.Id + left join Groups g on g.Id = 'a' + where g.Name like '%grp%' escape '%' + """ + +[] +let ``unioned queries must have the same number of columns`` () = + expectError (Error.expectedKnownColumnCount 2 3) + """ + select 1 a, 2 b, 3 c + union all + select 4, 5 + """ + +[] +let ``updates must set actual columns`` () = + expectError (Error.noSuchColumnToSet "Users" "Nane") + """ + update Users + set Id = 1, Nane = '' + where Id > 5 + """ + +[] +let ``updated column types must match`` () = + expectError (Error.cannotUnify "INT" "STRING") + """ + update Users + set Id = 'five' + """ + +[] +let ``inserted column types must match`` () = + expectError (Error.cannotUnify "INT" "STRING") + """ + insert into Users(Id, Name) values ('one', 'jim') + """ + +[] +let ``inserted columns must exist`` () = + expectError (Error.noSuchColumn "Goober") + """ + insert into Users(Goober, Booger) values ('one', 'jim') + """ + +[] +let ``sum argument must be numeric`` () = + expectError (Error.cannotUnify "" "STRING") + """ + select sum(Name) as Sum from Users + """ + +[] +let ``can't use list-parameter as a scalar result`` () = + expectError (Error.cannotUnify "" "[INT]") + """ + select @p as x + from Users + where Id in @p + """ \ No newline at end of file diff --git a/Rezoom.SQL.Test/TestTypeInference.fs b/Rezoom.SQL.Test/TestTypeInference.fs new file mode 100644 index 0000000..51e25c2 --- /dev/null +++ b/Rezoom.SQL.Test/TestTypeInference.fs @@ -0,0 +1,145 @@ +module Rezoom.SQL.Test.TestTypeInference +open NUnit.Framework +open FsUnit +open Rezoom.SQL.Compiler + +let zeroModel = + { Schemas = + [ Schema.Empty(Name("main")) + Schema.Empty(Name("temp")) + ] |> List.map (fun s -> s.SchemaName, s) |> Map.ofList + DefaultSchema = Name("main") + TemporarySchema = Name("temp") + Builtin = { Functions = Map.empty } + } + +[] +let ``simple select`` () = + let cmd = CommandEffect.OfSQL(zeroModel, "anonymous", @" + create table Users(id int primary key null, name string(128) null, email string(128) null); + select * from Users + ") + Assert.AreEqual(0, cmd.Parameters.Count) + let results = cmd.ResultSets() |> toReadOnlyList + Assert.AreEqual(1, results.Count) + let cs = results.[0].Columns + Assert.IsTrue(cs.[1].Expr.Info.PrimaryKey) + Assert.AreEqual(Name("id"), cs.[1].ColumnName) + Assert.AreEqual({ Nullable = true; Type = IntegerType Integer32 }, cs.[1].Expr.Info.Type) + Assert.IsFalse(cs.[2].Expr.Info.PrimaryKey) + Assert.AreEqual(Name("name"), cs.[2].ColumnName) + Assert.AreEqual({ Nullable = true; Type = StringType }, cs.[2].Expr.Info.Type) + Assert.IsFalse(cs.[0].Expr.Info.PrimaryKey) + Assert.AreEqual(Name("email"), cs.[0].ColumnName) + Assert.AreEqual({ Nullable = true; Type = StringType }, cs.[0].Expr.Info.Type) + +[] +let ``simple select with parameter`` () = + let cmd = CommandEffect.OfSQL(zeroModel, "anonymous", @" + create table Users(id int primary key null, name string(128) null, email string(128) null); + select * from Users u + where u.id = @id + ") + Assert.AreEqual(1, cmd.Parameters.Count) + Assert.AreEqual + ( (NamedParameter (Name("id")), { Nullable = false; Type = IntegerType Integer32 }) + , cmd.Parameters.[0]) + let results = cmd.ResultSets() |> toReadOnlyList + Assert.AreEqual(1, results.Count) + let cs = results.[0].Columns + Assert.IsTrue(cs.[1].Expr.Info.PrimaryKey) + Assert.AreEqual(Name("id"), cs.[1].ColumnName) + Assert.AreEqual({ Nullable = true; Type = IntegerType Integer32 }, cs.[1].Expr.Info.Type) + Assert.IsFalse(cs.[2].Expr.Info.PrimaryKey) + Assert.AreEqual(Name("name"), cs.[2].ColumnName) + Assert.AreEqual({ Nullable = true; Type = StringType }, cs.[2].Expr.Info.Type) + Assert.IsFalse(cs.[0].Expr.Info.PrimaryKey) + Assert.AreEqual(Name("email"), cs.[0].ColumnName) + Assert.AreEqual({ Nullable = true; Type = StringType }, cs.[0].Expr.Info.Type) + +[] +let ``simple select with parameter nullable id`` () = + let cmd = CommandEffect.OfSQL(zeroModel, "anonymous", @" + create table Users(id int primary key null, name string(128) null, email string(128) null); + select * from Users u + where u.id is @id + ") + Assert.AreEqual(1, cmd.Parameters.Count) + Assert.AreEqual + ( (NamedParameter (Name("id")), { Nullable = true; Type = IntegerType Integer32 }) + , cmd.Parameters.[0]) + let results = cmd.ResultSets() |> toReadOnlyList + Assert.AreEqual(1, results.Count) + let cs = results.[0].Columns + Assert.IsTrue(cs.[1].Expr.Info.PrimaryKey) + Assert.AreEqual(Name("id"), cs.[1].ColumnName) + Assert.AreEqual({ Nullable = true; Type = IntegerType Integer32 }, cs.[1].Expr.Info.Type) + Assert.IsFalse(cs.[2].Expr.Info.PrimaryKey) + Assert.AreEqual(Name("name"), cs.[2].ColumnName) + Assert.AreEqual({ Nullable = true; Type = StringType }, cs.[2].Expr.Info.Type) + Assert.IsFalse(cs.[0].Expr.Info.PrimaryKey) + Assert.AreEqual(Name("email"), cs.[0].ColumnName) + Assert.AreEqual({ Nullable = true; Type = StringType }, cs.[0].Expr.Info.Type) + +[] +let ``simple select with parameter not null`` () = + let cmd = + CommandEffect.OfSQL(zeroModel, "anonymous", @" + create table Users(id int primary key, name string(128) null, email string(128) null); + select * from Users u + where u.id = @id + ") + Assert.AreEqual(1, cmd.Parameters.Count) + Assert.AreEqual + ( (NamedParameter (Name("id")), { Nullable = false; Type = IntegerType Integer32 }) + , cmd.Parameters.[0]) + let results = cmd.ResultSets() |> toReadOnlyList + Assert.AreEqual(1, results.Count) + let cs = results.[0].Columns + Assert.IsTrue(cs.[1].Expr.Info.PrimaryKey) + Assert.AreEqual(Name("id"), cs.[1].ColumnName) + Assert.AreEqual({ Nullable = false; Type = IntegerType Integer32 }, cs.[1].Expr.Info.Type) + Assert.IsFalse(cs.[2].Expr.Info.PrimaryKey) + Assert.AreEqual(Name("name"), cs.[2].ColumnName) + Assert.AreEqual({ Nullable = true; Type = StringType }, cs.[2].Expr.Info.Type) + Assert.IsFalse(cs.[0].Expr.Info.PrimaryKey) + Assert.AreEqual(Name("email"), cs.[0].ColumnName) + Assert.AreEqual({ Nullable = true; Type = StringType }, cs.[0].Expr.Info.Type) + +[] +let ``select where id in param`` () = + let cmd = + CommandEffect.OfSQL(zeroModel, "anonymous", @" + create table Users(id int primary key, name string(128), email string(128)); + select * from Users u + where u.id in @id + ") + Assert.AreEqual(1, cmd.Parameters.Count) + +[] +let ``coalesce not null`` () = + let model = userModel1() + let cmd = + CommandEffect.OfSQL(model.Model, "anonymous", @" + select coalesce(u.Name, u.Email, @default) as c + from Users u + where u.id in @id + ") + printfn "%A" cmd.Parameters + Assert.AreEqual(2, cmd.Parameters.Count) + Assert.IsFalse((snd cmd.Parameters.[0]).Nullable) + Assert.IsFalse((snd cmd.Parameters.[1]).Nullable) + +[] +let ``coalesce null`` () = + let model = userModel1() + let cmd = + CommandEffect.OfSQL(model.Model, "anonymous", @" + select coalesce(u.Name, @default, u.Email) as c + from Users u + where u.id in @id + ") + printfn "%A" cmd.Parameters + Assert.AreEqual(2, cmd.Parameters.Count) + Assert.IsTrue((snd cmd.Parameters.[0]).Nullable) + Assert.IsFalse((snd cmd.Parameters.[1]).Nullable) diff --git a/Rezoom.SQL.Test/TestVendorStatements.fs b/Rezoom.SQL.Test/TestVendorStatements.fs new file mode 100644 index 0000000..58b44a7 --- /dev/null +++ b/Rezoom.SQL.Test/TestVendorStatements.fs @@ -0,0 +1,79 @@ +module Rezoom.SQL.Test.TestVendorStatements +open System +open NUnit.Framework +open FsUnit +open Rezoom.SQL.Compiler +open Rezoom.SQL.Mapping + +let vendor (sql : string) expected = + let userModel = userModel1() + let parsed = CommandEffect.OfSQL(userModel.Model, "anonymous", sql) + let indexer = dispenserParameterIndexer() + let fragments = userModel.Backend.ToCommandFragments(indexer, parsed.Statements) |> List.ofSeq + printfn "%A" fragments + if fragments <> expected then + failwith "Mismatch" + +[] +let ``vendor without exprs or imaginary`` () = + vendor """ + vendor sqlite { + this is raw text + } + """ + [ CommandText " + this is raw text + ;" + ] + +[] +let ``vendor without imaginary`` () = + vendor """ + vendor sqlite { + raw text {@param1} more raw {@param2} + } + """ + [ CommandText " + raw text " + Parameter 0 + CommandText " more raw " + Parameter 1 + CommandText " + ;" + ] + +[] +let ``vendor with imaginary`` () = + vendor """ + vendor sqlite { + raw text {@param1} more raw {@param2} + } imagine { + select Id from Users + } + """ + [ CommandText " + raw text " + Parameter 0 + CommandText " more raw " + Parameter 1 + CommandText " + ;" + ] + +[] +let ``vendor with wacky delimiters`` () = + vendor """ + vendor sqlite [:<# + raw text [:<# @param1 #>:] more raw [:<# @param2 #>:] + #>:] imagine [:<# + select Id from Users + #>:] + """ + [ CommandText " + raw text " + Parameter 0 + CommandText " more raw " + Parameter 1 + CommandText " + ;" + ] diff --git a/Rezoom.SQL.Test/app.config b/Rezoom.SQL.Test/app.config new file mode 100644 index 0000000..c130c89 --- /dev/null +++ b/Rezoom.SQL.Test/app.config @@ -0,0 +1,11 @@ + + + + + + + + + + + \ No newline at end of file diff --git a/Rezoom.SQL.Test/packages.config b/Rezoom.SQL.Test/packages.config new file mode 100644 index 0000000..b930637 --- /dev/null +++ b/Rezoom.SQL.Test/packages.config @@ -0,0 +1,7 @@ + + + + + + + \ No newline at end of file diff --git a/Rezoom.SQL.Test/user-model-1/V1.initial.sql b/Rezoom.SQL.Test/user-model-1/V1.initial.sql new file mode 100644 index 0000000..ca11deb --- /dev/null +++ b/Rezoom.SQL.Test/user-model-1/V1.initial.sql @@ -0,0 +1,18 @@ +create table Users + ( Id int primary key + , Name string(128) null + , Email string(128) null + , Password binary(64) null + , Salt binary(64) null + ); + +create table Groups + ( Id int primary key + , Name string(128) null + ); + +create table UserGroupMaps + ( UserId int primary key references Users(Id) + , GroupId int primary key references Groups(Id) + ); + diff --git a/Rezoom.SQL.Test/user-model-1/V2.view.sql b/Rezoom.SQL.Test/user-model-1/V2.view.sql new file mode 100644 index 0000000..4b61eca --- /dev/null +++ b/Rezoom.SQL.Test/user-model-1/V2.view.sql @@ -0,0 +1,4 @@ +create view ViewUsers(Id) as +select Id * 1 +from Users +where Name like '%stuff%' diff --git a/Rezoom.SQL.Test/user-model-1/rzsql.json b/Rezoom.SQL.Test/user-model-1/rzsql.json new file mode 100644 index 0000000..3eade6f --- /dev/null +++ b/Rezoom.SQL.Test/user-model-1/rzsql.json @@ -0,0 +1,3 @@ +{ "backend": "sqlite" +, "migrations": "." +} diff --git a/Rezoom.SQL.Test/user-model-2/V1.initial-employer.sql b/Rezoom.SQL.Test/user-model-2/V1.initial-employer.sql new file mode 100644 index 0000000..85a52d4 --- /dev/null +++ b/Rezoom.SQL.Test/user-model-2/V1.initial-employer.sql @@ -0,0 +1,8 @@ +create table Companies + ( Id int primary key + , Name string(128) null + ); + +alter table Users +add column EmployerId int +references Companies(Id); diff --git a/Rezoom.SQL.Test/user-model-2/V1.initial-groups.sql b/Rezoom.SQL.Test/user-model-2/V1.initial-groups.sql new file mode 100644 index 0000000..0c0c929 --- /dev/null +++ b/Rezoom.SQL.Test/user-model-2/V1.initial-groups.sql @@ -0,0 +1,10 @@ +create table Groups + ( Id int primary key + , Name string(128) null + ); + +create table UserGroupMaps + ( UserId int primary key references Users(Id) + , GroupId int primary key references Groups(Id) + ); + diff --git a/Rezoom.SQL.Test/user-model-2/V1.initial.sql b/Rezoom.SQL.Test/user-model-2/V1.initial.sql new file mode 100644 index 0000000..a6c1e1e --- /dev/null +++ b/Rezoom.SQL.Test/user-model-2/V1.initial.sql @@ -0,0 +1,8 @@ +create table Users + ( Id int primary key + , Name string(128) null + , Email string(128) null + , Password binary(64) null + , Salt binary(64) null + ); + diff --git a/Rezoom.SQL.Test/user-model-2/V2.company+groups.sql b/Rezoom.SQL.Test/user-model-2/V2.company+groups.sql new file mode 100644 index 0000000..38f8bc0 --- /dev/null +++ b/Rezoom.SQL.Test/user-model-2/V2.company+groups.sql @@ -0,0 +1,5 @@ +create table CompanyGroupMaps + ( CompanyId int primary key references Companies(Id) + , GroupId int primary key references Groups(Id) + ); + diff --git a/Rezoom.SQL.Test/user-model-2/rzsql.json b/Rezoom.SQL.Test/user-model-2/rzsql.json new file mode 100644 index 0000000..3eade6f --- /dev/null +++ b/Rezoom.SQL.Test/user-model-2/rzsql.json @@ -0,0 +1,3 @@ +{ "backend": "sqlite" +, "migrations": "." +} diff --git a/Rezoom.Test/Environment.fs b/Rezoom.Test/Environment.fs index e9a7376..50fc06d 100644 --- a/Rezoom.Test/Environment.fs +++ b/Rezoom.Test/Environment.fs @@ -6,38 +6,33 @@ open System open System.Collections open System.Collections.Generic -type TestContext() = - let batches = new List>() - let mutable inProgress = false - member __.Prepare(query : string) = - if not inProgress then - batches.Add(new List()) - inProgress <- true - let batch = batches.[batches.Count - 1] - let index = batch.Count - batch.Add(query) - () - member __.Execute() = - inProgress <- false // end this batch +type TestExecutionLog() = + inherit ExecutionLog() + let steps = ResizeArray() + override __.OnBeginStep() = + steps.Add(ResizeArray()) + override __.OnPreparedErrand(errand) = + steps.[steps.Count - 1].Add(string errand.CacheInfo.Identity) member __.Batches() = - batches - |> Seq.map List.ofSeq - |> List.ofSeq + steps |> Seq.map List.ofSeq |> Seq.filter (not << List.isEmpty) |> List.ofSeq type TestRequest<'a>(idem : bool, query : string, pre : unit -> unit, post : string -> 'a) = inherit SynchronousErrand<'a>() new (query, pre, post) = TestRequest<_>(true, query, pre, post) - override __.Mutation = not idem - override __.Idempotent = idem - override __.DataSource = box typeof - override __.Identity = box query + override __.CacheInfo = + { new CacheInfo() with + override __.DependencyMask = BitMask(0UL, 1UL) + override __.InvalidationMask = + if idem then BitMask.Zero + else BitMask.Full + override __.Cacheable = idem + override __.Category = upcast typeof.Assembly + override __.Identity = upcast query + } override __.Prepare(serviceContext : ServiceContext) = - let db = serviceContext.GetService>().Service pre() - db.Prepare(query) fun () -> - db.Execute() post query exception PrepareFailure of string @@ -46,16 +41,16 @@ exception ArtificialFailure of string let explode str = raise <| ArtificialFailure str -let sendWith query post = TestRequest<_>(query, id, post).ToPlan() +let sendWith query post = Plan.ofErrand <| TestRequest<_>(query, id, post) let send query = sendWith query id -let mutateWith query post = TestRequest<_>(false, query, id, post).ToPlan() +let mutateWith query post = Plan.ofErrand <| TestRequest<_>(false, query, id, post) let mutate query = mutateWith query id let failingPrepare msg query = TestRequest<_> ( query , fun () -> raise <| PrepareFailure msg , fun _ -> Unchecked.defaultof<_> - ) |> fun r -> r.ToPlan() + ) |> Plan.ofErrand let failingRetrieve msg query = sendWith query (fun _ -> raise <| RetrieveFailure msg) type ExpectedResult<'a> = @@ -63,28 +58,27 @@ type ExpectedResult<'a> = | Bad of (exn -> bool) type ExpectedResultTest<'a> = - { - Task : unit -> 'a Plan + { Task : unit -> 'a Plan Batches : string list list Result : ExpectedResult<'a> } let testSpeed expectedResult = - use execContext = new ExecutionContext(new ZeroServiceFactory()) - let result = execContext.Execute(expectedResult.Task()).Result + let log = TestExecutionLog() + let result = (execute { ExecutionConfig.Default with Log = log } (expectedResult.Task())).Result match expectedResult.Result with | Good x when x = result -> () | _ -> failwith "Invalid result for speed test (try running this as a regular test)" let test expectedResult = - use execContext = new ExecutionContext(new ZeroServiceFactory()) + let log = TestExecutionLog() + let execContext = execute { ExecutionConfig.Default with Log = log } let result = try - execContext.Execute(expectedResult.Task()).Result |> Choice1Of2 + (execContext (expectedResult.Task())).Result |> Choice1Of2 with | ex -> Choice2Of2 ex - let testContext = execContext.GetService>().Service - let batches = testContext.Batches() + let batches = log.Batches() if batches <> expectedResult.Batches then failwithf "Batches do not match (actual: %A)" batches diff --git a/Rezoom.Test/Rezoom.Test.fsproj b/Rezoom.Test/Rezoom.Test.fsproj index 631ad4b..e183b9e 100644 --- a/Rezoom.Test/Rezoom.Test.fsproj +++ b/Rezoom.Test/Rezoom.Test.fsproj @@ -1,4 +1,4 @@ - + @@ -33,37 +33,6 @@ 3 bin\Release\Rezoom.Test.XML - - - - - True - - - - - - - - - - - - - - - - - Rezoom - {d98acbeb-a039-4340-a7c5-6ed2b677268b} - True - - - Rezoom.Execution - {9db721d3-da97-4be3-b60b-9b7a682e803e} - True - - 11 @@ -80,6 +49,44 @@ + + + + + + + + + + + + + + ..\packages\FSharp.Core.4.0.0.1\lib\net40\FSharp.Core.dll + True + + + ..\packages\FsUnit.2.3.2\lib\net45\FsUnit.NUnit.dll + True + + + + ..\packages\FsUnit.2.3.2\lib\net45\NHamcrest.dll + True + + + ..\packages\NUnit.3.5.0\lib\net45\nunit.framework.dll + True + + + + + + Rezoom + {d98acbeb-a039-4340-a7c5-6ed2b677268b} + True + + - + \ No newline at end of file diff --git a/Rezoom.Test/TestCaching.fs b/Rezoom.Test/TestCaching.fs index 2cd25f8..800e857 100644 --- a/Rezoom.Test/TestCaching.fs +++ b/Rezoom.Test/TestCaching.fs @@ -1,100 +1,118 @@ -namespace Rezoom.Test +module Rezoom.Test.TestCaching open Rezoom -open Microsoft.VisualStudio.TestTools.UnitTesting +open NUnit.Framework +open FsUnit -[] -type TestCaching() = - [] - member __.TestStrictCachedPair() = - { - Task = fun () -> - plan { - let! q1 = send "q" - let! q2 = send "q" - return q1 + q2 - } - Batches = - [ - [ "q" ] - ] - Result = Good "qq" - } |> test - - [] - member __.TestConcurrentCachedPair() = - { - Task = fun () -> - plan { - let! q1, q2 = send "q", send "q" - return q1 + q2 - } - Batches = - [ - [ "q" ] - ] - Result = Good "qq" - } |> test +[] +let ``strict cached pair`` () = + { Task = fun () -> + plan { + let! q1 = send "q" + let! q2 = send "q" + return q1 + q2 + } + Batches = + [ [ "q" ] + ] + Result = Good "qq" + } |> test + +[] +let ``concurrent cached pair`` () = + { Task = fun () -> + plan { + let! q1, q2 = send "q", send "q" + return q1 + q2 + } + Batches = + [ [ "q" ] + ] + Result = Good "qq" + } |> test + +[] +let ``chaining cached concurrency`` () = + let testTask x = + plan { + let! a = send (x + "1") + let! b = send (x + "2") + let! c = send (x + "3") + return a + b + c + } + { Task = fun () -> + plan { + let! x1, x2, x3 = + testTask "x", testTask "x", testTask "x" + return x1 + " " + x2 + " " + x3 + } + Batches = + [ [ "x1" ] + [ "x2" ] + [ "x3" ] + ] + Result = Good "x1x2x3 x1x2x3 x1x2x3" + } |> test - [] - member __.TestChainingCachedConcurrency() = - let testTask x = +[] +let ``still valid after other`` () = + { Task = fun () -> plan { - let! a = send (x + "1") - let! b = send (x + "2") - let! c = send (x + "3") - return a + b + c + let! q1 = send "q" + let! q2 = send "q" + let! m = send "x" + let! q3 = send "q" + return q1 + q2 + m + q3 } - { - Task = fun () -> - plan { - let! x1, x2, x3 = - testTask "x", testTask "x", testTask "x" - return x1 + " " + x2 + " " + x3 - } - Batches = - [ - [ "x1" ] - [ "x2" ] - [ "x3" ] - ] - Result = Good "x1x2x3 x1x2x3 x1x2x3" - } |> test + Batches = + [ [ "q" ] + [ "x" ] + ] + Result = Good "qqxq" + } |> test - [] - member __.TestStillValid() = - { - Task = fun () -> - plan { - let! q1 = send "q" - let! q2 = send "q" - let! m = send "x" - let! q3 = send "q" - return q1 + q2 + m + q3 - } - Batches = - [ - [ "q" ] - [ "x" ] - ] - Result = Good "qqxq" - } |> test +[] +let ``invalidation invalidates`` () = + { Task = fun () -> + plan { + let! q1 = send "q" + let! q2 = send "q" + let! m = mutate "x" + let! q3 = send "q" + let! q4 = send "q" + return q1 + q2 + m + q3 + q4 + } + Batches = + [ [ "q" ] + [ "x" ] + [ "q" ] + ] + Result = Good "qqxqq" + } |> test - [] - member __.TestInvalidation() = - { - Task = fun () -> - plan { - let! q1 = send "q" - let! q2 = send "q" - let! m = mutate "x" - let! q3 = send "q" - return q1 + q2 + m + q3 - } - Batches = - [ - [ "q" ] - [ "x" ] - [ "q" ] - ] - Result = Good "qqxq" - } |> test \ No newline at end of file +[] +let ``deferred execution`` () = + let px = + plan { + let! q = send "q" + let! x = send "x" + return x + } + let py = send "y" + { Task = fun () -> + plan { + let! q = send "q" + // when px and py are batched together, at first there is a step with both + // q (from px) and y (from py) pending. + // q will be pulled from the cache, but rather than just executing y, we should + // defer y and advance px so we can batch x and y together. + let! x, y = px, py + let! z = send "z" + return q + x + y + z + } + Batches = + [ [ "q" ] + [ "x"; "y" ] // x and y batched together + [ "z" ] + ] + Result = Good "qxyz" + } |> test \ No newline at end of file diff --git a/Rezoom.Test/TestConcurrency.fs b/Rezoom.Test/TestConcurrency.fs index bc9adbc..078ecd1 100644 --- a/Rezoom.Test/TestConcurrency.fs +++ b/Rezoom.Test/TestConcurrency.fs @@ -1,62 +1,55 @@ -namespace Rezoom.Test +module Rezoom.Test.TestConcurrency open Rezoom -open Microsoft.VisualStudio.TestTools.UnitTesting +open NUnit.Framework +open FsUnit -[] -type TestConcurrency() = - [] - member __.TestStrictPair() = - { - Task = fun () -> - plan { - let! q = send "q" - let! r = send "r" - return q + r - } - Batches = - [ - [ "q" ] - [ "r" ] - ] - Result = Good "qr" - } |> test +[] +let ``strict pair`` () = + { Task = fun () -> + plan { + let! q = send "q" + let! r = send "r" + return q + r + } + Batches = + [ [ "q" ] + [ "r" ] + ] + Result = Good "qr" + } |> test - [] - member __.TestConcurrentPair() = - { - Task = fun () -> - plan { - let! q, r = send "q", send "r" - return q + r - } - Batches = - [ - [ "q"; "r" ] - ] - Result = Good "qr" - } |> test +[] +let ``concurrent pair`` () = + { Task = fun () -> + plan { + let! q, r = send "q", send "r" + return q + r + } + Batches = + [ [ "q"; "r" ] + ] + Result = Good "qr" + } |> test - [] - member __.TestChainingConcurrency() = - let testTask x = +[] +let ``chaining concurrency`` () = + let testTask x = + plan { + let! a = send (x + "1") + let! b = send (x + "2") + let! c = send (x + "3") + return a + b + c + } + { Task = fun () -> plan { - let! a = send (x + "1") - let! b = send (x + "2") - let! c = send (x + "3") - return a + b + c + let! x, y, z = + testTask "x", testTask "y", testTask "z" + return x + " " + y + " " + z } - { - Task = fun () -> - plan { - let! x, y, z = - testTask "x", testTask "y", testTask "z" - return x + " " + y + " " + z - } - Batches = - [ - [ "x1"; "y1"; "z1" ] - [ "x2"; "y2"; "z2" ] - [ "x3"; "y3"; "z3" ] - ] - Result = Good "x1x2x3 y1y2y3 z1z2z3" - } |> test \ No newline at end of file + Batches = + [ [ "x1"; "y1"; "z1" ] + [ "x2"; "y2"; "z2" ] + [ "x3"; "y3"; "z3" ] + ] + Result = Good "x1x2x3 y1y2y3 z1z2z3" + } |> test \ No newline at end of file diff --git a/Rezoom.Test/TestExceptionCatching.fs b/Rezoom.Test/TestExceptionCatching.fs index a53473f..11ad8c7 100644 --- a/Rezoom.Test/TestExceptionCatching.fs +++ b/Rezoom.Test/TestExceptionCatching.fs @@ -1,159 +1,148 @@ -namespace Rezoom.Test -open Rezoom +module Rezoom.Test.TestExceptionCaching open System -open Microsoft.VisualStudio.TestTools.UnitTesting - -[] -type TestExceptionCatching() = - [] - member __.TestSimpleCatch() = - { - Task = fun () -> - plan { - try - explode "fail" - return 2 - with - | ArtificialFailure "fail" -> return 1 - } - Batches = [] - Result = Good 1 - } |> test - - [] - member __.TestBoundCatch() = - { - Task = fun () -> - plan { - try - let! x = send "x" - explode "fail" - return 2 - with - | ArtificialFailure "fail" -> return 1 - } - Batches = [["x"]] - Result = Good 1 - } |> test - - [] - member __.TestSimpleFailingPrepare() = - { - Task = fun () -> - plan { - try - let! x = failingPrepare "fail" "x" - return 2 - with - | PrepareFailure "fail" -> return 1 - } - Batches = [] - Result = Good 1 - } |> test - - [] - member __.TestSimpleFailingRetrieve() = - { - Task = fun () -> - plan { - try - let! x = failingRetrieve "fail" "x" - return 2 - with - | RetrieveFailure "fail" -> return 1 - } - Batches = [["x"]] - Result = Good 1 - } |> test +open Rezoom +open NUnit.Framework +open FsUnit - [] - member __.TestConcurrentCatching() = - let catching query = +[] +let ``simple catch`` () = + { Task = fun () -> plan { - let guid = Guid.NewGuid().ToString() try - let! x = failingRetrieve guid query - return x + explode "fail" + return 2 with - | RetrieveFailure msg when msg = guid -> return "bad" - + | ArtificialFailure "fail" -> return 1 } - let good query = + Batches = [] + Result = Good 1 + } |> test + +[] +let ``bound catch`` () = + { Task = fun () -> plan { - let! result = send query - return result + try + let! x = send "x" + explode "fail" + return 2 + with + | ArtificialFailure "fail" -> return 1 } - { - Task = fun () -> - plan { - let! x, y, z = - catching "x", catching "y", good "z" - return x + y + z - } - Batches = - [ - [ "x"; "y"; "z" ] - ] - Result = Good "badbadz" - } |> test + Batches = [["x"]] + Result = Good 1 + } |> test - [] - member __.TestConcurrentLoopCatching() = - let catching query = +[] +let ``simple failing prepare`` () = + { Task = fun () -> plan { - let guid = Guid.NewGuid().ToString() try - let! x = failingRetrieve guid query - return () + let! x = failingPrepare "fail" "x" + return 2 with - | RetrieveFailure msg when msg = guid -> return () + | PrepareFailure "fail" -> return 1 + } + Batches = [] + Result = Good 1 + } |> test +[] +let ``simple failing retrieve`` () = + { Task = fun () -> + plan { + try + let! x = failingRetrieve "fail" "x" + return 2 + with + | RetrieveFailure "fail" -> return 1 } - let good query = + Batches = [["x"]] + Result = Good 1 + } |> test + +[] +let ``concurrent catching`` () = + let catching query = + plan { + let guid = Guid.NewGuid().ToString() + try + let! x = failingRetrieve guid query + return x + with + | RetrieveFailure msg when msg = guid -> return "bad" + + } + let good query = + plan { + let! result = send query + return result + } + { Task = fun () -> plan { - let! result = send query - return () + let! x, y, z = + catching "x", catching "y", good "z" + return x + y + z } - { - Task = fun () -> - plan { - for q in batch ["x"; "y"; "z"] do - if q = "y" then - do! good q - else - do! catching q - return () - } - Batches = - [ - [ "x"; "y"; "z" ] - ] - Result = Good () - } |> test + Batches = + [ [ "x"; "y"; "z" ] + ] + Result = Good "badbadz" + } |> test - [] - member __.TestConcurrentNonCatching() = - let notCatching query = +[] +let ``concurrent loop catching`` () = + let catching query = + plan { + let guid = Guid.NewGuid().ToString() + try + let! x = failingRetrieve guid query + return () + with + | RetrieveFailure msg when msg = guid -> return () + + } + let good query = + plan { + let! result = send query + return () + } + { Task = fun () -> plan { - let! x = failingRetrieve "fail" query - return x + for q in batch ["x"; "y"; "z"] do + if q = "y" then + do! good q + else + do! catching q + return () } - let good query = + Batches = + [ [ "x"; "y"; "z" ] + ] + Result = Good () + } |> test + +[] +let ``concurrent non-catching`` () = + let notCatching query = + plan { + let! x = failingRetrieve "fail" query + return x + } + let good query = + plan { + let! result = send query + let! next = send "jim" + return result + next + } + { Task = fun () -> plan { - let! result = send query - let! next = send "jim" - return result + next + let! x, y, z = + notCatching "x", notCatching "y", good "z" + return x + y + z } - { - Task = fun () -> - plan { - let! x, y, z = - notCatching "x", notCatching "y", good "z" - return x + y + z - } - Batches = - [ - [ "x"; "y"; "z" ] - ] - Result = Bad (fun ex -> true) - } |> test \ No newline at end of file + Batches = + [ [ "x"; "y"; "z" ] + ] + Result = Bad (fun ex -> true) + } |> test \ No newline at end of file diff --git a/Rezoom.Test/TestExceptionFinally.fs b/Rezoom.Test/TestExceptionFinally.fs index adf9e9c..1ae7402 100644 --- a/Rezoom.Test/TestExceptionFinally.fs +++ b/Rezoom.Test/TestExceptionFinally.fs @@ -1,213 +1,381 @@ -namespace Rezoom.Test -open Rezoom +module Rezoom.Test.TestExceptionFinally open System -open Microsoft.VisualStudio.TestTools.UnitTesting - -[] -type TestExceptionFinally() = - [] - member __.TestFinallyNoThrow() = - let mutable ran = 0 - { - Task = fun () -> - plan { - try - let! q = send "q" - let! r = send "r" - return q + r - finally - ran <- ran + 1 - } - Batches = - [ - [ "q" ] - [ "r" ] - ] - Result = Good "qr" - } |> test - if ran <> 1 then - failwithf "ran was %d" ran - - [] - member __.TestSimpleFinally() = - let mutable ran = 0 - { - Task = fun () -> - plan { - try - explode "fail" - return 2 - finally - ran <- ran + 1 - } - Batches = [] - Result = Bad (fun _ -> ran = 1) - } |> test - - [] - member __.TestBoundFinally() = - let mutable ran = 0 - { - Task = fun () -> - plan { - try - let! x = send "x" - explode "fail" - return 2 - finally - ran <- ran + 1 - } - Batches = [["x"]] - Result = Bad (fun _ -> ran = 1) - } |> test - - [] - member __.TestSimpleFailingPrepare() = - let mutable ran = 0 - { - Task = fun () -> - plan { - try - let! x = failingPrepare "fail" "x" - return 2 - finally - ran <- ran + 1 - } - Batches = [] - Result = Bad (fun _ -> ran = 1) - } |> test - - [] - member __.TestSimpleFailingRetrieve() = - let mutable ran = 0 - { - Task = fun () -> - plan { - try - let! x = failingRetrieve "fail" "x" - return 2 - finally - ran <- ran + 1 - } - Batches = [["x"]] - Result = Bad (fun _ -> ran = 1) - } |> test - - [] - member __.TestUsingThrow() = - let mutable ran = 0 - { - Task = fun () -> - plan { - use d = { new IDisposable with member x.Dispose() = ran <- ran + 1 } - let! x = failingRetrieve "fail" "x" +open Rezoom +open NUnit.Framework +open FsUnit + +[] +let ``finally no throw`` () = + let mutable ran = 0 + { Task = fun () -> + plan { + try + let! q = send "q" + let! r = send "r" + return q + r + finally + ran <- ran + 1 + } + Batches = + [ [ "q" ] + [ "r" ] + ] + Result = Good "qr" + } |> test + if ran <> 1 then + failwithf "ran was %d" ran + +[] +let ``simple finally`` () = + let mutable ran = 0 + { Task = fun () -> + plan { + try + explode "fail" return 2 - } - Batches = [["x"]] - Result = Bad (fun _ -> ran = 1) - } |> test - - [] - member __.TestUsingNoThrow() = - let mutable ran = 0 - { - Task = fun () -> - plan { - use d = { new IDisposable with member x.Dispose() = ran <- ran + 1 } + finally + ran <- ran + 1 + } + Batches = [] + Result = Bad (fun _ -> ran = 1) + } |> test + +[] +let ``bound finally`` () = + let mutable ran = 0 + { Task = fun () -> + plan { + try let! x = send "x" + explode "fail" return 2 - } - Batches = [["x"]] - Result = Good 2 - } |> test - - [] - member __.TestNestedFinally() = - let mutable counter = 0 - let mutable first = 0 - let mutable next = 0 - { - Task = fun () -> - plan { - try - try - let! x = failingRetrieve "fail" "x" - return 2 - finally - counter <- counter + 1 - first <- counter - finally - counter <- counter + 1 - next <- counter - } - Batches = [["x"]] - Result = Bad (fun _ -> - counter = 2 && first = 1 && next = 2) - } |> test + finally + ran <- ran + 1 + } + Batches = [["x"]] + Result = Bad (fun _ -> ran = 1) + } |> test - [] - member __.TestConcurrentAbortion() = - let mutable ranFinally = false - let deadly query = +[] +let ``finally failing prepare`` () = + let mutable ran = 0 + { Task = fun () -> plan { - let! x = failingRetrieve "fail" query - return x + try + let! x = failingPrepare "fail" "x" + return 2 + finally + ran <- ran + 1 } - let good query = + Batches = [] + Result = Bad (fun _ -> ran = 1) + } |> test + +[] +let ``finally failing retrieve`` () = + let mutable ran = 0 + { Task = fun () -> plan { try - let! result = send query - let! next = send "jim" - return result + next + let! x = failingRetrieve "fail" "x" + return 2 finally - ranFinally <- true - } - { - Task = fun () -> - plan { - let! x, y, z = - deadly "x", deadly "y", good "z" - return x + y + z - } - Batches = - [ - [ "x"; "y"; "z" ] - ] - Result = Bad (fun ex -> ranFinally) - } |> test - - [] - member __.TestConcurrentLoopAbortion() = - let mutable ranFinally = false - let deadly query = - plan { - let! x = failingRetrieve "fail" query - return () + ran <- ran + 1 + } + Batches = [["x"]] + Result = Bad (fun _ -> ran = 1) + } |> test + +[] +let ``using with throw`` () = + let mutable ran = 0 + { Task = fun () -> + plan { + use d = { new IDisposable with member x.Dispose() = ran <- ran + 1 } + let! x = failingRetrieve "fail" "x" + return 2 } - let good query = + Batches = [["x"]] + Result = Bad (fun _ -> ran = 1) + } |> test + +[] +let ``using without throw`` () = + let mutable ran = 0 + { Task = fun () -> + plan { + use d = { new IDisposable with member x.Dispose() = ran <- ran + 1 } + let! x = send "x" + return 2 + } + Batches = [["x"]] + Result = Good 2 + } |> test + +[] +let ``nested finally`` () = + let mutable counter = 0 + let mutable first = 0 + let mutable next = 0 + { Task = fun () -> plan { try - let! result = send query - let! next = send "jim" - return () + try + let! x = failingRetrieve "fail" "x" + return 2 + finally + counter <- counter + 1 + first <- counter finally - ranFinally <- true - } - { - Task = fun () -> - plan { - for q in batch ["x"; "y"; "z"] do - if q = "y" then - do! good q - else - do! deadly q - return () - } - Batches = - [ - [ "x"; "y"; "z" ] - ] - Result = Bad (fun _ -> - ranFinally) - } |> test + counter <- counter + 1 + next <- counter + } + Batches = [["x"]] + Result = Bad (fun _ -> + counter = 2 && first = 1 && next = 2) + } |> test + +[] +let ``concurrent retrieval abortion good last`` () = + let mutable ranFinally = false + let deadly query = + plan { + let! x = failingRetrieve "fail" query + return x + } + let good query = + plan { + try + let! result = send query + let! next = send "jim" + return result + next + finally + ranFinally <- true + } + { Task = fun () -> + plan { + let! x, y, z = + deadly "x", deadly "y", good "z" + return x + y + z + } + Batches = + [ [ "x"; "y"; "z" ] + ] + Result = Bad (fun ex -> ranFinally) + } |> test + +[] +let ``concurrent retrieval abortion good first`` () = + let mutable ranFinally = false + let deadly query = + plan { + let! x = failingRetrieve "fail" query + return x + } + let good query = + plan { + try + let! result = send query + let! next = send "jim" + return result + next + finally + ranFinally <- true + } + { Task = fun () -> + plan { + let! x, y, z = + good "x", deadly "y", deadly "z" + return x + y + z + } + Batches = + [ [ "x"; "y"; "z" ] + ] + Result = Bad (fun ex -> ranFinally) + } |> test + +[] +let ``concurrent retrieval abortion good middle`` () = + let mutable ranFinally = false + let deadly query = + plan { + let! x = failingRetrieve "fail" query + return x + } + let good query = + plan { + try + let! result = send query + let! next = send "jim" + return result + next + finally + ranFinally <- true + } + { Task = fun () -> + plan { + let! x, y, z = + deadly "x", good "y", deadly "z" + return x + y + z + } + Batches = + [ [ "x"; "y"; "z" ] + ] + Result = Bad (fun ex -> ranFinally) + } |> test + +[] +let ``concurrent logic abortion good last`` () = + let mutable ranFinally = false + let deadly query = + plan { + failwith "exn" + let! x = failingRetrieve "fail" query + return x + } + let good query = + plan { + try + let! result = send query + let! next = send "jim" + return result + next + finally + ranFinally <- true + } + { Task = fun () -> + plan { + let! x, y, z = + deadly "x", deadly "y", good "z" + return x + y + z + } + Batches = + [ + ] + Result = Bad (fun ex -> ranFinally) + } |> test + +[] +let ``concurrent logic abortion good first`` () = + let mutable ranFinally = false + let deadly query = + plan { + failwith "exn" + let! x = failingRetrieve "fail" query + return x + } + let good query = + plan { + try + let! result = send query + let! next = send "jim" + return result + next + finally + ranFinally <- true + } + { Task = fun () -> + plan { + let! x, y, z = + good "x", deadly "y", deadly "z" + return x + y + z + } + Batches = + [ + ] + Result = Bad (fun ex -> ranFinally) + } |> test + +[] +let ``concurrent logic abortion good middle`` () = + let mutable ranFinally = false + let deadly query = + plan { + failwith "exn" + let! x = failingRetrieve "fail" query + return x + } + let good query = + plan { + try + let! result = send query + let! next = send "jim" + return result + next + finally + ranFinally <- true + } + { Task = fun () -> + plan { + let! x, y, z = + deadly "x", good "y", deadly "z" + return x + y + z + } + Batches = + [ + ] + Result = Bad (fun ex -> ranFinally) + } |> test + +[] +let ``concurrent loop retrieval abortion`` () = + let mutable ranFinally = false + let deadly query = + plan { + let! x = failingRetrieve "fail" query + return () + } + let good query = + plan { + try + let! result = send query + let! next = send "jim" + return () + finally + ranFinally <- true + } + { Task = fun () -> + plan { + for q in batch ["x"; "y"; "z"] do + if q = "y" then + do! good q + else + do! deadly q + return () + } + Batches = + [ [ "x"; "y"; "z" ] + ] + Result = Bad (fun _ -> + ranFinally) + } |> test + +[] +let ``concurrent loop logic abortion`` () = + let mutable ranFinally = false + let deadly query = + plan { + failwith "logic" + let! x = failingRetrieve "fail" query + return () + } + let good query = + plan { + try + let! result = send query + let! next = send "jim" + return () + finally + ranFinally <- true + } + { Task = fun () -> + plan { + for q in batch ["x"; "y"; "z"] do + if q = "y" then + do! good q + else + do! deadly q + return () + } + Batches = + [ + ] + Result = Bad (fun _ -> + ranFinally) + } |> test \ No newline at end of file diff --git a/Rezoom.Test/TestPerformance.fs b/Rezoom.Test/TestPerformance.fs index 02436c0..89b2ec9 100644 --- a/Rezoom.Test/TestPerformance.fs +++ b/Rezoom.Test/TestPerformance.fs @@ -1,98 +1,90 @@ -namespace Rezoom.Test +module Rezoom.Test.TestPerformance open Rezoom open System open System.Text open System.Diagnostics -open Microsoft.VisualStudio.TestTools.UnitTesting +open NUnit.Framework +open FsUnit -[] -type TestPerformance() = - static let ret1 = plan { return 1 } - static let time f = - let sw = new Stopwatch() - sw.Start() - let mutable iterations = 0L - while sw.ElapsedMilliseconds < 1000L do - testSpeed f - iterations <- iterations + 1L - sw.Stop() - printfn "%s iterations in %O" (iterations.ToString("#,###")) sw.Elapsed - [] - member __.TestSingleReturn() = - time <| - { - Task = fun () -> ret1 - Batches = [] - Result = Good 1 - } +let ret1 = plan { return 1 } +let time f = + let sw = new Stopwatch() + sw.Start() + let iterations = 10 * 1000 + for i = 1 to iterations do + testSpeed f + sw.Stop() + printfn "%d iterations per second" (int64 iterations * 1000L / sw.ElapsedMilliseconds) + +[] +let ``single return`` () = + time <| + { Task = fun () -> ret1 + Batches = [] + Result = Good 1 + } - [] - member __.TestNestedReturn() = - time <| - { - Task = fun () -> plan { +[] +let ``nested return`` () = + time <| + { Task = fun () -> plan { + return! plan { return! plan { - return! plan { - return! ret1 - } + return! ret1 } } - Batches = [] - Result = Good 1 } + Batches = [] + Result = Good 1 + } - [] - member __.TestBindChain() = - time <| - { - Task = fun () -> plan { - let! one1 = ret1 - let! one2 = ret1 - let! one3 = ret1 - return one1 + one2 + one3 - } - Batches = [] - Result = Good 3 +[] +let ``bind chain`` () = + time <| + { Task = fun () -> plan { + let! one1 = ret1 + let! one2 = ret1 + let! one3 = ret1 + return one1 + one2 + one3 } + Batches = [] + Result = Good 3 + } - [] - member __.TestBindChainWithRequests() = - time <| - { - Task = fun () -> plan { - let! _ = send "x" - let! one1 = ret1 - let! _ = send "y" - let! one2 = ret1 - let! _ = send "z" - let! one3 = ret1 - return one1 + one2 + one3 - } - Batches = - [ - ["x"] - ["y"] - ["z"] - ] - Result = Good 3 +[] +let ``bind chain with requests`` () = + time <| + { Task = fun () -> plan { + let! _ = send "x" + let! one1 = ret1 + let! _ = send "y" + let! one2 = ret1 + let! _ = send "z" + let! one3 = ret1 + return one1 + one2 + one3 } + Batches = + [ ["x"] + ["y"] + ["z"] + ] + Result = Good 3 + } - [] - member __.TestBindChainWithBatchedRequests() = - time <| - { - Task = fun () -> plan { - let! one1 = ret1 - let! _ = send "x", send "y", send "z" - let! one2 = ret1 - let! _ = send "q", send "r", send "s" - let! one3 = ret1 - return one1 + one2 + one3 - } - Batches = - [ - ["x";"y";"z"] - ["q";"r";"s"] - ] - Result = Good 3 - } \ No newline at end of file +[] +let ``bind chain with batched requests`` () = + time <| + { Task = fun () -> plan { + let! one1 = ret1 + let! _ = send "x", send "y", send "z" + let! one2 = ret1 + let! _ = send "q", send "r", send "s" + let! one3 = ret1 + return one1 + one2 + one3 + } + Batches = + [ ["x";"y";"z"] + ["q";"r";"s"] + ] + Result = Good 3 + } \ No newline at end of file diff --git a/Rezoom.Test/app.config b/Rezoom.Test/app.config new file mode 100644 index 0000000..c130c89 --- /dev/null +++ b/Rezoom.Test/app.config @@ -0,0 +1,11 @@ + + + + + + + + + + + \ No newline at end of file diff --git a/Rezoom.Test/packages.config b/Rezoom.Test/packages.config new file mode 100644 index 0000000..02f2331 --- /dev/null +++ b/Rezoom.Test/packages.config @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/Rezoom.sln b/Rezoom.sln index 90fbdf1..7f7ad1e 100644 --- a/Rezoom.sln +++ b/Rezoom.sln @@ -1,20 +1,12 @@  Microsoft Visual Studio Solution File, Format Version 12.00 # Visual Studio 14 -VisualStudioVersion = 14.0.25123.0 +VisualStudioVersion = 14.0.25420.1 MinimumVisualStudioVersion = 10.0.40219.1 -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Rezoom.Execution", "Rezoom.Execution\Rezoom.Execution.csproj", "{9DB721D3-DA97-4BE3-B60B-9B7A682E803E}" -EndProject Project("{F2A71F9B-5D33-465A-A702-920D77279786}") = "Rezoom", "Rezoom\Rezoom.fsproj", "{D98ACBEB-A039-4340-A7C5-6ED2B677268B}" EndProject Project("{F2A71F9B-5D33-465A-A702-920D77279786}") = "Rezoom.Test", "Rezoom.Test\Rezoom.Test.fsproj", "{F5167029-2918-46B5-A3A6-AB4A91B231A1}" EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Rezoom.ADO", "Rezoom.ADO\Rezoom.ADO.csproj", "{13BB08A8-8135-4630-BEAB-1F35D660B52B}" -EndProject -Project("{F2A71F9B-5D33-465A-A702-920D77279786}") = "Rezoom.ADO.Test", "Rezoom.ADO.Test\Rezoom.ADO.Test.fsproj", "{39B23D3A-43D2-442F-9B83-B68ECE642FAD}" -EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Rezoom.EF", "Rezoom.EF\Rezoom.EF.csproj", "{51023E89-6081-4BBF-8945-8F17E6C4D65C}" -EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Tests", "Tests", "{3CFAC282-AFAA-4B7A-879F-D2BDF9EC631B}" EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Integrations", "Integrations", "{B021A795-0151-4404-B592-0D375821890F}" @@ -23,7 +15,13 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Rezoom.IPGeo", "Rezoom.IPGe EndProject Project("{F2A71F9B-5D33-465A-A702-920D77279786}") = "Rezoom.IPGeo.Test", "Rezoom.IPGeo.Test\Rezoom.IPGeo.Test.fsproj", "{97D70AE4-41FB-47D0-B518-589799F093D5}" EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Rezoom.ADO.Test.Internals", "Rezoom.ADO.Test.Internals\Rezoom.ADO.Test.Internals.csproj", "{3EA0244A-E97C-47B6-96C3-C83315674CAA}" +Project("{F2A71F9B-5D33-465A-A702-920D77279786}") = "Rezoom.SQL.Mapping", "Rezoom.SQL.Mapping\Rezoom.SQL.Mapping.fsproj", "{6B6A06C5-157A-4FE3-8B4C-2A1AE6A15333}" +EndProject +Project("{F2A71F9B-5D33-465A-A702-920D77279786}") = "Rezoom.SQL.Test", "Rezoom.SQL.Test\Rezoom.SQL.Test.fsproj", "{AA699897-F692-4ED0-9865-98B6B4C713DB}" +EndProject +Project("{F2A71F9B-5D33-465A-A702-920D77279786}") = "Rezoom.SQL.Provider", "Rezoom.SQL.Provider\Rezoom.SQL.Provider.fsproj", "{7B1765CB-23F8-419A-9CC6-3DA319ED066F}" +EndProject +Project("{F2A71F9B-5D33-465A-A702-920D77279786}") = "Rezoom.SQL.Compiler", "Rezoom.SQL.Compiler\Rezoom.SQL.Compiler.fsproj", "{87FCD04A-1F90-4D53-A428-CF5F5C532A22}" EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution @@ -31,10 +29,6 @@ Global Release|Any CPU = Release|Any CPU EndGlobalSection GlobalSection(ProjectConfigurationPlatforms) = postSolution - {9DB721D3-DA97-4BE3-B60B-9B7A682E803E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {9DB721D3-DA97-4BE3-B60B-9B7A682E803E}.Debug|Any CPU.Build.0 = Debug|Any CPU - {9DB721D3-DA97-4BE3-B60B-9B7A682E803E}.Release|Any CPU.ActiveCfg = Release|Any CPU - {9DB721D3-DA97-4BE3-B60B-9B7A682E803E}.Release|Any CPU.Build.0 = Release|Any CPU {D98ACBEB-A039-4340-A7C5-6ED2B677268B}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {D98ACBEB-A039-4340-A7C5-6ED2B677268B}.Debug|Any CPU.Build.0 = Debug|Any CPU {D98ACBEB-A039-4340-A7C5-6ED2B677268B}.Release|Any CPU.ActiveCfg = Release|Any CPU @@ -43,18 +37,6 @@ Global {F5167029-2918-46B5-A3A6-AB4A91B231A1}.Debug|Any CPU.Build.0 = Debug|Any CPU {F5167029-2918-46B5-A3A6-AB4A91B231A1}.Release|Any CPU.ActiveCfg = Release|Any CPU {F5167029-2918-46B5-A3A6-AB4A91B231A1}.Release|Any CPU.Build.0 = Release|Any CPU - {13BB08A8-8135-4630-BEAB-1F35D660B52B}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {13BB08A8-8135-4630-BEAB-1F35D660B52B}.Debug|Any CPU.Build.0 = Debug|Any CPU - {13BB08A8-8135-4630-BEAB-1F35D660B52B}.Release|Any CPU.ActiveCfg = Release|Any CPU - {13BB08A8-8135-4630-BEAB-1F35D660B52B}.Release|Any CPU.Build.0 = Release|Any CPU - {39B23D3A-43D2-442F-9B83-B68ECE642FAD}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {39B23D3A-43D2-442F-9B83-B68ECE642FAD}.Debug|Any CPU.Build.0 = Debug|Any CPU - {39B23D3A-43D2-442F-9B83-B68ECE642FAD}.Release|Any CPU.ActiveCfg = Release|Any CPU - {39B23D3A-43D2-442F-9B83-B68ECE642FAD}.Release|Any CPU.Build.0 = Release|Any CPU - {51023E89-6081-4BBF-8945-8F17E6C4D65C}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {51023E89-6081-4BBF-8945-8F17E6C4D65C}.Debug|Any CPU.Build.0 = Debug|Any CPU - {51023E89-6081-4BBF-8945-8F17E6C4D65C}.Release|Any CPU.ActiveCfg = Release|Any CPU - {51023E89-6081-4BBF-8945-8F17E6C4D65C}.Release|Any CPU.Build.0 = Release|Any CPU {CEB9E01B-71C6-468B-8C3E-A1617F036370}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {CEB9E01B-71C6-468B-8C3E-A1617F036370}.Debug|Any CPU.Build.0 = Debug|Any CPU {CEB9E01B-71C6-468B-8C3E-A1617F036370}.Release|Any CPU.ActiveCfg = Release|Any CPU @@ -63,21 +45,33 @@ Global {97D70AE4-41FB-47D0-B518-589799F093D5}.Debug|Any CPU.Build.0 = Debug|Any CPU {97D70AE4-41FB-47D0-B518-589799F093D5}.Release|Any CPU.ActiveCfg = Release|Any CPU {97D70AE4-41FB-47D0-B518-589799F093D5}.Release|Any CPU.Build.0 = Release|Any CPU - {3EA0244A-E97C-47B6-96C3-C83315674CAA}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {3EA0244A-E97C-47B6-96C3-C83315674CAA}.Debug|Any CPU.Build.0 = Debug|Any CPU - {3EA0244A-E97C-47B6-96C3-C83315674CAA}.Release|Any CPU.ActiveCfg = Release|Any CPU - {3EA0244A-E97C-47B6-96C3-C83315674CAA}.Release|Any CPU.Build.0 = Release|Any CPU + {6B6A06C5-157A-4FE3-8B4C-2A1AE6A15333}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {6B6A06C5-157A-4FE3-8B4C-2A1AE6A15333}.Debug|Any CPU.Build.0 = Debug|Any CPU + {6B6A06C5-157A-4FE3-8B4C-2A1AE6A15333}.Release|Any CPU.ActiveCfg = Release|Any CPU + {6B6A06C5-157A-4FE3-8B4C-2A1AE6A15333}.Release|Any CPU.Build.0 = Release|Any CPU + {AA699897-F692-4ED0-9865-98B6B4C713DB}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {AA699897-F692-4ED0-9865-98B6B4C713DB}.Debug|Any CPU.Build.0 = Debug|Any CPU + {AA699897-F692-4ED0-9865-98B6B4C713DB}.Release|Any CPU.ActiveCfg = Release|Any CPU + {AA699897-F692-4ED0-9865-98B6B4C713DB}.Release|Any CPU.Build.0 = Release|Any CPU + {7B1765CB-23F8-419A-9CC6-3DA319ED066F}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {7B1765CB-23F8-419A-9CC6-3DA319ED066F}.Debug|Any CPU.Build.0 = Debug|Any CPU + {7B1765CB-23F8-419A-9CC6-3DA319ED066F}.Release|Any CPU.ActiveCfg = Release|Any CPU + {7B1765CB-23F8-419A-9CC6-3DA319ED066F}.Release|Any CPU.Build.0 = Release|Any CPU + {87FCD04A-1F90-4D53-A428-CF5F5C532A22}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {87FCD04A-1F90-4D53-A428-CF5F5C532A22}.Debug|Any CPU.Build.0 = Debug|Any CPU + {87FCD04A-1F90-4D53-A428-CF5F5C532A22}.Release|Any CPU.ActiveCfg = Release|Any CPU + {87FCD04A-1F90-4D53-A428-CF5F5C532A22}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE EndGlobalSection GlobalSection(NestedProjects) = preSolution {F5167029-2918-46B5-A3A6-AB4A91B231A1} = {3CFAC282-AFAA-4B7A-879F-D2BDF9EC631B} - {13BB08A8-8135-4630-BEAB-1F35D660B52B} = {B021A795-0151-4404-B592-0D375821890F} - {39B23D3A-43D2-442F-9B83-B68ECE642FAD} = {3CFAC282-AFAA-4B7A-879F-D2BDF9EC631B} - {51023E89-6081-4BBF-8945-8F17E6C4D65C} = {B021A795-0151-4404-B592-0D375821890F} {CEB9E01B-71C6-468B-8C3E-A1617F036370} = {B021A795-0151-4404-B592-0D375821890F} {97D70AE4-41FB-47D0-B518-589799F093D5} = {3CFAC282-AFAA-4B7A-879F-D2BDF9EC631B} - {3EA0244A-E97C-47B6-96C3-C83315674CAA} = {3CFAC282-AFAA-4B7A-879F-D2BDF9EC631B} + {6B6A06C5-157A-4FE3-8B4C-2A1AE6A15333} = {B021A795-0151-4404-B592-0D375821890F} + {AA699897-F692-4ED0-9865-98B6B4C713DB} = {3CFAC282-AFAA-4B7A-879F-D2BDF9EC631B} + {7B1765CB-23F8-419A-9CC6-3DA319ED066F} = {B021A795-0151-4404-B592-0D375821890F} + {87FCD04A-1F90-4D53-A428-CF5F5C532A22} = {B021A795-0151-4404-B592-0D375821890F} EndGlobalSection EndGlobal diff --git a/Rezoom/CS.fs b/Rezoom/CS.fs index c264546..f2310b5 100644 --- a/Rezoom/CS.fs +++ b/Rezoom/CS.fs @@ -1,25 +1,33 @@ namespace Rezoom.CS open Rezoom open System +open System.Threading open System.Threading.Tasks +open System.Runtime.CompilerServices [] type AsynchronousErrand<'a>() = inherit Errand<'a>() static member private BoxResult(task : 'a Task) = box task.Result - abstract member Prepare : ServiceContext -> 'a Task Func - override this.InternalPrepare(cxt) : unit -> obj Task = + abstract member Prepare : ServiceContext -> Func + override this.PrepareUntyped(cxt) : CancellationToken -> obj Task = let typed = this.Prepare(cxt) - fun () -> - let t = typed.Invoke() + fun token -> + let t = typed.Invoke(token) t.ContinueWith(AsynchronousErrand<'a>.BoxResult, TaskContinuationOptions.ExecuteSynchronously) [] type SynchronousErrand<'a>() = inherit Errand<'a>() - abstract member Prepare : ServiceContext -> 'a Func - override this.InternalPrepare(cxt) : unit -> obj Task = + abstract member Prepare : ServiceContext -> Func<'a> + override this.PrepareUntyped(cxt) : CancellationToken -> obj Task = let sync = this.Prepare(cxt) - fun () -> - Task.FromResult(box (sync.Invoke())) \ No newline at end of file + fun _ -> + Task.FromResult(box (sync.Invoke())) + +[] +type CSExtensions = + [] + static member ToPlan(request : Errand<'a>) = + Plan.ofErrand request diff --git a/Rezoom/CSExtensions.fs b/Rezoom/CSExtensions.fs deleted file mode 100644 index 62b1011..0000000 --- a/Rezoom/CSExtensions.fs +++ /dev/null @@ -1,8 +0,0 @@ -namespace Rezoom -open System.Runtime.CompilerServices - -[] -type CSExtensions = - [] - static member ToPlan(request : Errand<'a>) = - Plan.ofErrand request diff --git a/Rezoom/CacheInfo.fs b/Rezoom/CacheInfo.fs new file mode 100644 index 0000000..04ba565 --- /dev/null +++ b/Rezoom/CacheInfo.fs @@ -0,0 +1,82 @@ +namespace Rezoom +open System +open System.Reflection +open System.Collections +open System.Collections.Generic + +[] +[] +[] +type BitMask(high : uint64, low : uint64) = + new(low) = BitMask(0UL, low) + member inline private __.High = high + member inline private __.Low = low + + member __.HighBits = high + member __.LowBits = low + + static member Zero = BitMask(0UL, 0UL) + static member Full = BitMask(~~~0UL, ~~~0UL) + + static member BitLength = 128 + + member __.IsZero = 0UL = (high ||| low) + member __.IsFull = ~~~0UL = (high &&& low) + member __.WithBit(bit : int, set : bool) = + if bit < 32 then + BitMask(high, if set then low ||| (1UL <<< bit) else low &&& ~~~(1UL <<< bit)) + else + let bit = bit - 32 + BitMask((if set then high ||| (1UL <<< bit) else low &&& ~~~(1UL <<< bit)), low) + + static member (&&&) (left : BitMask, right : BitMask) = + BitMask(left.High &&& right.High, left.Low &&& right.Low) + static member (|||) (left : BitMask, right : BitMask) = + BitMask(left.High ||| right.High, left.Low ||| right.Low) + static member (^^^) (left : BitMask, right : BitMask) = + BitMask(left.High ^^^ right.High, left.Low ^^^ right.Low) + static member (~~~) (bits : BitMask) = + BitMask(~~~bits.High, ~~~bits.Low) + + override __.ToString() = + high.ToString("X16") + low.ToString("X16") + member __.Equals(other : BitMask) = high = other.High && low = other.Low + override this.Equals(other : obj) = + match other with + | :? BitMask as bm -> this.Equals(bm) + | _ -> false + override __.GetHashCode() = + int high + ^^^ int (high >>> 32) + ^^^ int low + ^^^ int (low >>> 32) + interface IEquatable with + member this.Equals(other) = this.Equals(other) + +[] +type CacheInfo() = + /// A non-null comparable object which identifies the cache that this errand should use. Each category gets its + /// own isolated cache, so results associated with it can't be interfered with by errands from other categories. + /// Typically all errands defined by a library will have the same category because their identities are known + /// not to collide. A good choice for overriding this is `typeof] -type private ServiceContextCache() = - let services = new Dictionary() - let disposalStack = new Stack() - - member __.TryGetService(ty, svc : obj byref) = - services.TryGetValue(ty, &svc) - member __.CacheService(ty, svc : obj) = - services.[ty] <- svc - match svc with - | :? IDisposable as disp -> - disposalStack.Push(disp) - | _ -> () - member __.Dispose() = - let exns = new ResizeArray() - while disposalStack.Count > 0 do - try - disposalStack.Pop().Dispose() - with - | exn -> exns.Add(exn) - if exns.Count > 0 then - raise (aggregate exns) - interface IDisposable with - member this.Dispose() = this.Dispose() - -type DefaultServiceContext(factory : ServiceFactory) = - inherit ServiceContext() - let factory = new DefaultServiceFactory(factory) - let execution = new ServiceContextCache() - let sync = new obj() - let mutable step : ServiceContextCache = null - - member __.BeginStep() = - lock sync <| fun () -> - if not (isNull step) then step.Dispose() - step <- new ServiceContextCache() - member __.EndStep() = - lock sync <| fun () -> - if not (isNull step) then step.Dispose() - step <- null - - member public this.Dispose() = - lock sync <| fun () -> - this.EndStep() - execution.Dispose() - - override this.GetService<'svc>() : 'svc = - lock sync <| fun () -> - let ty = typeof<'svc> - let mutable cached : obj = null - if (not (isNull step) && step.TryGetService(ty, &cached) - || execution.TryGetService(ty, &cached)) then Unchecked.unbox cached else - let living = factory.CreateService<'svc>(this) - if isNull living then - notSupported (sprintf "The service type %O is not supported by the service factory" ty) - else - match living.Lifetime with - | ServiceLifetime.ExecutionLocal -> - execution.CacheService(ty, living.Service) - | ServiceLifetime.StepLocal -> - if isNull step then logicFault "Can't get step-local service outside of a step" - step.CacheService(ty, living.Service) - | _ -> invalidArg "lifetime" (sprintf "Unknown lifetime %d" (int living.Lifetime)) - living.Service - - interface IDisposable with - member this.Dispose() = this.Dispose() - \ No newline at end of file diff --git a/Rezoom/DefaultServiceFactory.fs b/Rezoom/DefaultServiceFactory.fs deleted file mode 100644 index 6abdd44..0000000 --- a/Rezoom/DefaultServiceFactory.fs +++ /dev/null @@ -1,38 +0,0 @@ -namespace Rezoom -open System -open System.Reflection -open System.Reflection.Emit - -type private DefaultServiceConstructor = - static member GetConstructor(ty : Type) = - if not ty.IsConstructedGenericType then null else - let tyDef = ty.GetGenericTypeDefinition() - let stepLocal = tyDef = typedefof> - let execLocal = tyDef = typedefof> - if not stepLocal && not execLocal then null else - let tag = - if stepLocal then ServiceLifetime.StepLocal - else ServiceLifetime.ExecutionLocal - let svcTy = typedefof>.MakeGenericType(ty) - let funcTy = typedefof>.MakeGenericType(svcTy) - let cons = new DynamicMethod("DynamicConstructor", svcTy, Type.EmptyTypes, true) - let il = cons.GetILGenerator() - il.Emit(OpCodes.Ldc_I4, int tag) - il.Emit(OpCodes.Newobj, ty.GetConstructor(Type.EmptyTypes)) - il.Emit(OpCodes.Newobj, svcTy.GetConstructor([|typeof; ty|])) - il.Emit(OpCodes.Ret) - cons.CreateDelegate(funcTy) - -type private DefaultServiceConstructor<'a>() = - static let constr = - downcast DefaultServiceConstructor.GetConstructor(typeof<'a>) - : Func> - static member Constructor = constr - -type private DefaultServiceFactory(userFactory : ServiceFactory) = - inherit ServiceFactory() - override __.CreateService<'svc>(cxt) = - let cons = DefaultServiceConstructor<'svc>.Constructor - if isNull cons then userFactory.CreateService<'svc>(cxt) - else cons.Invoke() - diff --git a/Rezoom/Errand.fs b/Rezoom/Errand.fs index d0b951c..622498a 100644 --- a/Rezoom/Errand.fs +++ b/Rezoom/Errand.fs @@ -1,45 +1,63 @@ namespace Rezoom +open System +open System.Collections.Generic +open System.Threading open System.Threading.Tasks +/// Base class for all errands. [] type Errand() = - abstract member Identity : obj - default __.Identity = null - abstract member DataSource : obj - default __.DataSource = null + /// Specifies caching information for the errand. This can be used for purposes other than + /// caching as well, such as logging and replaying logged/serialized results. + abstract member CacheInfo : CacheInfo + /// A comparable object (or null) that represents the dynamic part of this errand's caching identity. + /// This is separate because typically the rest of the cache info can be static for a given function that produces + /// errands, and the argument can be the only thing that varies for each individual errand. + abstract member CacheArgument : obj + default __.CacheArgument = null + /// A comparable object (or null). + /// Errands with the same non-null sequence group will not be executed concurrently with one another. abstract member SequenceGroup : obj default __.SequenceGroup = null - abstract member Idempotent : bool - default __.Idempotent = false - abstract member Mutation : bool - default __.Mutation = true - abstract member Parallelizable : bool - default __.Parallelizable = false - abstract member InternalPrepare : ServiceContext -> (unit -> obj Task) + /// Given a `ServiceContext` with which to obtain execution-local or step-local shared services, + /// adds the work this errand needs to do to a shared batch, and returns a function that can be called to + /// force execution of the entire batch and return a task that will get this errand's result. + /// Untyped version intended for internal use only. + abstract member PrepareUntyped : ServiceContext -> (CancellationToken -> obj Task) +/// An errand implements an activity that might run in batches or have a cacheable result. +/// A SQL query, an HTTP request, or an FTP operation would all be good candidates to represent as errands. +/// An `Errand<'a>` returns data of type `'a`. [] type Errand<'a>() = inherit Errand() +/// Base class for errands that retrieve their data asynchronously using a `System.Threading.Task`. [] type AsynchronousErrand<'a>() = inherit Errand<'a>() static member private BoxResult(task : 'a Task) = box task.Result - abstract member Prepare : ServiceContext -> (unit -> 'a Task) - override this.InternalPrepare(cxt) : unit -> obj Task = + /// Given a `ServiceContext` with which to obtain execution-local or step-local shared services, + /// adds the work this errand needs to do to a shared batch, and returns a function that can be called to + /// force execution of the entire batch and return a task that will get this errand's result. + abstract member Prepare : ServiceContext -> (CancellationToken -> 'a Task) + override this.PrepareUntyped(cxt) : CancellationToken -> obj Task = let typed = this.Prepare(cxt) - fun () -> - let t = typed() - t.ContinueWith(AsynchronousErrand<'a>.BoxResult, TaskContinuationOptions.ExecuteSynchronously) + fun token -> + (typed token).ContinueWith(AsynchronousErrand<'a>.BoxResult, TaskContinuationOptions.ExecuteSynchronously) +/// Base class for errands that retrieve their data synchronously (i.e. with a plain old function call). [] type SynchronousErrand<'a>() = inherit Errand<'a>() + /// Given a `ServiceContext` with which to obtain execution-local or step-local shared services, + /// adds the work this errand needs to do to a shared batch, and returns a function that can be called to + /// force execution of the entire batch and return this errand's result. abstract member Prepare : ServiceContext -> (unit -> 'a) - override this.InternalPrepare(cxt) : unit -> obj Task = + override this.PrepareUntyped(cxt) : CancellationToken -> obj Task = let sync = this.Prepare(cxt) - fun () -> + fun _ -> Task.FromResult(box (sync())) diff --git a/Rezoom/Execution.fs b/Rezoom/Execution.fs new file mode 100644 index 0000000..28ffebf --- /dev/null +++ b/Rezoom/Execution.fs @@ -0,0 +1,327 @@ +module Rezoom.Execution +open System +open System.Collections +open System.Collections.Generic +open System.Runtime.InteropServices +open System.Threading +open System.Threading.Tasks +open FSharp.Control.Tasks.ContextInsensitive +open Rezoom + +type ExecutionLog() = + abstract member OnBeginStep : unit -> unit + default __.OnBeginStep() = () + abstract member OnEndStep : unit -> unit + default __.OnEndStep() = () + abstract member OnPreparingErrand : Errand -> unit + default __.OnPreparingErrand(_) = () + abstract member OnPreparedErrand : Errand -> unit + default __.OnPreparedErrand(_) = () + +type ConsoleExecutionLog() = + inherit ExecutionLog() + let write str = + Diagnostics.Debug.WriteLine(str) + Console.WriteLine(str) + override __.OnBeginStep() = write "Step {" + override __.OnEndStep() = write "} // end step" + override __.OnPreparingErrand(errand) = + write (" Preparing errand " + string errand) + override __.OnPreparedErrand(errand) = + write (" Prepared errand " + string errand) + +type ExecutionConfig = + { Log : ExecutionLog + ServiceConfig : IServiceConfig + } + static member Default = + { Log = ExecutionLog() + ServiceConfig = { new IServiceConfig with member __.TryGetConfig() = None } + } + +type Dictionary<'k, 'v> with + member this.GetValue(key : 'k, generate : 'k -> 'v) = + let succ, v = this.TryGetValue(key) + if succ then v else + let generated = generate key + this.Add(key, generated) + generated + +[] +[] +[] +type private CacheKey(identity : obj, argument : obj) = + member inline private __.Identity = identity + member inline private __.Argument = argument + member this.Equals(other : CacheKey) = + identity = other.Identity && argument = other.Argument + override this.Equals(other : obj) = + match other with + | :? CacheKey as k -> this.Equals(k) + | _ -> false + override __.GetHashCode() = + let h1 = identity.GetHashCode() + if isNull argument then h1 else + ((h1 <<< 5) + h1) ^^^ argument.GetHashCode() + interface IEquatable with + member this.Equals(other) = this.Equals(other) + +[] +[] +[] +type private CacheValue(generation : int, value : obj) = + member __.Generation = generation + member __.Value = value + +[] +type private CategoryCache(windowSize : int, category : obj) = + let cache = Dictionary() + /// Moving window of dependency bitmasks, indexed by generation % windowSize. + let history = Array.zeroCreate windowSize : BitMask array + /// Generation of invalidations we're on. + let mutable generation = 0 + /// Pending invalidation mask. + let mutable invalidationMask = BitMask.Full + + new(category) = CategoryCache(16, category) + + member __.Category = category + + member private __.Sweep() = + if invalidationMask.IsFull then () else + let mask = invalidationMask + let latest = generation % windowSize + let mutable i = latest + let mutable sweeping = true + // Go back in time invalidating bits. We can stop when we perform a mask that has no effect, since older + // entries will necessarily have only the same or a subset of the bits of newer entries. + while sweeping && i >= 0 do + let existing = history.[i] + let updated = existing &&& mask + sweeping <- not (existing.Equals(updated)) + history.[i] <- updated + i <- i - 1 + let anySwept = sweeping || i <> latest - 1 + i <- windowSize - 1 + while sweeping && i > latest do + let existing = history.[i] + let masked = existing &&& mask + sweeping <- not (existing.Equals(mask)) + history.[i] <- masked + i <- i - 1 + invalidationMask <- BitMask.Full + if anySwept then + generation <- generation + 1 + history.[generation % windowSize] <- history.[latest] + + member this.Store(info : CacheInfo, arg : obj, result : obj) = + this.Sweep() + cache.[CacheKey(info.Identity, arg)] <- CacheValue(generation, result) + let index = generation % windowSize + history.[index] <- history.[index] ||| info.DependencyMask + + member this.Retrieve(info : CacheInfo, arg : obj) = + this.Sweep() + let mask = info.DependencyMask + // If we're not valid in the current generation, we definitely won't be valid in any older ones. + // This might save us from doing a dictionary lookup with a complex object as the key. + if not <| mask.Equals(mask &&& history.[generation % windowSize]) then None else + let succ, cached = cache.TryGetValue(CacheKey(info.Identity, arg)) + if not succ then None else + if generation - cached.Generation >= windowSize then None else + // Check that all the dependency bits are still 1 + if mask.Equals(mask &&& history.[cached.Generation % windowSize]) then Some cached.Value + else None + + member __.Invalidate(info : CacheInfo) = + invalidationMask <- invalidationMask &&& ~~~info.InvalidationMask + +type private Cache() = + let byCategory = Dictionary() + let sync = obj() + // Remember the last one touched as a shortcut. + let mutable lastCategory = CategoryCache(null) + let getExistingCategory (category : obj) = + if lastCategory.Category = category then lastCategory else + let succ, found = byCategory.TryGetValue(category) + if succ then found else null + let getCategory (category : obj) = + let existing = getExistingCategory category + if isNull existing then + let newCategory = CategoryCache(category) + lastCategory <- newCategory + byCategory.[category] <- newCategory + newCategory + else existing + member __.Invalidate(info : CacheInfo) = + match getExistingCategory info.Category with + | null -> () // nothing to invalidate + | cat -> cat.Invalidate(info) + member __.Retrieve(info : CacheInfo, arg : obj) = + match getExistingCategory info.Category with + | null -> None + | cat -> cat.Retrieve(info, arg) + member __.Store(info : CacheInfo, arg : obj, result : obj) = + lock sync <| fun unit -> // only stores run asynchronously and might need to be thread-safe + let cat = getCategory info.Category + cat.Store(info, arg, result) + +type private Step(log : ExecutionLog, context : ServiceContext, cache : Cache) = + static let defaultGroup _ = ResizeArray() + static let retrievalDeferred () = RetrievalDeferred + let ungrouped = ResizeArray() + let grouped = Dictionary() + let deduped = Dictionary() + let anyCached = ref false + let pending = ResizeArray() + let run (errand : Errand) = + if !anyCached then retrievalDeferred else + try + let mutable result = Unchecked.defaultof<_> + log.OnPreparingErrand(errand) + let prepared = errand.PrepareUntyped(context) + log.OnPreparedErrand(errand) + let retrieve token = + cache.Invalidate(errand.CacheInfo) + task { + try + let! obj = prepared token + result <- RetrievalSuccess obj + cache.Store(errand.CacheInfo, errand.CacheArgument, obj) + with + | exn -> + result <- RetrievalException exn + } + let sequenceGroup = errand.SequenceGroup + match errand.SequenceGroup with + | null -> + ungrouped.Add(retrieve) + | sequenceGroup -> + let group = grouped.GetValue(sequenceGroup, defaultGroup) + group.Add(retrieve) + fun () -> result + with + | exn -> fun () -> RetrievalException exn + + let addToRun (errand : Errand) = + let ran = lazy run errand + pending.Add(ran) + fun () -> ran.Value() + + let addWithDedup (errand : Errand) = + let cacheInfo = errand.CacheInfo + let dedupKey = cacheInfo.Category, cacheInfo.Identity, errand.CacheArgument + let succ, already = deduped.TryGetValue(dedupKey) + if succ then already else + let added = addToRun errand + deduped.Add(dedupKey, added) + added + + member __.AddRequest(errand : Errand) = + let cacheInfo = errand.CacheInfo + if cacheInfo.Cacheable then + match cache.Retrieve(cacheInfo, errand.CacheArgument) with + | None -> + addWithDedup errand + | Some cached -> + anyCached := true + fun () -> RetrievalSuccess cached + else + addToRun errand + + member __.Execute(token) = + for i = 0 to pending.Count - 1 do ignore <| pending.[i].Force() + let taskCount = ungrouped.Count + grouped.Count + if taskCount <= 0 then Task.CompletedTask else + let all = Array.zeroCreate taskCount + let mutable i = 0 + for group in grouped.Values do + all.[i] <- + task { + for sub in group do + do! sub token + } :> Task + i <- i + 1 + for ungrouped in ungrouped do + all.[i] <- upcast ungrouped token + i <- i + 1 + Task.WhenAll(all) + +type private ExecutionServiceContext(config : IServiceConfig) = + inherit ServiceContext() + let services = Dictionary() + let locals = Stack<_>() + let globals = Stack<_>() + let mutable totalSuccess = false + override __.Configuration = config + override this.GetService<'f, 'a when 'f :> ServiceFactory<'a> and 'f : (new : unit -> 'f)>() = + let ty = typeof<'f> + let succ, service = services.TryGetValue(ty) + if succ then Unchecked.unbox service else + let factory = new 'f() + let service = factory.CreateService(this) + let stack = + match factory.ServiceLifetime with + | ServiceLifetime.ExecutionLocal -> globals + | ServiceLifetime.StepLocal -> locals + | other -> failwithf "Unknown service lifetime: %O" other + services.Add(ty, box service) + stack.Push(fun state -> + factory.DisposeService(state, service) + ignore <| services.Remove(ty)) + service + static member private ClearStack(stack : _ Stack, state) = + let mutable exn = null + while stack.Count > 0 do + let disposer = stack.Pop() + try + disposer state + with + | e -> + if isNull exn then exn <- e + else exn <- AggregateException(exn, e) + if not (isNull exn) then raise exn + member __.ClearLocals(state) = ExecutionServiceContext.ClearStack(locals, state) + member __.SetSuccessful() = totalSuccess <- true + member this.Dispose() = + let state = if totalSuccess then ExecutionSuccess else ExecutionFault + try + this.ClearLocals(state) + finally + ExecutionServiceContext.ClearStack(globals, state) + interface IDisposable with + member this.Dispose() = this.Dispose() + +let executeWithCancellation (token : CancellationToken) (config : ExecutionConfig) (plan : 'a Plan) = + task { + let log = config.Log + let cache = Cache() + use context = new ExecutionServiceContext(config.ServiceConfig) + let mutable planState = plan() + let mutable looping = true + let mutable returned = Unchecked.defaultof<_> + while looping do + match planState with + | Result r -> + looping <- false + returned <- r + | Step (requests, resume) -> + log.OnBeginStep() + let mutable stepState = ExecutionFault + try + let step = Step(log, context, cache) + let retrievals = requests.Map(step.AddRequest) + do! step.Execute(token).ConfigureAwait(continueOnCapturedContext = true) + planState <- resume <| retrievals.Map((|>) ()) + stepState <- ExecutionSuccess + finally + context.ClearLocals(stepState) + log.OnEndStep() + context.SetSuccessful() // if we got this far we can dispose with success (commit) + return returned + } + +let execute (config : ExecutionConfig) (plan : 'a Plan) = + let token = CancellationToken() + executeWithCancellation token config plan + \ No newline at end of file diff --git a/Rezoom/Plan.fs b/Rezoom/Plan.fs index c76bfbe..66555f5 100644 --- a/Rezoom/Plan.fs +++ b/Rezoom/Plan.fs @@ -1,20 +1,18 @@ namespace Rezoom type DataResponse = - | RetrievalSuccess of obj - | RetrievalException of exn + /// The errand ran and produced a result. + | RetrievalSuccess of result : obj + /// The errand failed with an exception. + | RetrievalException of exn : exn + /// The errand has not yet been run. + | RetrievalDeferred type Batch<'a> = | BatchLeaf of 'a | BatchPair of ('a Batch * 'a Batch) | BatchMany of ('a Batch array) | BatchAbort - member this.MapCS(f : System.Func<'a, 'b>) = - match this with - | BatchLeaf x -> BatchLeaf (f.Invoke(x)) - | BatchPair (l, r) -> BatchPair (l.MapCS(f), r.MapCS(f)) - | BatchMany arr -> BatchMany (arr |> Array.map (fun b -> b.MapCS(f))) - | BatchAbort -> BatchAbort member this.Map(f : 'a -> 'b) = match this with | BatchLeaf x -> BatchLeaf (f x) @@ -25,11 +23,13 @@ type Batch<'a> = type Requests = Errand Batch type Responses = DataResponse Batch -type Step<'result> = Requests * (Responses -> Plan<'result>) -and Plan<'result> = +type Step<'result> = Requests * (Responses -> PlanState<'result>) +and PlanState<'result> = | Result of 'result | Step of Step<'result> +type Plan<'result> = unit -> PlanState<'result> + /// Hint that it is OK to batch the given sequence or task type BatchHint<'a> = internal | BatchHint of 'a diff --git a/Rezoom/PlanBuilder.fs b/Rezoom/PlanBuilder.fs index 24b8498..ecb5e04 100644 --- a/Rezoom/PlanBuilder.fs +++ b/Rezoom/PlanBuilder.fs @@ -8,36 +8,33 @@ open System.Threading.Tasks type PlanBuilder() = member inline __.Zero() : unit Plan = zero - member inline __.Return(value) = ret value + member inline __.Return(value : 'a) : 'a Plan = ret value - member inline __.ReturnFrom(task : _ Plan) = task - member inline __.ReturnFrom(task : _ Step) = task + member inline __.ReturnFrom(plan : 'a Plan) : 'a Plan = plan - member inline __.Bind(task, cont) = bind task cont + member inline __.Combine(plan : unit Plan, cont : 'b Plan) : 'b Plan = combine plan cont + member inline __.Bind(plan : 'a Plan, cont : 'a -> 'b Plan) : 'b Plan = bind plan cont member inline __.Bind((a, b), cont) = bind (tuple2 a b) cont member inline __.Bind((a, b, c), cont) = bind (tuple3 a b c) cont member inline __.Bind((a, b, c, d), cont) = bind (tuple4 a b c d) cont - member inline __.Delay(task : unit -> _ Plan) = task + member inline __.Delay(delayed : unit -> 'a Plan) : 'a Plan = fun () -> delayed () () + member inline __.Run(plan : 'a Plan) : 'a Plan = plan - member inline __.Run(task : unit -> _ Plan) = task() - - member inline __.For(sequence : #seq<'a>, iteration : 'a -> unit Plan) = + member inline __.For(sequence : #seq<'a>, iteration : 'a -> unit Plan) : unit Plan = forM sequence iteration - member __.For(BatchHint (sequence : #seq<'a>), iteration : 'a -> unit Plan) = + member __.For(BatchHint (sequence : #seq<'a>), iteration : 'a -> unit Plan) : unit Plan = forA sequence iteration - member inline __.Using(disposable : #IDisposable, body) = + member inline __.Using(disposable : #IDisposable, body : #IDisposable -> 'a Plan) : 'a Plan = let dispose () = match disposable with | null -> () | d -> d.Dispose() - tryFinally (fun () -> body disposable) dispose - - member inline __.TryFinally(task, onExit) = tryFinally task onExit - member inline __.TryWith(task, onExit) = tryCatch task onExit + tryFinally (fun () -> body disposable ()) dispose - member inline __.Combine(task, cont) = bind task (fun _ -> cont()) + member inline __.TryFinally(body : 'a Plan, onExit : unit -> unit) : 'a Plan = tryFinally body onExit + member inline __.TryWith(body : 'a Plan, onExn : exn -> 'a Plan) : 'a Plan = tryCatch body onExn let plan = new PlanBuilder() diff --git a/Rezoom/PlanModule.fs b/Rezoom/PlanModule.fs index 6181d54..15a5dca 100644 --- a/Rezoom/PlanModule.fs +++ b/Rezoom/PlanModule.fs @@ -21,8 +21,8 @@ let internal abortSteps (steps : 'a Step seq) (reason : exn) : 'b = if exns.Count > 1 then raise (aggregate exns) else dispatchRaise reason -let internal abortTask (task : 'a Plan) (reason : exn) : 'b = - match task with +let internal abortTask (state : 'a PlanState) (reason : exn) : 'b = + match state with | Step (_, resume) -> try ignore <| resume BatchAbort @@ -40,7 +40,7 @@ let internal abortTask (task : 'a Plan) (reason : exn) : 'b = /// Monadic return for `Plan`s. /// Creates a `Plan` with no steps, whose immediate result is `result`. -let inline ret (result : 'a) = Result result +let ret (result : 'a) : Plan<'a> = fun () -> Result result /// Monoidal identity for `Plan`. /// Equivalent to `ret ()`. @@ -48,14 +48,17 @@ let zero = ret () /// Convert an `Errand<'a>` to a `Plan<'a>`. let ofErrand (request : Errand<'a>) : Plan<'a> = - let onResponse = + let rec onResponse = function - | BatchLeaf (RetrievalSuccess suc) -> ret (Unchecked.unbox suc : 'a) + | BatchLeaf RetrievalDeferred -> step + | BatchLeaf (RetrievalSuccess suc) -> Result (Unchecked.unbox suc : 'a) | BatchLeaf (RetrievalException exn) -> dispatchRaise exn | BatchAbort -> abort() | BatchPair _ | BatchMany _ -> logicFault "Incorrect response shape for data request" - Step (BatchLeaf (request :> Errand), onResponse) + and step : PlanState<'a> = + Step (BatchLeaf (request :> Errand), onResponse) + fun () -> step //////////////////////////////////////////////////////////// // Mapping of plain-old functions over `Plan`s. @@ -63,15 +66,15 @@ let ofErrand (request : Errand<'a>) : Plan<'a> = // This lets you transform the eventual values produced by the task. //////////////////////////////////////////////////////////// -let inline _mapInline map f task = - match task with +let inline _mapInline map f plan = + match plan with | Result r -> Result (f r) | Step s -> Step (map f s) let rec _mapRecursive (f : 'a -> 'b) ((pending, resume) : 'a Step) : 'b Step = pending, fun responses -> _mapInline _mapRecursive f (resume responses) /// Map a function over the result of a `Plan<'a>`, producing a new `Plan<'b>`. -and inline map (f : 'a -> 'b) (task : 'a Plan): 'b Plan = - _mapInline _mapRecursive f task +and inline map (f : 'a -> 'b) (plan : 'a Plan): 'b Plan = + fun () -> _mapInline _mapRecursive f (plan()) //////////////////////////////////////////////////////////// // Monadic `bind`. @@ -80,18 +83,31 @@ and inline map (f : 'a -> 'b) (task : 'a Plan): 'b Plan = // first task is necessary to decide what to do as the next task. //////////////////////////////////////////////////////////// -let inline _bindInline bind task cont = - match task with - | Result r -> cont r +let inline _bindInline bind plan cont = + match plan with + | Result r -> cont r () | Step (pending, resume) -> Step (pending, fun responses -> bind (resume responses) cont) -let rec _bindRecursive task cont = - _bindInline _bindRecursive task cont +let rec _bindRecursive plan cont = + _bindInline _bindRecursive plan cont /// Chain a continuation `Plan` onto an existing `Plan` to /// get a new `Plan`. /// The continuation can be dependent on the result of the first task. -let inline bind (task : 'a Plan) (cont : 'a -> 'b Plan) : 'b Plan = - _bindInline (_bindInline _bindRecursive) task cont +let inline bind (plan : 'a Plan) (cont : 'a -> 'b Plan) : 'b Plan = + fun () -> _bindInline (_bindInline _bindRecursive) (plan()) cont + +let inline _combineInline bind plan cont = + match plan with + | Result _ -> cont () + | Step (pending, resume) -> + Step (pending, fun responses -> bind (resume responses) cont) +let rec _combineRecursive plan cont = + _combineInline _combineRecursive plan cont +/// Chain a continuation `Plan` onto an existing `Plan` to +/// get a new `Plan`. +/// The continuation can be dependent on the result of the first task. +let inline combine (plan : 'a Plan) (cont : 'b Plan) : 'b Plan = + fun () -> _combineInline (_combineInline _combineRecursive) (plan()) cont //////////////////////////////////////////////////////////// // Applicative functor `apply`. @@ -101,49 +117,56 @@ let inline bind (task : 'a Plan) (cont : 'a -> 'b Plan) : 'b Plan = // concurrently and share batchable resources. //////////////////////////////////////////////////////////// -/// Create a task that will eventually apply the function produced by -/// `taskF` to the value produced by `taskA` to obtain its result. -/// The two tasks are independent, so they will execute concurrently and -/// share batchable resources. -let rec apply (taskF : Plan<'a -> 'b>) (taskA : Plan<'a>) : Plan<'b> = +let inline private next2 taskF taskA proceed = + let mutable exnF : exn = null + let mutable exnA : exn = null + let mutable resF : PlanState<'a -> 'b> = Unchecked.defaultof<_> + let mutable resA : PlanState<'a> = Unchecked.defaultof<_> + try + resF <- taskF() + with + | exn -> + exnF <- exn + try + resA <- taskA() + with + | exn -> exnA <- exn + if isNull exnF && isNull exnA then + proceed resF resA + else if not (isNull exnF) && not (isNull exnA) then + raise (new AggregateException(exnF, exnA)) + else if isNull exnF then + abortTask resF exnA + else + abortTask resA exnF + +let rec private applyState (taskF : PlanState<'a -> 'b>) (taskA : PlanState<'a>) : PlanState<'b> = match taskF, taskA with | Result f, Result a -> Result (f a) | Result f, step -> - map ((<|) f) step + _mapInline _mapRecursive ((<|) f) step | step, Result a -> - map ((|>) a) step + _mapInline _mapRecursive ((|>) a) step | Step (pendingF, resumeF), Step (pendingA, resumeA) -> + let pending = BatchPair (pendingF, pendingA) let onResponses = function | BatchPair (rspF, rspA) -> - let mutable exnF : exn = null - let mutable exnA : exn = null - let mutable resF : Plan<'a -> 'b> = Unchecked.defaultof<_> - let mutable resA : Plan<'a> = Unchecked.defaultof<_> - try - resF <- resumeF rspF - with - | exn -> - exnF <- exn - try - resA <- resumeA rspA - with - | exn -> exnA <- exn - if isNull exnF && isNull exnA then - apply resF resA - else if not (isNull exnF) && not (isNull exnA) then - raise (new AggregateException(exnF, exnA)) - else if isNull exnF then - abortTask resF exnA - else - abortTask resA exnF + next2 (fun () -> resumeF rspF) (fun () -> resumeA rspA) applyState | BatchAbort -> abort() | BatchLeaf _ | BatchMany _ -> logicFault "Incorrect response shape for applied pair" Step (pending, onResponses) +/// Create a task that will eventually apply the function produced by +/// `taskF` to the value produced by `taskA` to obtain its result. +/// The two tasks are independent, so they will execute concurrently and +/// share batchable resources. +let apply (taskF : Plan<'a -> 'b>) (taskA : Plan<'a>) : Plan<'b> = + fun () -> next2 taskF taskA applyState + /// Create a task that runs `taskA` and `taskB` concurrently and combines their results into a tuple. let tuple2 (taskA : 'a Plan) (taskB : 'b Plan) : ('a * 'b) Plan = apply @@ -183,46 +206,48 @@ let tuple4 /// during execution of `wrapped`, whether it's in creating the `Plan` /// to be run or in executing any step of the resulting task. /// The exception handler may rethrow the exception. -let rec tryCatch (wrapped : unit -> 'a Plan) (catcher : exn -> 'a Plan) = - try - match wrapped() with - | Result _ as result -> result - | Step (pending, resume) -> - let onResponses (responses : Responses) = - tryCatch (fun () -> resume responses) catcher - Step (pending, onResponses) - with - | PlanAbortException _ -> reraise() // don't let them catch these - | ex -> catcher(ex) - -/// Wrap a `Plan<'a>` with a block that must execute. -/// When the task is executed, the function `onExit` will be called -/// after `wrapped` completes, regardless of whether the task -/// succeeded, failed to be created, or failed while partially executed. -let rec tryFinally (wrapped : unit -> 'a Plan) (onExit : unit -> unit) = - let mutable cleanExit = false - let task = +let rec tryCatch (wrapped : 'a Plan) (catcher : exn -> 'a Plan) : 'a Plan = + fun () -> try match wrapped() with - | Result _ as result -> - cleanExit <- true - result + | Result _ as result -> result | Step (pending, resume) -> let onResponses (responses : Responses) = - tryFinally (fun () -> resume responses) onExit + tryCatch (fun () -> resume responses) catcher () Step (pending, onResponses) with - | ex -> + | PlanAbortException _ -> reraise() // don't let them catch these + | ex -> catcher ex () + +/// Wrap a `Plan<'a>` with a block that must execute. +/// When the task is executed, the function `onExit` will be called +/// after `wrapped` completes, regardless of whether the task +/// succeeded, failed to be created, or failed while partially executed. +let rec tryFinally (wrapped : 'a Plan) (onExit : unit -> unit) : 'a Plan = + fun () -> + let mutable cleanExit = false + let task = try - onExit() + match wrapped() with + | Result _ as result -> + cleanExit <- true + result + | Step (pending, resume) -> + let onResponses (responses : Responses) = + tryFinally (fun () -> resume responses) onExit () + Step (pending, onResponses) with - | inner -> - raise (aggregate [|ex; inner|]) - reraise() - if cleanExit then - // run outside of the try/catch so we don't risk recursion - onExit() - task + | ex -> + try + onExit() + with + | inner -> + raise (aggregate [|ex; inner|]) + reraise() + if cleanExit then + // run outside of the try/catch so we don't risk recursion + onExit() + task //////////////////////////////////////////////////////////// // Looping. @@ -242,45 +267,48 @@ let rec private forIterator (enumerator : 'a IEnumerator) (iteration : 'a -> uni /// Monadic iteration. /// Create a task that lazily iterates a sequence, executing `iteration` for each element. -let forM (sequence : 'a seq) (iteration : 'a -> unit Plan) = - let enumerator = sequence.GetEnumerator() - tryFinally - (fun () -> forIterator enumerator iteration) - (fun () -> enumerator.Dispose()) +let forM (sequence : 'a seq) (iteration : 'a -> unit Plan) : unit Plan = + fun () -> + let enumerator = sequence.GetEnumerator() + tryFinally + (fun () -> forIterator enumerator iteration ()) + (fun () -> enumerator.Dispose()) + () -let rec private forAs (tasks : (unit -> unit Plan) seq) : unit Plan = - let steps = - let steps = new ResizeArray<_>() - let exns = new ResizeArray<_>() - for task in tasks do - try - match task() with - | Step step -> steps.Add(step) - | Result _ -> () - with - | exn -> exns.Add(exn) - if exns.Count > 0 then abortSteps steps (aggregate exns) - else steps - if steps.Count <= 0 then zero - else - let pending = - let arr = Array.zeroCreate steps.Count - for i = 0 to steps.Count - 1 do - arr.[i] <- fst steps.[i] - BatchMany arr - let onResponses = - function - | BatchMany responses -> - responses - |> Seq.mapi (fun i rsp () -> snd steps.[i] rsp) - |> forAs - | BatchAbort -> abort() - | BatchPair _ - | BatchLeaf _ -> logicFault "Incorrect response shape for applicative batch" - Step (pending, onResponses) +let rec private forAs (tasks : (unit Plan) seq) : unit Plan = + fun () -> + let steps = + let steps = new ResizeArray<_>() + let exns = new ResizeArray<_>() + for task in tasks do + try + match task() with + | Step step -> steps.Add(step) + | Result _ -> () + with + | exn -> exns.Add(exn) + if exns.Count > 0 then abortSteps steps (aggregate exns) + else steps + if steps.Count <= 0 then Result () + else + let pending = + let arr = Array.zeroCreate steps.Count + for i = 0 to steps.Count - 1 do + arr.[i] <- fst steps.[i] + BatchMany arr + let onResponses = + function + | BatchMany responses -> + responses + |> Seq.mapi (fun i rsp () -> snd steps.[i] rsp) + |> forAs <| () + | BatchAbort -> abort() + | BatchPair _ + | BatchLeaf _ -> logicFault "Incorrect response shape for applicative batch" + Step (pending, onResponses) /// Applicative iteration. /// Create a task that strictly iterates a sequence, creating a `Plan` for each element /// using the given `iteration` function, then runs those tasks concurrently. -let forA (sequence : 'a seq) (iteration : 'a -> unit Plan) = - forAs (sequence |> Seq.map (fun element () -> iteration element)) +let forA (sequence : 'a seq) (iteration : 'a -> unit Plan) : unit Plan = + forAs (sequence |> Seq.map (fun element -> iteration element)) diff --git a/Rezoom/ResponseCache.fs b/Rezoom/ResponseCache.fs deleted file mode 100644 index ab33d36..0000000 --- a/Rezoom/ResponseCache.fs +++ /dev/null @@ -1,33 +0,0 @@ -namespace Rezoom -open System.Collections.Generic - -type private Response = obj -type private DataSource = obj -type private Identity = obj - -type ResponseCache() = - let nullDataSource = new Dictionary() - let byDataSource = new Dictionary>() - - member __.Invalidate(dataSource : DataSource) = - if isNull dataSource then nullDataSource.Clear() - else ignore <| byDataSource.Remove(dataSource) - - member __.Store(dataSource : DataSource, identity : Identity, value : Response) = - if isNull dataSource then - nullDataSource.[identity] <- value - else - let mutable subCache : Dictionary = null - if not <| byDataSource.TryGetValue(dataSource, &subCache) then - subCache <- new _() - byDataSource.[dataSource] <- subCache - subCache.[identity] <- value - - member __.TryGetValue(dataSource : DataSource, identity : Identity, value : Response byref) = - if isNull dataSource then - nullDataSource.TryGetValue(identity, &value) - else - let mutable subCache : Dictionary = null - if byDataSource.TryGetValue(dataSource, &subCache) then - subCache.TryGetValue(identity, &value) - else false \ No newline at end of file diff --git a/Rezoom/Rezoom.fsproj b/Rezoom/Rezoom.fsproj index 652b550..fb92816 100644 --- a/Rezoom/Rezoom.fsproj +++ b/Rezoom/Rezoom.fsproj @@ -46,17 +46,17 @@ - - - - - + + + + TaskBuilder.fs + + - 11 diff --git a/Rezoom/ServiceFactories.fs b/Rezoom/ServiceFactories.fs deleted file mode 100644 index af99181..0000000 --- a/Rezoom/ServiceFactories.fs +++ /dev/null @@ -1,27 +0,0 @@ -namespace Rezoom -open System -open System.Reflection -open System.Reflection.Emit - -type ZeroServiceFactory() = - inherit ServiceFactory() - override __.CreateService(_) = null - -[] -type SingleServiceFactory<'a>() = - inherit ServiceFactory() - abstract Create : ServiceContext -> 'a LivingService - override this.CreateService<'svc>(cxt) = - if obj.ReferenceEquals(typeof<'a>, typeof<'svc>) then - this.Create(cxt) |> box |> Unchecked.unbox - : 'svc LivingService - else null - -type CoalescingServiceFactory - (main : ServiceFactory, fallback : ServiceFactory) = - inherit ServiceFactory() - override __.CreateService<'svc>(cxt) = - let main = main.CreateService<'svc>(cxt) - if isNull main then - fallback.CreateService<'svc>(cxt) - else main diff --git a/Rezoom/ServiceFactory.fs b/Rezoom/ServiceFactory.fs deleted file mode 100644 index 65f07c9..0000000 --- a/Rezoom/ServiceFactory.fs +++ /dev/null @@ -1,29 +0,0 @@ -namespace Rezoom - -type ServiceLifetime = - | ExecutionLocal = 1 - | StepLocal = 2 - -[] -type LivingService<'svc> = - class - val public Lifetime : ServiceLifetime - val public Service : 'svc - new (life, svc) = { Lifetime = life; Service = svc } - end - -[] -type ServiceContext() = - abstract member GetService<'svc> : unit -> 'svc - -[] -type ServiceFactory() = - abstract member CreateService : ServiceContext -> LivingService<'svc> - -type StepLocal<'a when 'a : (new : unit -> 'a)>() = - let a = new 'a() - member __.Service = a - -type ExecutionLocal<'a when 'a : (new : unit -> 'a)>() = - let a = new 'a() - member __.Service = a diff --git a/Rezoom/Services.fs b/Rezoom/Services.fs new file mode 100644 index 0000000..3837a7d --- /dev/null +++ b/Rezoom/Services.fs @@ -0,0 +1,54 @@ +namespace Rezoom +open System +open System.Collections.Generic + +type ServiceLifetime = + | ExecutionLocal = 1 + | StepLocal = 2 + +type ExecutionState = + | ExecutionFault + | ExecutionSuccess + +type IServiceConfig = + abstract member TryGetConfig<'cfg> : unit -> 'cfg option + +type ServiceConfig() = + let configs = Dictionary() + member this.SetConfiguration(cfg : 'cfg) = + let ty = typeof<'cfg> + configs.[ty] <- box cfg + this + interface IServiceConfig with + member __.TryGetConfig<'cfg>() = + let ty = typeof<'cfg> + let succ, config = configs.TryGetValue(ty) + if succ then Some (Unchecked.unbox config : 'cfg) + else None + +[] +type ServiceFactory<'a>() = + abstract member CreateService : ServiceContext -> 'a + abstract member DisposeService : ExecutionState * 'a -> unit + abstract member ServiceLifetime : ServiceLifetime +and [] ServiceContext() = + abstract member Configuration : IServiceConfig + abstract member GetService<'f, 'a when 'f :> ServiceFactory<'a> and 'f : (new : unit -> 'f)> : unit -> 'a + +type StepLocal<'a when 'a : (new : unit -> 'a)>() = + inherit ServiceFactory<'a>() + override __.ServiceLifetime = ServiceLifetime.StepLocal + override __.CreateService(_) = new 'a() + override __.DisposeService(_, s) = + match box s with + | :? IDisposable as d -> d.Dispose() + | _ -> () + +type ExecutionLocal<'a when 'a : (new : unit -> 'a)>() = + inherit ServiceFactory<'a>() + override __.ServiceLifetime = ServiceLifetime.ExecutionLocal + override __.CreateService(_) = new 'a() + override __.DisposeService(_, s) = + match box s with + | :? IDisposable as d -> d.Dispose() + | _ -> () \ No newline at end of file diff --git a/Shared/TaskBuilder.fs b/Shared/TaskBuilder.fs new file mode 100644 index 0000000..c8eef2e --- /dev/null +++ b/Shared/TaskBuilder.fs @@ -0,0 +1,258 @@ +// TaskBuilder.fs - TPL task computation expressions for F# +// +// Written in 2016 by Robert Peele (humbobst@gmail.com) +// +// To the extent possible under law, the author(s) have dedicated all copyright and related and neighboring rights +// to this software to the public domain worldwide. This software is distributed without any warranty. +// +// You should have received a copy of the CC0 Public Domain Dedication along with this software. +// If not, see . + +namespace FSharp.Control.Tasks +open System +open System.Threading.Tasks +open System.Runtime.CompilerServices + +// This module is not really obsolete, but it's not intended to be referenced directly from user code. +// However, it can't be private because it is used within inline functions that *are* user-visible. +// Marking it as obsolete is a workaround to hide it from auto-completion tools. +[] +module TaskBuilder = + /// Represents the state of a computation: + /// either awaiting something with a continuation, + /// or completed with a return value. + type Step<'a> = + | Await of ICriticalNotifyCompletion * (unit -> Step<'a>) + | Return of 'a + /// Implements the machinery of running a `Step<'m, 'm>` as a `Task<'m>`. + and StepStateMachine<'a>(firstStep) as this = + let methodBuilder = AsyncTaskMethodBuilder<'a>() + /// The continuation we left off awaiting on our last MoveNext(). + let mutable continuation = fun () -> firstStep + /// Return true if we should call `AwaitOnCompleted` on the current awaitable. + let nextAwaitable() = + try + match continuation() with + | Return r -> + methodBuilder.SetResult(r) + null + | Await (await, next) -> + continuation <- next + await + with + | exn -> + methodBuilder.SetException(exn) + null + let mutable self = this + + /// Start execution as a `Task<'m>`. + member __.Run() = + methodBuilder.Start(&self) + methodBuilder.Task + + interface IAsyncStateMachine with + /// Proceed to one of three states: result, failure, or awaiting. + /// If awaiting, MoveNext() will be called again when the awaitable completes. + member __.MoveNext() = + let mutable await = nextAwaitable() + if not (isNull await) then + // Tell the builder to call us again when this thing is done. + methodBuilder.AwaitUnsafeOnCompleted(&await, &self) + member __.SetStateMachine(_) = () // Doesn't really apply since we're a reference type. + + /// Used to represent no-ops like the implicit empty "else" branch of an "if" expression. + let inline zero() = Return () + + /// Used to return a value. + let inline ret (x : 'a) = Return x + + // The following flavors of `bind` are for sequencing tasks with the continuations + // that should run following them. They all follow pretty much the same formula. + + let inline bindTask (task : 'a Task) (continuation : 'a -> Step<'b>) = + let taskAwaiter = task.GetAwaiter() + if taskAwaiter.IsCompleted then // Proceed to the next step based on the result we already have. + taskAwaiter.GetResult() |> continuation + else // Await and continue later when a result is available. + Await (taskAwaiter, (fun () -> taskAwaiter.GetResult() |> continuation)) + + let inline bindVoidTask (task : Task) (continuation : unit -> Step<'b>) = + let taskAwaiter = task.GetAwaiter() + if taskAwaiter.IsCompleted then continuation() else + Await (taskAwaiter, continuation) + + let inline bindConfiguredTask (task : 'a ConfiguredTaskAwaitable) (continuation : 'a -> Step<'b>) = + let taskAwaiter = task.GetAwaiter() + if taskAwaiter.IsCompleted then + taskAwaiter.GetResult() |> continuation + else + Await (taskAwaiter, (fun () -> taskAwaiter.GetResult() |> continuation)) + + let inline bindVoidConfiguredTask (task : ConfiguredTaskAwaitable) (continuation : unit -> Step<'b>) = + let taskAwaiter = task.GetAwaiter() + if taskAwaiter.IsCompleted then continuation() else + Await (taskAwaiter, continuation) + + let inline + bindGenericAwaitable< ^a, ^b, ^c when ^a : (member GetAwaiter : unit -> ^b) and ^b :> ICriticalNotifyCompletion > + (awt : ^a) (continuation : unit -> Step< ^c >) = + let taskAwaiter = (^a : (member GetAwaiter : unit -> ^b)(awt)) + Await (taskAwaiter, continuation) + + /// Chains together a step with its following step. + /// Note that this requires that the first step has no result. + /// This prevents constructs like `task { return 1; return 2; }`. + let rec combine (step : Step) (continuation : unit -> Step<'b>) = + match step with + | Return _ -> continuation() + | Await (awaitable, next) -> + Await(awaitable, fun () -> combine (next()) continuation) + + /// Builds a step that executes the body while the condition predicate is true. + let inline whileLoop (cond : unit -> bool) (body : unit -> Step) = + if cond() then + // Create a self-referencing closure to test whether to repeat the loop on future iterations. + let rec repeat () = + if cond() then combine (body()) repeat + else zero() + // Run the body the first time and chain it to the repeat logic. + combine (body()) repeat + else zero() + + /// Wraps a step in a try/with. This catches exceptions both in the evaluation of the function + /// to retrieve the step, and in the continuation of the step (if any). + let rec tryWith(step : unit -> Step<'a>) (catch : exn -> Step<'a>) = + try + match step() with + | Return _ as i -> i + | Await (awaitable, next) -> Await (awaitable, fun () -> tryWith next catch) + with + | exn -> catch exn + + /// Wraps a step in a try/finally. This catches exceptions both in the evaluation of the function + /// to retrieve the step, and in the continuation of the step (if any). + let rec tryFinally (step : unit -> Step<'a>) fin = + let step = + try step() + // Important point: we use a try/with, not a try/finally, to implement tryFinally. + // The reason for this is that if we're just building a continuation, we definitely *shouldn't* + // execute the `fin()` part yet -- the actual execution of the asynchronous code hasn't completed! + with + | _ -> + fin() + reraise() + match step with + | Return _ as i -> + fin() + i + | Await (awaitable, next) -> + Await (awaitable, fun () -> tryFinally next fin) + + /// Implements a using statement that disposes `disp` after `body` has completed. + let inline using (disp : #IDisposable) (body : _ -> Step<'a>) = + // A using statement is just a try/finally with the finally block disposing if non-null. + tryFinally + (fun () -> body disp) + (fun () -> if not (isNull (box disp)) then disp.Dispose()) + + /// Implements a loop that runs `body` for each element in `sequence`. + let forLoop (sequence : 'a seq) (body : 'a -> Step) = + // A for loop is just a using statement on the sequence's enumerator... + using (sequence.GetEnumerator()) + // ... and its body is a while loop that advances the enumerator and runs the body on each element. + (fun e -> whileLoop e.MoveNext (fun () -> body e.Current)) + + /// Runs a step as a task -- with a short-circuit for immediately completed steps. + let inline run (firstStep : unit -> Step<'a>) = + try + match firstStep() with + | Return x -> Task.FromResult(x) + | Await _ as step -> + StepStateMachine<'a>(step).Run() + // Any exceptions should go on the task, rather than being thrown from this call. + // This matches C# behavior where you won't see an exception until awaiting the task, + // even if it failed before reaching the first "await". + with + | exn -> Task.FromException<_>(exn) + + type UnitTask = + struct + val public Task : Task + new(task) = { Task = task } + end + + type TaskBuilder() = + // These methods are consistent between the two builders. + // Unfortunately, inline members do not work with inheritance. + member inline __.Delay(f : unit -> Step<_>) = f + member inline __.Run(f : unit -> Step<'m>) = run f + member inline __.Zero() = zero() + member inline __.Return(x) = ret x + member inline __.ReturnFrom(task) = bindConfiguredTask task ret + member inline __.ReturnFrom(task) = bindVoidConfiguredTask task ret + member inline __.ReturnFrom(yld : YieldAwaitable) = bindGenericAwaitable yld ret + member inline __.Combine(step, continuation) = combine step continuation + member inline __.Bind(task, continuation) = bindConfiguredTask task continuation + member inline __.Bind(task, continuation) = bindVoidConfiguredTask task continuation + member inline __.Bind(yld : YieldAwaitable, continuation) = bindGenericAwaitable yld continuation + member inline __.While(condition, body) = whileLoop condition body + member inline __.For(sequence, body) = forLoop sequence body + member inline __.TryWith(body, catch) = tryWith body catch + member inline __.TryFinally(body, fin) = tryFinally body fin + member inline __.Using(disp, body) = using disp body + // End of consistent methods -- the following methods are different between + // `TaskBuilder` and `ContextInsensitiveTaskBuilder`! + + member inline __.ReturnFrom(task : _ Task) = + bindTask task ret + member inline __.ReturnFrom(task : UnitTask) = + bindVoidTask task.Task ret + member inline __.Bind(task : _ Task, continuation) = + bindTask task continuation + member inline __.Bind(task : UnitTask, continuation) = + bindVoidTask task.Task continuation + + type ContextInsensitiveTaskBuilder() = + // These methods are consistent between the two builders. + // Unfortunately, inline members do not work with inheritance. + member inline __.Delay(f : unit -> Step<_>) = f + member inline __.Run(f : unit -> Step<'m>) = run f + member inline __.Zero() = zero() + member inline __.Return(x) = ret x + member inline __.ReturnFrom(task) = bindConfiguredTask task ret + member inline __.ReturnFrom(task) = bindVoidConfiguredTask task ret + member inline __.ReturnFrom(yld : YieldAwaitable) = bindGenericAwaitable yld ret + member inline __.Combine(step, continuation) = combine step continuation + member inline __.Bind(task, continuation) = bindConfiguredTask task continuation + member inline __.Bind(task, continuation) = bindVoidConfiguredTask task continuation + member inline __.Bind(yld : YieldAwaitable, continuation) = bindGenericAwaitable yld continuation + member inline __.While(condition, body) = whileLoop condition body + member inline __.For(sequence, body) = forLoop sequence body + member inline __.TryWith(body, catch) = tryWith body catch + member inline __.TryFinally(body, fin) = tryFinally body fin + member inline __.Using(disp, body) = using disp body + // End of consistent methods -- the following methods are different between + // `TaskBuilder` and `ContextInsensitiveTaskBuilder`! + + member inline __.ReturnFrom(task : _ Task) = + bindConfiguredTask (task.ConfigureAwait(continueOnCapturedContext = false)) ret + member inline __.Bind(task : _ Task, continuation) = + bindConfiguredTask (task.ConfigureAwait(continueOnCapturedContext = false)) continuation + +// Don't warn about our use of the "obsolete" module we just defined (see notes at start of file). +#nowarn "44" + +[] +module ContextSensitive = + /// Builds a `System.Threading.Tasks.Task<'a>` similarly to a C# async/await method. + /// Use this like `task { let! taskResult = someTask(); return taskResult.ToString(); }`. + let task = TaskBuilder.TaskBuilder() + let inline unitTask t = TaskBuilder.UnitTask(t) + +module ContextInsensitive = + /// Builds a `System.Threading.Tasks.Task<'a>` similarly to a C# async/await method, but with + /// all awaited tasks automatically configured *not* to resume on the captured context. + /// This is often preferable when writing library code that is not context-aware, but undesirable when writing + /// e.g. code that must interact with user interface controls on the same thread as its caller. + let task = TaskBuilder.ContextInsensitiveTaskBuilder() + let inline unitTask (t : Task) = t.ConfigureAwait(false)