@@ -344,84 +344,67 @@ def connection_pattern(self, node):
344344
345345        return  [[True  for  _  in  node .outputs ] for  _  in  node .inputs ]
346346
347-     def  _bgrad (self , inputs , outputs , ograds ):
348-         # Grad, with respect to broadcasted versions of inputs 
349- 
350-         def  as_core (t , core_t ):
351-             # Inputs could be NullType or DisconnectedType 
352-             if  isinstance (t .type , NullType  |  DisconnectedType ):
353-                 return  t 
354-             return  core_t .type ()
347+     def  L_op (self , inputs , outputs , output_gradients ):
348+         batch_ndim  =  self .batch_ndim (outputs [0 ].owner )
355349
350+         # Obtain core_op gradients 
356351        with  config .change_flags (compute_test_value = "off" ):
357-             safe_inputs  =  [
352+             core_inputs  =  [
358353                tensor (
359354                    dtype = inp .type .dtype ,
360-                     shape = inp .type .shape [inp . type . ndim   -   len ( sig )  :],
355+                     shape = inp .type .shape [batch_ndim :],
361356                )
362-                 for  inp , sig  in  zip (inputs , self .inputs_sig , strict = True )
363-             ]
364-             core_node  =  self ._create_dummy_core_node (safe_inputs )
365- 
366-             core_inputs  =  [
367-                 as_core (inp , core_inp )
368-                 for  inp , core_inp  in  zip (inputs , core_node .inputs , strict = True )
369-             ]
370-             core_ograds  =  [
371-                 as_core (ograd , core_ograd )
372-                 for  ograd , core_ograd  in  zip (ograds , core_node .outputs , strict = True )
357+                 for  inp  in  inputs 
373358            ]
374-             # FIXME: These core_outputs do not depend on core_inputs, not pretty 
375-             # It's not neccessarily a problem because if they are referenced by the gradient, 
376-             # they get replaced later in vectorize. But if the Op was to make any decision 
377-             # by introspecting the dependencies of output on inputs it would fail badly! 
359+             core_node  =  self ._create_dummy_core_node (core_inputs )
378360            core_outputs  =  core_node .outputs 
379361
380-             core_igrads  =  self .core_op .L_op (core_inputs , core_outputs , core_ograds )
381- 
382-         igrads  =  vectorize_graph (
383-             [core_igrad  for  core_igrad  in  core_igrads  if  core_igrad  is  not   None ],
384-             replace = dict (
385-                 zip (
386-                     core_inputs  +  core_outputs  +  core_ograds ,
387-                     inputs  +  outputs  +  ograds ,
388-                     strict = True ,
362+             # Define core output_gradients, but keep original disconnected/null output_gradients (if any) 
363+             core_output_gradients  =  [
364+                 output_grad 
365+                 if  isinstance (output_grad .type , NullType  |  DisconnectedType )
366+                 else  core_output .type ()
367+                 for  output_grad , core_output  in  zip (
368+                     output_gradients , core_outputs , strict = True 
389369                )
390-             ),
391-         )
392- 
393-         igrads_iter  =  iter (igrads )
394-         return  [
395-             None  if  core_igrad  is  None  else  next (igrads_iter )
396-             for  core_igrad  in  core_igrads 
397-         ]
370+             ]
398371
399-     def  L_op (self , inputs , outs , ograds ):
400-         from  pytensor .tensor .math  import  sum  as  pt_sum 
372+             core_input_gradients  =  self .core_op .L_op (
373+                 core_inputs , core_outputs , core_output_gradients 
374+             )
401375
402-         # Compute grad with respect to broadcasted input 
403-         rval  =  self ._bgrad (inputs , outs , ograds )
376+         # Vectorize gradients to batch inputs 
377+         input_gradients  =  list (
378+             vectorize_graph (
379+                 core_input_gradients ,
380+                 replace = dict (
381+                     zip (
382+                         core_inputs  +  core_outputs  +  core_output_gradients ,
383+                         inputs  +  outputs  +  output_gradients ,
384+                         strict = True ,
385+                     )
386+                 ),
387+             )
388+         )
404389
405-         # Sum out the broadcasted dimensions 
406-         batch_ndims  =  self .batch_ndim (outs [0 ].owner )
407-         batch_shape  =  outs [0 ].type .shape [:batch_ndims ]
390+         # Sum out the broadcasted batch dimensions 
391+         batch_shape  =  outputs [0 ].type .shape [:batch_ndim ]
408392        for  i , (inp , sig ) in  enumerate (zip (inputs , self .inputs_sig , strict = True )):
409-             if  isinstance (rval [i ].type , NullType  |  DisconnectedType ):
393+             if  isinstance (input_gradients [i ].type , NullType  |  DisconnectedType ):
410394                continue 
411395
412-             assert  inp .type .ndim  ==  batch_ndims  +  len (sig )
396+             assert  inp .type .ndim  ==  batch_ndim  +  len (sig )
413397
414-             to_sum  =  [
398+             if   to_sum  : =  [
415399                j 
416400                for  j , (inp_s , out_s ) in  enumerate (
417401                    zip (inp .type .shape , batch_shape , strict = False )
418402                )
419403                if  inp_s  ==  1  and  out_s  !=  1 
420-             ]
421-             if  to_sum :
422-                 rval [i ] =  pt_sum (rval [i ], axis = to_sum , keepdims = True )
404+             ]:
405+                 input_gradients [i ] =  input_gradients [i ].sum (axis = to_sum , keepdims = True )
423406
424-         return  rval 
407+         return  input_gradients 
425408
426409    def  _create_node_gufunc (self , node : Apply , impl ) ->  Callable :
427410        """Define (or retrieve) the node gufunc used in `perform`. 
0 commit comments