diff --git a/docs/api-reference/expression-estimator.md b/docs/api-reference/expression-estimator.md
new file mode 100644
index 0000000000..34b79b1349
--- /dev/null
+++ b/docs/api-reference/expression-estimator.md
@@ -0,0 +1,181 @@
+### The Expression Language
+
+The language for the expression estimator should be comfortable to a broad range of users. It shares many similarities with some popular languages.
+It is case sensitive, supports multiple types and has a rich set of operators and functions. It is pure functional, in the sense that there are no
+mutable values or mutating operations in the language. It does not have, nor need, any exception mechanism, instead producing NA values when a normal
+value is not appropriate. It is statically typed, but all types are inferred by the compiler.
+
+## Syntax
+
+Syntax for the lambda consists of a parameter list followed by either colon (:) or arrow (=>) followed by an expression.
+The parameter list can be either a single identifier or a comma-separated list of one or more identifiers surrounded by parentheses.
+
+_lambda:_
+
+- _parameter-list **:** expression_
+- _parameter-list **=>** expression_
+
+_parameter-list:_
+
+- _identifier_
+- **(** _parameter-names_ **)**
+
+_parameter-names:_
+
+- _identifier_
+- _identifier **,** parameter-names_
+
+The expression can use parameters, literals, operators, with-expressions, and functions.
+
+## Literals
+
+- The boolean literals are true and false.
+- Integer literals may be decimal or hexadecimal (e.g., 0x1234ABCD). They can be suffixed with u or U, indicating unsigned,
+as well as l or L, indicating long (Int64). The use of u or U is rare and only affects promotion of certain 32 bit hexadecimal values,
+determining whether the constant is considered a negative Int32 value or a positive Int64 value.
+- Floating point literals use the standard syntax, including exponential notation (123.45e-37). They can be suffixed with
+f or F, indicating single precision, or d or D, indicating double precision. Unlike in C#, the default precision of a
+floating point literal is single precision. To specify double precision, append d or D.
+- Text literals are enclosed in double-quotation marks and support the standard escape mechanisms.
+
+## Operators
+
+The operators of the expression language are listed in the following table, in precendence order. Unless otherwise noted,
+binary operators are left associative and propagate NA values (if either operand value is NA, the result is NA). Generally,
+overflow of integer values produces NA, while overflow of floating point values produces infinity.
+
+| **Operator** | **Meaning** | **Arity**| **Comments** |
+| --- | --- | ---| --- |
+| ? : | conditional | Ternary | The expression condition ? value1 : value2 resolves to value1 if condition is true and to value2 if condition is false. The condition must be boolean, while value1 and value2 must be of compatible type. |
+| ?? | coalesce | Binary | The expression x ?? y resolves to x if x is not NA, and resolves to y otherwise. The operands must be both Singles, or both Doubles. This operator is right associative. |
+| \| \| or | logical or | Binary | The operands and result are boolean. If one operand is true, the result is true, otherwise it is false. |
+| && and | logical and | Binary | The operands and result are boolean. If one operand is false, the result is false, otherwise, it is true. |
+| ==, =
!=, <>
<, <=
>, >= | equals
not equals
less than or equal to
greater than or equal to | Multiple |- The comparison operators are multi-arity, meaning they can be applied to two or more operands. For example, a == b == c results in true if a, b, and c have the same value. The not equal operator requires that all of the operands be distinct, so 1 != 2 != 1 is false. To test whether x is non-negative but less than 10, use 0 <= x < 10. There is no need to write 0 <= x && x < 10, and doing so will not perform as well. Operators listed on the same line can be combined in a single expression, so a > b >= c is legal, but a < b >= c is not.
- Equals and not equals apply to any operand type, while the ordered operators require numeric operands. |
+| + - | addition and subtraction | Binary | Numeric addition and subtraction with NA propagation. |
+| \* / % | multiplication, division, and modulus | Binary | Numeric multiplication, division, and modulus with NA propagation. |
+| - ! not | numeric negation and logical not | Unary | These are unary prefix operators, negation (-) requiring a numeric operand, and not (!) requiring a boolean operand. |
+| ^ | power | Binary | This is right associative exponentiation. It requires numeric operands. For integer operands, 0^0 produce 1.|
+| ( ) | parenthetical grouping | Unary | Standard meaning. |
+
+## The With Expression
+
+The syntax for the with-expression is:
+
+_with-expression:_
+
+- **with** **(** assignment-list **;** expression **)**
+
+_assignment-list:_
+
+- assignment
+- assignment **,** assignment-list
+
+_assignment:_
+
+- identifier **=** expression
+
+The with-expression introduces one or more named values. For example, the following expression converts a celcius temperature to fahrenheit, then produces a message based on whether the fahrenheit is too low or high.
+```
+c => with(f = c * 9 / 5 + 32 ; f < 60 ? "Too Cold!" : f > 90 ? "Too Hot!" : "Just Right!")
+```
+The expression for one assignment may reference the identifiers introduced by previous assignments, as in this example that returns -1, 0, or 1 instead of the messages:
+```
+c : with(f = c * 9 / 5 + 32, cold = f < 60, hot = f > 90 ; -float(cold) + float(hot))
+```
+As demonstrated above, the with-expression is useful when an expression value is needed multiple times in a larger expression. It is also useful when dealing with complicated or significant constants:
+```
+ ticks => with(
+ ticksPerSecond = 10000000L,
+ ticksPerHour = ticksPerSecond \* 3600,
+ ticksPerDay = ticksPerHour \* 24,
+ day = ticks / ticksPerDay,
+ dayEpoch = 1 ;
+ (day + dayEpoch) % 7)
+```
+This computes the day of the week from the number of ticks (as an Int64) since the standard .Net DateTime epoch (01/01/0001
+in the idealized Gregorian calendar). Assignments are used for number of ticks in a second, number of ticks in an hour,
+number of ticks in a year, and the day of the week for the epoch. For this example, we want to map Sunday to zero, so,
+since the epoch is a Monday, we set dayEpoch to 1. If the epoch were changed or we wanted to map a different day of the week to zero,
+we'd simply change dayEpoch. Note that ticksPerSecond is defined as 10000000L, to make it an Int64 value (8 byte integer).
+Without the L suffix, ticksPerDay will overflow Int32's range.
+
+## Functions
+
+The expression transform supports many useful functions.
+
+General unary functions that can accept an operand of any type are listed in the following table.
+
+| **Name** | **Meaning** | **Comments** |
+| --- | --- | --- |
+| isna | test for na | Returns a boolean value indicating whether the operand is an NA value. |
+| na | the na value | Returns the NA value of the same type as the operand (either float or double). Note that this does not evaluate the operand, it only uses the operand to determine the type of NA to return, and that determination happens at compile time. |
+| default | the default value | Returns the default value of the same type as the operand. For example, to map NA values to default values, use x ?? default(x). Note that this does not evaluate the operand, it only uses the operand to determine the type of default value to return, and that determination happens at compile time. For numeric types, the default is zero. For boolean, the default is false. For text, the default is empty. |
+
+The unary conversion functions are listed in the following table. An NA operand produces an NA, or throws if the type does not support it.
+A conversion that doesn't succeed, or overflow also result in NA or an exception. The most common case of this is when converting from text,
+which uses the standard conversion parsing. When converting from a floating point value (float or double) to an integer value
+(Int32 or Int64), the conversion does a truncate operation (round toward zero).
+
+| **Name** | **Meaning** | **Comments** |
+| --- | --- | --- |
+| bool | convert to BL | The operand must be text or boolean. |
+| int | convert to I4 | The input may be of any type. |
+| long | convert to I8 | The input may be of any type. |
+| single, float | convert to R4 | The input may be of any type. |
+| double | convert to R8 | The input may be of any type. |
+| text | convert to TX | The input may be of any type. This produces a default text representation. |
+
+The unary functions that require a numeric operand are listed in the following table. The result type is the same as the operand type. An NA operand value produces NA.
+
+| **Name** | **Meaning** | **Comments** |
+| --- | --- | --- |
+| abs | absolute value | Produces the absolute value of the operand. |
+| sign | sign (-1, 0, 1) | Produces -1, 0, or 1 depending on whether the operand is negative, zero, or positive. |
+
+The binary functions that require numeric operands are listed in the following table. When the operand types aren't the same,
+the operands are promoted to an appropriate type. The result type is the same as the promoted operand type. An NA operand value produces NA.
+
+| **Name** | **Meaning** | **Comments** |
+| --- | --- | --- |
+| min | minimum | Produces the minimum of the operands. |
+| max | maximum | Produces the maximum of the operands. |
+
+The unary functions that require a floating point operand are listed in the following table. The result type is the same as the operand type. Overflow produces infinity. Invalid input values produce NA.
+
+| **Name** | **Meaning** | **Comments** |
+| --- | --- | --- |
+| sqrt | square root | Negative operands produce NA. |
+| trunc, truncate | truncate to an integer | Rounds toward zero to the nearest integer value. |
+| floor | floor | Rounds toward negative infinity to the nearest integer value. |
+| ceil, ceiling | ceiling | Rounds toward positive infinity to the nearest integer value. |
+| round | unbiased rounding | Rounds to the nearest integer value. When the operand is half way between two integer values, this produces the even integer. |
+| exp | exponential | Raises e to the operand. |
+| ln, log | logarithm | Produces the natural (base e) logarithm. There is also a two operand version of log for using a different base. |
+| deg, degrees | radians to degrees | Maps from radians to degrees. |
+| rad, radians | degrees to radians | Maps from degrees to radians. |
+| sin, sind | sine | Takes the sine of an angle. The sin function assumes the operand is in radians, while the sind function assumes that the operand is in degrees. |
+| cos, cosd | cosine | Takes the cosine of an angle. The cos function assumes the operand is in radians, while the cosd function assumes that the operand is in degrees. |
+| tan, tand | tangent | Takes the tangent of an angle. The tan function assumes the operand is in radians, while the tand function assumes that the operand is in degrees. |
+| sinh | hyperbolic sine | Takes the hyperbolic sine of its operand. |
+| cosh | hyperbolic cosine | Takes the hyperbolic cosine of its operand. |
+| tanh | hyperbolic tangent | Takes the hyperbolic tangent of its operand. |
+| asin | inverse sine | Takes the inverse sine of its operand. |
+| acos | inverse cosine | Takes the inverse cosine of its operand. |
+| atan | inverse tangent | Takes the inverse tangent of its operand. |
+
+The binary functions that require floating point operands are listed in the following table. When the operand types aren't the same, the operands are promoted to an appropriate type. The result type is the same as the promoted operand type. An NA operand value produces NA.
+
+| **Name** | **Meaning** | **Comments** |
+| --- | --- | --- |
+| log | logarithm with given base | The second operand is the base. The first is the value to take the logarithm of. |
+| atan2, atanyx | determine angle | Determines the angle between -pi and pi from the given y and x values. Note that y is the first operand. |
+
+The text functions are listed in the following table.
+
+| **Name** | **Meaning** | **Comments** |
+| --- | --- | --- |
+| len(x) | length of text | The operand must be text. The result is an I4 indicating the length of the operand. If the operand is NA, the result is NA. |
+| lower(x), upper(x) | lower or upper case | Maps the text to lower or upper case. |
+| left(x, k), right(x, k) | substring | The first operand must be text and the second operand must be Int32. If the second operand is negative it is treated as an offset from the end of the text. This adjusted index is then clamped to 0 to len(x). The result is the characters to the left or right of the resulting position. |
+| mid(x, a, b) | substring | The first operand must be text and the other two operands must be Int32. The indices are transformed in the same way as for the left and right functions: negative values are treated as offsets from the end of the text; these adjusted indices are clamped to 0 to len(x). The second clamped index is also clamped below to the first clamped index. The result is the characters between these two clamped indices. |
+| concat(x1, x2, ..., xn) | concatenation | This accepts an arbitrary number of operands (including zero). All operands must be text. The result is the concatenation of all the operands, in order. |
\ No newline at end of file
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Expression.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Expression.cs
new file mode 100644
index 0000000000..193a65e86e
--- /dev/null
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Expression.cs
@@ -0,0 +1,104 @@
+using System;
+using System.Collections.Generic;
+using Microsoft.ML;
+using Microsoft.ML.Data;
+
+namespace Samples.Dynamic.Transforms
+{
+ public static class Expression
+ {
+ public static void Example()
+ {
+ // Create a new ML context, for ML.NET operations. It can be used for
+ // exception tracking and logging, as well as the source of randomness.
+ var mlContext = new MLContext();
+
+ // Create a small dataset as an IEnumerable.
+ var samples = new List()
+ {
+ new InputData(0.5f, new[] { 1f, 0.2f }, 3, "hi", true, new[] { "zero", "one" }),
+ new InputData(-2.7f, new[] { 3.5f, -0.1f }, 2, "bye", false, new[] { "a", "b" }),
+ new InputData(1.3f, new[] { 1.9f, 3.3f }, 39, "hi", false, new[] { "0", "1" }),
+ new InputData(3, new[] { 3f, 3f }, 4, "hello", true, new[] { "c", "d" }),
+ new InputData(0, new[] { 1f, 1f }, 1, "hi", true, new[] { "zero", "one" }),
+ new InputData(30.4f, new[] { 10f, 4f }, 9, "bye", true, new[] { "e", "f" }),
+ new InputData(5.6f, new[] { 1.1f, 2.2f }, 0, "hey", false, new[] { "g", "h" }),
+ };
+
+ // Convert training data to IDataView.
+ var dataview = mlContext.Data.LoadFromEnumerable(samples);
+
+ // A pipeline that applies various expressions to the input columns.
+ var pipeline = mlContext.Transforms.Expression("Expr1", "(x,y)=>log(y)+x",
+ nameof(InputData.FloatColumn), nameof(InputData.FloatVectorColumn))
+ .Append(mlContext.Transforms.Expression("Expr2", "(b,s,i)=>b ? len(s) : i",
+ nameof(InputData.BooleanColumn), nameof(InputData.StringVectorColumn), nameof(InputData.IntColumn)))
+ .Append(mlContext.Transforms.Expression("Expr3", "(s,f1,f2,i)=>len(concat(s,\"a\"))+f1+f2+i",
+ nameof(InputData.StringColumn), nameof(InputData.FloatVectorColumn), nameof(InputData.FloatColumn), nameof(InputData.IntColumn)))
+ .Append(mlContext.Transforms.Expression("Expr4", "(x,y)=>cos(x+pi())*y",
+ nameof(InputData.FloatColumn), nameof(InputData.IntColumn)));
+
+ // The transformed data.
+ var transformedData = pipeline.Fit(dataview).Transform(dataview);
+
+ // Now let's take a look at what this concatenation did.
+ // We can extract the newly created column as an IEnumerable of
+ // TransformedData.
+ var featuresColumn = mlContext.Data.CreateEnumerable(
+ transformedData, reuseRowObject: false);
+
+ // And we can write out a few rows
+ Console.WriteLine($"Features column obtained post-transformation.");
+ foreach (var featureRow in featuresColumn)
+ {
+ Console.Write(string.Join(" ", featureRow.Expr1));
+ Console.Write(" ");
+ Console.Write(string.Join(" ", featureRow.Expr2));
+ Console.Write(" ");
+ Console.Write(string.Join(" ", featureRow.Expr3));
+ Console.Write(" ");
+ Console.WriteLine(featureRow.Expr4);
+ }
+
+ // Expected output:
+ // Features column obtained post-transformation.
+ // 0.5 - 1.109438 4 3 7.5 6.7 - 2.63274768567112
+ // - 1.447237 NaN 2 2 6.8 3.2 1.80814432479224
+ // 1.941854 2.493922 39 39 45.2 46.6 - 10.4324561082543
+ // 4.098612 4.098612 1 1 16 16 3.95996998640178
+ // 0 0 4 3 5 5 - 1
+ // 32.70258 31.78629 1 1 53.4 47.4 - 4.74149076052604
+ // 5.69531 6.388457 0 0 10.7 11.8 0
+ }
+
+ private class InputData
+ {
+ public float FloatColumn;
+ [VectorType(3)]
+ public float[] FloatVectorColumn;
+ public int IntColumn;
+ public string StringColumn;
+ public bool BooleanColumn;
+ [VectorType(2)]
+ public string[] StringVectorColumn;
+
+ public InputData(float f, float[] fv, int i, string s, bool b, string[] sv)
+ {
+ FloatColumn = f;
+ FloatVectorColumn = fv;
+ IntColumn = i;
+ StringColumn = s;
+ BooleanColumn = b;
+ StringVectorColumn = sv;
+ }
+ }
+
+ private sealed class TransformedData
+ {
+ public float[] Expr1 { get; set; }
+ public int[] Expr2 { get; set; }
+ public float[] Expr3 { get; set; }
+ public double Expr4 { get; set; }
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Transforms/Expression/BuiltinFunctions.cs b/src/Microsoft.ML.Transforms/Expression/BuiltinFunctions.cs
new file mode 100644
index 0000000000..02935718b3
--- /dev/null
+++ b/src/Microsoft.ML.Transforms/Expression/BuiltinFunctions.cs
@@ -0,0 +1,1095 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+#pragma warning disable 420 // volatile with Interlocked.CompareExchange
+
+using System;
+using System.Globalization;
+using System.Linq;
+using System.Reflection;
+using System.Runtime.CompilerServices;
+using System.Text;
+using System.Threading;
+using Microsoft.ML.Data;
+using Microsoft.ML.Data.Conversion;
+using Microsoft.ML.Internal.Utilities;
+using Microsoft.ML.Runtime;
+
+namespace Microsoft.ML.Transforms
+{
+ using BL = System.Boolean;
+ using I4 = System.Int32;
+ using I8 = System.Int64;
+ using R4 = Single;
+ using R8 = Double;
+ using TX = ReadOnlyMemory;
+
+ internal static class FunctionProviderUtils
+ {
+ ///
+ /// Returns whether the given object is non-null and an NA value for one of the standard types.
+ ///
+ public static bool IsNA(object v)
+ {
+ if (v == null)
+ return false;
+ Type type = v.GetType();
+ if (type == typeof(R4))
+ return R4.IsNaN((R4)v);
+ if (type == typeof(R8))
+ return R8.IsNaN((R8)v);
+ Contracts.Assert(type == typeof(BL) || type == typeof(I4) || type == typeof(I8) || type == typeof(TX),
+ "Unexpected constant value type!");
+ return false;
+ }
+
+ ///
+ /// Returns the standard NA value for the given standard type.
+ ///
+ public static object GetNA(Type type)
+ {
+ if (type == typeof(R4))
+ return R4.NaN;
+ if (type == typeof(R8))
+ return R8.NaN;
+ Contracts.Assert(false, "Unexpected constant value type!");
+ return null;
+ }
+
+ ///
+ /// Helper method to bundle one or more MethodInfos into an array.
+ ///
+ public static MethodInfo[] Ret(params MethodInfo[] funcs)
+ {
+ Contracts.AssertValue(funcs);
+ return funcs;
+ }
+
+ ///
+ /// Returns the MethodInfo for the given delegate.
+ ///
+ public static MethodInfo Fn(Func fn)
+ {
+ Contracts.AssertValue(fn);
+ Contracts.Assert(fn.Target == null);
+ return fn.GetMethodInfo();
+ }
+
+ ///
+ /// Returns the MethodInfo for the given delegate.
+ ///
+ public static MethodInfo Fn(Func fn)
+ {
+ Contracts.AssertValue(fn);
+ Contracts.Assert(fn.Target == null);
+ return fn.GetMethodInfo();
+ }
+
+ ///
+ /// Returns the MethodInfo for the given delegate.
+ ///
+ public static MethodInfo Fn(Func fn)
+ {
+ Contracts.AssertValue(fn);
+ Contracts.Assert(fn.Target == null);
+ return fn.GetMethodInfo();
+ }
+
+ ///
+ /// Returns the MethodInfo for the given delegate.
+ ///
+ public static MethodInfo Fn(Func fn)
+ {
+ Contracts.AssertValue(fn);
+ Contracts.Assert(fn.Target == null);
+ return fn.GetMethodInfo();
+ }
+ }
+
+ ///
+ /// The standard builtin functions for ExprTransform.
+ ///
+ internal sealed class BuiltinFunctions : IFunctionProvider
+ {
+ private static volatile BuiltinFunctions _instance;
+ public static BuiltinFunctions Instance
+ {
+ get
+ {
+ if (_instance == null)
+ Interlocked.CompareExchange(ref _instance, new BuiltinFunctions(), null);
+ return _instance;
+ }
+ }
+
+ public string NameSpace { get { return "global"; } }
+
+ ///
+ /// Returns the MethodInfo for
+ ///
+ private static MethodInfo Id()
+ {
+ Action fn = Id;
+ Contracts.Assert(fn.Target == null);
+ return fn.GetMethodInfo();
+ }
+
+ // This is an "identity" function.
+ private static void Id(T src) { }
+
+ public MethodInfo[] Lookup(string name)
+ {
+ switch (name)
+ {
+ case "pi":
+ return FunctionProviderUtils.Ret(FunctionProviderUtils.Fn(Pi));
+
+ case "na":
+ return FunctionProviderUtils.Ret(
+ FunctionProviderUtils.Fn(NA),
+ FunctionProviderUtils.Fn(NA));
+ case "default":
+ return FunctionProviderUtils.Ret(
+ FunctionProviderUtils.Fn(Default),
+ FunctionProviderUtils.Fn(Default),
+ FunctionProviderUtils.Fn(Default),
+ FunctionProviderUtils.Fn(Default),
+ FunctionProviderUtils.Fn(Default),
+ FunctionProviderUtils.Fn(Default));
+
+ case "abs":
+ return FunctionProviderUtils.Ret(
+ FunctionProviderUtils.Fn(Math.Abs),
+ FunctionProviderUtils.Fn(Math.Abs),
+ FunctionProviderUtils.Fn(Math.Abs),
+ FunctionProviderUtils.Fn(Math.Abs));
+ case "sign":
+ return FunctionProviderUtils.Ret(
+ FunctionProviderUtils.Fn(Sign),
+ FunctionProviderUtils.Fn(Sign),
+ FunctionProviderUtils.Fn(Sign),
+ FunctionProviderUtils.Fn(Sign));
+ case "exp":
+ return FunctionProviderUtils.Ret(
+ FunctionProviderUtils.Fn(Exp),
+ FunctionProviderUtils.Fn(Math.Exp));
+ case "ln":
+ return FunctionProviderUtils.Ret(
+ FunctionProviderUtils.Fn(Log),
+ FunctionProviderUtils.Fn(Math.Log));
+ case "log":
+ return FunctionProviderUtils.Ret(
+ FunctionProviderUtils.Fn(Log),
+ FunctionProviderUtils.Fn(Math.Log),
+ FunctionProviderUtils.Fn(Log),
+ FunctionProviderUtils.Fn(Math.Log));
+
+ case "deg":
+ case "degrees":
+ return FunctionProviderUtils.Ret(
+ FunctionProviderUtils.Fn(Deg),
+ FunctionProviderUtils.Fn(Deg));
+ case "rad":
+ case "radians":
+ return FunctionProviderUtils.Ret(
+ FunctionProviderUtils.Fn(Rad),
+ FunctionProviderUtils.Fn(Rad));
+
+ case "sin":
+ return FunctionProviderUtils.Ret(
+ FunctionProviderUtils.Fn(Sin),
+ FunctionProviderUtils.Fn(Sin));
+ case "sind":
+ return FunctionProviderUtils.Ret(
+ FunctionProviderUtils.Fn(SinD),
+ FunctionProviderUtils.Fn(SinD));
+ case "cos":
+ return FunctionProviderUtils.Ret(
+ FunctionProviderUtils.Fn(Cos),
+ FunctionProviderUtils.Fn(Cos));
+ case "cosd":
+ return FunctionProviderUtils.Ret(
+ FunctionProviderUtils.Fn(CosD),
+ FunctionProviderUtils.Fn(CosD));
+ case "tan":
+ return FunctionProviderUtils.Ret(
+ FunctionProviderUtils.Fn(Tan),
+ FunctionProviderUtils.Fn(Math.Tan));
+ case "tand":
+ return FunctionProviderUtils.Ret(
+ FunctionProviderUtils.Fn(TanD),
+ FunctionProviderUtils.Fn(TanD));
+
+ case "asin":
+ return FunctionProviderUtils.Ret(
+ FunctionProviderUtils.Fn(Asin),
+ FunctionProviderUtils.Fn(Math.Asin));
+ case "acos":
+ return FunctionProviderUtils.Ret(
+ FunctionProviderUtils.Fn(Acos),
+ FunctionProviderUtils.Fn(Math.Acos));
+ case "atan":
+ return FunctionProviderUtils.Ret(
+ FunctionProviderUtils.Fn(Atan),
+ FunctionProviderUtils.Fn(Math.Atan));
+ case "atan2":
+ case "atanyx":
+ return FunctionProviderUtils.Ret(
+ FunctionProviderUtils.Fn(Atan2),
+ FunctionProviderUtils.Fn(Atan2));
+
+ case "sinh":
+ return FunctionProviderUtils.Ret(
+ FunctionProviderUtils.Fn(Sinh),
+ FunctionProviderUtils.Fn(Math.Sinh));
+ case "cosh":
+ return FunctionProviderUtils.Ret(
+ FunctionProviderUtils.Fn(Cosh),
+ FunctionProviderUtils.Fn(Math.Cosh));
+ case "tanh":
+ return FunctionProviderUtils.Ret(
+ FunctionProviderUtils.Fn(Tanh),
+ FunctionProviderUtils.Fn(Math.Tanh));
+
+ case "sqrt":
+ return FunctionProviderUtils.Ret(
+ FunctionProviderUtils.Fn(Sqrt),
+ FunctionProviderUtils.Fn(Math.Sqrt));
+
+ case "trunc":
+ case "truncate":
+ return FunctionProviderUtils.Ret(
+ FunctionProviderUtils.Fn(Truncate),
+ FunctionProviderUtils.Fn(Math.Truncate));
+ case "floor":
+ return FunctionProviderUtils.Ret(
+ FunctionProviderUtils.Fn(Floor),
+ FunctionProviderUtils.Fn(Math.Floor));
+ case "ceil":
+ case "ceiling":
+ return FunctionProviderUtils.Ret(
+ FunctionProviderUtils.Fn(Ceiling),
+ FunctionProviderUtils.Fn(Math.Ceiling));
+ case "round":
+ return FunctionProviderUtils.Ret(
+ FunctionProviderUtils.Fn(Round),
+ FunctionProviderUtils.Fn(Math.Round));
+
+ case "min":
+ return FunctionProviderUtils.Ret(
+ FunctionProviderUtils.Fn(Math.Min),
+ FunctionProviderUtils.Fn(Math.Min),
+ FunctionProviderUtils.Fn(Math.Min),
+ FunctionProviderUtils.Fn(Math.Min));
+ case "max":
+ return FunctionProviderUtils.Ret(
+ FunctionProviderUtils.Fn(Math.Max),
+ FunctionProviderUtils.Fn(Math.Max),
+ FunctionProviderUtils.Fn(Math.Max),
+ FunctionProviderUtils.Fn(Math.Max));
+
+ case "len":
+ return FunctionProviderUtils.Ret(FunctionProviderUtils.Fn(Len));
+ case "lower":
+ return FunctionProviderUtils.Ret(FunctionProviderUtils.Fn(Lower));
+ case "upper":
+ return FunctionProviderUtils.Ret(FunctionProviderUtils.Fn(Upper));
+ case "right":
+ return FunctionProviderUtils.Ret(FunctionProviderUtils.Fn(Right));
+ case "left":
+ return FunctionProviderUtils.Ret(FunctionProviderUtils.Fn(Left));
+ case "mid":
+ return FunctionProviderUtils.Ret(FunctionProviderUtils.Fn(Mid));
+
+ case "concat":
+ return FunctionProviderUtils.Ret(
+ FunctionProviderUtils.Fn(Empty),
+ Id(),
+ FunctionProviderUtils.Fn(Concat),
+ FunctionProviderUtils.Fn(Concat),
+ FunctionProviderUtils.Fn(Concat));
+
+ case "isna":
+ return FunctionProviderUtils.Ret(
+ FunctionProviderUtils.Fn(IsNA),
+ FunctionProviderUtils.Fn(IsNA));
+
+ case "bool":
+ return FunctionProviderUtils.Ret(
+ FunctionProviderUtils.Fn(ToBL),
+ Id());
+ case "int":
+ return FunctionProviderUtils.Ret(
+ FunctionProviderUtils.Fn(Convert.ToInt32),
+ FunctionProviderUtils.Fn(Convert.ToInt32),
+ FunctionProviderUtils.Fn(Convert.ToInt32),
+ FunctionProviderUtils.Fn(Convert.ToInt32),
+ FunctionProviderUtils.Fn(ToI4),
+ Id());
+ case "long":
+ return FunctionProviderUtils.Ret(
+ FunctionProviderUtils.Fn(Convert.ToInt64),
+ FunctionProviderUtils.Fn(Convert.ToInt64),
+ FunctionProviderUtils.Fn(Convert.ToInt64),
+ FunctionProviderUtils.Fn(Convert.ToInt64),
+ FunctionProviderUtils.Fn(ToI8),
+ Id());
+ case "float":
+ case "single":
+ return FunctionProviderUtils.Ret(
+ FunctionProviderUtils.Fn(Convert.ToSingle),
+ FunctionProviderUtils.Fn(Convert.ToSingle),
+ FunctionProviderUtils.Fn(ToR4),
+ FunctionProviderUtils.Fn(ToR4),
+ FunctionProviderUtils.Fn(Convert.ToSingle),
+ FunctionProviderUtils.Fn(ToR4));
+ case "double":
+ return FunctionProviderUtils.Ret(
+ FunctionProviderUtils.Fn(Convert.ToDouble),
+ FunctionProviderUtils.Fn(Convert.ToDouble),
+ FunctionProviderUtils.Fn(ToR8),
+ FunctionProviderUtils.Fn(ToR8),
+ FunctionProviderUtils.Fn(Convert.ToDouble),
+ FunctionProviderUtils.Fn(ToR8));
+ case "text":
+ return FunctionProviderUtils.Ret(
+ FunctionProviderUtils.Fn(ToTX),
+ FunctionProviderUtils.Fn(ToTX),
+ FunctionProviderUtils.Fn(ToTX),
+ FunctionProviderUtils.Fn(ToTX),
+ FunctionProviderUtils.Fn(ToTX),
+ Id());
+ }
+
+ return null;
+ }
+
+ public object ResolveToConstant(string name, MethodInfo fn, object[] values)
+ {
+ Contracts.CheckNonEmpty(name, nameof(name));
+ Contracts.CheckValue(fn, nameof(fn));
+ Contracts.CheckParam(Utils.Size(values) > 0, nameof(values), "Expected values to have positive length");
+ Contracts.CheckParam(!values.All(x => x != null), nameof(values), "Expected values to contain at least one null");
+
+ switch (name)
+ {
+ case "na":
+ {
+ Contracts.Assert(values.Length == 1);
+
+ Type type = fn.ReturnType;
+ if (type == typeof(R4))
+ return R4.NaN;
+ if (type == typeof(R8))
+ return R8.NaN;
+ return null;
+ }
+ case "default":
+ {
+ Contracts.Assert(values.Length == 1);
+
+ Type type = fn.ReturnType;
+ if (type == typeof(I4))
+ return default(I4);
+ if (type == typeof(I8))
+ return default(I8);
+ if (type == typeof(R4))
+ return default(R4);
+ if (type == typeof(R8))
+ return default(R8);
+ if (type == typeof(BL))
+ return default(BL);
+ if (type == typeof(TX))
+ return default(TX);
+ Contracts.Assert(false, "Unexpected return type!");
+ return null;
+ }
+ }
+
+ // By default, constant NA arguments produce an NA result. Note that this is not true for isna,
+ // but those functions will get here only if values contains a single null, not an NA.
+ for (int i = 0; i < values.Length; i++)
+ {
+ if (FunctionProviderUtils.IsNA(values[i]))
+ {
+ Contracts.Assert(values.Length > 1);
+ return FunctionProviderUtils.GetNA(fn.ReturnType);
+ }
+ }
+
+ return null;
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static R8 Pi()
+ {
+ return Math.PI;
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static R4 NA(R4 a)
+ {
+ return R4.NaN;
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static R8 NA(R8 a)
+ {
+ return R8.NaN;
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static I4 Default(I4 a)
+ {
+ return default(I4);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static I8 Default(I8 a)
+ {
+ return default(I8);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static R4 Default(R4 a)
+ {
+ return default(R4);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static R8 Default(R8 a)
+ {
+ return default(R8);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static BL Default(BL a)
+ {
+ return default(BL);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static TX Default(TX a)
+ {
+ return default(TX);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static R4 Sign(R4 a)
+ {
+ // Preserves NaN. Unfortunately, it also preserves negative zero,
+ // but perhaps that is a good thing?
+ return a > 0 ? +1 : a < 0 ? -1 : a;
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static R8 Sign(R8 a)
+ {
+ // Preserves NaN. Unfortunately, it also preserves negative zero,
+ // but perhaps that is a good thing?
+ return a > 0 ? +1 : a < 0 ? -1 : a;
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static I4 Sign(I4 a)
+ {
+ return a > 0 ? +1 : a < 0 ? -1 : a;
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static I8 Sign(I8 a)
+ {
+ return a > 0 ? +1 : a < 0 ? -1 : a;
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static R4 Pow(R4 a, R4 b)
+ {
+ return (R4)Math.Pow(a, b);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static R4 Exp(R4 a)
+ {
+ return (R4)Math.Exp(a);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static R4 Log(R4 a)
+ {
+ return (R4)Math.Log(a);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static R4 Log(R4 a, R4 b)
+ {
+ return (R4)Math.Log(a, b);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static R4 Deg(R4 a)
+ {
+ return (R4)(a * (180 / Math.PI));
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static R8 Deg(R8 a)
+ {
+ return a * (180 / Math.PI);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static R4 Rad(R4 a)
+ {
+ return (R4)(a * (Math.PI / 180));
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static R8 Rad(R8 a)
+ {
+ return a * (Math.PI / 180);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static R4 Sin(R4 a)
+ {
+ return (R4)Math.Sin(a);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static R8 Sin(R8 a)
+ {
+ return MathUtils.Sin(a);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static R4 SinD(R4 a)
+ {
+ return (R4)Math.Sin(a * (Math.PI / 180));
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static R8 SinD(R8 a)
+ {
+ return MathUtils.Sin(a * (Math.PI / 180));
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static R4 Cos(R4 a)
+ {
+ return (R4)Math.Cos(a);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static R8 Cos(R8 a)
+ {
+ return MathUtils.Cos(a);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static R4 CosD(R4 a)
+ {
+ return (R4)Math.Cos(a * (Math.PI / 180));
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static R8 CosD(R8 a)
+ {
+ return MathUtils.Cos(a * (Math.PI / 180));
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static R4 Tan(R4 a)
+ {
+ return (R4)Math.Tan(a);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static R4 TanD(R4 a)
+ {
+ return (R4)Math.Tan(a * (Math.PI / 180));
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static R8 TanD(R8 a)
+ {
+ return Math.Tan(a * (Math.PI / 180));
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static R4 Asin(R4 a)
+ {
+ return (R4)Math.Asin(a);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static R4 Acos(R4 a)
+ {
+ return (R4)Math.Acos(a);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static R4 Atan(R4 a)
+ {
+ return (R4)Math.Atan(a);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static R4 Atan2(R4 a, R4 b)
+ {
+ // According to the documentation of Math.Atan2: if x and y are either System.Double.PositiveInfinity
+ // or System.Double.NegativeInfinity, the method returns System.Double.NaN, but this seems to not be the case.
+ if (R4.IsInfinity(a) && R4.IsInfinity(b))
+ return R4.NaN;
+ return (R4)Math.Atan2(a, b);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static R8 Atan2(R8 a, R8 b)
+ {
+ // According to the documentation of Math.Atan2: if x and y are either System.Double.PositiveInfinity
+ // or System.Double.NegativeInfinity, the method returns System.Double.NaN, but this seems to not be the case.
+ if (R8.IsInfinity(a) && R8.IsInfinity(b))
+ return R8.NaN;
+ return Math.Atan2(a, b);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static R4 Sinh(R4 a)
+ {
+ return (R4)Math.Sinh(a);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static R4 Cosh(R4 a)
+ {
+ return (R4)Math.Cosh(a);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static R4 Tanh(R4 a)
+ {
+ return (R4)Math.Tanh(a);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static R4 Sqrt(R4 a)
+ {
+ return (R4)Math.Sqrt(a);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static R4 Truncate(R4 a)
+ {
+ return (R4)Math.Truncate(a);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static R4 Floor(R4 a)
+ {
+ return (R4)Math.Floor(a);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static R4 Ceiling(R4 a)
+ {
+ return (R4)Math.Ceiling(a);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static R4 Round(R4 a)
+ {
+ return (R4)Math.Round(a);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static TX Lower(TX a)
+ {
+ if (a.IsEmpty)
+ return a;
+ var sb = new StringBuilder();
+ ReadOnlyMemoryUtils.AddLowerCaseToStringBuilder(a.Span, sb);
+ return sb.ToString().AsMemory();
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static TX Upper(TX a)
+ {
+ if (a.IsEmpty)
+ return a;
+ var dst = new char[a.Length];
+ a.Span.ToUpperInvariant(dst);
+ return new TX(dst);
+ }
+
+ // Special case some common Concat sizes, for better efficiency.
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static TX Empty()
+ {
+ return TX.Empty;
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static TX Concat(TX a, TX b)
+ {
+ if (a.IsEmpty)
+ return b;
+ if (b.IsEmpty)
+ return a;
+ var dst = new char[a.Length + b.Length];
+ a.Span.CopyTo(dst);
+ b.Span.CopyTo(new Span(dst, a.Length, b.Length));
+ return new TX(dst);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static TX Concat(TX a, TX b, TX c)
+ {
+ var dst = new char[a.Length + b.Length + c.Length];
+ a.Span.CopyTo(dst);
+ b.Span.CopyTo(new Span(dst, a.Length, b.Length));
+ c.Span.CopyTo(new Span(dst, a.Length + b.Length, c.Length));
+ return new TX(dst);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static TX Concat(TX[] a)
+ {
+ Contracts.AssertValue(a);
+
+ int len = 0;
+ for (int i = 0; i < a.Length; i++)
+ len += a[i].Length;
+ if (len == 0)
+ return TX.Empty;
+
+ var sb = new StringBuilder(len);
+ for (int i = 0; i < a.Length; i++)
+ sb.AppendSpan(a[i].Span);
+ return sb.ToString().AsMemory();
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static I4 Len(TX a)
+ {
+ return a.Length;
+ }
+
+ ///
+ /// Given an index meant to index into a given sequence normalize it according to
+ /// these rules: negative indices get added to them, and
+ /// then the index is clamped the range 0 to inclusive,
+ /// and that result is returned. (For those familiar with Python, this is the same
+ /// as the logic for slice normalization.)
+ ///
+ /// The index to normalize
+ /// The length of the sequence
+ /// The normalized version of the index, a non-positive value no greater
+ /// than .
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ private static int NormalizeIndex(int i, int len)
+ {
+ Contracts.Assert(0 <= len);
+ if (i < 0)
+ {
+ if ((i += len) < 0)
+ return 0;
+ }
+ else if (i > len)
+ return len;
+ return i;
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static TX Right(TX a, I4 min)
+ {
+ if (a.IsEmpty)
+ return a;
+ return a.Slice(NormalizeIndex(min, a.Length));
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static TX Left(TX a, I4 lim)
+ {
+ if (a.IsEmpty)
+ return a;
+ return a.Slice(0, NormalizeIndex(lim, a.Length));
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static TX Mid(TX a, I4 min, I4 lim)
+ {
+ if (a.IsEmpty)
+ return a;
+ int im = NormalizeIndex(min, a.Length);
+ int il = NormalizeIndex(lim, a.Length);
+ if (im >= il)
+ return TX.Empty;
+ return a.Slice(im, il - im);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static BL IsNA(R4 a)
+ {
+ return R4.IsNaN(a);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static BL IsNA(R8 a)
+ {
+ return R8.IsNaN(a);
+ }
+
+ public static BL ToBL(TX a)
+ {
+ BL res = default(BL);
+ Conversions.Instance.Convert(in a, ref res);
+ return res;
+ }
+
+ public static I4 ToI4(TX a)
+ {
+ I4 res = default(I4);
+ Conversions.Instance.Convert(in a, ref res);
+ return res;
+ }
+
+ public static I8 ToI8(TX a)
+ {
+ I8 res = default(I8);
+ Conversions.Instance.Convert(in a, ref res);
+ return res;
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static R4 ToR4(R4 a)
+ {
+ // Note that the cast is intentional and NOT a no-op. It forces the JIT
+ // to narrow to R4 when it might be tempted to keep intermediate
+ // computations in larger precision.
+ return (R4)a;
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static R4 ToR4(R8 a)
+ {
+ return (R4)a;
+ }
+
+ public static R4 ToR4(TX a)
+ {
+ R4 res = default(R4);
+ Conversions.Instance.Convert(in a, ref res);
+ return res;
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static R8 ToR8(R4 a)
+ {
+ return (R8)a;
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static R8 ToR8(R8 a)
+ {
+ // Note that the cast is intentional and NOT a no-op. It forces the JIT
+ // to narrow to R4 when it might be tempted to keep intermediate
+ // computations in larger precision.
+ return a;
+ }
+
+ public static R8 ToR8(TX a)
+ {
+ R8 res = default(R8);
+ Conversions.Instance.Convert(in a, ref res);
+ return res;
+ }
+
+ public static TX ToTX(I4 src) => src.ToString().AsMemory();
+ public static TX ToTX(I8 src) => src.ToString().AsMemory();
+ public static TX ToTX(R4 src) => src.ToString("R", CultureInfo.InvariantCulture).AsMemory();
+ public static TX ToTX(R8 src) => src.ToString("G17", CultureInfo.InvariantCulture).AsMemory();
+ public static TX ToTX(BL src)
+ {
+ if (!src)
+ return "0".AsMemory();
+ else
+ return "1".AsMemory();
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static BL Equals(TX first, TX second)
+ {
+ return first.Span.SequenceEqual(second.Span);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static BL NotEquals(TX first, TX second)
+ {
+ return !Equals(first, second);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static byte Not(bool b)
+ {
+ return !b ? (byte)1 : default;
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static BL False()
+ {
+ return false;
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static BL True()
+ {
+ return true;
+ }
+
+ ///
+ /// Raise a to the b power. Special cases:
+ /// * 1^NA => 1
+ /// * NA^0 => 1
+ ///
+ public static I4 Pow(I4 a, I4 b)
+ {
+ if (a == 1)
+ return 1;
+ switch (b)
+ {
+ case 0:
+ return 1;
+ case 1:
+ return a;
+ case 2:
+ return a * a;
+ }
+ if (a == -1)
+ return (b & 1) == 0 ? 1 : -1;
+ if (b < 0)
+ throw Contracts.Except("Cannot raise an integer to a negative power");
+
+ // Since the abs of the base is at least two, the exponent must be less than 31.
+ if (b >= 31)
+ throw Contracts.Except("Cannot raise an integer to a power greater than 30");
+
+ if (a == 0)
+ {
+ if (b == 0)
+ return 1;
+ return 0;
+ }
+
+ bool neg = false;
+ if (a < 0)
+ {
+ a = -a;
+ neg = (b & 1) != 0;
+ }
+ Contracts.Assert(a >= 2);
+
+ // Since the exponent is at least three, the base must be <= 1290.
+ Contracts.Assert(b >= 3);
+ if (a > 1290)
+ throw Contracts.Except($"Base must be at most 1290 when raising to the power of {b}");
+
+ // REVIEW: Should we use a checked context and exception catching like I8 does?
+ ulong u = (ulong)(uint)a;
+ ulong result = 1;
+ for (; ; )
+ {
+ if ((b & 1) != 0 && (result *= u) > I4.MaxValue)
+ throw Contracts.Except("Overflow");
+ b >>= 1;
+ if (b == 0)
+ break;
+ if ((u *= u) > I4.MaxValue)
+ throw Contracts.Except("Overflow");
+ }
+ Contracts.Assert(result <= I4.MaxValue);
+
+ var res = (I4)result;
+ if (neg)
+ res = -res;
+ return res;
+ }
+
+ ///
+ /// Raise a to the b power. Special cases:
+ /// * 1^NA => 1
+ /// * NA^0 => 1
+ ///
+ public static I8 Pow(I8 a, I8 b)
+ {
+ if (a == 1)
+ return 1;
+ switch (b)
+ {
+ case 0:
+ return 1;
+ case 1:
+ return a;
+ case 2:
+ return a * a;
+ }
+ if (a == -1)
+ return (b & 1) == 0 ? 1 : -1;
+ if (b < 0)
+ throw Contracts.Except("Cannot raise an integer to a negative power");
+
+ // Since the abs of the base is at least two, the exponent must be less than 63.
+ if (b >= 63)
+ throw Contracts.Except("Cannot raise an integer to a power greater than 62");
+
+ if (a == 0)
+ {
+ if (b == 0)
+ return 1;
+ return 0;
+ }
+
+ bool neg = false;
+ if (a < 0)
+ {
+ a = -a;
+ neg = (b & 1) != 0;
+ }
+ Contracts.Assert(a >= 2);
+
+ // Since the exponent is at least three, the base must be < 2^21.
+ Contracts.Assert(b >= 3);
+ if (a >= (1L << 21))
+ throw Contracts.Except($"Base must be less than 2^21 when raising to the power of {b}");
+
+ long res = 1;
+ long x = a;
+ // REVIEW: Is the catch too slow in the overflow case?
+ try
+ {
+ checked
+ {
+ for (; ; )
+ {
+ if ((b & 1) != 0)
+ res *= x;
+ b >>= 1;
+ if (b == 0)
+ break;
+ x *= x;
+ }
+ }
+ }
+ catch (OverflowException)
+ {
+ throw Contracts.Except("Overflow");
+ }
+ Contracts.Assert(res > 0);
+
+ if (neg)
+ res = -res;
+ return res;
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Transforms/Expression/CharCursor.cs b/src/Microsoft.ML.Transforms/Expression/CharCursor.cs
new file mode 100644
index 0000000000..79a696b643
--- /dev/null
+++ b/src/Microsoft.ML.Transforms/Expression/CharCursor.cs
@@ -0,0 +1,111 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Runtime;
+
+namespace Microsoft.ML.Transforms
+{
+ internal sealed class CharCursor
+ {
+ private readonly char[] _buffer;
+
+ // The base index for the beginning of the buffer.
+ private int _ichBase;
+
+ // Position within the buffer.
+ private int _ichNext;
+ private int _ichLim;
+ private bool _fNoMore;
+
+ public bool Eof { get; private set; }
+
+ public int IchCur => _ichBase + _ichNext - 1;
+
+ public char ChCur { get; private set; }
+
+ public CharCursor(string text)
+ : this(Contracts.CheckRef(text, nameof(text)).ToCharArray(), text.Length)
+ {
+ }
+
+ public CharCursor(string text, int ichMin, int ichLim)
+ : this(text.ToCharArray(ichMin, ichLim - ichMin), ichLim - ichMin)
+ {
+ }
+
+ private CharCursor(char[] buffer, int ichLimInit)
+ {
+ Contracts.AssertValue(buffer);
+ Contracts.Assert(0 <= ichLimInit && ichLimInit <= buffer.Length);
+
+ _buffer = buffer;
+ _ichBase = 0;
+ _ichNext = 0;
+ _ichLim = ichLimInit;
+ _fNoMore = false;
+ Eof = false;
+ ChNext();
+ }
+
+ // Fetch the next character into _chCur and return it.
+ public char ChNext()
+ {
+ if (Eof)
+ return ChCur;
+
+ if (_ichNext < _ichLim || EnsureMore())
+ {
+ Contracts.Assert(_ichNext < _ichLim);
+ return ChCur = _buffer[_ichNext++];
+ }
+
+ Contracts.Assert(_fNoMore);
+ Eof = true;
+ _ichNext++; // This is so the final IchCur is reported correctly.
+ return ChCur = '\x00';
+ }
+
+ public char ChPeek(int dich)
+ {
+ // If someone is peeking at ich, they should have peeked everything up to ich.
+ Contracts.Assert(0 < dich && dich <= _ichLim - _ichNext + 1);
+
+ int ich = dich + _ichNext - 1;
+ if (ich < _ichLim)
+ return _buffer[ich];
+ if (EnsureMore())
+ {
+ ich = dich + _ichNext - 1;
+ Contracts.Assert(ich < _ichLim);
+ return _buffer[ich];
+ }
+
+ Contracts.Assert(_fNoMore);
+ return '\x00';
+ }
+
+ private bool EnsureMore()
+ {
+ if (_fNoMore)
+ return false;
+
+ if (_ichNext > 0)
+ {
+ int ichDst = 0;
+ int ichSrc = _ichNext;
+ while (ichSrc < _ichLim)
+ _buffer[ichDst++] = _buffer[ichSrc++];
+ _ichBase += _ichNext;
+ _ichNext = 0;
+ _ichLim = ichDst;
+ }
+
+ int ichLim = _ichLim;
+
+ Contracts.Assert(ichLim == _ichLim);
+ _fNoMore = true;
+ return false;
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Transforms/Expression/CodeGen.cs b/src/Microsoft.ML.Transforms/Expression/CodeGen.cs
new file mode 100644
index 0000000000..f5ebef5488
--- /dev/null
+++ b/src/Microsoft.ML.Transforms/Expression/CodeGen.cs
@@ -0,0 +1,1434 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using System.Reflection;
+using System.Reflection.Emit;
+using Microsoft.ML.Internal.Utilities;
+using Microsoft.ML.Runtime;
+
+namespace Microsoft.ML.Transforms
+{
+ using BL = System.Boolean;
+ using I4 = System.Int32;
+ using I8 = System.Int64;
+ using R4 = Single;
+ using R8 = Double;
+ using TX = ReadOnlyMemory;
+
+ internal sealed partial class LambdaCompiler : IDisposable
+ {
+ public const int MaxParams = 16;
+
+ private LambdaNode _top;
+ private Type _delType;
+ private MethodGenerator _meth;
+
+ public static Delegate Compile(out List errors, LambdaNode node)
+ {
+ Contracts.CheckValue(node, nameof(node));
+
+ using (var cmp = new LambdaCompiler(node))
+ return cmp.Do(out errors);
+ }
+
+ private LambdaCompiler(LambdaNode node)
+ {
+ Contracts.AssertValue(node);
+ Contracts.Assert(1 <= node.Vars.Length & node.Vars.Length <= MaxParams);
+ Contracts.Assert(MaxParams <= 16);
+ Contracts.AssertValue(node.ResultType);
+
+ _top = node;
+
+ Type typeFn;
+ switch (node.Vars.Length)
+ {
+ case 1:
+ { typeFn = typeof(Func<,>); break; }
+ case 2:
+ { typeFn = typeof(Func<,,>); break; }
+ case 3:
+ { typeFn = typeof(Func<,,,>); break; }
+ case 4:
+ { typeFn = typeof(Func<,,,,>); break; }
+ case 5:
+ { typeFn = typeof(Func<,,,,,>); break; }
+ default:
+ throw Contracts.Except("Internal error in LambdaCompiler: Maximum number of inputs exceeded.");
+ }
+
+ var types = new Type[node.Vars.Length + 1];
+ foreach (var v in node.Vars)
+ {
+ Contracts.Assert(0 <= v.Index & v.Index < node.Vars.Length);
+ Contracts.Assert(types[v.Index] == null);
+ types[v.Index] = v.Type.RawType;
+ }
+ types[node.Vars.Length] = node.ResultType.RawType;
+ _delType = typeFn.MakeGenericType(types);
+
+ Array.Copy(types, 0, types, 1, types.Length - 1);
+ types[0] = typeof(object);
+
+ _meth = new MethodGenerator("lambda", typeof(Exec), node.ResultType.RawType, types);
+ }
+
+ private Delegate Do(out List errors)
+ {
+ var visitor = new Visitor(_meth);
+ _top.Expr.Accept(visitor);
+
+ errors = visitor.GetErrors();
+ if (errors != null)
+ return null;
+
+ _meth.Il.Ret();
+ return _meth.CreateDelegate(_delType);
+ }
+
+ public void Dispose()
+ {
+ _meth.Dispose();
+ }
+ }
+
+ internal sealed partial class LambdaCompiler
+ {
+ private sealed class Visitor : ExprVisitor
+ {
+ private MethodGenerator _meth;
+ private ILGenerator _gen;
+ private List _errors;
+ private readonly MethodInfo _methGetFalseBL;
+ private readonly MethodInfo _methGetTrueBL;
+
+ private sealed class CachedWithLocal
+ {
+ public readonly WithLocalNode Node;
+ ///
+ /// The IL local containing the computed value.
+ ///
+ public readonly LocalBuilder Value;
+ ///
+ /// The boolean local indicating whether the value has been computed yet.
+ /// If the value is pre-computed, this is null.
+ ///
+ public readonly LocalBuilder Flag;
+
+ public CachedWithLocal(WithLocalNode node, LocalBuilder value, LocalBuilder flag)
+ {
+ Contracts.AssertValue(node);
+ Contracts.AssertValue(value);
+ Contracts.AssertValueOrNull(flag);
+ Node = node;
+ Value = value;
+ Flag = flag;
+ }
+ }
+
+ // The active cached "with" locals. For "with" local values that are aggressively computed, the
+ // corresponding flag local is null. For lazy computed values, the flag indicates whether
+ // the value has been computed and stored yet. Lazy computed values avoid potentially
+ // expensive computation that might not be needed, but result in code bloat since each
+ // use tests the flag, and if false, computes and stores the value.
+ private List _cacheWith;
+
+ public Visitor(MethodGenerator meth)
+ {
+ _meth = meth;
+ _gen = meth.Il;
+
+ Func f = BuiltinFunctions.False;
+ Contracts.Assert(f.Target == null);
+ _methGetFalseBL = f.GetMethodInfo();
+
+ Func t = BuiltinFunctions.True;
+ Contracts.Assert(t.Target == null);
+ _methGetTrueBL = t.GetMethodInfo();
+
+ _cacheWith = new List();
+ }
+
+ public List GetErrors()
+ {
+ return _errors;
+ }
+
+ private void DoConvert(ExprNode node)
+ {
+ Contracts.AssertValue(node);
+ if (!node.NeedsConversion)
+ {
+ // When dealing with floating point, always emit the conversion so
+ // the result is stable.
+ if (node.IsR4)
+ _gen.Conv_R4();
+ else if (node.IsR8)
+ _gen.Conv_R8();
+ return;
+ }
+
+ switch (node.SrcKind)
+ {
+ default:
+ Contracts.Assert(false, "Unexpected src kind in DoConvert");
+ PostError(node, "Internal error in implicit conversion");
+ break;
+
+ case ExprTypeKind.R4:
+ // R4 will only implicitly convert to R8.
+ if (node.IsR8)
+ {
+ _gen.Conv_R8();
+ return;
+ }
+ break;
+ case ExprTypeKind.I4:
+ // I4 can convert to I8, R4, or R8.
+ switch (node.ExprType)
+ {
+ case ExprTypeKind.I8:
+ _gen.Conv_I8();
+ return;
+ case ExprTypeKind.R4:
+ _gen.Conv_R4();
+ return;
+ case ExprTypeKind.R8:
+ _gen.Conv_R8();
+ return;
+ }
+ break;
+ case ExprTypeKind.I8:
+ // I8 will only implicitly convert to R8.
+ if (node.IsR8)
+ {
+ _gen.Conv_R8();
+ return;
+ }
+ break;
+ }
+
+ Contracts.Assert(false, "Unexpected dst kind in DoConvert");
+ PostError(node, "Internal error(2) in implicit conversion");
+ }
+
+ private void PostError(Node node)
+ {
+ Utils.Add(ref _errors, new Error(node.Token, "Code generation error"));
+ }
+
+ private void PostError(Node node, string msg)
+ {
+ Utils.Add(ref _errors, new Error(node.Token, msg));
+ }
+
+ private void PostError(Node node, string msg, params object[] args)
+ {
+ Utils.Add(ref _errors, new Error(node.Token, msg, args));
+ }
+
+ private bool TryUseValue(ExprNode node)
+ {
+ var value = node.ExprValue;
+ if (value == null)
+ return false;
+
+ switch (node.ExprType)
+ {
+ case ExprTypeKind.BL:
+ Contracts.Assert(value is BL);
+ GenBL((BL)value);
+ break;
+ case ExprTypeKind.I4:
+ Contracts.Assert(value is I4);
+ _gen.Ldc_I4((I4)value);
+ break;
+ case ExprTypeKind.I8:
+ Contracts.Assert(value is I8);
+ _gen.Ldc_I8((I8)value);
+ break;
+ case ExprTypeKind.R4:
+ Contracts.Assert(value is R4);
+ _gen.Ldc_R4((R4)value);
+ break;
+ case ExprTypeKind.R8:
+ Contracts.Assert(value is R8);
+ _gen.Ldc_R8((R8)value);
+ break;
+ case ExprTypeKind.TX:
+ {
+ Contracts.Assert(value is TX);
+ TX text = (TX)value;
+ _gen.Ldstr(text.ToString());
+ CallFnc(Exec.ToTX);
+ }
+ break;
+
+ case ExprTypeKind.Error:
+ PostError(node);
+ break;
+
+ default:
+ PostError(node, "Bad ExprType");
+ break;
+ }
+
+ return true;
+ }
+
+ public override void Visit(BoolLitNode node)
+ {
+ Contracts.Assert(node.IsBool);
+ Contracts.Assert(node.ExprValue is BL);
+ GenBL((BL)node.ExprValue);
+ }
+
+ public override void Visit(StrLitNode node)
+ {
+ Contracts.Assert(node.IsTX);
+ Contracts.Assert(node.ExprValue is TX);
+ TX text = (TX)node.ExprValue;
+
+ _gen.Ldstr(text.ToString());
+ CallFnc(Exec.ToTX);
+ }
+
+ public override void Visit(NumLitNode node)
+ {
+ Contracts.Assert(node.IsNumber);
+ var value = node.ExprValue;
+ Contracts.Assert(value != null);
+ switch (node.ExprType)
+ {
+ case ExprTypeKind.I4:
+ Contracts.Assert(value is I4);
+ _gen.Ldc_I4((I4)value);
+ break;
+ case ExprTypeKind.I8:
+ Contracts.Assert(value is I8);
+ _gen.Ldc_I8((I8)value);
+ break;
+ case ExprTypeKind.R4:
+ Contracts.Assert(value is R4);
+ _gen.Ldc_R4((R4)value);
+ break;
+ case ExprTypeKind.R8:
+ Contracts.Assert(value is R8);
+ _gen.Ldc_R8((R8)value);
+ break;
+ default:
+ Contracts.Assert(false, "Bad NumLitNode");
+ PostError(node, "Internal error in numeric literal");
+ break;
+ }
+ }
+
+ public override void Visit(IdentNode node)
+ {
+ if (TryUseValue(node))
+ return;
+
+ Node referent = node.Referent;
+ if (node.Referent == null)
+ {
+ PostError(node, "Unbound name!");
+ return;
+ }
+
+ switch (referent.Kind)
+ {
+ default:
+ PostError(node, "Unbound name!");
+ return;
+
+ case NodeKind.Param:
+ _gen.Ldarg(referent.AsParam.Index + 1);
+ break;
+
+ case NodeKind.WithLocal:
+ var loc = referent.AsWithLocal;
+ Contracts.Assert(loc.Value.ExprValue == null);
+ Contracts.Assert(loc.GenCount >= 0);
+
+ if (loc.UseCount <= 1)
+ {
+ Contracts.Assert(loc.UseCount == 1);
+ Contracts.Assert(loc.Index == -1);
+ loc.GenCount++;
+ loc.Value.Accept(this);
+ }
+ else
+ {
+ Contracts.Assert(0 <= loc.Index & loc.Index < _cacheWith.Count);
+ var cache = _cacheWith[loc.Index];
+ Contracts.Assert(cache.Value != null);
+ if (cache.Flag != null)
+ {
+ // This is a lazy computed value. If we've already computed the value, skip the code
+ // that generates it. If this is the first place that generates it, we don't need to
+ // test the bool - we know it hasn't been computed yet (since we never jump backwards).
+ bool needTest = loc.GenCount > 0;
+ Label labHave = default(Label);
+ if (needTest)
+ {
+ labHave = _gen.DefineLabel();
+ _gen
+ .Ldloc(cache.Flag)
+ .Brtrue(labHave);
+ }
+
+ // Generate the code for the value.
+ loc.GenCount++;
+ loc.Value.Accept(this);
+
+ // Store the value and set the flag indicating that we have it.
+ _gen
+ .Stloc(cache.Value)
+ .Ldc_I4(1)
+ .Stloc(cache.Flag);
+ if (needTest)
+ _gen.MarkLabel(labHave);
+ }
+
+ // Load the value.
+ _gen.Ldloc(cache.Value);
+ }
+ break;
+ }
+ DoConvert(node);
+ }
+
+ public override bool PreVisit(UnaryOpNode node)
+ {
+ if (TryUseValue(node))
+ return false;
+ return true;
+ }
+
+ public override void PostVisit(UnaryOpNode node)
+ {
+ Contracts.AssertValue(node);
+
+ switch (node.Op)
+ {
+ default:
+ Contracts.Assert(false, "Bad unary op");
+ PostError(node, "Internal error in unary operator");
+ break;
+
+ case UnaryOp.Minus:
+ Contracts.Assert(node.IsNumber);
+ Contracts.Assert(node.Arg.ExprType == node.SrcKind);
+ switch (node.SrcKind)
+ {
+ case ExprTypeKind.I4:
+ _gen.Neg();
+ break;
+ case ExprTypeKind.I8:
+ _gen.Neg();
+ break;
+ case ExprTypeKind.R4:
+ case ExprTypeKind.R8:
+ _gen.Neg();
+ break;
+
+ default:
+ Contracts.Assert(false, "Bad operand type in unary minus");
+ PostError(node, "Internal error in unary minus");
+ break;
+ }
+ break;
+
+ case UnaryOp.Not:
+ CallFnc(BuiltinFunctions.Not);
+ break;
+ }
+
+ DoConvert(node);
+ }
+
+ public override bool PreVisit(BinaryOpNode node)
+ {
+ Contracts.AssertValue(node);
+
+ if (TryUseValue(node))
+ return false;
+
+ if (node.ReduceToLeft)
+ {
+ node.Left.Accept(this);
+ DoConvert(node);
+ return false;
+ }
+
+ if (node.ReduceToRight)
+ {
+ node.Right.Accept(this);
+ DoConvert(node);
+ return false;
+ }
+
+ switch (node.Op)
+ {
+ default:
+ Contracts.Assert(false, "Bad binary op");
+ PostError(node, "Internal error in binary operator");
+ break;
+
+ case BinaryOp.Coalesce:
+ GenCoalesce(node);
+ break;
+
+ case BinaryOp.Or:
+ case BinaryOp.And:
+ GenBoolBinOp(node);
+ break;
+
+ case BinaryOp.Add:
+ case BinaryOp.Sub:
+ case BinaryOp.Mul:
+ case BinaryOp.Div:
+ case BinaryOp.Mod:
+ case BinaryOp.Power:
+ GenNumBinOp(node);
+ break;
+
+ case BinaryOp.Error:
+ PostError(node);
+ break;
+ }
+
+ DoConvert(node);
+ return false;
+ }
+
+ private void GenBoolBinOp(BinaryOpNode node)
+ {
+ Contracts.AssertValue(node);
+ Contracts.Assert(node.Op == BinaryOp.Or || node.Op == BinaryOp.And);
+ Contracts.Assert(node.SrcKind == ExprTypeKind.BL);
+ Contracts.Assert(node.Left.IsBool);
+ Contracts.Assert(node.Right.IsBool);
+
+ // This does naive code gen for short-circuiting binary bool operators.
+ // Ideally, this would cooperate with comparisons to produce better code gen.
+ // However, this is merely an optimization issue, not correctness, so possibly
+ // not worth the additional complexity.
+
+ Label labEnd = _gen.DefineLabel();
+
+ node.Left.Accept(this);
+ _gen.Dup();
+
+ if (node.Op == BinaryOp.Or)
+ _gen.Brtrue(labEnd);
+ else
+ _gen.Brfalse(labEnd);
+
+ _gen.Pop();
+ node.Right.Accept(this);
+
+ _gen.Br(labEnd);
+ _gen.MarkLabel(labEnd);
+ }
+
+ private void GenNumBinOp(BinaryOpNode node)
+ {
+ Contracts.AssertValue(node);
+
+ // Note that checking for special known values like NA and identity values
+ // is done in the binder and handled in PreVisit(BinaryOpNode).
+ node.Left.Accept(this);
+ node.Right.Accept(this);
+
+ if (node.SrcKind == ExprTypeKind.I4)
+ {
+ Contracts.Assert(node.Left.IsI4);
+ Contracts.Assert(node.Right.IsI4);
+ switch (node.Op)
+ {
+ default:
+ Contracts.Assert(false, "Bad numeric bin op");
+ PostError(node, "Internal error in numeric binary operator");
+ break;
+
+ case BinaryOp.Add:
+ _gen.Add();
+ break;
+ case BinaryOp.Sub:
+ _gen.Sub();
+ break;
+ case BinaryOp.Mul:
+ _gen.Mul_Ovf();
+ break;
+ case BinaryOp.Div:
+ _gen.Div();
+ break;
+ case BinaryOp.Mod:
+ _gen.Rem();
+ break;
+ case BinaryOp.Power:
+ CallBin(BuiltinFunctions.Pow);
+ break;
+ }
+ }
+ else if (node.SrcKind == ExprTypeKind.I8)
+ {
+ Contracts.Assert(node.Left.IsI8);
+ Contracts.Assert(node.Right.IsI8);
+ switch (node.Op)
+ {
+ default:
+ Contracts.Assert(false, "Bad numeric bin op");
+ PostError(node, "Internal error in numeric binary operator");
+ break;
+
+ case BinaryOp.Add:
+ _gen.Add();
+ break;
+ case BinaryOp.Sub:
+ _gen.Sub();
+ break;
+ case BinaryOp.Mul:
+ _gen.Mul();
+ break;
+ case BinaryOp.Div:
+ _gen.Div();
+ break;
+ case BinaryOp.Mod:
+ _gen.Rem();
+ break;
+ case BinaryOp.Power:
+ CallBin(BuiltinFunctions.Pow);
+ break;
+ }
+ }
+ else if (node.SrcKind == ExprTypeKind.R4)
+ {
+ Contracts.Assert(node.Left.IsR4);
+ Contracts.Assert(node.Right.IsR4);
+ switch (node.Op)
+ {
+ default:
+ Contracts.Assert(false, "Bad numeric bin op");
+ PostError(node, "Internal error in numeric binary operator");
+ break;
+
+ case BinaryOp.Add:
+ _gen.Add();
+ break;
+ case BinaryOp.Sub:
+ _gen.Sub();
+ break;
+ case BinaryOp.Mul:
+ _gen.Mul();
+ break;
+ case BinaryOp.Div:
+ _gen.Div();
+ break;
+ case BinaryOp.Mod:
+ _gen.Rem();
+ break;
+ case BinaryOp.Power:
+ CallBin(BuiltinFunctions.Pow);
+ break;
+ }
+ }
+ else
+ {
+ Contracts.Assert(node.SrcKind == ExprTypeKind.R8);
+ Contracts.Assert(node.Left.IsR8);
+ Contracts.Assert(node.Right.IsR8);
+ switch (node.Op)
+ {
+ default:
+ Contracts.Assert(false, "Bad numeric bin op");
+ PostError(node, "Internal error in numeric binary operator");
+ break;
+
+ case BinaryOp.Add:
+ _gen.Add();
+ break;
+ case BinaryOp.Sub:
+ _gen.Sub();
+ break;
+ case BinaryOp.Mul:
+ _gen.Mul();
+ break;
+ case BinaryOp.Div:
+ _gen.Div();
+ break;
+ case BinaryOp.Mod:
+ _gen.Rem();
+ break;
+ case BinaryOp.Power:
+ CallBin(Math.Pow);
+ break;
+ }
+ }
+ }
+
+ private void GenBL(BL value)
+ {
+ MethodInfo meth;
+ if (!value)
+ meth = _methGetFalseBL;
+ else
+ meth = _methGetTrueBL;
+ _gen.Call(meth);
+ }
+
+ private void CallFnc(Func fn)
+ {
+ _gen.Call(fn.GetMethodInfo());
+ }
+
+ private void CallFnc(Func fn)
+ {
+ _gen.Call(fn.GetMethodInfo());
+ }
+
+ private void CallBin(Func fn)
+ {
+ _gen.Call(fn.GetMethodInfo());
+ }
+
+ private void GenCoalesce(BinaryOpNode node)
+ {
+ Contracts.AssertValue(node);
+ Contracts.Assert(node.Op == BinaryOp.Coalesce);
+
+ // If left is a constant, then the binder should have dealt with it!
+ Contracts.Assert(node.Left.ExprValue == null);
+
+ Label labEnd = _gen.DefineLabel();
+
+ // Branch to end if the left operand is NOT NA.
+ node.Left.Accept(this);
+ GenBrNa(node.Left, labEnd);
+
+ _gen.Pop();
+ node.Right.Accept(this);
+ _gen.MarkLabel(labEnd);
+ }
+
+ public override void PostVisit(BinaryOpNode node)
+ {
+ Contracts.Assert(false);
+ }
+
+ public override bool PreVisit(ConditionalNode node)
+ {
+ Contracts.AssertValue(node);
+
+ if (TryUseValue(node))
+ return false;
+
+ var cond = (BL?)node.Cond.ExprValue;
+ if (cond != null)
+ {
+ if (cond.Value)
+ node.Left.Accept(this);
+ else
+ node.Right.Accept(this);
+ goto LDone;
+ }
+
+ Label labEnd = _gen.DefineLabel();
+ Label labFalse = _gen.DefineLabel();
+
+ node.Cond.Accept(this);
+ _gen.Brfalse(labFalse);
+
+ // Left is the "true" branch.
+ node.Left.Accept(this);
+ _gen.Br(labEnd)
+ .MarkLabel(labFalse);
+
+ node.Right.Accept(this);
+ _gen.Br(labEnd);
+ _gen.MarkLabel(labEnd);
+
+ LDone:
+ DoConvert(node);
+ return false;
+ }
+
+ public override void PostVisit(ConditionalNode node)
+ {
+ Contracts.Assert(false);
+ }
+
+ public override bool PreVisit(CompareNode node)
+ {
+ Contracts.AssertValue(node);
+ Contracts.Assert(node.Operands.Items.Length >= 2);
+
+ if (TryUseValue(node))
+ return false;
+
+ ExprTypeKind kind = node.ArgTypeKind;
+ Node[] items = node.Operands.Items;
+ if (kind == ExprTypeKind.TX && items.Length == 2)
+ {
+ // Two value text comparison is handled by methods.
+ items[0].Accept(this);
+ items[1].Accept(this);
+ switch (node.Op)
+ {
+ default:
+ Contracts.Assert(false, "Bad bool compare op");
+ break;
+
+ case CompareOp.Equal:
+ CallFnc(BuiltinFunctions.Equals);
+ break;
+ case CompareOp.NotEqual:
+ CallFnc(BuiltinFunctions.NotEquals);
+ break;
+ }
+
+ DoConvert(node);
+ return false;
+ }
+
+ Label labEnd = _gen.DefineLabel();
+ if (items.Length == 2)
+ {
+ // Common case of two operands. Note that the binder should have handled the case when
+ // one or both is a constant NA.
+
+ ExprNode arg;
+ GenRaw(arg = items[0].AsExpr);
+ Contracts.Assert(arg.ExprType == kind);
+ GenRaw(arg = items[1].AsExpr);
+ Contracts.Assert(arg.ExprType == kind);
+
+ TokKind tid = node.Operands.Delimiters[0].Kind;
+ Contracts.Assert(tid == node.TidLax || tid == node.TidStrict);
+ var isStrict = tid == node.TidStrict;
+ switch (kind)
+ {
+ case ExprTypeKind.BL:
+ GenCmpBool(node.Op, isStrict);
+ break;
+ case ExprTypeKind.I4:
+ case ExprTypeKind.I8:
+ GenCmpInt(node.Op, isStrict);
+ break;
+ case ExprTypeKind.R4:
+ case ExprTypeKind.R8:
+ GenCmpFloat(node.Op, isStrict);
+ break;
+
+ default:
+ PostError(node, "Compare codegen for this comparison is NYI");
+ return false;
+ }
+ }
+ else
+ {
+ // For more than two items, we use branching instructions instead of ceq, clt, cgt, etc.
+ Contracts.Assert(items.Length > 2);
+
+ // Get the comparison generation function and the (raw) local type.
+ Action fnc;
+ Type typeLoc;
+ switch (kind)
+ {
+ case ExprTypeKind.BL:
+ fnc = GenCmpBool;
+ typeLoc = typeof(byte);
+ break;
+ case ExprTypeKind.TX:
+ fnc = GenCmpText;
+ typeLoc = typeof(TX);
+ break;
+ case ExprTypeKind.I4:
+ fnc = GenCmpInt;
+ typeLoc = typeof(int);
+ break;
+ case ExprTypeKind.I8:
+ fnc = GenCmpInt;
+ typeLoc = typeof(long);
+ break;
+ case ExprTypeKind.R4:
+ fnc = GenCmpFloat;
+ typeLoc = typeof(R4);
+ break;
+ case ExprTypeKind.R8:
+ fnc = GenCmpFloat;
+ typeLoc = typeof(R8);
+ break;
+
+ default:
+ PostError(node, "Compare codegen for this comparison is NYI");
+ return false;
+ }
+
+ Label labFalse = _gen.DefineLabel();
+ if (node.Op != CompareOp.NotEqual)
+ {
+ // Note: this loop doesn't work for != so it is handled separately below.
+ ExprNode arg = items[0].AsExpr;
+ Contracts.Assert(arg.ExprType == kind);
+
+ GenRaw(arg = items[0].AsExpr);
+ Contracts.Assert(arg.ExprType == kind);
+
+ for (int i = 1; ; i++)
+ {
+ TokKind tid = node.Operands.Delimiters[i - 1].Kind;
+ Contracts.Assert(tid == node.TidLax || tid == node.TidStrict);
+ var isStrict = tid == node.TidStrict;
+
+ arg = items[i].AsExpr;
+ Contracts.Assert(arg.ExprType == kind);
+ GenRaw(arg);
+
+ if (i == items.Length - 1)
+ {
+ // Last one.
+ fnc(node.Op, isStrict, labFalse);
+ break;
+ }
+
+ // We'll need this value again, so stash it in a local.
+ _gen.Dup();
+ using (var local = _meth.AcquireTemporary(typeLoc))
+ {
+ _gen.Stloc(local.Local);
+ fnc(node.Op, isStrict, labFalse);
+ _gen.Ldloc(local.Local);
+ }
+ }
+ }
+ else
+ {
+ // NotEqual is special - it means that the values are all distinct, so comparing adjacent
+ // items is not enough.
+ Contracts.Assert(node.Op == CompareOp.NotEqual & items.Length > 2);
+
+ // We need a local for each item.
+ var locals = new MethodGenerator.Temporary[items.Length];
+ for (int i = 0; i < locals.Length; i++)
+ locals[i] = _meth.AcquireTemporary(typeLoc);
+ try
+ {
+ ExprNode arg = items[0].AsExpr;
+ Contracts.Assert(arg.ExprType == kind);
+
+ GenRaw(arg);
+ _gen.Stloc(locals[0].Local);
+
+ for (int i = 1; i < items.Length; i++)
+ {
+ // Need to evaluate the expression and store it in the local.
+ arg = items[i].AsExpr;
+ Contracts.Assert(arg.ExprType == kind);
+ GenRaw(arg);
+ _gen.Stloc(locals[i].Local);
+
+ for (int j = 0; j < i; j++)
+ {
+ _gen.Ldloc(locals[j].Local)
+ .Ldloc(locals[i].Local);
+ fnc(node.Op, true, labFalse);
+ }
+ }
+ }
+ finally
+ {
+ for (int i = locals.Length; --i >= 0;)
+ locals[i].Dispose();
+ }
+ }
+
+ _gen.Call(_methGetTrueBL)
+ .Br(labEnd);
+
+ _gen.MarkLabel(labFalse);
+ _gen.Call(_methGetFalseBL)
+ .Br(labEnd);
+ }
+
+ _gen.MarkLabel(labEnd);
+
+ DoConvert(node);
+ return false;
+ }
+
+ ///
+ /// Get the raw bits from an expression node. If the node is constant, this avoids the
+ /// silly "convert to dv type" followed by "extract raw bits". Returns whether the expression
+ /// is a constant NA, with null meaning "don't know".
+ ///
+ private void GenRaw(ExprNode node)
+ {
+ Contracts.AssertValue(node);
+
+ var val = node.ExprValue;
+ if (val != null)
+ {
+ switch (node.ExprType)
+ {
+ case ExprTypeKind.BL:
+ {
+ var x = (BL)val;
+ _gen.Ldc_I4(x ? 1 : 0);
+ return;
+ }
+ case ExprTypeKind.I4:
+ {
+ var x = (I4)val;
+ _gen.Ldc_I4(x);
+ return;
+ }
+ case ExprTypeKind.I8:
+ {
+ var x = (I8)val;
+ _gen.Ldc_I8(x);
+ return;
+ }
+ case ExprTypeKind.R4:
+ {
+ var x = (R4)val;
+ _gen.Ldc_R4(x);
+ return;
+ }
+ case ExprTypeKind.R8:
+ {
+ var x = (R8)val;
+ _gen.Ldc_R8(x);
+ return;
+ }
+ case ExprTypeKind.TX:
+ {
+ var x = (TX)val;
+ _gen.Ldstr(x.ToString());
+ CallFnc(Exec.ToTX);
+ return;
+ }
+ }
+ }
+
+ node.Accept(this);
+ }
+
+ ///
+ /// Generate code to branch to labNa if the top stack element is NA.
+ /// Note that this leaves the element on the stack (duplicates before comparing).
+ /// If rev is true, this branches when NOT NA.
+ ///
+ private void GenBrNa(ExprNode node, Label labNa, bool dup = true)
+ {
+ GenBrNaCore(node, node.ExprType, labNa, dup);
+ }
+
+ ///
+ /// Generate code to branch to labNa if the top stack element is NA.
+ /// If dup is true, this leaves the element on the stack (duplicates before comparing).
+ /// If rev is true, this branches when NOT NA.
+ ///
+ private void GenBrNaCore(ExprNode node, ExprTypeKind kind, Label labNa, bool dup)
+ {
+ if (dup)
+ _gen.Dup();
+
+ switch (kind)
+ {
+ case ExprTypeKind.R4:
+ case ExprTypeKind.R8:
+ // Any value that is not equal to itself is an NA.
+ _gen.Dup();
+ _gen.Beq(labNa);
+ break;
+ case ExprTypeKind.Error:
+ case ExprTypeKind.None:
+ Contracts.Assert(false, "Bad expr kind in GenBrNa");
+ PostError(node, "Internal error in GenBrNa");
+ break;
+ }
+ }
+
+ ///
+ /// Generate a bool from comparing the raw bits. The values are guaranteed to not be NA.
+ ///
+ private void GenCmpBool(CompareOp op, bool isStrict)
+ {
+ switch (op)
+ {
+ default:
+ Contracts.Assert(false, "Bad bool compare op");
+ break;
+
+ case CompareOp.Equal:
+ _gen.Ceq();
+ break;
+ case CompareOp.NotEqual:
+ _gen.Xor();
+ break;
+ }
+ }
+
+ ///
+ /// Generate a bool from comparing the raw bits. The values are guaranteed to not be NA.
+ ///
+ private void GenCmpInt(CompareOp op, bool isStrict)
+ {
+ switch (op)
+ {
+ default:
+ Contracts.Assert(false, "Bad compare op");
+ break;
+
+ case CompareOp.Equal:
+ _gen.Ceq();
+ break;
+ case CompareOp.NotEqual:
+ _gen.Ceq().Ldc_I4(0).Ceq();
+ break;
+ case CompareOp.DecrChain:
+ if (isStrict)
+ _gen.Cgt();
+ else
+ _gen.Clt().Ldc_I4(0).Ceq();
+ break;
+ case CompareOp.IncrChain:
+ if (isStrict)
+ _gen.Clt();
+ else
+ _gen.Cgt().Ldc_I4(0).Ceq();
+ break;
+ }
+ }
+
+ ///
+ /// Generate a bool from comparing the raw bits. The values are guaranteed to not be NA.
+ ///
+ private void GenCmpFloat(CompareOp op, bool isStrict)
+ {
+ switch (op)
+ {
+ default:
+ Contracts.Assert(false, "Bad compare op");
+ break;
+
+ case CompareOp.Equal:
+ _gen.Ceq();
+ break;
+ case CompareOp.NotEqual:
+ _gen.Ceq().Ldc_I4(0).Ceq();
+ break;
+ case CompareOp.DecrChain:
+ if (isStrict)
+ _gen.Cgt();
+ else
+ _gen.Clt_Un().Ldc_I4(0).Ceq();
+ break;
+ case CompareOp.IncrChain:
+ if (isStrict)
+ _gen.Clt();
+ else
+ _gen.Cgt_Un().Ldc_I4(0).Ceq();
+ break;
+ }
+ }
+
+ private void GenCmpBool(CompareOp op, bool isStrict, Label labFalse)
+ {
+ switch (op)
+ {
+ default:
+ Contracts.Assert(false, "Bad bool compare op");
+ break;
+
+ case CompareOp.Equal:
+ _gen.Bne_Un(labFalse);
+ break;
+ case CompareOp.NotEqual:
+ _gen.Beq(labFalse);
+ break;
+ }
+ }
+
+ private void GenCmpText(CompareOp op, bool isStrict, Label labFalse)
+ {
+ // Note that NA values don't come through here, so we don't need NA propagating equality comparison.
+ switch (op)
+ {
+ default:
+ Contracts.Assert(false, "Bad bool compare op");
+ break;
+
+ case CompareOp.Equal:
+ CallFnc(BuiltinFunctions.Equals);
+ _gen.Brfalse(labFalse);
+ break;
+ case CompareOp.NotEqual:
+ CallFnc(BuiltinFunctions.Equals);
+ _gen.Brtrue(labFalse);
+ break;
+ }
+ }
+
+ private void GenCmpInt(CompareOp op, bool isStrict, Label labFalse)
+ {
+ switch (op)
+ {
+ default:
+ Contracts.Assert(false, "Bad compare op");
+ break;
+
+ case CompareOp.Equal:
+ _gen.Bne_Un(labFalse);
+ break;
+ case CompareOp.NotEqual:
+ _gen.Beq(labFalse);
+ break;
+ case CompareOp.DecrChain:
+ if (isStrict)
+ _gen.Ble(labFalse);
+ else
+ _gen.Blt(labFalse);
+ break;
+ case CompareOp.IncrChain:
+ if (isStrict)
+ _gen.Bge(labFalse);
+ else
+ _gen.Bgt(labFalse);
+ break;
+ }
+ }
+
+ private void GenCmpFloat(CompareOp op, bool isStrict, Label labFalse)
+ {
+ switch (op)
+ {
+ default:
+ Contracts.Assert(false, "Bad compare op");
+ break;
+
+ case CompareOp.Equal:
+ _gen.Bne_Un(labFalse);
+ break;
+ case CompareOp.NotEqual:
+ _gen.Beq(labFalse);
+ break;
+ case CompareOp.DecrChain:
+ if (isStrict)
+ _gen.Ble_Un(labFalse);
+ else
+ _gen.Blt_Un(labFalse);
+ break;
+ case CompareOp.IncrChain:
+ if (isStrict)
+ _gen.Bge_Un(labFalse);
+ else
+ _gen.Bgt_Un(labFalse);
+ break;
+ }
+ }
+
+ public override void PostVisit(CompareNode node)
+ {
+ Contracts.Assert(false);
+ }
+
+ public override bool PreVisit(CallNode node)
+ {
+ Contracts.AssertValue(node);
+
+ if (TryUseValue(node))
+ return false;
+
+ if (node.Method == null)
+ {
+ Contracts.Assert(false, "Bad function");
+ PostError(node, "Internal error: unknown function: '{0}'", node.Head.Value);
+ return false;
+ }
+
+ var meth = node.Method;
+ var ps = meth.GetParameters();
+ Type type;
+ if (Utils.Size(ps) > 0 && (type = ps[ps.Length - 1].ParameterType).IsArray)
+ {
+ // Variable case, so can't be identity.
+ Contracts.Assert(node.Method.ReturnType != typeof(void));
+
+ // Get the item type of the array.
+ type = type.GetElementType();
+
+ var args = node.Args.Items;
+ int head = ps.Length - 1;
+ int tail = node.Args.Items.Length - head;
+ Contracts.Assert(tail >= 0);
+
+ // Generate the "head" args.
+ for (int i = 0; i < head; i++)
+ args[i].Accept(this);
+
+ // Bundle the "tail" args into an array.
+ _gen.Ldc_I4(tail)
+ .Newarr(type);
+ for (int i = 0; i < tail; i++)
+ {
+ _gen.Dup()
+ .Ldc_I4(i);
+ args[head + i].Accept(this);
+ _gen.Stelem(type);
+ }
+
+ // Make the call.
+ _gen.Call(node.Method);
+ }
+ else
+ {
+ Contracts.Assert(Utils.Size(ps) == node.Args.Items.Length);
+ node.Args.Accept(this);
+
+ // An identity function is marked with a void return type.
+ if (node.Method.ReturnType != typeof(void))
+ _gen.Call(node.Method);
+ else
+ Contracts.Assert(node.Args.Items.Length == 1);
+ }
+
+ DoConvert(node);
+ return false;
+ }
+
+ public override void PostVisit(CallNode node)
+ {
+ Contracts.Assert(false);
+ }
+
+ public override void PostVisit(ListNode node)
+ {
+ Contracts.AssertValue(node);
+ }
+
+ public override bool PreVisit(WithNode node)
+ {
+ Contracts.AssertValue(node);
+
+ var local = node.Local;
+ Contracts.Assert(local.Index == -1);
+ Contracts.Assert(local.UseCount >= 0);
+
+ if (local.Value.ExprValue != null || local.UseCount <= 1)
+ {
+ // In this case, simply inline the code generation, no need
+ // to cache the value in an IL local.
+ node.Body.Accept(this);
+ Contracts.Assert(local.Index == -1);
+ }
+ else
+ {
+ // REVIEW: What's a reasonable value? This allows binary uses of 7 locals.
+ // This should cover most cases, but allows a rather large bloat factor.
+ const int maxTotalUse = 128;
+
+ // This case uses a cache value. When lazy, it also keeps a bool flag indicating
+ // whether the value has been computed and stored in the cache yet.
+ int index = _cacheWith.Count;
+
+ // Lazy can bloat code gen exponentially. This test decides whether to be lazy for this
+ // particular local, based on its use count and nesting. This assumes the worst case,
+ // that each lazy value is used by the next lazy value the full UseCount number of times.
+ // REVIEW: We should try to do better at some point.... Strictness analysis would
+ // solve this, but is non-trivial to implement.
+ bool lazy = true;
+ long totalUse = local.UseCount;
+ if (totalUse > maxTotalUse)
+ lazy = false;
+ else
+ {
+ for (int i = index; --i >= 0;)
+ {
+ var item = _cacheWith[i];
+ Contracts.Assert(item.Node.UseCount >= 2);
+ if (item.Flag == null)
+ continue;
+ totalUse *= item.Node.UseCount;
+ if (totalUse > maxTotalUse)
+ {
+ lazy = false;
+ break;
+ }
+ }
+ }
+
+ // This risks code gen
+ // bloat but avoids unnecessary computation. Perhaps we should determine whether the
+ // value is always needed. However, this can be quite complicated, requiring flow
+ // analysis through all expression kinds.
+
+ // REVIEW: Should we always make the code generation lazy? This risks code gen
+ // bloat but avoids unnecessary computation. Perhaps we should determine whether the
+ // value is always needed. However, this can be quite complicated, requiring flow
+ // analysis through all expression kinds.
+ using (var value = _meth.AcquireTemporary(ExprNode.ToSysType(local.Value.ExprType)))
+ using (var flag = lazy ? _meth.AcquireTemporary(typeof(bool)) : default(MethodGenerator.Temporary))
+ {
+ LocalBuilder flagBldr = flag.Local;
+ Contracts.Assert((flagBldr != null) == lazy);
+
+ if (lazy)
+ {
+ _gen
+ .Ldc_I4(0)
+ .Stloc(flagBldr);
+ }
+ else
+ {
+ local.Value.Accept(this);
+ _gen.Stloc(value.Local);
+ }
+
+ // Activate the cache item.
+ var cache = new CachedWithLocal(local, value.Local, flag.Local);
+ _cacheWith.Add(cache);
+ Contracts.Assert(_cacheWith.Count == index + 1);
+
+ // Generate the code for the body.
+ local.Index = index;
+ node.Body.Accept(this);
+ Contracts.Assert(local.Index == index);
+ local.Index = -1;
+
+ // Remove the cache locals.
+ Contracts.Assert(_cacheWith.Count == index + 1);
+ Contracts.Assert(_cacheWith[index] == cache);
+ _cacheWith.RemoveAt(index);
+ }
+ }
+
+#if DEBUG
+ System.Diagnostics.Debug.WriteLine("Generated code '{0}' times for '{1}'", local.GenCount, local);
+#endif
+ return false;
+ }
+
+ public override void PostVisit(WithNode node)
+ {
+ Contracts.Assert(false);
+ }
+
+ public override bool PreVisit(WithLocalNode node)
+ {
+ Contracts.Assert(false);
+ return false;
+ }
+
+ public override void PostVisit(WithLocalNode node)
+ {
+ Contracts.Assert(false);
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/Microsoft.ML.Transforms/Expression/Error.cs b/src/Microsoft.ML.Transforms/Expression/Error.cs
new file mode 100644
index 0000000000..e6dc80fea5
--- /dev/null
+++ b/src/Microsoft.ML.Transforms/Expression/Error.cs
@@ -0,0 +1,46 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Internal.Utilities;
+using Microsoft.ML.Runtime;
+
+namespace Microsoft.ML.Transforms
+{
+ internal sealed class Error
+ {
+ public readonly Token Token;
+ public readonly string Message;
+ // Args may be null.
+ public readonly object[] Args;
+
+ public Error(Token tok, string msg)
+ {
+ Contracts.AssertValue(tok);
+ Contracts.AssertNonEmpty(msg);
+ Token = tok;
+ Message = msg;
+ Args = null;
+ }
+
+ public Error(Token tok, string msg, params object[] args)
+ {
+ Contracts.AssertValue(tok);
+ Contracts.AssertNonEmpty(msg);
+ Contracts.AssertValue(args);
+ Token = tok;
+ Message = msg;
+ Args = args;
+ }
+
+ public string GetMessage()
+ {
+ var msg = Message;
+ if (Utils.Size(Args) > 0)
+ msg = string.Format(msg, Args);
+ if (Token != null)
+ msg = string.Format("at '{0}': {1}", Token, msg);
+ return msg;
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/Microsoft.ML.Transforms/Expression/Exec.cs b/src/Microsoft.ML.Transforms/Expression/Exec.cs
new file mode 100644
index 0000000000..cba51e5e18
--- /dev/null
+++ b/src/Microsoft.ML.Transforms/Expression/Exec.cs
@@ -0,0 +1,35 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Runtime.CompilerServices;
+using Microsoft.ML.Runtime;
+
+namespace Microsoft.ML.Transforms
+{
+ using TX = ReadOnlyMemory;
+
+ ///
+ /// This class contains static helper methods needed for execution.
+ ///
+ internal sealed class Exec
+ {
+ ///
+ /// Currently this class is not intended to be instantiated. However the methods generated
+ /// by ExprTransform need to be associated with some public type. This one serves that
+ /// purpose as well as containing static helpers.
+ ///
+ private Exec()
+ {
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static TX ToTX(string str)
+ {
+ // We shouldn't allow a null in here.
+ Contracts.AssertValue(str);
+ return str.AsMemory();
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Transforms/Expression/FunctionProvider.cs b/src/Microsoft.ML.Transforms/Expression/FunctionProvider.cs
new file mode 100644
index 0000000000..aa898f63ef
--- /dev/null
+++ b/src/Microsoft.ML.Transforms/Expression/FunctionProvider.cs
@@ -0,0 +1,44 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Reflection;
+
+namespace Microsoft.ML.Transforms
+{
+ public delegate void SignatureFunctionProvider();
+
+ ///
+ /// This interface enables extending the ExprTransform language with additional functions.
+ ///
+ public interface IFunctionProvider
+ {
+ ///
+ /// The namespace for this provider. This should be a legal identifier in the expression language.
+ /// Multiple providers may contribute to the same namespace.
+ ///
+ string NameSpace { get; }
+
+ ///
+ /// Returns an array of overloads for the given function name. This may return null instead of an
+ /// empty array. The returned MethodInfos should be public static methods that can be freely invoked
+ /// by IL in a different assembly. They should also be "pure" functions - with the output only
+ /// depending on the inputs and NOT on any global state.
+ ///
+ MethodInfo[] Lookup(string name);
+
+ ///
+ /// If the function's value can be determined by the given subset of its arguments, this should
+ /// return the resulting value. Note that this should only be called if values is non-empty and
+ /// contains at least one null. If all the arguments are non-null, then the MethodInfo will be
+ /// invoked to produce the value.
+ ///
+ /// The name of the function.
+ /// The MethodInfo provided by Lookup. When there are multiple overloads of
+ /// a function with a given name, this can be used to determine which overload is being used.
+ /// The values of the input arguments, with null for the non-constant arguments. This should
+ /// only be called if there is at least one null.
+ /// The constant value, when it can be determined; null otherwise.
+ object ResolveToConstant(string name, MethodInfo meth, object[] values);
+ }
+}
diff --git a/src/Microsoft.ML.Transforms/Expression/IlGeneratorExtensions.cs b/src/Microsoft.ML.Transforms/Expression/IlGeneratorExtensions.cs
new file mode 100644
index 0000000000..d765e7f9fb
--- /dev/null
+++ b/src/Microsoft.ML.Transforms/Expression/IlGeneratorExtensions.cs
@@ -0,0 +1,405 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Reflection;
+using System.Reflection.Emit;
+using Microsoft.ML.Runtime;
+
+#pragma warning disable MSML_GeneralName // The names are derived from .NET OpCode names. These do not adhere to .NET naming standards.
+namespace Microsoft.ML.Transforms
+{
+ ///
+ /// Helper extension methods for using ILGenerator.
+ /// Rather than typing out something like:
+ /// il.Emit(OpCodes.Ldarg_0);
+ /// il.Emit(OpCodes.Ldarg_1);
+ /// il.Emit(OpCodes.Ldc_I4, i);
+ /// il.Emit(OpCodes.Ldelem_Ref);
+ /// il.Emit(OpCodes.Stfld, literalFields[i]);
+ /// You can do:
+ /// il
+ /// .Ldarg(0)
+ /// .Ldarg(1)
+ /// .Ldc_I4(i)
+ /// .Ldelem_Ref()
+ /// .Stfld(literalFields[i]);
+ /// It also provides some type safety over the Emit methods by ensuring
+ /// that you don't pass any args when using Add or that you only
+ /// pass a Label when using Br.
+ ///
+ internal static class ILGeneratorExtensions
+ {
+ public static ILGenerator Add(this ILGenerator il)
+ {
+ Contracts.AssertValue(il);
+ il.Emit(OpCodes.Add);
+ return il;
+ }
+
+ public static ILGenerator Beq(this ILGenerator il, Label label)
+ {
+ Contracts.AssertValue(il);
+ il.Emit(OpCodes.Beq, label);
+ return il;
+ }
+
+ public static ILGenerator Bge(this ILGenerator il, Label label)
+ {
+ Contracts.AssertValue(il);
+ il.Emit(OpCodes.Bge, label);
+ return il;
+ }
+
+ public static ILGenerator Bge_Un(this ILGenerator il, Label label)
+ {
+ Contracts.AssertValue(il);
+ il.Emit(OpCodes.Bge_Un, label);
+ return il;
+ }
+
+ public static ILGenerator Bgt(this ILGenerator il, Label label)
+ {
+ Contracts.AssertValue(il);
+ il.Emit(OpCodes.Bgt, label);
+ return il;
+ }
+
+ public static ILGenerator Bgt_Un(this ILGenerator il, Label label)
+ {
+ Contracts.AssertValue(il);
+ il.Emit(OpCodes.Bgt_Un, label);
+ return il;
+ }
+
+ public static ILGenerator Ble(this ILGenerator il, Label label)
+ {
+ Contracts.AssertValue(il);
+ il.Emit(OpCodes.Ble, label);
+ return il;
+ }
+
+ public static ILGenerator Ble_Un(this ILGenerator il, Label label)
+ {
+ Contracts.AssertValue(il);
+ il.Emit(OpCodes.Ble_Un, label);
+ return il;
+ }
+
+ public static ILGenerator Blt(this ILGenerator il, Label label)
+ {
+ Contracts.AssertValue(il);
+ il.Emit(OpCodes.Blt, label);
+ return il;
+ }
+
+ public static ILGenerator Blt_Un(this ILGenerator il, Label label)
+ {
+ Contracts.AssertValue(il);
+ il.Emit(OpCodes.Blt_Un, label);
+ return il;
+ }
+
+ public static ILGenerator Bne_Un(this ILGenerator il, Label label)
+ {
+ Contracts.AssertValue(il);
+ il.Emit(OpCodes.Bne_Un, label);
+ return il;
+ }
+
+ public static ILGenerator Br(this ILGenerator il, Label label)
+ {
+ Contracts.AssertValue(il);
+ il.Emit(OpCodes.Br, label);
+ return il;
+ }
+
+ public static ILGenerator Brfalse(this ILGenerator il, Label label)
+ {
+ Contracts.AssertValue(il);
+ il.Emit(OpCodes.Brfalse, label);
+ return il;
+ }
+
+ public static ILGenerator Brtrue(this ILGenerator il, Label label)
+ {
+ Contracts.AssertValue(il);
+ il.Emit(OpCodes.Brtrue, label);
+ return il;
+ }
+
+ public static ILGenerator Call(this ILGenerator il, MethodInfo info)
+ {
+ Contracts.AssertValue(il);
+ Contracts.AssertValue(info);
+ Contracts.Assert(!info.IsVirtual);
+ il.Emit(OpCodes.Call, info);
+ return il;
+ }
+
+ public static ILGenerator Ceq(this ILGenerator il)
+ {
+ Contracts.AssertValue(il);
+ il.Emit(OpCodes.Ceq);
+ return il;
+ }
+
+ public static ILGenerator Cgt(this ILGenerator il)
+ {
+ Contracts.AssertValue(il);
+ il.Emit(OpCodes.Cgt);
+ return il;
+ }
+
+ public static ILGenerator Cgt_Un(this ILGenerator il)
+ {
+ Contracts.AssertValue(il);
+ il.Emit(OpCodes.Cgt_Un);
+ return il;
+ }
+
+ public static ILGenerator Clt(this ILGenerator il)
+ {
+ Contracts.AssertValue(il);
+ il.Emit(OpCodes.Clt);
+ return il;
+ }
+
+ public static ILGenerator Clt_Un(this ILGenerator il)
+ {
+ Contracts.AssertValue(il);
+ il.Emit(OpCodes.Clt_Un);
+ return il;
+ }
+
+ public static ILGenerator Conv_I8(this ILGenerator il)
+ {
+ Contracts.AssertValue(il);
+ il.Emit(OpCodes.Conv_I8);
+ return il;
+ }
+
+ public static ILGenerator Conv_R4(this ILGenerator il)
+ {
+ Contracts.AssertValue(il);
+ il.Emit(OpCodes.Conv_R4);
+ return il;
+ }
+
+ public static ILGenerator Conv_R8(this ILGenerator il)
+ {
+ Contracts.AssertValue(il);
+ il.Emit(OpCodes.Conv_R8);
+ return il;
+ }
+
+ public static ILGenerator Div(this ILGenerator il)
+ {
+ Contracts.AssertValue(il);
+ il.Emit(OpCodes.Div);
+ return il;
+ }
+
+ public static ILGenerator Dup(this ILGenerator il)
+ {
+ Contracts.AssertValue(il);
+ il.Emit(OpCodes.Dup);
+ return il;
+ }
+
+ public static ILGenerator Ldarg(this ILGenerator il, int arg)
+ {
+ Contracts.AssertValue(il);
+ Contracts.Assert(0 <= arg && arg <= short.MaxValue);
+
+ switch (arg)
+ {
+ case 0:
+ il.Emit(OpCodes.Ldarg_0);
+ break;
+ case 1:
+ il.Emit(OpCodes.Ldarg_1);
+ break;
+ case 2:
+ il.Emit(OpCodes.Ldarg_2);
+ break;
+ case 3:
+ il.Emit(OpCodes.Ldarg_3);
+ break;
+ default:
+ if (arg <= byte.MaxValue)
+ il.Emit(OpCodes.Ldarg_S, (byte)arg);
+ else
+ il.Emit(OpCodes.Ldarg, (short)arg);
+ break;
+ }
+ return il;
+ }
+
+ public static ILGenerator Ldc_I4(this ILGenerator il, int arg)
+ {
+ Contracts.AssertValue(il);
+
+ switch (arg)
+ {
+ case -1:
+ il.Emit(OpCodes.Ldc_I4_M1);
+ break;
+ case 0:
+ il.Emit(OpCodes.Ldc_I4_0);
+ break;
+ case 1:
+ il.Emit(OpCodes.Ldc_I4_1);
+ break;
+ case 2:
+ il.Emit(OpCodes.Ldc_I4_2);
+ break;
+ case 3:
+ il.Emit(OpCodes.Ldc_I4_3);
+ break;
+ case 4:
+ il.Emit(OpCodes.Ldc_I4_4);
+ break;
+ case 5:
+ il.Emit(OpCodes.Ldc_I4_5);
+ break;
+ case 6:
+ il.Emit(OpCodes.Ldc_I4_6);
+ break;
+ case 7:
+ il.Emit(OpCodes.Ldc_I4_7);
+ break;
+ case 8:
+ il.Emit(OpCodes.Ldc_I4_8);
+ break;
+ default:
+ // REVIEW: Docs say use ILGenerator.Emit(OpCode, byte) even though the value is signed
+ if (sbyte.MinValue <= arg && arg <= sbyte.MaxValue)
+ il.Emit(OpCodes.Ldc_I4_S, (byte)arg);
+ else
+ il.Emit(OpCodes.Ldc_I4, arg);
+ break;
+ }
+ return il;
+ }
+
+ public static ILGenerator Ldc_I8(this ILGenerator il, long arg)
+ {
+ Contracts.AssertValue(il);
+ il.Emit(OpCodes.Ldc_I8, arg);
+ return il;
+ }
+
+ public static ILGenerator Ldc_R4(this ILGenerator il, float arg)
+ {
+ Contracts.AssertValue(il);
+ il.Emit(OpCodes.Ldc_R4, arg);
+ return il;
+ }
+
+ public static ILGenerator Ldc_R8(this ILGenerator il, double arg)
+ {
+ Contracts.AssertValue(il);
+ il.Emit(OpCodes.Ldc_R8, arg);
+ return il;
+ }
+
+ public static ILGenerator Ldloc(this ILGenerator il, LocalBuilder builder)
+ {
+ Contracts.AssertValue(il);
+ Contracts.AssertValue(builder);
+ il.Emit(OpCodes.Ldloc, builder);
+ return il;
+ }
+
+ public static ILGenerator Ldstr(this ILGenerator il, string str)
+ {
+ Contracts.AssertValue(il);
+ Contracts.AssertValue(str);
+ il.Emit(OpCodes.Ldstr, str);
+ return il;
+ }
+
+ public static ILGenerator Mul(this ILGenerator il)
+ {
+ Contracts.AssertValue(il);
+ il.Emit(OpCodes.Mul);
+ return il;
+ }
+
+ public static ILGenerator Mul_Ovf(this ILGenerator il)
+ {
+ Contracts.AssertValue(il);
+ il.Emit(OpCodes.Mul_Ovf);
+ return il;
+ }
+
+ public static ILGenerator Neg(this ILGenerator il)
+ {
+ Contracts.AssertValue(il);
+ il.Emit(OpCodes.Neg);
+ return il;
+ }
+
+ public static ILGenerator Newarr(this ILGenerator il, Type type)
+ {
+ Contracts.AssertValue(il);
+ Contracts.AssertValue(type);
+ il.Emit(OpCodes.Newarr, type);
+ return il;
+ }
+
+ public static ILGenerator Pop(this ILGenerator il)
+ {
+ Contracts.AssertValue(il);
+ il.Emit(OpCodes.Pop);
+ return il;
+ }
+
+ public static ILGenerator Rem(this ILGenerator il)
+ {
+ Contracts.AssertValue(il);
+ il.Emit(OpCodes.Rem);
+ return il;
+ }
+
+ public static ILGenerator Ret(this ILGenerator il)
+ {
+ Contracts.AssertValue(il);
+ il.Emit(OpCodes.Ret);
+ return il;
+ }
+
+ public static ILGenerator Stelem(this ILGenerator il, Type type)
+ {
+ Contracts.AssertValue(il);
+ Contracts.AssertValue(type);
+ il.Emit(OpCodes.Stelem, type);
+ return il;
+ }
+
+ public static ILGenerator Stloc(this ILGenerator il, LocalBuilder builder)
+ {
+ Contracts.AssertValue(il);
+ Contracts.AssertValue(builder);
+ il.Emit(OpCodes.Stloc, builder);
+ return il;
+ }
+
+ public static ILGenerator Sub(this ILGenerator il)
+ {
+ Contracts.AssertValue(il);
+ il.Emit(OpCodes.Sub);
+ return il;
+ }
+
+ public static ILGenerator Xor(this ILGenerator il)
+ {
+ Contracts.AssertValue(il);
+ il.Emit(OpCodes.Xor);
+ return il;
+ }
+ }
+}
+#pragma warning restore MSML_GeneralName
\ No newline at end of file
diff --git a/src/Microsoft.ML.Transforms/Expression/KeyWordTable.cs b/src/Microsoft.ML.Transforms/Expression/KeyWordTable.cs
new file mode 100644
index 0000000000..822dbcd50d
--- /dev/null
+++ b/src/Microsoft.ML.Transforms/Expression/KeyWordTable.cs
@@ -0,0 +1,109 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Collections.Generic;
+using Microsoft.ML.Internal.Utilities;
+using Microsoft.ML.Runtime;
+
+namespace Microsoft.ML.Transforms
+{
+ ///
+ /// Maps from normalized string to keyword token kind. A lexer must be provided with one of these.
+ ///
+ internal partial class KeyWordTable
+ {
+ public struct KeyWordKind
+ {
+ public readonly TokKind Kind;
+ public readonly bool IsContextKeyWord;
+
+ public KeyWordKind(TokKind kind, bool isContextKeyWord)
+ {
+ Kind = kind;
+ IsContextKeyWord = isContextKeyWord;
+ }
+ }
+
+ private readonly NormStr.Pool _pool;
+ private Dictionary _mpnstrtidWord;
+ private Dictionary _mpnstrtidPunc;
+
+ public KeyWordTable(NormStr.Pool pool)
+ {
+ Contracts.AssertValue(pool);
+
+ _pool = pool;
+ _mpnstrtidWord = new Dictionary();
+ _mpnstrtidPunc = new Dictionary();
+ }
+
+ public void AddKeyWord(string str, TokKind tid)
+ {
+ Contracts.AssertNonEmpty(str);
+ _mpnstrtidWord.Add(_pool.Add(str), new KeyWordKind(tid, false));
+ }
+
+ public bool TryAddPunctuator(string str, TokKind tid)
+ {
+ Contracts.AssertNonEmpty(str);
+
+ // Note: this assumes that once a prefix is found, that all shorter
+ // prefixes are mapped to something (TokKind.None to indicate that
+ // it is only a prefix and not itself a token).
+
+ TokKind tidCur;
+ NormStr nstr = _pool.Add(str);
+ if (_mpnstrtidPunc.TryGetValue(_pool.Add(str), out tidCur))
+ {
+ if (tidCur == tid)
+ return true;
+ if (tidCur != TokKind.None)
+ return false;
+ }
+ else
+ {
+ // Map all prefixes (that aren't already mapped) to TokKind.None.
+ for (int cch = str.Length; --cch > 0;)
+ {
+ NormStr nstrTmp = _pool.Add(str.Substring(0, cch));
+ TokKind tidTmp;
+ if (_mpnstrtidPunc.TryGetValue(_pool.Add(nstrTmp.Value), out tidTmp))
+ break;
+ _mpnstrtidPunc.Add(nstrTmp, TokKind.None);
+ }
+ }
+ _mpnstrtidPunc[nstr] = tid;
+ return true;
+ }
+
+ public void AddPunctuator(string str, TokKind tid)
+ {
+ Contracts.AssertNonEmpty(str);
+ if (!TryAddPunctuator(str, tid))
+ Contracts.Assert(false, "duplicate punctuator!");
+ }
+
+ public bool IsKeyWord(NormStr nstr, out KeyWordKind kind)
+ {
+ Contracts.Assert(!nstr.Value.IsEmpty);
+ return _mpnstrtidWord.TryGetValue(nstr, out kind);
+ }
+
+ public bool IsPunctuator(NormStr nstr, out TokKind tid)
+ {
+ Contracts.Assert(!nstr.Value.IsEmpty);
+ return _mpnstrtidPunc.TryGetValue(nstr, out tid);
+ }
+
+ public IEnumerable> Punctuators
+ {
+ get { return _mpnstrtidPunc; }
+ }
+
+ public IEnumerable> KeyWords
+ {
+ get { return _mpnstrtidWord; }
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/Microsoft.ML.Transforms/Expression/LambdaBinder.cs b/src/Microsoft.ML.Transforms/Expression/LambdaBinder.cs
new file mode 100644
index 0000000000..8a756fec08
--- /dev/null
+++ b/src/Microsoft.ML.Transforms/Expression/LambdaBinder.cs
@@ -0,0 +1,1858 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Reflection;
+using Microsoft.ML.Data;
+using Microsoft.ML.Internal.Utilities;
+using Microsoft.ML.Runtime;
+
+namespace Microsoft.ML.Transforms
+{
+ using BL = System.Boolean;
+ using I4 = System.Int32;
+ using I8 = System.Int64;
+ using R4 = Single;
+ using R8 = Double;
+ using TX = ReadOnlyMemory;
+
+ internal sealed partial class LambdaBinder : NodeVisitor
+ {
+ private readonly IHost _host;
+ // The stack of active with nodes.
+ private List _rgwith;
+
+ private List _errors;
+ private LambdaNode _lambda;
+
+ private readonly IFunctionProvider[] _providers;
+ private readonly Action _printError;
+
+ private LambdaBinder(IHostEnvironment env, Action printError)
+ {
+ _host = env.Register("LambdaBinder");
+ _printError = printError;
+ _rgwith = new List();
+ _providers = env.ComponentCatalog.GetAllDerivedClasses(typeof(IFunctionProvider), typeof(SignatureFunctionProvider))
+ .Select(info => info.CreateInstance(_host))
+ .Prepend(BuiltinFunctions.Instance)
+ .ToArray();
+ }
+
+ ///
+ /// Run Lambda binder on LambdaNode and populate Expr values.
+ /// The errors contain list of user errors that occurred during binding.
+ /// The printError delegate is only used for reporting issues with function provider implementations, which are programmer errors.
+ /// In particular, it is NOT used to report user errors in the lambda expression.
+ ///
+ public static void Run(IHostEnvironment env, ref List errors, LambdaNode node, Action printError)
+ {
+ Contracts.AssertValue(env);
+ env.AssertValueOrNull(errors);
+ env.AssertValue(node);
+ env.AssertValue(printError);
+
+ var binder = new LambdaBinder(env, printError);
+ binder._errors = errors;
+ node.Accept(binder);
+ env.Assert(binder._rgwith.Count == 0);
+
+ var expr = node.Expr;
+ switch (expr.ExprType)
+ {
+ case ExprTypeKind.BL:
+ node.ResultType = BooleanDataViewType.Instance;
+ break;
+ case ExprTypeKind.I4:
+ node.ResultType = NumberDataViewType.Int32;
+ break;
+ case ExprTypeKind.I8:
+ node.ResultType = NumberDataViewType.Int64;
+ break;
+ case ExprTypeKind.R4:
+ node.ResultType = NumberDataViewType.Single;
+ break;
+ case ExprTypeKind.R8:
+ node.ResultType = NumberDataViewType.Double;
+ break;
+ case ExprTypeKind.TX:
+ node.ResultType = TextDataViewType.Instance;
+ break;
+ default:
+ if (!binder.HasErrors)
+ binder.PostError(expr, "Invalid result type");
+ break;
+ }
+
+ errors = binder._errors;
+ }
+
+ private bool HasErrors
+ {
+ get { return Utils.Size(_errors) > 0; }
+ }
+
+ private void PostError(Node node, string msg)
+ {
+ Utils.Add(ref _errors, new Error(node.Token, msg));
+ }
+
+ private void PostError(Node node, string msg, params object[] args)
+ {
+ Utils.Add(ref _errors, new Error(node.Token, string.Format(msg, args)));
+ }
+
+ public override void Visit(BoolLitNode node)
+ {
+ _host.AssertValue(node);
+ _host.Assert(node.IsBool);
+ _host.AssertValue(node.ExprValue);
+ }
+
+ public override void Visit(StrLitNode node)
+ {
+ _host.AssertValue(node);
+ _host.Assert(node.IsTX);
+ _host.AssertValue(node.ExprValue);
+ }
+
+ public override void Visit(NumLitNode node)
+ {
+ _host.AssertValue(node);
+ _host.Assert(node.IsNumber || node.IsError);
+ _host.Assert((node.ExprValue == null) == node.IsError);
+
+ if (node.IsError)
+ PostError(node, "Overflow");
+ }
+
+ public override void Visit(NameNode node)
+ {
+ }
+
+ public override void Visit(IdentNode node)
+ {
+ _host.AssertValue(node);
+
+ // If the IdentNode didn't actually have an IdentToken, just bag out.
+ if (node.IsMissing)
+ {
+ _host.Assert(HasErrors);
+ node.SetType(ExprTypeKind.Error);
+ return;
+ }
+
+ // Look for "with" locals.
+ string name = node.Value;
+ for (int i = _rgwith.Count; --i >= 0;)
+ {
+ var with = _rgwith[i];
+ if (name == with.Local.Name)
+ {
+ node.Referent = with.Local;
+ node.SetValue(with.Local.Value);
+ // REVIEW: Note that some uses might get pruned, but this gives us
+ // an upper bound on the time of places in the code where this value is needed.
+ with.Local.UseCount++;
+ return;
+ }
+ }
+
+ // Look for parameters.
+ ParamNode param;
+ if (_lambda != null && (param = _lambda.FindParam(node.Value)) != null)
+ {
+ node.Referent = param;
+ node.SetType(param.ExprType);
+ return;
+ }
+
+ PostError(node, "Unresolved identifier '{0}'", node.Value);
+ node.SetType(ExprTypeKind.Error);
+ }
+
+ public override void Visit(ParamNode node)
+ {
+ _host.AssertValue(node);
+ _host.Assert(node.ExprType != 0);
+ }
+
+ public override bool PreVisit(LambdaNode node)
+ {
+ _host.AssertValue(node);
+ _host.Assert(_lambda == null, "Can't support nested lambdas");
+
+ _lambda = node;
+
+ node.Expr.Accept(this);
+
+ _host.Assert(_lambda == node);
+ _lambda = null;
+
+ return false;
+ }
+
+ public override void PostVisit(LambdaNode node)
+ {
+ _host.Assert(false);
+ }
+
+ private string GetStr(ExprTypeKind kind)
+ {
+ switch (kind)
+ {
+ case ExprTypeKind.BL:
+ return "boolean";
+ case ExprTypeKind.R4:
+ case ExprTypeKind.R8:
+ return "numeric";
+ case ExprTypeKind.I4:
+ case ExprTypeKind.I8:
+ return "integer";
+ case ExprTypeKind.TX:
+ return "text";
+ }
+
+ return null;
+ }
+
+ private void BadNum(ExprNode arg)
+ {
+ if (!arg.IsError)
+ PostError(arg, "Invalid numeric operand");
+ _host.Assert(HasErrors);
+ }
+
+ private void BadNum(ExprNode node, ExprNode arg)
+ {
+ BadNum(arg);
+ _host.Assert(HasErrors);
+ node.SetType(ExprTypeKind.Error);
+ }
+
+ private void BadText(ExprNode arg)
+ {
+ if (!arg.IsError)
+ PostError(arg, "Invalid text operand");
+ _host.Assert(HasErrors);
+ }
+
+ private void BadArg(ExprNode arg, ExprTypeKind kind)
+ {
+ if (!arg.IsError)
+ {
+ var str = GetStr(kind);
+ if (str != null)
+ PostError(arg, "Invalid {0} operand", str);
+ else
+ PostError(arg, "Invalid operand");
+ }
+ _host.Assert(HasErrors);
+ }
+
+ public override void PostVisit(UnaryOpNode node)
+ {
+ _host.AssertValue(node);
+ var arg = node.Arg;
+ switch (node.Op)
+ {
+ case UnaryOp.Minus:
+ switch (arg.ExprType)
+ {
+ default:
+ BadNum(node, arg);
+ break;
+ case ExprTypeKind.I4:
+ node.SetValue(-(I4?)arg.ExprValue);
+ break;
+ case ExprTypeKind.I8:
+ node.SetValue(-(I8?)arg.ExprValue);
+ break;
+ case ExprTypeKind.R4:
+ node.SetValue(-(R4?)arg.ExprValue);
+ break;
+ case ExprTypeKind.R8:
+ node.SetValue(-(R8?)arg.ExprValue);
+ break;
+ }
+ break;
+
+ case UnaryOp.Not:
+ BL? bl = GetBoolOp(node.Arg);
+ if (bl != null)
+ node.SetValue(!bl.Value);
+ else
+ node.SetValue(bl);
+ break;
+
+ default:
+ _host.Assert(false);
+ PostError(node, "Unknown unary operator");
+ node.SetType(ExprTypeKind.Error);
+ break;
+ }
+ }
+
+ private BL? GetBoolOp(ExprNode arg)
+ {
+ _host.AssertValue(arg);
+ if (arg.IsBool)
+ return (BL?)arg.ExprValue;
+ BadArg(arg, ExprTypeKind.BL);
+ return null;
+ }
+
+ public override void PostVisit(BinaryOpNode node)
+ {
+ _host.AssertValue(node);
+
+ // REVIEW: We should really use the standard function overload resolution
+ // mechanism that CallNode binding uses. That would ensure that our type promotion
+ // and resolution mechanisms are consistent.
+ switch (node.Op)
+ {
+ case BinaryOp.Coalesce:
+ if (!node.Left.IsRx)
+ {
+ BadArg(node, node.Left.ExprType);
+ node.SetType(ExprTypeKind.Error);
+ }
+ else // Default to numeric.
+ ApplyNumericBinOp(node);
+ break;
+
+ case BinaryOp.Or:
+ case BinaryOp.And:
+ ApplyBoolBinOp(node);
+ break;
+
+ case BinaryOp.Add:
+ case BinaryOp.Sub:
+ case BinaryOp.Mul:
+ case BinaryOp.Div:
+ case BinaryOp.Mod:
+ case BinaryOp.Power:
+ ApplyNumericBinOp(node);
+ break;
+
+ case BinaryOp.Error:
+ _host.Assert(HasErrors);
+ node.SetType(ExprTypeKind.Error);
+ break;
+
+ default:
+ _host.Assert(false);
+ PostError(node, "Unknown binary operator");
+ node.SetType(ExprTypeKind.Error);
+ break;
+ }
+ }
+
+ private void ApplyBoolBinOp(BinaryOpNode node)
+ {
+ _host.AssertValue(node);
+ _host.Assert(node.Op == BinaryOp.And || node.Op == BinaryOp.Or || node.Op == BinaryOp.Coalesce);
+
+ node.SetType(ExprTypeKind.BL);
+
+ BL? v1 = GetBoolOp(node.Left);
+ BL? v2 = GetBoolOp(node.Right);
+ switch (node.Op)
+ {
+ case BinaryOp.Or:
+ if (v1 != null && v2 != null)
+ node.SetValue(v1.Value | v2.Value);
+ else if (v1 != null && v1.Value || v2 != null && v2.Value)
+ node.SetValue(true);
+ else if (v1 != null && !v1.Value)
+ node.ReduceToRight = true;
+ else if (v2 != null && !v2.Value)
+ node.ReduceToLeft = true;
+ break;
+
+ case BinaryOp.And:
+ if (v1 != null && v2 != null)
+ node.SetValue(v1.Value & v2.Value);
+ else if (v1 != null && !v1.Value || v2 != null && !v2.Value)
+ node.SetValue(false);
+ else if (v1 != null && v1.Value)
+ node.ReduceToRight = true;
+ else if (v2 != null && v2.Value)
+ node.ReduceToLeft = true;
+ break;
+
+ case BinaryOp.Coalesce:
+ if (v1 != null)
+ node.SetValue(v1);
+ break;
+ }
+
+ _host.Assert(node.IsBool);
+ }
+
+ ///
+ /// Reconcile the types of the two ExprNodes. Favor numeric types in cases
+ /// where the types can't be reconciled. This does not guarantee that
+ /// the resulting kind is numeric, eg, if both a and b are of type Text, it
+ /// simply sets kind to Text.
+ ///
+ private void ReconcileNumericTypes(ExprNode a, ExprNode b, out ExprTypeKind kind)
+ {
+ _host.AssertValue(a);
+ _host.AssertValue(b);
+
+ // REVIEW: Consider converting I4 + R4 to R8, unless the I4
+ // is a constant known to not lose precision when converted to R4.
+ if (!CanPromote(false, a.ExprType, b.ExprType, out kind))
+ {
+ // If either is numeric, use that numeric type.
+ if (a.IsNumber)
+ kind = a.ExprType;
+ else if (b.IsNumber)
+ kind = b.ExprType;
+ else // Default to Float (for error reporting).
+ kind = ExprTypeKind.Float;
+ _host.Assert(MapKindToIndex(kind) >= 0);
+ }
+ }
+
+ private void ApplyNumericBinOp(BinaryOpNode node)
+ {
+ _host.AssertValue(node);
+
+ var left = node.Left;
+ var right = node.Right;
+ ExprTypeKind kind;
+ ReconcileNumericTypes(left, right, out kind);
+
+ // REVIEW: Should we prohibit constant evaluations that produce NA?
+ switch (kind)
+ {
+ default:
+ // Default to Float (for error reporting).
+ goto case ExprTypeKind.Float;
+
+ case ExprTypeKind.I4:
+ {
+ node.SetType(ExprTypeKind.I4);
+ I4? v1;
+ I4? v2;
+ // Boiler plate below here...
+ bool f1 = left.TryGet(out v1);
+ bool f2 = right.TryGet(out v2);
+ if (!f1)
+ BadNum(left);
+ else if (!f2)
+ BadNum(right);
+ else
+ ReduceBinOp(node, v1, v2);
+ }
+ break;
+ case ExprTypeKind.I8:
+ {
+ node.SetType(ExprTypeKind.I8);
+ I8? v1;
+ I8? v2;
+ // Boiler plate below here...
+ bool f1 = left.TryGet(out v1);
+ bool f2 = right.TryGet(out v2);
+ if (!f1)
+ BadNum(left);
+ else if (!f2)
+ BadNum(right);
+ else
+ ReduceBinOp(node, v1, v2);
+ }
+ break;
+ case ExprTypeKind.R4:
+ {
+ node.SetType(ExprTypeKind.R4);
+ R4? v1;
+ R4? v2;
+ // Boiler plate below here...
+ bool f1 = left.TryGet(out v1);
+ bool f2 = right.TryGet(out v2);
+ if (!f1)
+ BadNum(left);
+ else if (!f2)
+ BadNum(right);
+ else
+ ReduceBinOp(node, v1, v2);
+ }
+ break;
+ case ExprTypeKind.R8:
+ {
+ node.SetType(ExprTypeKind.R8);
+ R8? v1;
+ R8? v2;
+ // Boiler plate below here...
+ bool f1 = left.TryGet(out v1);
+ bool f2 = right.TryGet(out v2);
+ if (!f1)
+ BadNum(left);
+ else if (!f2)
+ BadNum(right);
+ else
+ ReduceBinOp(node, v1, v2);
+ }
+ break;
+ }
+ }
+
+ #region ReduceBinOp
+
+ // The I4 and I8 methods are identical, as are the R4 and R8 methods.
+ private void ReduceBinOp(BinaryOpNode node, I4? a, I4? b)
+ {
+ if (a != null && b != null)
+ node.SetValue(BinOp(node, a.Value, b.Value));
+ else if (a != null)
+ {
+ // Special reductions when only the left value is known.
+ var v = a.Value;
+ switch (node.Op)
+ {
+ case BinaryOp.Add:
+ if (v == 0)
+ node.ReduceToRight = true;
+ break;
+ case BinaryOp.Mul:
+ if (v == 1)
+ node.ReduceToRight = true;
+ break;
+ }
+ }
+ else if (b != null)
+ {
+ // Special reductions when only the right value is known.
+ var v = b.Value;
+ switch (node.Op)
+ {
+ case BinaryOp.Add:
+ if (v == 0)
+ node.ReduceToLeft = true;
+ break;
+ case BinaryOp.Mul:
+ if (v == 1)
+ node.ReduceToLeft = true;
+ break;
+ }
+ }
+ }
+
+ private void ReduceBinOp(BinaryOpNode node, I8? a, I8? b)
+ {
+ if (a != null && b != null)
+ node.SetValue(BinOp(node, a.Value, b.Value));
+ else if (a != null)
+ {
+ // Special reductions when only the left value is known.
+ var v = a.Value;
+ switch (node.Op)
+ {
+ case BinaryOp.Add:
+ if (v == 0)
+ node.ReduceToRight = true;
+ break;
+ case BinaryOp.Mul:
+ if (v == 1)
+ node.ReduceToRight = true;
+ break;
+ }
+ }
+ else if (b != null)
+ {
+ // Special reductions when only the right value is known.
+ var v = b.Value;
+ switch (node.Op)
+ {
+ case BinaryOp.Add:
+ if (v == 0)
+ node.ReduceToLeft = true;
+ break;
+ case BinaryOp.Mul:
+ if (v == 1)
+ node.ReduceToLeft = true;
+ break;
+ }
+ }
+ }
+
+ private void ReduceBinOp(BinaryOpNode node, R4? a, R4? b)
+ {
+ if (a != null && b != null)
+ node.SetValue(BinOp(node, a.Value, b.Value));
+ else if (a != null)
+ {
+ // Special reductions when only the left value is known.
+ var v = a.Value;
+ switch (node.Op)
+ {
+ case BinaryOp.Coalesce:
+ if (!R4.IsNaN(v))
+ node.SetValue(v);
+ else
+ node.ReduceToRight = true;
+ break;
+ case BinaryOp.Add:
+ if (R4.IsNaN(v))
+ node.SetValue(v);
+ else if (v == 0)
+ node.ReduceToRight = true;
+ break;
+ case BinaryOp.Mul:
+ if (R4.IsNaN(v))
+ node.SetValue(v);
+ else if (v == 1)
+ node.ReduceToRight = true;
+ break;
+ case BinaryOp.Sub:
+ case BinaryOp.Div:
+ case BinaryOp.Mod:
+ if (R4.IsNaN(v))
+ node.SetValue(v);
+ break;
+ }
+ }
+ else if (b != null)
+ {
+ // Special reductions when only the right value is known.
+ var v = b.Value;
+ switch (node.Op)
+ {
+ case BinaryOp.Coalesce:
+ if (R4.IsNaN(v))
+ node.ReduceToLeft = true;
+ break;
+ case BinaryOp.Add:
+ if (R4.IsNaN(v))
+ node.SetValue(v);
+ else if (v == 0)
+ node.ReduceToLeft = true;
+ break;
+ case BinaryOp.Mul:
+ if (R4.IsNaN(v))
+ node.SetValue(v);
+ else if (v == 1)
+ node.ReduceToLeft = true;
+ break;
+ case BinaryOp.Sub:
+ case BinaryOp.Div:
+ case BinaryOp.Mod:
+ if (R4.IsNaN(v))
+ node.SetValue(v);
+ break;
+ }
+ }
+ }
+
+ private void ReduceBinOp(BinaryOpNode node, R8? a, R8? b)
+ {
+ if (a != null && b != null)
+ node.SetValue(BinOp(node, a.Value, b.Value));
+ else if (a != null)
+ {
+ // Special reductions when only the left value is known.
+ var v = a.Value;
+ switch (node.Op)
+ {
+ case BinaryOp.Coalesce:
+ if (!R8.IsNaN(v))
+ node.SetValue(v);
+ else
+ node.ReduceToRight = true;
+ break;
+ case BinaryOp.Add:
+ if (R8.IsNaN(v))
+ node.SetValue(v);
+ else if (v == 0)
+ node.ReduceToRight = true;
+ break;
+ case BinaryOp.Mul:
+ if (R8.IsNaN(v))
+ node.SetValue(v);
+ else if (v == 1)
+ node.ReduceToRight = true;
+ break;
+ case BinaryOp.Sub:
+ case BinaryOp.Div:
+ case BinaryOp.Mod:
+ if (R8.IsNaN(v))
+ node.SetValue(v);
+ break;
+ }
+ }
+ else if (b != null)
+ {
+ // Special reductions when only the right value is known.
+ var v = b.Value;
+ switch (node.Op)
+ {
+ case BinaryOp.Coalesce:
+ if (R8.IsNaN(v))
+ node.ReduceToLeft = true;
+ break;
+ case BinaryOp.Add:
+ if (R8.IsNaN(v))
+ node.SetValue(v);
+ else if (v == 0)
+ node.ReduceToLeft = true;
+ break;
+ case BinaryOp.Mul:
+ if (R8.IsNaN(v))
+ node.SetValue(v);
+ else if (v == 1)
+ node.ReduceToLeft = true;
+ break;
+ case BinaryOp.Sub:
+ case BinaryOp.Div:
+ case BinaryOp.Mod:
+ if (R8.IsNaN(v))
+ node.SetValue(v);
+ break;
+ }
+ }
+ }
+
+ #endregion ReduceBinOp
+
+ #region BinOp
+
+ private I4 BinOp(BinaryOpNode node, I4 v1, I4 v2)
+ {
+ switch (node.Op)
+ {
+ case BinaryOp.Add:
+ return v1 + v2;
+ case BinaryOp.Sub:
+ return v1 - v2;
+ case BinaryOp.Mul:
+ return v1 * v2;
+ case BinaryOp.Div:
+ return v1 / v2;
+ case BinaryOp.Mod:
+ return v1 % v2;
+ case BinaryOp.Power:
+ return BuiltinFunctions.Pow(v1, v2);
+ default:
+ _host.Assert(false);
+ throw Contracts.Except();
+ }
+ }
+
+ private I8 BinOp(BinaryOpNode node, I8 v1, I8 v2)
+ {
+ switch (node.Op)
+ {
+ case BinaryOp.Add:
+ return v1 + v2;
+ case BinaryOp.Sub:
+ return v1 - v2;
+ case BinaryOp.Mul:
+ return v1 * v2;
+ case BinaryOp.Div:
+ return v1 / v2;
+ case BinaryOp.Mod:
+ return v1 % v2;
+ case BinaryOp.Power:
+ return BuiltinFunctions.Pow(v1, v2);
+ default:
+ _host.Assert(false);
+ throw Contracts.Except();
+ }
+ }
+
+ private R4 BinOp(BinaryOpNode node, R4 v1, R4 v2)
+ {
+ switch (node.Op)
+ {
+ case BinaryOp.Coalesce:
+ return !R4.IsNaN(v1) ? v1 : v2;
+ case BinaryOp.Add:
+ return v1 + v2;
+ case BinaryOp.Sub:
+ return v1 - v2;
+ case BinaryOp.Mul:
+ return v1 * v2;
+ case BinaryOp.Div:
+ return v1 / v2;
+ case BinaryOp.Mod:
+ return v1 % v2;
+ case BinaryOp.Power:
+ return BuiltinFunctions.Pow(v1, v2);
+ default:
+ _host.Assert(false);
+ return R4.NaN;
+ }
+ }
+
+ private R8 BinOp(BinaryOpNode node, R8 v1, R8 v2)
+ {
+ switch (node.Op)
+ {
+ case BinaryOp.Coalesce:
+ return !R8.IsNaN(v1) ? v1 : v2;
+ case BinaryOp.Add:
+ return v1 + v2;
+ case BinaryOp.Sub:
+ return v1 - v2;
+ case BinaryOp.Mul:
+ return v1 * v2;
+ case BinaryOp.Div:
+ return v1 / v2;
+ case BinaryOp.Mod:
+ return v1 % v2;
+ case BinaryOp.Power:
+ return Math.Pow(v1, v2);
+ default:
+ _host.Assert(false);
+ return R8.NaN;
+ }
+ }
+ #endregion BinOp
+
+ public override void PostVisit(ConditionalNode node)
+ {
+ _host.AssertValue(node);
+
+ BL? cond = GetBoolOp(node.Cond);
+
+ var left = node.Left;
+ var right = node.Right;
+ ExprTypeKind kind;
+ if (!CanPromote(false, left.ExprType, right.ExprType, out kind))
+ {
+ // If either is numeric, use that numeric type. Otherwise, use the first
+ // that isn't error or none.
+ if (left.IsNumber)
+ kind = left.ExprType;
+ else if (right.IsNumber)
+ kind = right.ExprType;
+ else if (!left.IsError && !left.IsNone)
+ kind = left.ExprType;
+ else if (!right.IsError && !right.IsNone)
+ kind = right.ExprType;
+ else
+ kind = ExprTypeKind.None;
+ }
+
+ switch (kind)
+ {
+ default:
+ PostError(node, "Invalid conditional expression");
+ node.SetType(ExprTypeKind.Error);
+ break;
+
+ case ExprTypeKind.BL:
+ {
+ node.SetType(ExprTypeKind.BL);
+ BL? v1 = GetBoolOp(node.Left);
+ BL? v2 = GetBoolOp(node.Right);
+ if (cond != null)
+ {
+ if (cond.Value)
+ node.SetValue(v1);
+ else
+ node.SetValue(v2);
+ }
+ }
+ break;
+ case ExprTypeKind.I4:
+ {
+ node.SetType(ExprTypeKind.I4);
+ I4? v1;
+ I4? v2;
+ bool f1 = left.TryGet(out v1);
+ bool f2 = right.TryGet(out v2);
+ if (!f1)
+ BadNum(left);
+ if (!f2)
+ BadNum(right);
+ if (cond != null)
+ {
+ if (cond.Value)
+ node.SetValue(v1);
+ else
+ node.SetValue(v2);
+ }
+ }
+ break;
+ case ExprTypeKind.I8:
+ {
+ node.SetType(ExprTypeKind.I8);
+ I8? v1;
+ I8? v2;
+ bool f1 = left.TryGet(out v1);
+ bool f2 = right.TryGet(out v2);
+ if (!f1)
+ BadNum(left);
+ if (!f2)
+ BadNum(right);
+ if (cond != null)
+ {
+ if (cond.Value)
+ node.SetValue(v1);
+ else
+ node.SetValue(v2);
+ }
+ }
+ break;
+ case ExprTypeKind.R4:
+ {
+ node.SetType(ExprTypeKind.R4);
+ R4? v1;
+ R4? v2;
+ bool f1 = left.TryGet(out v1);
+ bool f2 = right.TryGet(out v2);
+ if (!f1)
+ BadNum(left);
+ if (!f2)
+ BadNum(right);
+ if (cond != null)
+ {
+ if (cond.Value)
+ node.SetValue(v1);
+ else
+ node.SetValue(v2);
+ }
+ }
+ break;
+ case ExprTypeKind.R8:
+ {
+ node.SetType(ExprTypeKind.R8);
+ R8? v1;
+ R8? v2;
+ bool f1 = left.TryGet(out v1);
+ bool f2 = right.TryGet(out v2);
+ if (!f1)
+ BadNum(left);
+ if (!f2)
+ BadNum(right);
+ if (cond != null)
+ {
+ if (cond.Value)
+ node.SetValue(v1);
+ else
+ node.SetValue(v2);
+ }
+ }
+ break;
+ case ExprTypeKind.TX:
+ {
+ node.SetType(ExprTypeKind.TX);
+ TX? v1;
+ TX? v2;
+ bool f1 = left.TryGet(out v1);
+ bool f2 = right.TryGet(out v2);
+ if (!f1)
+ BadText(left);
+ if (!f2)
+ BadText(right);
+ if (cond != null)
+ {
+ if (cond.Value)
+ node.SetValue(v1);
+ else
+ node.SetValue(v2);
+ }
+ }
+ break;
+ }
+ }
+
+ public override void PostVisit(CompareNode node)
+ {
+ _host.AssertValue(node);
+
+ TokKind tidLax = node.TidLax;
+ TokKind tidStrict = node.TidStrict;
+ ExprTypeKind kind = ExprTypeKind.None;
+
+ // First validate the types.
+ ExprNode arg;
+ bool hasErrors = false;
+ var items = node.Operands.Items;
+ for (int i = 0; i < items.Length; i++)
+ {
+ arg = items[i].AsExpr;
+ if (!ValidateType(arg, ref kind))
+ {
+ BadArg(arg, kind);
+ hasErrors = true;
+ }
+ }
+
+ // Set the arg type and the type of this node.
+ node.ArgTypeKind = kind;
+ node.SetType(ExprTypeKind.BL);
+
+ if (hasErrors)
+ {
+ _host.Assert(HasErrors);
+ return;
+ }
+
+ // Find the number of initial constant inputs in "lim" and convert the args to "kind".
+ int lim = items.Length;
+ int count = lim;
+ for (int i = 0; i < count; i++)
+ {
+ arg = items[i].AsExpr;
+ arg.Convert(kind);
+ if (i < lim && arg.ExprValue == null)
+ lim = i;
+ }
+
+ // Now try to compute the value.
+ int ifn = (int)kind;
+ if (ifn >= _fnEqual.Length || ifn < 0)
+ {
+ _host.Assert(false);
+ PostError(node, "Internal error in CompareNode");
+ return;
+ }
+
+ Cmp cmpLax;
+ Cmp cmpStrict;
+ switch (node.Op)
+ {
+ case CompareOp.DecrChain:
+ cmpLax = _fnGreaterEqual[ifn];
+ cmpStrict = _fnGreater[ifn];
+ break;
+ case CompareOp.IncrChain:
+ cmpLax = _fnLessEqual[ifn];
+ cmpStrict = _fnLess[ifn];
+ break;
+ case CompareOp.Equal:
+ cmpLax = _fnEqual[ifn];
+ cmpStrict = cmpLax;
+ break;
+ case CompareOp.NotEqual:
+ cmpLax = _fnNotEqual[ifn];
+ cmpStrict = cmpLax;
+ break;
+ default:
+ _host.Assert(false);
+ return;
+ }
+
+ _host.Assert((cmpLax == null) == (cmpStrict == null));
+ if (cmpLax == null)
+ {
+ PostError(node, "Bad operands for comparison");
+ return;
+ }
+
+ // If one of the first two operands is NA, the result is NA, even if the other operand
+ // is not a constant.
+ object value;
+ if (lim < 2 && (value = items[1 - lim].AsExpr.ExprValue) != null && !cmpLax(value, value).HasValue)
+ {
+ node.SetValue(default(BL?));
+ return;
+ }
+
+ // See if we can reduce to a constant BL value.
+ if (lim >= 2)
+ {
+ if (node.Op != CompareOp.NotEqual)
+ {
+ // Note: this loop doesn't work for != when there are more than two operands,
+ // so != is handled separately below.
+ bool isStrict = false;
+ arg = items[0].AsExpr;
+ _host.Assert(arg.ExprType == kind);
+ var valuePrev = arg.ExprValue;
+ _host.Assert(valuePrev != null);
+ for (int i = 1; i < lim; i++)
+ {
+ TokKind tid = node.Operands.Delimiters[i - 1].Kind;
+ _host.Assert(tid == tidLax || tid == tidStrict);
+
+ if (tid == tidStrict)
+ isStrict = true;
+
+ arg = items[i].AsExpr;
+ _host.Assert(arg.ExprType == kind);
+
+ value = arg.ExprValue;
+ _host.Assert(value != null);
+ BL? res = isStrict ? cmpStrict(valuePrev, value) : cmpLax(valuePrev, value);
+ if (res == null || !res.Value)
+ {
+ node.SetValue(false);
+ return;
+ }
+ valuePrev = value;
+ isStrict = false;
+ }
+ }
+ else
+ {
+ // NotEqual is special - it means that the values are all distinct, so comparing adjacent
+ // items is not enough.
+ for (int i = 1; i < lim; i++)
+ {
+ arg = items[i].AsExpr;
+ _host.Assert(arg.ExprType == kind);
+
+ value = arg.ExprValue;
+ _host.Assert(value != null);
+ for (int j = 0; j < i; j++)
+ {
+ var arg2 = items[j].AsExpr;
+ _host.Assert(arg2.ExprType == kind);
+
+ var value2 = arg2.ExprValue;
+ _host.Assert(value2 != null);
+ BL? res = cmpStrict(value2, value);
+ if (res == null || !res.Value)
+ {
+ node.SetValue(res);
+ return;
+ }
+ }
+ }
+ }
+
+ if (lim == count)
+ node.SetValue(true);
+ }
+ }
+
+ private sealed class Candidate
+ {
+ public readonly IFunctionProvider Provider;
+ public readonly MethodInfo Method;
+ public readonly ExprTypeKind[] Kinds;
+ public readonly ExprTypeKind ReturnKind;
+ public readonly bool IsVariable;
+
+ public bool MatchesArity(int arity)
+ {
+ if (!IsVariable)
+ return arity == Kinds.Length;
+ Contracts.Assert(Kinds.Length > 0);
+ return arity >= Kinds.Length - 1;
+ }
+
+ public int Arity
+ {
+ get { return Kinds.Length; }
+ }
+
+ public bool IsIdentity
+ {
+ get { return Method.ReturnType == typeof(void); }
+ }
+
+ public static bool TryGetCandidate(CallNode node, IFunctionProvider provider, MethodInfo meth, Action printError, out Candidate cand)
+ {
+ cand = default(Candidate);
+ if (meth == null)
+ return false;
+
+ // An "identity" function has one parameter and returns void.
+ var ps = meth.GetParameters();
+ bool isIdent = ps.Length == 1 && meth.ReturnType == typeof(void);
+
+ if (!meth.IsStatic || !meth.IsPublic && !isIdent)
+ {
+ // This is an error in the extension functions, not in the user code.
+ printError(string.Format(
+ "Error in ExprTransform: Function '{0}' in namespace '{1}' must be static and public",
+ node.Head.Value, provider.NameSpace));
+ return false;
+ }
+
+ // Verify the parameter types.
+ bool isVar = false;
+ var kinds = new ExprTypeKind[ps.Length];
+ for (int i = 0; i < ps.Length; i++)
+ {
+ var type = ps[i].ParameterType;
+ if (i == ps.Length - 1 && !isIdent && type.IsArray)
+ {
+ // Last parameter is variable.
+ isVar = true;
+ type = type.GetElementType();
+ }
+ var extCur = ExprNode.ToExprTypeKind(type);
+ if (extCur <= ExprTypeKind.Error || extCur >= ExprTypeKind._Lim)
+ {
+ printError(string.Format(
+ "Error in ExprTransform: Function '{0}' in namespace '{1}' has invalid parameter type '{2}'",
+ node.Head.Value, provider.NameSpace, type));
+ return false;
+ }
+ kinds[i] = extCur;
+ }
+
+ // Verify the return type.
+ ExprTypeKind kindRet;
+ if (isIdent)
+ {
+ Contracts.Assert(kinds.Length == 1);
+ kindRet = kinds[0];
+ }
+ else
+ {
+ var extRet = ExprNode.ToExprTypeKind(meth.ReturnType);
+ kindRet = extRet;
+ if (kindRet <= ExprTypeKind.Error || kindRet >= ExprTypeKind._Lim)
+ {
+ printError(string.Format(
+ "Error in ExprTransform: Function '{0}' in namespace '{1}' has invalid return type '{2}'",
+ node.Head.Value, provider.NameSpace, meth.ReturnType));
+ return false;
+ }
+ }
+
+ cand = new Candidate(provider, meth, kinds, kindRet, isVar);
+ return true;
+ }
+
+ private Candidate(IFunctionProvider provider, MethodInfo meth, ExprTypeKind[] kinds, ExprTypeKind kindRet, bool isVar)
+ {
+ Contracts.AssertValue(provider);
+ Contracts.AssertValue(meth);
+ Contracts.AssertValue(kinds);
+ Provider = provider;
+ Method = meth;
+ Kinds = kinds;
+ ReturnKind = kindRet;
+ IsVariable = isVar;
+ }
+
+ ///
+ /// Returns whether this candidate is applicable to the given argument types.
+ ///
+ public bool IsApplicable(ExprTypeKind[] kinds, out int bad)
+ {
+ Contracts.Assert(kinds.Length == Kinds.Length || IsVariable && kinds.Length >= Kinds.Length - 1);
+
+ bad = 0;
+ int head = IsVariable ? Kinds.Length - 1 : Kinds.Length;
+
+ for (int i = 0; i < head; i++)
+ {
+ if (!CanConvert(kinds[i], Kinds[i]))
+ bad++;
+ }
+
+ if (IsVariable)
+ {
+ // Handle the tail.
+ var kind = Kinds[Kinds.Length - 1];
+ for (int i = head; i < kinds.Length; i++)
+ {
+ if (!CanConvert(kinds[i], kind))
+ bad++;
+ }
+ }
+
+ return bad == 0;
+ }
+
+ ///
+ /// Returns -1 if 'this' is better than 'other', 0 if they are the same, +1 otherwise.
+ /// Non-variable is always better than variable. When both are variable, longer prefix is
+ /// better than shorter prefix.
+ ///
+ public int CompareSignatures(Candidate other)
+ {
+ Contracts.AssertValue(other);
+
+ if (IsVariable)
+ {
+ if (!other.IsVariable)
+ return +1;
+ if (Kinds.Length != other.Kinds.Length)
+ return Kinds.Length > other.Kinds.Length ? -1 : +1;
+ }
+ else if (other.IsVariable)
+ return -1;
+
+ int cmp = 0;
+ for (int k = 0; k < Kinds.Length; k++)
+ {
+ var t1 = Kinds[k];
+ var t2 = other.Kinds[k];
+ if (t1 == t2)
+ continue;
+ if (!CanConvert(t1, t2))
+ return +1;
+ cmp = -1;
+ }
+ return cmp;
+ }
+ }
+
+ public override void PostVisit(CallNode node)
+ {
+ _host.AssertValue(node);
+
+ // Get the argument types and number of arguments.
+ var kinds = node.Args.Items.Select(item => item.AsExpr.ExprType).ToArray();
+ var arity = kinds.Length;
+
+ // Find the candidates.
+ bool hasGoodArity = false;
+ var candidates = new List();
+ foreach (var prov in _providers)
+ {
+ if (node.NameSpace != null && prov.NameSpace != node.NameSpace.Value)
+ continue;
+
+ var meths = prov.Lookup(node.Head.Value);
+ if (Utils.Size(meths) == 0)
+ continue;
+
+ foreach (var meth in meths)
+ {
+ Candidate cand;
+ if (!Candidate.TryGetCandidate(node, prov, meth, _printError, out cand))
+ continue;
+
+ bool good = cand.MatchesArity(arity);
+ if (hasGoodArity)
+ {
+ // We've seen one or more with good arity, so ignore wrong arity.
+ if (!good)
+ continue;
+ }
+ else if (good)
+ {
+ // This is the first one with good arity.
+ candidates.Clear();
+ hasGoodArity = true;
+ }
+
+ candidates.Add(cand);
+ }
+ }
+
+ if (candidates.Count == 0)
+ {
+ // Unknown function.
+ PostError(node.Head, "Unknown function");
+ node.SetType(ExprTypeKind.Error);
+ return;
+ }
+
+ if (!hasGoodArity)
+ {
+ // No overloads have the target arity. Generate an appropriate error.
+ // REVIEW: Will this be good enough with variable arity functions?
+ var arities = candidates.Select(c => c.Arity).Distinct().OrderBy(x => x).ToArray();
+ if (arities.Length == 1)
+ {
+ if (arities[0] == 1)
+ PostError(node, "Expected one argument to function '{1}'", arities[0], node.Head.Value);
+ else
+ PostError(node, "Expected {0} arguments to function '{1}'", arities[0], node.Head.Value);
+ }
+ else if (arities.Length == 2)
+ PostError(node, "Expected {0} or {1} arguments to function '{2}'", arities[0], arities[1], node.Head.Value);
+ else
+ PostError(node, "No overload of function '{0}' takes {1} arguments", node.Head.Value, arity);
+
+ // Set the type of the node. If there is only one possible type, use that, otherwise, use Error.
+ var kindsRet = candidates.Select(c => c.ReturnKind).Distinct().ToArray();
+ if (kindsRet.Length == 1)
+ node.SetType(kindsRet[0]);
+ else
+ node.SetType(ExprTypeKind.Error);
+ return;
+ }
+
+ // Count applicable candidates and move them to the front.
+ int count = 0;
+ int minBad = int.MaxValue;
+ int icandMinBad = -1;
+ for (int i = 0; i < candidates.Count; i++)
+ {
+ var cand = candidates[i];
+ int bad;
+ if (cand.IsApplicable(kinds, out bad))
+ candidates[count++] = cand;
+ else if (bad < minBad)
+ {
+ minBad = bad;
+ icandMinBad = i;
+ }
+ }
+ if (0 < count && count < candidates.Count)
+ candidates.RemoveRange(count, candidates.Count - count);
+ _host.Assert(candidates.Count > 0);
+ _host.Assert(count == 0 || count == candidates.Count);
+
+ // When there are multiple, GetBestOverload picks the one to use and emits an
+ // error message if there isn't a unique best answer.
+ Candidate best;
+ if (count > 1)
+ best = GetBestOverload(node, candidates);
+ else if (count == 1)
+ best = candidates[0];
+ else
+ {
+ _host.Assert(0 <= icandMinBad & icandMinBad < candidates.Count);
+ best = candidates[icandMinBad];
+ PostError(node, "The best overload of '{0}' has some invalid arguments", node.Head.Value);
+ }
+
+ // First convert the arguments to the proper types and get any constant values.
+ var args = new object[node.Args.Items.Length];
+ bool all = true;
+ // For variable, limit the index into best.Kinds to ivMax.
+ int ivMax = best.Kinds.Length - 1;
+ for (int i = 0; i < node.Args.Items.Length; i++)
+ {
+ args[i] = Convert(node.Args.Items[i].AsExpr, best.Kinds[Math.Min(i, ivMax)]);
+ all &= args[i] != null;
+ }
+
+ object res;
+ if (best.IsIdentity)
+ {
+ _host.Assert(!best.IsVariable);
+ _host.Assert(best.Arity == 1);
+ res = args[0];
+ }
+ else if (!all)
+ {
+ res = best.Provider.ResolveToConstant(node.Head.Value, best.Method, args);
+ if (res != null && res.GetType() != best.Method.ReturnType)
+ {
+ _printError(string.Format(
+ "Error in ExprTransform: Function '{0}' in namespace '{1}' produced wrong constant value type '{2}' vs '{3}'",
+ node.Head.Value, best.Provider.NameSpace, res.GetType(), best.Method.ReturnType));
+ res = null;
+ }
+ }
+ else
+ {
+ if (best.IsVariable)
+ {
+ int head = best.Kinds.Length - 1;
+ int tail = args.Length - head;
+ _host.Assert(tail >= 0);
+ var type = best.Method.GetParameters()[ivMax].ParameterType;
+ _host.Assert(type.IsArray);
+ type = type.GetElementType();
+ Array rest = Array.CreateInstance(type, tail);
+ for (int i = 0; i < tail; i++)
+ rest.SetValue(args[head + i], i);
+ Array.Resize(ref args, head + 1);
+ args[head] = rest;
+ }
+
+ res = best.Method.Invoke(null, args);
+ _host.Assert(res != null);
+ _host.Assert(res.GetType() == best.Method.ReturnType);
+ }
+
+ node.SetType(best.ReturnKind, res);
+ node.SetMethod(best.Method);
+ }
+
+ ///
+ /// Returns whether the given source type can be converted to the given destination type,
+ /// for the purposes of function invocation. Returns true if src is null and dst is any
+ /// valid type.
+ ///
+ private static bool CanConvert(ExprTypeKind src, ExprTypeKind dst)
+ {
+ // src can be Error, but dst should not be.
+ Contracts.Assert(ExprTypeKind.Error <= src & src < ExprTypeKind._Lim);
+ Contracts.Assert(ExprTypeKind.Error < dst & dst < ExprTypeKind._Lim);
+
+ if (src == ExprTypeKind.Error)
+ return true;
+
+ if (src == dst)
+ return true;
+ if (src == ExprTypeKind.I4)
+ return dst == ExprTypeKind.I8 || dst == ExprTypeKind.R4 || dst == ExprTypeKind.R8;
+ if (src == ExprTypeKind.I8)
+ return dst == ExprTypeKind.R8;
+ if (src == ExprTypeKind.R4)
+ return dst == ExprTypeKind.R8;
+ return false;
+ }
+
+ ///
+ /// Convert the given ExprNode to the given type and get its value, when constant.
+ ///
+ private object Convert(ExprNode expr, ExprTypeKind kind)
+ {
+ switch (kind)
+ {
+ case ExprTypeKind.BL:
+ {
+ BL? val;
+ if (!expr.TryGet(out val))
+ BadArg(expr, ExprTypeKind.BL);
+ return val;
+ }
+ case ExprTypeKind.I4:
+ {
+ I4? val;
+ if (!expr.TryGet(out val))
+ BadArg(expr, ExprTypeKind.I4);
+ return val;
+ }
+ case ExprTypeKind.I8:
+ {
+ I8? val;
+ if (!expr.TryGet(out val))
+ BadArg(expr, ExprTypeKind.I8);
+ return val;
+ }
+ case ExprTypeKind.R4:
+ {
+ R4? val;
+ if (!expr.TryGet(out val))
+ BadArg(expr, ExprTypeKind.R4);
+ return val;
+ }
+ case ExprTypeKind.R8:
+ {
+ R8? val;
+ if (!expr.TryGet(out val))
+ BadArg(expr, ExprTypeKind.R8);
+ return val;
+ }
+ case ExprTypeKind.TX:
+ {
+ TX? val;
+ if (!expr.TryGet(out val))
+ BadArg(expr, ExprTypeKind.TX);
+ return val;
+ }
+ default:
+ _host.Assert(false, "Unexpected type in Convert");
+ PostError(expr, "Internal error in Convert");
+ return null;
+ }
+ }
+
+ ///
+ /// Multiple applicable candidates; pick the best one. We use a simplification of
+ /// C#'s rules, only considering the types TX, BL, I4, I8, R4, R8. Basically, parameter
+ /// type X is better than parameter type Y if X can be converted to Y. The conversions are:
+ /// I4 => I8, I4 => R4, I4 => R8, I8 => R8, R4 => R8.
+ ///
+ private Candidate GetBestOverload(CallNode node, List candidates)
+ {
+ _host.Assert(Utils.Size(candidates) >= 2);
+
+ var dup1 = default(Candidate);
+ var dup2 = default(Candidate);
+ for (int i = 0; i < candidates.Count; i++)
+ {
+ var c1 = candidates[i];
+ int dup = -1;
+ for (int j = 0; ; j++)
+ {
+ if (j == i)
+ continue;
+
+ if (j >= candidates.Count)
+ {
+ if (dup < 0)
+ return c1;
+ if (dup1 == null)
+ {
+ dup1 = c1;
+ dup2 = candidates[dup];
+ }
+ break;
+ }
+
+ int cmp = c1.CompareSignatures(candidates[j]);
+
+ // Break if c1 isn't better.
+ if (cmp > 0)
+ break;
+ if (cmp == 0)
+ dup = j;
+ }
+ }
+
+ _host.Assert((dup1 != null) == (dup2 != null));
+ if (dup1 != null)
+ {
+ if (dup1.Provider.NameSpace.CompareTo(dup2.Provider.NameSpace) > 0)
+ Utils.Swap(ref dup1, ref dup2);
+ PostError(node, "Duplicate candidate functions in namespaces '{0}' and '{1}'",
+ dup1.Provider.NameSpace, dup2.Provider.NameSpace);
+ }
+ else
+ PostError(node, "Ambiguous invocation of function '{0}'", node.Head.Value);
+
+ return dup1 ?? candidates[0];
+ }
+
+ public override void PostVisit(ListNode node)
+ {
+ _host.AssertValue(node);
+ }
+
+ public override bool PreVisit(WithNode node)
+ {
+ _host.AssertValue(node);
+
+ // First bind the value expressions.
+ node.Local.Accept(this);
+
+ // Push the with.
+ int iwith = _rgwith.Count;
+ _rgwith.Add(node);
+
+ // Bind the body.
+ node.Body.Accept(this);
+
+ // Pop the var context.
+ _host.Assert(_rgwith.Count == iwith + 1);
+ _host.Assert(_rgwith[iwith] == node);
+ _rgwith.RemoveAt(iwith);
+ _host.Assert(_rgwith.Count == iwith);
+
+ node.SetValue(node.Body);
+
+ return false;
+ }
+
+ public override void PostVisit(WithNode node)
+ {
+ _host.Assert(false);
+ }
+
+ public override void PostVisit(WithLocalNode node)
+ {
+ _host.AssertValue(node);
+ }
+
+ // This aggregates the type of expr into itemKind. It returns false
+ // if an error condition is encountered. This takes into account
+ // possible conversions.
+ private bool ValidateType(ExprNode expr, ref ExprTypeKind itemKind)
+ {
+ _host.AssertValue(expr);
+ _host.Assert(expr.ExprType != 0);
+
+ ExprTypeKind kind = expr.ExprType;
+ switch (kind)
+ {
+ case ExprTypeKind.Error:
+ _host.Assert(HasErrors);
+ return false;
+ case ExprTypeKind.None:
+ return false;
+ }
+
+ if (kind == itemKind)
+ return true;
+
+ switch (itemKind)
+ {
+ case ExprTypeKind.Error:
+ // This is the first non-error item type we've seen.
+ _host.Assert(HasErrors);
+ itemKind = kind;
+ return true;
+ case ExprTypeKind.None:
+ // This is the first non-error item type we've seen.
+ itemKind = kind;
+ return true;
+ }
+
+ ExprTypeKind kindNew;
+ if (!CanPromote(true, kind, itemKind, out kindNew))
+ return false;
+
+ itemKind = kindNew;
+ return true;
+ }
+
+ internal static bool CanPromote(bool precise, ExprTypeKind k1, ExprTypeKind k2, out ExprTypeKind res)
+ {
+ if (k1 == k2)
+ {
+ res = k1;
+ if (res != ExprTypeKind.Error && res != ExprTypeKind.None)
+ return true;
+ res = ExprTypeKind.Error;
+ return false;
+ }
+
+ // Encode numeric types in a two-bit value.
+ int i1 = MapKindToIndex(k1);
+ int i2 = MapKindToIndex(k2);
+ if (i1 < 0 || i2 < 0)
+ {
+ res = ExprTypeKind.Error;
+ return false;
+ }
+
+ Contracts.Assert(0 <= i1 & i1 < 4);
+ Contracts.Assert(0 <= i2 & i2 < 4);
+ Contracts.Assert(i1 != i2);
+
+ // Combine the two two-bit values.
+ int index = i1 | (i2 << 2);
+ Contracts.Assert(0 <= index & index < 16);
+ switch (index)
+ {
+ // Only integer types -> I8
+ case 0x1:
+ case 0x4:
+ res = ExprTypeKind.I8;
+ return true;
+
+ // R4 and I4 -> R8 for precise and R4 otherwise.
+ case 0x2:
+ case 0x8:
+ res = precise ? ExprTypeKind.R8 : ExprTypeKind.R4;
+ return true;
+
+ // At least one RX type and at least one 8-byte type -> R8
+ case 0x3:
+ case 0x6:
+ case 0x7:
+ case 0x9:
+ case 0xB:
+ case 0xC:
+ case 0xD:
+ case 0xE:
+ res = ExprTypeKind.R8;
+ return true;
+
+ default:
+ Contracts.Assert(false);
+ res = ExprTypeKind.Error;
+ return false;
+ }
+ }
+
+ ///
+ /// Maps numeric type kinds to an index 0,1,2,3. All others map to -1.
+ ///
+ private static int MapKindToIndex(ExprTypeKind kind)
+ {
+ switch (kind)
+ {
+ case ExprTypeKind.I4:
+ return 0;
+ case ExprTypeKind.I8:
+ return 1;
+ case ExprTypeKind.R4:
+ return 2;
+ case ExprTypeKind.R8:
+ return 3;
+ }
+ return -1;
+ }
+ }
+
+ internal sealed partial class LambdaBinder : NodeVisitor
+ {
+ // This partial contains stuff needed for equality and ordered comparison.
+ private static T Cast(object a)
+ {
+ Contracts.Assert(a is T);
+ return (T)a;
+ }
+
+ private delegate BL? Cmp(object a, object b);
+
+ // Indexed by ExprTypeKind.
+ private static readonly Cmp[] _fnEqual = new Cmp[(int)ExprTypeKind._Lim]
+ {
+ // None, Error
+ null, null,
+
+ (a,b) => Cast(a) == Cast(b),
+ (a,b) => Cast(a) == Cast(b),
+ (a,b) => Cast(a) == Cast(b),
+ (a,b) => { var x = Cast(a); var y = Cast(b); if (x == y) return true; if (!R4.IsNaN(x) && !R4.IsNaN(y)) return false; return null; },
+ (a,b) => { var x = Cast(a); var y = Cast(b); if (x == y) return true; if (!R8.IsNaN(x) && !R8.IsNaN(y)) return false; return null; },
+ (a,b) => Cast(a).Span.SequenceEqual(Cast(b).Span),
+ };
+
+ // Indexed by ExprTypeKind.
+ private static readonly Cmp[] _fnNotEqual = new Cmp[(int)ExprTypeKind._Lim]
+ {
+ // None, Error
+ null, null,
+
+ (a,b) => Cast(a) != Cast(b),
+ (a,b) => Cast(a) != Cast(b),
+ (a,b) => Cast(a) != Cast(b),
+ (a,b) => { var x = Cast(a); var y = Cast(b); if (x == y) return false; if (!R4.IsNaN(x) && !R4.IsNaN(y)) return true; return null; },
+ (a,b) => { var x = Cast(a); var y = Cast(b); if (x == y) return false; if (!R8.IsNaN(x) && !R8.IsNaN(y)) return true; return null; },
+ (a,b) => !Cast(a).Span.SequenceEqual(Cast(b).Span),
+ };
+
+ // Indexed by ExprTypeKind.
+ private static readonly Cmp[] _fnLess = new Cmp[(int)ExprTypeKind._Lim]
+ {
+ // None, Error
+ null, null,
+
+ null,
+ (a,b) => Cast(a) < Cast(b),
+ (a,b) => Cast(a) < Cast(b),
+ (a,b) => { var x = Cast(a); var y = Cast(b); if (x < y) return true; if (x >= y) return false; return null; },
+ (a,b) => { var x = Cast(a); var y = Cast(b); if (x < y) return true; if (x >= y) return false; return null; },
+ null,
+ };
+
+ // Indexed by ExprTypeKind.
+ private static readonly Cmp[] _fnLessEqual = new Cmp[(int)ExprTypeKind._Lim]
+ {
+ // None, Error
+ null, null,
+
+ null,
+ (a,b) => Cast(a) <= Cast(b),
+ (a,b) => Cast(a) <= Cast(b),
+ (a,b) => { var x = Cast(a); var y = Cast(b); if (x <= y) return true; if (x > y) return false; return null; },
+ (a,b) => { var x = Cast(a); var y = Cast(b); if (x <= y) return true; if (x > y) return false; return null; },
+ null,
+ };
+
+ // Indexed by ExprTypeKind.
+ private static readonly Cmp[] _fnGreater = new Cmp[(int)ExprTypeKind._Lim]
+ {
+ // None, Error
+ null, null,
+
+ null,
+ (a,b) => Cast(a) > Cast(b),
+ (a,b) => Cast(a) > Cast(b),
+ (a,b) => { var x = Cast(a); var y = Cast(b); if (x > y) return true; if (x <= y) return false; return null; },
+ (a,b) => { var x = Cast(a); var y = Cast(b); if (x > y) return true; if (x <= y) return false; return null; },
+ null,
+ };
+
+ // Indexed by ExprTypeKind.
+ private static readonly Cmp[] _fnGreaterEqual = new Cmp[(int)ExprTypeKind._Lim]
+ {
+ // None, Error
+ null, null,
+
+ null,
+ (a,b) => Cast(a) >= Cast(b),
+ (a,b) => Cast(a) >= Cast(b),
+ (a,b) => { var x = Cast(a); var y = Cast(b); if (x >= y) return true; if (x < y) return false; return null; },
+ (a,b) => { var x = Cast(a); var y = Cast(b); if (x >= y) return true; if (x < y) return false; return null; },
+ null,
+ };
+ }
+}
\ No newline at end of file
diff --git a/src/Microsoft.ML.Transforms/Expression/LambdaParser.cs b/src/Microsoft.ML.Transforms/Expression/LambdaParser.cs
new file mode 100644
index 0000000000..c003265bc6
--- /dev/null
+++ b/src/Microsoft.ML.Transforms/Expression/LambdaParser.cs
@@ -0,0 +1,795 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using Microsoft.ML.Data;
+using Microsoft.ML.Internal.Utilities;
+using Microsoft.ML.Runtime;
+
+namespace Microsoft.ML.Transforms
+{
+ internal sealed class LambdaParser
+ {
+ public struct SourcePos
+ {
+ public readonly int IchMin;
+ public readonly int IchLim;
+ public readonly int LineMin;
+ public readonly int ColumnMin;
+ public readonly int LineLim;
+ public readonly int ColumnLim;
+
+ public SourcePos(List lineMap, TextSpan span, int lineMin = 1)
+ {
+ Contracts.AssertValue(lineMap);
+ Contracts.Assert(span.Min <= span.Lim);
+
+ IchMin = span.Min;
+ IchLim = span.Lim;
+
+ if (Utils.Size(lineMap) == 0)
+ {
+ LineMin = lineMin;
+ ColumnMin = IchMin + 1;
+ LineLim = lineMin;
+ ColumnLim = IchLim + 1;
+ return;
+ }
+
+ int index = FindIndex(lineMap, IchMin, 0);
+ LineMin = index + lineMin;
+ int ichBase = index == 0 ? 0 : lineMap[index - 1];
+ ColumnMin = IchMin - ichBase + 1;
+
+ if (index == lineMap.Count || IchLim < lineMap[index])
+ {
+ // Same line.
+ LineLim = LineMin;
+ ColumnLim = IchLim - ichBase + 1;
+ }
+ else
+ {
+ index = FindIndex(lineMap, IchLim, index);
+ Contracts.Assert(index > 0);
+ ichBase = lineMap[index - 1];
+ LineLim = index + lineMin;
+ ColumnLim = IchLim - ichBase + 1;
+ }
+ }
+
+ private static int FindIndex(List map, int value, int ivMin)
+ {
+ Contracts.Assert(ivMin <= map.Count);
+ int ivLim = map.Count;
+ while (ivMin < ivLim)
+ {
+ int iv = (ivMin + ivLim) / 2;
+ if (value >= map[iv])
+ ivMin = iv + 1;
+ else
+ ivLim = iv;
+ }
+ Contracts.Assert(0 <= ivMin & ivMin <= map.Count);
+ Contracts.Assert(ivMin == map.Count || value < map[ivMin]);
+ Contracts.Assert(ivMin == 0 || value >= map[ivMin - 1]);
+ return ivMin;
+ }
+ }
+
+ // This is re-usable state (if we choose to re-use).
+ private readonly NormStr.Pool _pool;
+ private readonly KeyWordTable _kwt;
+ private readonly Lexer _lex;
+
+ // Created lazily. If we choose to share static state in the future, this
+ // should be volatile and set using Interlocked.CompareExchange.
+ private Dictionary _mapTidStr;
+
+ // This is the parsing state.
+ private int[] _perm; // The parameter permutation.
+ private DataViewType[] _types;
+ private TokenCursor _curs;
+ private List _errors;
+ private List _lineMap;
+
+ private LambdaParser()
+ {
+ _pool = new NormStr.Pool();
+ _kwt = new KeyWordTable(_pool);
+ InitKeyWordTable();
+ _lex = new Lexer(_pool, _kwt);
+ }
+
+ private void InitKeyWordTable()
+ {
+ Action p = _kwt.AddPunctuator;
+
+ p("^", TokKind.Car);
+
+ p("*", TokKind.Mul);
+ p("/", TokKind.Div);
+ p("%", TokKind.Per);
+ p("+", TokKind.Add);
+ p("-", TokKind.Sub);
+
+ p("&&", TokKind.AmpAmp);
+ p("||", TokKind.BarBar);
+
+ p("!", TokKind.Bng);
+ p("!=", TokKind.BngEqu);
+
+ p("=", TokKind.Equ);
+ p("==", TokKind.EquEqu);
+ p("=>", TokKind.EquGrt);
+ p("<", TokKind.Lss);
+ p("<=", TokKind.LssEqu);
+ p("<>", TokKind.LssGrt);
+ p(">", TokKind.Grt);
+ p(">=", TokKind.GrtEqu);
+
+ p(".", TokKind.Dot);
+ p(",", TokKind.Comma);
+ p(":", TokKind.Colon);
+ p(";", TokKind.Semi);
+ p("?", TokKind.Que);
+ p("??", TokKind.QueQue);
+
+ p("(", TokKind.OpenParen);
+ p(")", TokKind.CloseParen);
+
+ Action w = _kwt.AddKeyWord;
+
+ w("false", TokKind.False);
+ w("true", TokKind.True);
+ w("not", TokKind.Not);
+ w("and", TokKind.And);
+ w("or", TokKind.Or);
+ w("with", TokKind.With);
+ }
+
+ public static LambdaNode Parse(out List errors, out List lineMap, CharCursor chars, int[] perm, params DataViewType[] types)
+ {
+ Contracts.AssertValue(chars);
+ Contracts.AssertNonEmpty(types);
+ Contracts.Assert(types.Length <= LambdaCompiler.MaxParams);
+ Contracts.Assert(Utils.Size(perm) == types.Length);
+
+ LambdaParser psr = new LambdaParser();
+ return psr.ParseCore(out errors, out lineMap, chars, perm, types);
+ }
+
+ private LambdaNode ParseCore(out List errors, out List lineMap, CharCursor chars, int[] perm, DataViewType[] types)
+ {
+ Contracts.AssertValue(chars);
+ Contracts.AssertNonEmpty(types);
+ Contracts.Assert(Utils.Size(perm) == types.Length);
+
+ _errors = null;
+ _lineMap = new List();
+ _curs = new TokenCursor(_lex.LexSource(chars));
+ _types = types;
+ _perm = perm;
+
+ // Skip over initial comments, new lines, lexing errors, etc.
+ SkipJunk();
+
+ LambdaNode node = ParseLambda(TokCur);
+ if (TidCur != TokKind.Eof)
+ PostError(TokCur, "Expected end of input");
+
+ errors = _errors;
+ lineMap = _lineMap;
+
+ _errors = null;
+ _lineMap = null;
+ _curs = null;
+
+ return node;
+ }
+
+ private void AddError(Error err)
+ {
+ Contracts.Assert(_errors == null || _errors.Count > 0);
+
+ if (Utils.Size(_errors) > 0 && _errors[_errors.Count - 1].Token == err.Token)
+ {
+ // There's already an error report on this token, so don't issue another.
+ return;
+ }
+
+ if (_errors == null)
+ _errors = new List();
+ _errors.Add(err);
+ }
+
+ private void PostError(Token tok, string msg)
+ {
+ var err = new Error(tok, msg);
+ AddError(err);
+ }
+
+ private void PostError(Token tok, string msg, params object[] args)
+ {
+ var err = new Error(tok, msg, args);
+ AddError(err);
+ }
+
+ private void PostTidError(Token tok, TokKind tidWanted)
+ {
+ Contracts.Assert(tidWanted != tok.Kind);
+ Contracts.Assert(tidWanted != tok.KindContext);
+ PostError(tok, "Expected: '{0}', Found: '{1}'", Stringize(tidWanted), Stringize(tok));
+ }
+
+ private string Stringize(Token tok)
+ {
+ Contracts.AssertValue(tok);
+ switch (tok.Kind)
+ {
+ case TokKind.Ident:
+ return tok.As().Value;
+ default:
+ return Stringize(tok.Kind);
+ }
+ }
+
+ private string Stringize(TokKind tid)
+ {
+ if (_mapTidStr == null)
+ {
+ // Build the inverse key word table, mapping token kinds to strings.
+ _mapTidStr = new Dictionary