Skip to content
226 changes: 127 additions & 99 deletions Source/EntityFramework.Extended/Batch/SqlServerBatchRunner.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
using System;
using System.Collections.Generic;
using System.Data;
using System.Data.Common;
using System.Data.EntityClient;
using System.Data.Objects;
using System.Diagnostics;
using System.Linq;
using System.Linq.Dynamic;
using System.Linq.Expressions;
Expand Down Expand Up @@ -170,112 +172,28 @@ public int Update<TEntity>(ObjectContext objectContext, EntityMap entityMap, Obj
bool wroteSet = false;
foreach (MemberBinding binding in memberInitExpression.Bindings)
{
if (wroteSet)
sqlBuilder.AppendLine(", ");

string propertyName = binding.Member.Name;
string columnName = entityMap.PropertyMaps
.Where(p => p.PropertyName == propertyName)
.Select(p => p.ColumnName)
.FirstOrDefault();


var memberAssignment = binding as MemberAssignment;
if (memberAssignment == null)
throw new ArgumentException("The update expression MemberBinding must only by type MemberAssignment.", "updateExpression");

Expression memberExpression = memberAssignment.Expression;

ParameterExpression parameterExpression = null;
memberExpression.Visit((ParameterExpression p) =>
{
if (p.Type == entityMap.EntityType)
parameterExpression = p;

return p;
});


if (parameterExpression == null)
IPropertyMapElement propertyMap =
entityMap.PropertyMaps.SingleOrDefault(p => p.PropertyName == binding.Member.Name);
if (propertyMap is ComplexPropertyMap)
{
object value;

if (memberExpression.NodeType == ExpressionType.Constant)
{
var constantExpression = memberExpression as ConstantExpression;
if (constantExpression == null)
throw new ArgumentException(
"The MemberAssignment expression is not a ConstantExpression.", "updateExpression");

value = constantExpression.Value;
}
else
{
LambdaExpression lambda = Expression.Lambda(memberExpression, null);
value = lambda.Compile().DynamicInvoke();
}

if (value != null)
{
string parameterName = "p__update__" + nameCount++;
var parameter = updateCommand.CreateParameter();
parameter.ParameterName = parameterName;
parameter.Value = value;
updateCommand.Parameters.Add(parameter);

sqlBuilder.AppendFormat("[{0}] = @{1}", columnName, parameterName);
}
else
ComplexPropertyMap cpm = propertyMap as ComplexPropertyMap;
var memberAssignment = binding as MemberAssignment;
if (memberAssignment == null)
throw new ArgumentException("The update expression MemberBinding must only by type MemberAssignment.", "updateExpression");
var expr = memberAssignment.Expression as MemberInitExpression;
if (expr == null)
throw new ArgumentException("The update expression MemberBinding must only by type MemberAssignment.", "updateExpression");
foreach (var subBinding in expr.Bindings)
{
sqlBuilder.AppendFormat("[{0}] = NULL", columnName);
AddUpdateRow<TEntity>(objectContext, entityMap, subBinding, sqlBuilder, updateCommand,
cpm.TypeElements, ref nameCount, ref wroteSet);
}
}
else
{
// create clean objectset to build query from
var objectSet = objectContext.CreateObjectSet<TEntity>();

Type[] typeArguments = new[] { entityMap.EntityType, memberExpression.Type };

ConstantExpression constantExpression = Expression.Constant(objectSet);
LambdaExpression lambdaExpression = Expression.Lambda(memberExpression, parameterExpression);

MethodCallExpression selectExpression = Expression.Call(
typeof(Queryable),
"Select",
typeArguments,
constantExpression,
lambdaExpression);

// create query from expression
var selectQuery = objectSet.CreateQuery(selectExpression, entityMap.EntityType);
string sql = selectQuery.ToTraceString();

// parse select part of sql to use as update
string regex = @"SELECT\s*\r\n(?<ColumnValue>.+)?\s*AS\s*(?<ColumnAlias>\[\w+\])\r\nFROM\s*(?<TableName>\[\w+\]\.\[\w+\]|\[\w+\])\s*AS\s*(?<TableAlias>\[\w+\])";
Match match = Regex.Match(sql, regex);
if (!match.Success)
throw new ArgumentException("The MemberAssignment expression could not be processed.", "updateExpression");

string value = match.Groups["ColumnValue"].Value;
string alias = match.Groups["TableAlias"].Value;

value = value.Replace(alias + ".", "");

foreach (ObjectParameter objectParameter in selectQuery.Parameters)
{
string parameterName = "p__update__" + nameCount++;

var parameter = updateCommand.CreateParameter();
parameter.ParameterName = parameterName;
parameter.Value = objectParameter.Value;
updateCommand.Parameters.Add(parameter);

value = value.Replace(objectParameter.Name, parameterName);
}
sqlBuilder.AppendFormat("[{0}] = {1}", columnName, value);
AddUpdateRow<TEntity>(objectContext, entityMap, binding, sqlBuilder, updateCommand,
entityMap.PropertyMaps, ref nameCount, ref wroteSet);
}
wroteSet = true;
}

sqlBuilder.AppendLine(" ");
Expand Down Expand Up @@ -316,6 +234,116 @@ public int Update<TEntity>(ObjectContext objectContext, EntityMap entityMap, Obj
}
}


private static void AddUpdateRow<TEntity>(ObjectContext objectContext, EntityMap entityMap, MemberBinding binding, StringBuilder sqlBuilder, DbCommand updateCommand, IEnumerable<IPropertyMapElement> propertyMap, ref int nameCount, ref bool wroteSet)
where TEntity : class
{
if (wroteSet)
sqlBuilder.AppendLine(", ");

string propertyName = binding.Member.Name;
PropertyMap property =
propertyMap.SingleOrDefault(p => p.PropertyName == propertyName) as PropertyMap;
Debug.Assert(property != null, "property != null");
string columnName = property.ColumnName;
var memberAssignment = binding as MemberAssignment;
if (memberAssignment == null)
throw new ArgumentException("The update expression MemberBinding must only by type MemberAssignment.", "binding");

Expression memberExpression = memberAssignment.Expression;

ParameterExpression parameterExpression = null;
memberExpression.Visit((ParameterExpression p) =>
{
if (p.Type == entityMap.EntityType)
parameterExpression = p;

return p;
});


if (parameterExpression == null)
{
object value;

if (memberExpression.NodeType == ExpressionType.Constant)
{
var constantExpression = memberExpression as ConstantExpression;
if (constantExpression == null)
throw new ArgumentException(
"The MemberAssignment expression is not a ConstantExpression.", "binding");

value = constantExpression.Value;
}
else
{
LambdaExpression lambda = Expression.Lambda(memberExpression, null);
value = lambda.Compile().DynamicInvoke();
}

if (value != null)
{
string parameterName = "p__update__" + nameCount++;
var parameter = updateCommand.CreateParameter();
parameter.ParameterName = parameterName;
parameter.Value = value;
updateCommand.Parameters.Add(parameter);

sqlBuilder.AppendFormat("[{0}] = @{1}", columnName, parameterName);
}
else
{
sqlBuilder.AppendFormat("[{0}] = NULL", columnName);
}
}
else
{
// create clean objectset to build query from
var objectSet = objectContext.CreateObjectSet<TEntity>();

Type[] typeArguments = new[] { entityMap.EntityType, memberExpression.Type };

ConstantExpression constantExpression = Expression.Constant(objectSet);
LambdaExpression lambdaExpression = Expression.Lambda(memberExpression, parameterExpression);

MethodCallExpression selectExpression = Expression.Call(
typeof(Queryable),
"Select",
typeArguments,
constantExpression,
lambdaExpression);

// create query from expression
var selectQuery = objectSet.CreateQuery(selectExpression, entityMap.EntityType);
string sql = selectQuery.ToTraceString();

// parse select part of sql to use as update
const string regex = @"SELECT\s*\r\n(?<ColumnValue>.+)?\s*AS\s*(?<ColumnAlias>\[\w+\])\r\nFROM\s*(?<TableName>\[\w+\]\.\[\w+\]|\[\w+\])\s*AS\s*(?<TableAlias>\[\w+\])";
Match match = Regex.Match(sql, regex);
if (!match.Success)
throw new ArgumentException("The MemberAssignment expression could not be processed.", "binding");

string value = match.Groups["ColumnValue"].Value;
string alias = match.Groups["TableAlias"].Value;

value = value.Replace(alias + ".", "");

foreach (ObjectParameter objectParameter in selectQuery.Parameters)
{
string parameterName = "p__update__" + nameCount++;

var parameter = updateCommand.CreateParameter();
parameter.ParameterName = parameterName;
parameter.Value = objectParameter.Value;
updateCommand.Parameters.Add(parameter);

value = value.Replace(objectParameter.Name, parameterName);
}
sqlBuilder.AppendFormat("[{0}] = {1}", columnName, value);
}
wroteSet = true;
}

private static Tuple<DbConnection, DbTransaction> GetStore(ObjectContext objectContext)
{
DbConnection dbConnection = objectContext.Connection;
Expand Down
53 changes: 28 additions & 25 deletions Source/EntityFramework.Extended/Dynamic/DynamicQueryable.cs
Original file line number Diff line number Diff line change
Expand Up @@ -285,8 +285,19 @@ public Type GetDynamicClass(IEnumerable<DynamicProperty> properties)
Type type;
if (!classes.TryGetValue(signature, out type))
{
type = CreateDynamicClass(signature.properties);
classes.Add(signature, type);
LockCookie cookie = rwLock.UpgradeToWriterLock(Timeout.Infinite);
try
{
if (!classes.TryGetValue(signature, out type))
{
type = CreateDynamicClass(signature.properties);
classes.Add(signature, type);
}
}
finally
{
rwLock.DowngradeFromWriterLock(ref cookie);
}
}
return type;
}
Expand All @@ -298,34 +309,26 @@ public Type GetDynamicClass(IEnumerable<DynamicProperty> properties)

Type CreateDynamicClass(DynamicProperty[] properties)
{
LockCookie cookie = rwLock.UpgradeToWriterLock(Timeout.Infinite);
try
{
string typeName = "DynamicClass" + (classCount + 1);
#if ENABLE_LINQ_PARTIAL_TRUST
new ReflectionPermission(PermissionState.Unrestricted).Assert();
#endif
try
{
TypeBuilder tb = this.module.DefineType(typeName, TypeAttributes.Class |
TypeAttributes.Public, typeof(DynamicClass));
FieldInfo[] fields = GenerateProperties(tb, properties);
GenerateEquals(tb, fields);
GenerateGetHashCode(tb, fields);
Type result = tb.CreateType();
classCount++;
return result;
}
finally
{
string typeName = "DynamicClass" + (classCount + 1);
#if ENABLE_LINQ_PARTIAL_TRUST
PermissionSet.RevertAssert();
new ReflectionPermission(PermissionState.Unrestricted).Assert();
#endif
}
try
{
TypeBuilder tb = this.module.DefineType(typeName, TypeAttributes.Class |
TypeAttributes.Public, typeof(DynamicClass));
FieldInfo[] fields = GenerateProperties(tb, properties);
GenerateEquals(tb, fields);
GenerateGetHashCode(tb, fields);
Type result = tb.CreateType();
classCount++;
return result;
}
finally
{
rwLock.DowngradeFromWriterLock(ref cookie);
#if ENABLE_LINQ_PARTIAL_TRUST
PermissionSet.RevertAssert();
#endif
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
<Compile Include="Caching\Query\LocalCollectionExpander.cs" />
<Compile Include="Caching\Query\QueryCache.cs" />
<Compile Include="Caching\Query\Utility.cs" />
<Compile Include="Extensions\IQueryUnwrapper.cs" />
<Compile Include="Future\IFutureRunner.cs" />
<Compile Include="Locator.cs" />
<Compile Include="Container.cs" />
Expand All @@ -100,10 +101,12 @@
<Compile Include="Extensions\BatchExtensions.cs" />
<Compile Include="Extensions\ObjectContextExtensions.cs" />
<Compile Include="IContainer.cs" />
<Compile Include="Mapping\ComplexPropertyMap.cs" />
<Compile Include="Mapping\EntityMap.cs" />
<Compile Include="Extensions\ExpressionExtensions.cs" />
<Compile Include="Extensions\ObjectQueryExtensions.cs" />
<Compile Include="Mapping\IMappingProvider.cs" />
<Compile Include="Mapping\IPropertyMapElement.cs" />
<Compile Include="Mapping\ReflectionMappingProvider.cs" />
<Compile Include="Mapping\MappingResolver.cs" />
<Compile Include="Mapping\PropertyMap.cs" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,12 @@
<Compile Include="Extensions\BatchExtensions.cs" />
<Compile Include="Extensions\ObjectContextExtensions.cs" />
<Compile Include="IContainer.cs" />
<Compile Include="Mapping\ComplexPropertyMap.cs" />
<Compile Include="Mapping\EntityMap.cs" />
<Compile Include="Extensions\ExpressionExtensions.cs" />
<Compile Include="Extensions\ObjectQueryExtensions.cs" />
<Compile Include="Mapping\IMappingProvider.cs" />
<Compile Include="Mapping\IPropertyMapElement.cs" />
<Compile Include="Mapping\ReflectionMappingProvider.cs" />
<Compile Include="Mapping\MappingResolver.cs" />
<Compile Include="Mapping\PropertyMap.cs" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public static FutureQuery<TEntity> Future<TEntity>(this IQueryable<TEntity> sour
if (source == null)
throw new ArgumentNullException("source");

ObjectQuery<TEntity> sourceQuery = source.ToObjectQuery();
ObjectQuery sourceQuery = source.ToObjectQuery();
if (sourceQuery == null)
throw new ArgumentException("The source query must be of type ObjectQuery or DbQuery.", "source");

Expand Down Expand Up @@ -136,7 +136,7 @@ public static FutureValue<TEntity> FutureFirstOrDefault<TEntity>(this IQueryable
// make sure to only get the first value
IQueryable<TEntity> firstQuery = source.Take(1);

ObjectQuery<TEntity> objectQuery = firstQuery.ToObjectQuery();
ObjectQuery objectQuery = firstQuery.ToObjectQuery();
if (objectQuery == null)
throw new ArgumentException("The source query must be of type ObjectQuery or DbQuery.", "source");

Expand Down
Loading