@@ -17,7 +17,7 @@ import {
1717import  {  getLiteral ,  GUARD_FIELD_NAME ,  PluginError  }  from  '@zenstackhq/sdk' ; 
1818import  {  CodeBlockWriter  }  from  'ts-morph' ; 
1919import  {  FILTER_OPERATOR_FUNCTIONS  }  from  '../../language-server/constants' ; 
20- import  {  getIdField ,  isAuthInvocation  }  from  '../../utils/ast-utils' ; 
20+ import  {  getIdFields ,  isAuthInvocation  }  from  '../../utils/ast-utils' ; 
2121import  TypeScriptExpressionTransformer  from  './typescript-expression-transformer' ; 
2222import  {  isFutureExpr  }  from  './utils' ; 
2323
@@ -99,12 +99,17 @@ export class ExpressionWriter {
9999
100100    private  writeMemberAccess ( expr : MemberAccessExpr )  { 
101101        this . block ( ( )  =>  { 
102-             // must be a boolean member 
103-             this . writeFieldCondition ( expr . operand ,  ( )  =>  { 
104-                 this . block ( ( )  =>  { 
105-                     this . writer . write ( `${ expr . member . ref ?. name }  ) ; 
102+             if  ( this . isAuthOrAuthMemberAccess ( expr ) )  { 
103+                 // member access of `auth()`, generate plain expression 
104+                 this . guard ( ( )  =>  this . plain ( expr ) ,  true ) ; 
105+             }  else  { 
106+                 // must be a boolean member 
107+                 this . writeFieldCondition ( expr . operand ,  ( )  =>  { 
108+                     this . block ( ( )  =>  { 
109+                         this . writer . write ( `${ expr . member . ref ?. name }  ) ; 
110+                     } ) ; 
106111                } ) ; 
107-             } ) ; 
112+             } 
108113        } ) ; 
109114    } 
110115
@@ -190,9 +195,14 @@ export class ExpressionWriter {
190195        return  false ; 
191196    } 
192197
193-     private  guard ( write : ( )  =>  void )  { 
198+     private  guard ( write : ( )  =>  void ,   cast   =   false )  { 
194199        this . writer . write ( `${ GUARD_FIELD_NAME }  ) ; 
195-         write ( ) ; 
200+         if  ( cast )  { 
201+             this . writer . write ( '!!' ) ; 
202+             write ( ) ; 
203+         }  else  { 
204+             write ( ) ; 
205+         } 
196206    } 
197207
198208    private  plain ( expr : Expression )  { 
@@ -211,12 +221,9 @@ export class ExpressionWriter {
211221            // compile down to a plain expression 
212222            this . block ( ( )  =>  { 
213223                this . guard ( ( )  =>  { 
214-                     this . plain ( expr . left ) ; 
215-                     this . writer . write ( ' '  +  operator  +  ' ' ) ; 
216-                     this . plain ( expr . right ) ; 
224+                     this . plain ( expr ) ; 
217225                } ) ; 
218226            } ) ; 
219- 
220227            return ; 
221228        } 
222229
@@ -242,65 +249,105 @@ export class ExpressionWriter {
242249            }  as  ReferenceExpr ; 
243250        } 
244251
245-         // if the operand refers to auth(), need to build a guard to avoid 
246-         // using undefined user as filter (which means no filter to Prisma) 
247-         // if auth() evaluates falsy, just treat the condition as false 
248-         if  ( this . isAuthOrAuthMemberAccess ( operand ) )  { 
249-             this . writer . write ( `!user ? { ${ GUARD_FIELD_NAME }  ) ; 
252+         // guard member access of `auth()` with null check 
253+         if  ( this . isAuthOrAuthMemberAccess ( operand )  &&  ! fieldAccess . $resolvedType ?. nullable )  { 
254+             this . writer . write ( 
255+                 `(${ this . plainExprBuilder . transform ( operand ) } ${ GUARD_FIELD_NAME } ${  
256+                     // auth().x != user.x is true when auth().x is null and user is not nullable  
257+                     // other expressions are evaluated to false when null is involved  
258+                     operator  ===  '!='  ? 'true'  : 'false'  
259+                 }   } : `
260+             ) ; 
250261        } 
251262
252-         this . block ( ( )   =>   { 
253-             this . writeFieldCondition ( fieldAccess ,   ( )  =>  { 
254-                 this . block ( 
255-                     ( )  =>  { 
263+         this . block ( 
264+             ( )  =>  { 
265+                 this . writeFieldCondition ( fieldAccess ,   ( )   =>   { 
266+                     this . block ( ( )  =>  { 
256267                        const  dataModel  =  this . isModelTyped ( fieldAccess ) ; 
257-                         if  ( dataModel )  { 
258-                             const  idField  =  getIdField ( dataModel ) ; 
259-                             if  ( ! idField )  { 
268+                         if  ( dataModel  &&  isAuthInvocation ( operand ) )  { 
269+                             // right now this branch only serves comparison with `auth`, like 
270+                             //     @@allow ('all', owner == auth()) 
271+ 
272+                             const  idFields  =  getIdFields ( dataModel ) ; 
273+                             if  ( ! idFields  ||  idFields . length  ===  0 )  { 
260274                                throw  new  PluginError ( `Data model ${ dataModel . name }  ) ; 
261275                            } 
262-                             // comparing with an object, convert to "id" comparison instead 
263-                             this . writer . write ( `${ idField . name }  ) ; 
276+ 
277+                             if  ( operator  !==  '=='  &&  operator  !==  '!=' )  { 
278+                                 throw  new  PluginError ( 'Only == and != operators are allowed' ) ; 
279+                             } 
280+ 
281+                             if  ( ! isThisExpr ( fieldAccess ) )  { 
282+                                 this . writer . writeLine ( operator  ===  '=='  ? 'is:'  : 'isNot:' ) ; 
283+                                 const  fieldIsNullable  =  ! ! fieldAccess . $resolvedType ?. nullable ; 
284+                                 if  ( fieldIsNullable )  { 
285+                                     // if field is nullable, we can generate "null" check condition 
286+                                     this . writer . write ( `(user == null) ? null : ` ) ; 
287+                                 } 
288+                             } 
289+ 
264290                            this . block ( ( )  =>  { 
265-                                 this . writeOperator ( operator ,  ( )  =>  { 
266-                                     // operand ? operand.field : null 
267-                                     this . writer . write ( '(' ) ; 
268-                                     this . plain ( operand ) ; 
269-                                     this . writer . write ( ' ? ' ) ; 
270-                                     this . plain ( operand ) ; 
271-                                     this . writer . write ( `.${ idField . name }  ) ; 
272-                                     this . writer . write ( ' : null' ) ; 
273-                                     this . writer . write ( ')' ) ; 
291+                                 idFields . forEach ( ( idField ,  idx )  =>  { 
292+                                     const  writeIdsCheck  =  ( )  =>  { 
293+                                         // id: user.id 
294+                                         this . writer . write ( `${ idField . name }  ) ; 
295+                                         this . plain ( operand ) ; 
296+                                         this . writer . write ( `.${ idField . name }  ) ; 
297+                                         if  ( idx  !==  idFields . length  -  1 )  { 
298+                                             this . writer . write ( ',' ) ; 
299+                                         } 
300+                                     } ; 
301+ 
302+                                     if  ( isThisExpr ( fieldAccess )  &&  operator  ===  '!=' )  { 
303+                                         // wrap a not 
304+                                         this . writer . writeLine ( 'NOT:' ) ; 
305+                                         this . block ( ( )  =>  writeIdsCheck ( ) ) ; 
306+                                     }  else  { 
307+                                         writeIdsCheck ( ) ; 
308+                                     } 
274309                                } ) ; 
275310                            } ) ; 
276311                        }  else  { 
277-                             this . writeOperator ( operator ,  ( )  =>  { 
312+                             this . writeOperator ( operator ,  fieldAccess ,   ( )  =>  { 
278313                                this . plain ( operand ) ; 
279314                            } ) ; 
280315                        } 
281-                     } , 
282-                      // "this" expression is compiled away (to .id access), so we should 
283-                      // avoid generating a new layer 
284-                      ! isThisExpr ( fieldAccess ) 
285-                  ) ; 
286-             } ) ; 
287-         } ) ; 
316+                     } ,   ! isThisExpr ( fieldAccess ) ) ; 
317+                 } ) ; 
318+             } , 
319+             // "this" expression is compiled away (to .id access), so we should 
320+             // avoid generating a new layer 
321+             ! isThisExpr ( fieldAccess ) 
322+         ) ; 
288323    } 
289324
290325    private  isAuthOrAuthMemberAccess ( expr : Expression )  { 
291326        return  isAuthInvocation ( expr )  ||  ( isMemberAccessExpr ( expr )  &&  isAuthInvocation ( expr . operand ) ) ; 
292327    } 
293328
294-     private  writeOperator ( operator : ComparisonOperator ,  writeOperand : ( )  =>  void )  { 
295-         if  ( operator   ===   '!=' )  { 
296-             // wrap a 'not' 
297-             this . writer . write ( 'not : ' ) ; 
298-             this . block ( ( )   =>  { 
299-                 this . writeOperator ( '==' ,   writeOperand ) ; 
300-             } ) ; 
301-         }   else   { 
302-             this . writer . write ( ` ${ this . mapOperator ( operator ) } : ` ) ; 
329+     private  writeOperator ( operator : ComparisonOperator ,  fieldAccess :  Expression ,   writeOperand : ( )  =>  void )  { 
330+         if  ( isDataModel ( fieldAccess . $resolvedType ?. decl ) )  { 
331+             if   ( operator   ===   '==' )   { 
332+                  this . writer . write ( 'is : ' ) ; 
333+             }   else   if   ( operator   ===   '!=' )  { 
334+                 this . writer . write ( 'isNot: ' ) ; 
335+             }   else   { 
336+                  throw   new   PluginError ( 'Only == and != operators are allowed for data model comparison' ) ; 
337+             } 
303338            writeOperand ( ) ; 
339+         }  else  { 
340+             if  ( operator  ===  '!=' )  { 
341+                 // wrap a 'not' 
342+                 this . writer . write ( 'not: ' ) ; 
343+                 this . block ( ( )  =>  { 
344+                     this . writer . write ( `${ this . mapOperator ( '==' ) }  ) ; 
345+                     writeOperand ( ) ; 
346+                 } ) ; 
347+             }  else  { 
348+                 this . writer . write ( `${ this . mapOperator ( operator ) }  ) ; 
349+                 writeOperand ( ) ; 
350+             } 
304351        } 
305352    } 
306353
@@ -414,10 +461,37 @@ export class ExpressionWriter {
414461    } 
415462
416463    private  writeLogical ( expr : BinaryExpr ,  operator : '&&'  |  '||' )  { 
417-         this . block ( ( )  =>  { 
418-             this . writer . write ( `${ operator  ===  '&&'  ? 'AND'  : 'OR' }  ) ; 
419-             this . writeExprList ( [ expr . left ,  expr . right ] ) ; 
420-         } ) ; 
464+         // TODO: do we need short-circuit for logical operators? 
465+ 
466+         if  ( operator  ===  '&&' )  { 
467+             // // && short-circuit: left && right -> left ? right : { zenstack_guard: false } 
468+             // if (!this.hasFieldAccess(expr.left)) { 
469+             //     this.plain(expr.left); 
470+             //     this.writer.write(' ? '); 
471+             //     this.write(expr.right); 
472+             //     this.writer.write(' : '); 
473+             //     this.block(() => this.guard(() => this.writer.write('false'))); 
474+             // } else { 
475+             this . block ( ( )  =>  { 
476+                 this . writer . write ( 'AND:' ) ; 
477+                 this . writeExprList ( [ expr . left ,  expr . right ] ) ; 
478+             } ) ; 
479+             // } 
480+         }  else  { 
481+             // // || short-circuit: left || right -> left ? { zenstack_guard: true } : right 
482+             // if (!this.hasFieldAccess(expr.left)) { 
483+             //     this.plain(expr.left); 
484+             //     this.writer.write(' ? '); 
485+             //     this.block(() => this.guard(() => this.writer.write('true'))); 
486+             //     this.writer.write(' : '); 
487+             //     this.write(expr.right); 
488+             // } else { 
489+             this . block ( ( )  =>  { 
490+                 this . writer . write ( 'OR:' ) ; 
491+                 this . writeExprList ( [ expr . left ,  expr . right ] ) ; 
492+             } ) ; 
493+             // } 
494+         } 
421495    } 
422496
423497    private  writeUnary ( expr : UnaryExpr )  { 
0 commit comments