Skip to content

Commit 534d79f

Browse files
authored
Merge pull request #81 from lazaro-ansaldi/master
Add support to ExecuteDelete and ExecuteUpdate
2 parents 7f02ca9 + 2023477 commit 534d79f

File tree

6 files changed

+513
-428
lines changed

6 files changed

+513
-428
lines changed

src/MockQueryable/MockQueryable.Core/TestQueryProvider.cs

Lines changed: 87 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -6,78 +6,92 @@
66

77
namespace MockQueryable.Core
88
{
9-
public abstract class TestQueryProvider<T> : IOrderedQueryable<T>, IQueryProvider
10-
{
11-
private IEnumerable<T> _enumerable;
12-
13-
protected TestQueryProvider(Expression expression)
14-
{
15-
Expression = expression;
16-
}
17-
18-
protected TestQueryProvider(IEnumerable<T> enumerable)
19-
{
20-
_enumerable = enumerable;
21-
Expression = enumerable.AsQueryable().Expression;
22-
}
23-
24-
public IQueryable CreateQuery(Expression expression)
25-
{
26-
if (expression is MethodCallExpression m)
27-
{
28-
var resultType = m.Method.ReturnType; // it should be IQueryable<T>
29-
var tElement = resultType.GetGenericArguments().First();
30-
return (IQueryable) CreateInstance(tElement, expression);
31-
}
32-
33-
return CreateQuery<T>(expression);
34-
}
35-
36-
public IQueryable<TEntity> CreateQuery<TEntity>(Expression expression)
37-
{
38-
return (IQueryable<TEntity>) CreateInstance(typeof(TEntity), expression);
39-
}
40-
41-
private object CreateInstance(Type tElement, Expression expression)
9+
public abstract class TestQueryProvider<T> : IOrderedQueryable<T>, IQueryProvider
4210
{
43-
var queryType = GetType().GetGenericTypeDefinition().MakeGenericType(tElement);
44-
return Activator.CreateInstance(queryType, expression);
45-
}
46-
47-
public object Execute(Expression expression)
48-
{
49-
return CompileExpressionItem<object>(expression);
50-
}
51-
52-
public TResult Execute<TResult>(Expression expression)
53-
{
54-
return CompileExpressionItem<TResult>(expression);
55-
}
56-
57-
IEnumerator<T> IEnumerable<T>.GetEnumerator()
58-
{
59-
if (_enumerable == null) _enumerable = CompileExpressionItem<IEnumerable<T>>(Expression);
60-
return _enumerable.GetEnumerator();
61-
}
62-
63-
IEnumerator IEnumerable.GetEnumerator()
64-
{
65-
if (_enumerable == null) _enumerable = CompileExpressionItem<IEnumerable<T>>(Expression);
66-
return _enumerable.GetEnumerator();
67-
}
68-
69-
public Type ElementType => typeof(T);
70-
71-
public Expression Expression { get; }
72-
73-
public IQueryProvider Provider => this;
74-
75-
private static TResult CompileExpressionItem<TResult>(Expression expression)
76-
{
77-
var visitor = new TestExpressionVisitor();
78-
var body = visitor.Visit(expression);
79-
var f = Expression.Lambda<Func<TResult>>(body ?? throw new InvalidOperationException($"{nameof(body)} is null"), (IEnumerable<ParameterExpression>) null);
80-
return f.Compile()();
81-
}
82-
}
11+
// Hardcoding this constants to avoid the reference to EFCore
12+
private const string EF_EXECUTE_UPDATE_METHOD_NAME = "ExecuteUpdate";
13+
private const string EF_EXECUTE_DELETE_METHOD_NAME = "ExecuteDelete";
14+
15+
private IEnumerable<T> _enumerable;
16+
17+
protected TestQueryProvider(Expression expression)
18+
{
19+
Expression = expression;
20+
}
21+
22+
protected TestQueryProvider(IEnumerable<T> enumerable)
23+
{
24+
_enumerable = enumerable;
25+
Expression = enumerable.AsQueryable().Expression;
26+
}
27+
28+
public IQueryable CreateQuery(Expression expression)
29+
{
30+
if (expression is MethodCallExpression m)
31+
{
32+
var resultType = m.Method.ReturnType; // it should be IQueryable<T>
33+
var tElement = resultType.GetGenericArguments().First();
34+
return (IQueryable) CreateInstance(tElement, expression);
35+
}
36+
37+
return CreateQuery<T>(expression);
38+
}
39+
40+
public IQueryable<TEntity> CreateQuery<TEntity>(Expression expression)
41+
{
42+
return (IQueryable<TEntity>) CreateInstance(typeof(TEntity), expression);
43+
}
44+
45+
private object CreateInstance(Type tElement, Expression expression)
46+
{
47+
var queryType = GetType().GetGenericTypeDefinition().MakeGenericType(tElement);
48+
return Activator.CreateInstance(queryType, expression);
49+
}
50+
51+
public object Execute(Expression expression)
52+
{
53+
return CompileExpressionItem<object>(expression);
54+
}
55+
56+
public TResult Execute<TResult>(Expression expression)
57+
{
58+
if (expression is MethodCallExpression methodCall && (methodCall.Method.Name == EF_EXECUTE_UPDATE_METHOD_NAME || methodCall.Method.Name == EF_EXECUTE_DELETE_METHOD_NAME)
59+
&& typeof(TResult) == typeof(int))
60+
{
61+
// Intercept ExecuteDelete and ExecuteUpdate calls
62+
var affectedItems = CompileExpressionItem<IEnumerable<T>>(Expression).ToList();
63+
// Return the count of affected items
64+
return (TResult)(object)affectedItems.Count;
65+
}
66+
67+
// Fall back to default expression execution
68+
return CompileExpressionItem<TResult>(expression);
69+
}
70+
71+
IEnumerator<T> IEnumerable<T>.GetEnumerator()
72+
{
73+
if (_enumerable == null) _enumerable = CompileExpressionItem<IEnumerable<T>>(Expression);
74+
return _enumerable.GetEnumerator();
75+
}
76+
77+
IEnumerator IEnumerable.GetEnumerator()
78+
{
79+
if (_enumerable == null) _enumerable = CompileExpressionItem<IEnumerable<T>>(Expression);
80+
return _enumerable.GetEnumerator();
81+
}
82+
83+
public Type ElementType => typeof(T);
84+
85+
public Expression Expression { get; }
86+
87+
public IQueryProvider Provider => this;
88+
89+
private static TResult CompileExpressionItem<TResult>(Expression expression)
90+
{
91+
var visitor = new TestExpressionVisitor();
92+
var body = visitor.Visit(expression);
93+
var f = Expression.Lambda<Func<TResult>>(body ?? throw new InvalidOperationException($"{nameof(body)} is null"), (IEnumerable<ParameterExpression>) null);
94+
return f.Compile()();
95+
}
96+
}
8397
}

src/MockQueryable/MockQueryable.EntityFrameworkCore/MockQueryable.EntityFrameworkCore.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
</None>
4646
</ItemGroup>
4747
<ItemGroup>
48-
<PackageReference Include="Microsoft.EntityFrameworkCore" Version="7.0.0" />
48+
<PackageReference Include="Microsoft.EntityFrameworkCore" Version="7.0.20" />
4949
</ItemGroup>
5050

5151
<ItemGroup>

src/MockQueryable/MockQueryable.Sample/MockQueryable.Sample.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
<ItemGroup>
88
<PackageReference Include="AutoMapper" Version="8.0.0" />
9+
<PackageReference Include="Microsoft.EntityFrameworkCore.Relational" Version="7.0.20" />
910
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.1.0" />
1011
<PackageReference Include="NUnit" Version="3.13.3" />
1112
<PackageReference Include="NUnit3TestAdapter" Version="4.2.1">
Lines changed: 90 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1,108 +1,111 @@
1-
using System;
1+
using AutoMapper;
2+
using AutoMapper.QueryableExtensions;
3+
using Microsoft.EntityFrameworkCore;
4+
using System;
25
using System.Collections.Generic;
36
using System.Linq;
47
using System.Threading.Tasks;
5-
using AutoMapper;
6-
using AutoMapper.QueryableExtensions;
7-
using Microsoft.EntityFrameworkCore;
88

99
namespace MockQueryable.Sample
1010
{
11-
public class MyService
12-
{
13-
private readonly IUserRepository _userRepository;
14-
15-
public static void Initialize()
11+
public class MyService
1612
{
17-
Mapper.Initialize(cfg => cfg.CreateMap<UserEntity, UserReport>()
18-
.ForMember(dto => dto.FirstName, conf => conf.MapFrom(ol => ol.FirstName))
19-
.ForMember(dto => dto.LastName, conf => conf.MapFrom(ol => ol.LastName)));
20-
Mapper.Configuration.AssertConfigurationIsValid();
13+
private readonly IUserRepository _userRepository;
14+
15+
public static void Initialize()
16+
{
17+
Mapper.Initialize(cfg => cfg.CreateMap<UserEntity, UserReport>()
18+
.ForMember(dto => dto.FirstName, conf => conf.MapFrom(ol => ol.FirstName))
19+
.ForMember(dto => dto.LastName, conf => conf.MapFrom(ol => ol.LastName)));
20+
Mapper.Configuration.AssertConfigurationIsValid();
21+
}
22+
23+
public MyService(IUserRepository userRepository)
24+
{
25+
_userRepository = userRepository;
26+
}
27+
28+
public async Task CreateUserIfNotExist(string firstName, string lastName, DateTime dateOfBirth)
29+
{
30+
var query = _userRepository.GetQueryable();
31+
32+
if (await query.AnyAsync(x => x.LastName == lastName && x.DateOfBirth == dateOfBirth))
33+
{
34+
throw new ApplicationException("User already exist");
35+
}
36+
37+
var existUser = await query.FirstOrDefaultAsync(x => x.FirstName == firstName);
38+
if (existUser != null)
39+
{
40+
throw new ApplicationException("User with FirstName already exist");
41+
}
42+
43+
if (await query.CountAsync(x => x.DateOfBirth == dateOfBirth.Date) > 3)
44+
{
45+
throw new ApplicationException("Users with DateOfBirth more than limit");
46+
}
47+
48+
await _userRepository.CreateUser(new UserEntity
49+
{
50+
FirstName = firstName,
51+
LastName = lastName,
52+
DateOfBirth = dateOfBirth.Date,
53+
});
54+
55+
}
56+
57+
public async Task<List<UserReport>> GetUserReports(DateTime dateFrom, DateTime dateTo)
58+
{
59+
var query = _userRepository.GetQueryable();
60+
61+
query = query.Where(x => x.DateOfBirth >= dateFrom.Date);
62+
query = query.Where(x => x.DateOfBirth <= dateTo.Date);
63+
64+
return await query.Select(x => new UserReport
65+
{
66+
FirstName = x.FirstName,
67+
LastName = x.LastName,
68+
}).ToListAsync();
69+
}
70+
71+
72+
public async Task<List<UserReport>> GetUserReportsAutoMap(DateTime dateFrom, DateTime dateTo)
73+
{
74+
var query = _userRepository.GetQueryable();
75+
76+
query = query.Where(x => x.DateOfBirth >= dateFrom.Date);
77+
query = query.Where(x => x.DateOfBirth <= dateTo.Date);
78+
79+
return await query.ProjectTo<UserReport>().ToListAsync();
80+
}
2181
}
2282

23-
public MyService(IUserRepository userRepository)
83+
public interface IUserRepository
2484
{
25-
_userRepository = userRepository;
26-
}
85+
IQueryable<UserEntity> GetQueryable();
2786

28-
public async Task CreateUserIfNotExist(string firstName, string lastName, DateTime dateOfBirth)
29-
{
30-
var query = _userRepository.GetQueryable();
31-
32-
if (await query.AnyAsync(x => x.LastName == lastName && x.DateOfBirth == dateOfBirth))
33-
{
34-
throw new ApplicationException("User already exist");
35-
}
36-
37-
var existUser = await query.FirstOrDefaultAsync(x => x.FirstName == firstName);
38-
if (existUser != null)
39-
{
40-
throw new ApplicationException("User with FirstName already exist");
41-
}
42-
43-
if (await query.CountAsync(x => x.DateOfBirth == dateOfBirth.Date) > 3)
44-
{
45-
throw new ApplicationException("Users with DateOfBirth more than limit");
46-
}
47-
48-
await _userRepository.CreateUser(new UserEntity
49-
{
50-
FirstName = firstName,
51-
LastName = lastName,
52-
DateOfBirth = dateOfBirth.Date,
53-
});
87+
Task CreateUser(UserEntity user);
5488

55-
}
89+
Task<List<UserEntity>> GetAll();
5690

57-
public async Task<List<UserReport>> GetUserReports(DateTime dateFrom, DateTime dateTo)
58-
{
59-
var query = _userRepository.GetQueryable();
91+
IAsyncEnumerable<UserEntity> GetAllAsync();
6092

61-
query = query.Where(x => x.DateOfBirth >= dateFrom.Date);
62-
query = query.Where(x => x.DateOfBirth <= dateTo.Date);
93+
Task<int> DeleteUserAsync(Guid id);
6394

64-
return await query.Select(x => new UserReport
65-
{
66-
FirstName = x.FirstName,
67-
LastName = x.LastName,
68-
}).ToListAsync();
95+
Task<int> UpdateFirstNameByIdAsync(Guid id, string firstName);
6996
}
7097

71-
72-
public async Task<List<UserReport>> GetUserReportsAutoMap(DateTime dateFrom, DateTime dateTo)
98+
public class UserReport
7399
{
74-
var query = _userRepository.GetQueryable();
75-
76-
query = query.Where(x => x.DateOfBirth >= dateFrom.Date);
77-
query = query.Where(x => x.DateOfBirth <= dateTo.Date);
78-
79-
return await query.ProjectTo<UserReport>().ToListAsync();
100+
public string FirstName { get; set; }
101+
public string LastName { get; set; }
80102
}
81-
}
82-
83-
public interface IUserRepository
84-
{
85-
IQueryable<UserEntity> GetQueryable();
86-
87-
Task CreateUser(UserEntity user);
88-
89-
Task<List<UserEntity>> GetAll();
90103

91-
IAsyncEnumerable<UserEntity> GetAllAsync();
92-
}
93-
94-
95-
public class UserReport
96-
{
97-
public string FirstName { get; set; }
98-
public string LastName { get; set; }
99-
}
100-
101-
public class UserEntity
102-
{
103-
public Guid Id { get; set; }
104-
public string FirstName { get; set; }
105-
public string LastName { get; set; }
106-
public DateTime DateOfBirth { get; set; }
107-
}
104+
public class UserEntity
105+
{
106+
public Guid Id { get; set; }
107+
public string FirstName { get; set; }
108+
public string LastName { get; set; }
109+
public DateTime DateOfBirth { get; set; }
110+
}
108111
}

0 commit comments

Comments
 (0)