Skip to content

Commit a6a6422

Browse files
committed
Fix issue from supabase-community/supabase-csharp#48 where boolean model properties would not be evaluated in predicate expressions
1 parent 41b44fc commit a6a6422

File tree

7 files changed

+183
-15
lines changed

7 files changed

+183
-15
lines changed

Postgrest/Interfaces/IPostgrestTable.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ public interface IPostgrestTable<T> : IGettableHeaders
4747
Table<T> Select(Expression<Func<T, object[]>> predicate);
4848
Table<T> Where(Expression<Func<T, bool>> predicate);
4949
Task<T?> Single(CancellationToken cancellationToken = default);
50+
Table<T> Set(Expression<Func<T, object>> keySelector, object value);
5051
Table<T> Set(Expression<Func<T, KeyValuePair<object, object>>> keyValuePairExpression);
5152
Task<ModeledResponse<T>> Update(QueryOptions? options = null, CancellationToken cancellationToken = default);
5253
Task<ModeledResponse<T>> Update(T model, QueryOptions? options = null, CancellationToken cancellationToken = default);

Postgrest/Linq/SetExpressionVisitor.cs

Lines changed: 70 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,32 +30,93 @@ internal class SetExpressionVisitor : ExpressionVisitor
3030
/// </summary>
3131
public object? Value { get; private set; }
3232

33+
/// <summary>
34+
/// A Unary Node, delved into to represent a property on a BaseModel.
35+
/// </summary>
36+
/// <param name="node"></param>
37+
/// <returns></returns>
38+
protected override Expression VisitUnary(UnaryExpression node)
39+
{
40+
if (node.Operand is MemberExpression memberExpression)
41+
{
42+
var column = GetColumnFromMemberExpression(memberExpression);
43+
44+
if (column != null)
45+
{
46+
Column = column;
47+
ExpectedType = memberExpression.Type;
48+
}
49+
}
50+
51+
return node;
52+
}
53+
54+
/// <summary>
55+
/// A Member Node, representing a property on a BaseModel.
56+
/// </summary>
57+
/// <param name="node"></param>
58+
/// <returns></returns>
59+
protected override Expression VisitMember(MemberExpression node)
60+
{
61+
var column = GetColumnFromMemberExpression(node);
62+
63+
if (column != null)
64+
{
65+
Column = column;
66+
ExpectedType = node.Type;
67+
}
68+
69+
return node;
70+
}
71+
3372
/// <summary>
3473
/// Called when visiting a the expected new KeyValuePair().
3574
/// </summary>
3675
/// <param name="node"></param>
3776
/// <returns></returns>
3877
/// <exception cref="ArgumentException"></exception>
3978
protected override Expression VisitNew(NewExpression node)
79+
{
80+
if (typeof(KeyValuePair<object, object>).IsAssignableFrom(node.Type))
81+
{
82+
HandleKeyValuePair(node);
83+
}
84+
85+
return node;
86+
}
87+
88+
private void HandleKeyValuePair(NewExpression node)
4089
{
4190
if (node.Arguments.Count != 2)
4291
throw new ArgumentException("Unknown expression, should be a `KeyValuePair<object, object>`");
4392

44-
var member = node.Arguments[0] as MemberExpression;
93+
var left = node.Arguments[0];
94+
var right = node.Arguments[1];
4595

46-
if (member == null)
96+
if (left is NewExpression)
97+
{
98+
Visit(left);
99+
}
100+
else if (left is MemberExpression member)
101+
{
102+
Column = GetColumnFromMemberExpression(member);
103+
ExpectedType = member.Type;
104+
}
105+
else if (left is UnaryExpression unaryExpression && unaryExpression.Operand is MemberExpression unaryMemberExpression)
106+
{
107+
Column = GetColumnFromMemberExpression(unaryMemberExpression);
108+
ExpectedType = unaryMemberExpression.Type;
109+
}
110+
else
111+
{
47112
throw new ArgumentException("Key should reference a Model Property.");
113+
}
48114

49-
Column = GetColumnFromMemberExpression(member);
50-
ExpectedType = member.Type;
51-
52-
var valueArgument = Expression.Lambda(node.Arguments[1]).Compile().DynamicInvoke();
115+
var valueArgument = Expression.Lambda(right).Compile().DynamicInvoke();
53116
Value = valueArgument;
54117

55-
if (!ExpectedType.IsAssignableFrom(Value.GetType()))
118+
if (!ExpectedType!.IsAssignableFrom(Value.GetType()))
56119
throw new ArgumentException(string.Format("Expected Value to be of Type: {0}, instead received: {1}.", ExpectedType.Name, Value.GetType().Name));
57-
58-
return node;
59120
}
60121

61122
/// <summary>

Postgrest/Linq/WhereExpressionVisitor.cs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,10 @@ protected override Expression VisitBinary(BinaryExpression node)
7575
{
7676
HandleNewExpression(column, op, newExpression);
7777
}
78+
else if (right is UnaryExpression unaryExpression)
79+
{
80+
HandleUnaryExpression(column, op, unaryExpression);
81+
}
7882

7983
return node;
8084
}
@@ -138,6 +142,21 @@ private void HandleMemberExpression(string column, Operator op, MemberExpression
138142
Filter = new QueryFilter(column, op, GetMemberExpressionValue(memberExpression));
139143
}
140144

145+
146+
/// <summary>
147+
/// A unary expression parser (i.e. => x.Id == 1 <- where both `1` is considered unary)
148+
/// </summary>
149+
/// <param name="column"></param>
150+
/// <param name="op"></param>
151+
/// <param name="memberExpression"></param>
152+
private void HandleUnaryExpression(string column, Operator op, UnaryExpression unaryExpression)
153+
{
154+
if (unaryExpression.Operand is ConstantExpression constantExpression)
155+
{
156+
HandleConstantExpression(column, op, constantExpression);
157+
}
158+
}
159+
141160
/// <summary>
142161
/// An instantiated class parser (i.e. x => x.CreatedAt <= new DateTime(2022, 08, 20) <- where `new DateTime(...)` is an instantiated expression.
143162
/// </summary>

Postgrest/Table.cs

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
using System.Threading.Tasks;
1010
using System.Web;
1111
using Newtonsoft.Json;
12+
using Newtonsoft.Json.Linq;
1213
using Postgrest.Attributes;
1314
using Postgrest.Extensions;
1415
using Postgrest.Interfaces;
@@ -583,6 +584,31 @@ public Task<ModeledResponse<T>> Upsert(ICollection<T> model, QueryOptions? optio
583584
return PerformInsert(model, options, cancellationToken);
584585
}
585586

587+
588+
/// <summary>
589+
/// Specifies a key and value to be updated. Should be combined with filters/where clauses.
590+
///
591+
/// Can be called multiple times to set multiple values.
592+
/// </summary>
593+
/// <param name="keySelector"></param>
594+
/// <param name="value"></param>
595+
/// <returns></returns>
596+
public Table<T> Set(Expression<Func<T, object>> keySelector, object value)
597+
{
598+
var visitor = new SetExpressionVisitor();
599+
visitor.Visit(keySelector);
600+
601+
if (visitor.Column == null || visitor.ExpectedType == null)
602+
throw new ArgumentException("Expression should return a KeyValuePair with a key of a Model Property and a value.");
603+
604+
if (!visitor.ExpectedType.IsAssignableFrom(value.GetType()))
605+
throw new ArgumentException(string.Format("Expected Value to be of Type: {0}, instead received: {1}.", visitor.ExpectedType.Name, value.GetType().Name));
606+
607+
setData.Add(visitor.Column, value);
608+
609+
return this;
610+
}
611+
586612
/// <summary>
587613
/// Specifies a KeyValuePair to be updated. Should be combined with filters/where clauses.
588614
///
@@ -596,7 +622,7 @@ public Table<T> Set(Expression<Func<T, KeyValuePair<object, object>>> keyValuePa
596622
var visitor = new SetExpressionVisitor();
597623
visitor.Visit(keyValuePairExpression);
598624

599-
if (visitor.Column == null || visitor.Value == null)
625+
if (visitor.Column == null || visitor.Value == default)
600626
throw new ArgumentException("Expression should return a KeyValuePair with a key of a Model Property and a value.");
601627

602628
setData.Add(visitor.Column, visitor.Value);

PostgrestTests/Linq.cs

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99

1010
namespace PostgrestTests
1111
{
12-
[TestClass]
13-
public class LinqTests
14-
{
15-
private static string baseUrl = "http://localhost:3000";
12+
[TestClass]
13+
public class LinqTests
14+
{
15+
private static string baseUrl = "http://localhost:3000";
1616

1717
[TestMethod("Linq: Select")]
1818
public async Task TestLinqSelect()
@@ -213,6 +213,63 @@ public async Task TestLinqUpdate()
213213
Assert.IsNotNull(exists);
214214
Assert.IsTrue(count == 1);
215215

216+
var originalRecord = await client.Table<KitchenSink>().Where(x => x.Id == 1).Single();
217+
218+
var newRecord = await client.Table<KitchenSink>()
219+
.Set(x => new KeyValuePair<object, object>(x.BooleanValue, !originalRecord.BooleanValue))
220+
.Set(x => new KeyValuePair<object, object>(x.IntValue, originalRecord.IntValue + 1))
221+
.Set(x => new KeyValuePair<object, object>(x.FloatValue, originalRecord.FloatValue + 1))
222+
.Set(x => new KeyValuePair<object, object>(x.DoubleValue, originalRecord.DoubleValue + 1))
223+
.Set(x => new KeyValuePair<object, object>(x.DateTimeValue, DateTime.Now))
224+
.Set(x => new KeyValuePair<object, object>(x.ListOfStrings, new List<string>(originalRecord.ListOfStrings)
225+
{
226+
"updated"
227+
}))
228+
.Where(x => x.Id == originalRecord.Id)
229+
.Update(new QueryOptions { Returning = QueryOptions.ReturnType.Representation });
230+
231+
var testRecord1 = newRecord.Models[0];
232+
233+
Assert.AreNotEqual(originalRecord.BooleanValue, testRecord1.BooleanValue);
234+
Assert.AreNotEqual(originalRecord.IntValue, testRecord1.IntValue);
235+
Assert.AreNotEqual(originalRecord.FloatValue, testRecord1.FloatValue);
236+
Assert.AreNotEqual(originalRecord.DoubleValue, testRecord1.DoubleValue);
237+
Assert.AreNotEqual(originalRecord.DateTimeValue, testRecord1.DateTimeValue);
238+
CollectionAssert.AreNotEqual(originalRecord.ListOfStrings, testRecord1.ListOfStrings);
239+
240+
241+
var newRecord2 = await client.Table<KitchenSink>()
242+
.Set(x => x.BooleanValue, !testRecord1.BooleanValue)
243+
.Set(x => x.IntValue, testRecord1.IntValue + 1)
244+
.Set(x => x.FloatValue, testRecord1.FloatValue + 1)
245+
.Set(x => x.DoubleValue, testRecord1.DoubleValue + 1)
246+
.Set(x => x.DateTimeValue, DateTime.Now.AddSeconds(30))
247+
.Set(x => x.ListOfStrings, new List<string>(testRecord1.ListOfStrings)
248+
{
249+
"updated"
250+
})
251+
.Where(x => x.Id == testRecord1.Id)
252+
.Update(new QueryOptions { Returning = QueryOptions.ReturnType.Representation });
253+
254+
var testRecord2 = newRecord2.Models[0];
255+
256+
Assert.AreNotEqual(testRecord1.BooleanValue, testRecord2.BooleanValue);
257+
Assert.AreNotEqual(testRecord1.IntValue, testRecord2.IntValue);
258+
Assert.AreNotEqual(testRecord1.FloatValue, testRecord2.FloatValue);
259+
Assert.AreNotEqual(testRecord1.DoubleValue, testRecord2.DoubleValue);
260+
Assert.AreNotEqual(testRecord1.DateTimeValue, testRecord2.DateTimeValue);
261+
CollectionAssert.AreNotEqual(testRecord1.ListOfStrings, testRecord2.ListOfStrings);
262+
263+
Assert.ThrowsException<ArgumentException>(() =>
264+
{
265+
return client.Table<Movie>().Set(x => x.Name, DateTime.Now).Update();
266+
});
267+
268+
Assert.ThrowsException<ArgumentException>(() =>
269+
{
270+
return client.Table<Movie>().Set(x => DateTime.Now, newName).Update();
271+
});
272+
216273
Assert.ThrowsException<ArgumentException>(() =>
217274
{
218275
return client.Table<Movie>().Set(x => new KeyValuePair<object, object>(x.Name, DateTime.Now)).Update();

PostgrestTests/Models/KitchenSink.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,14 @@ namespace PostgrestTests.Models
1212
public class KitchenSink : BaseModel
1313
{
1414
[PrimaryKey("id", false)]
15-
public string? Id { get; set; }
15+
public int? Id { get; set; }
1616

1717
[Column("string_value")]
1818
public string? StringValue { get; set; }
1919

20+
[Column("bool_value")]
21+
public bool? BooleanValue { get; set; }
22+
2023
[Column("unique_value")]
2124
public string? UniqueValue { get; set; }
2225

PostgrestTests/db/00-schema.sql

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ COMMENT ON COLUMN public.messages.data IS 'For unstructured data and prototyping
4747
create table "public"."kitchen_sink" (
4848
"id" serial primary key,
4949
"string_value" varchar(255) null,
50+
"bool_value" BOOL DEFAULT false,
5051
"unique_value" varchar(255) UNIQUE,
5152
"int_value" INT null,
5253
"float_value" FLOAT null,

0 commit comments

Comments
 (0)