@@ -84,13 +84,14 @@ def lqr_direct(rho):
8484 return anp .concatenate ([K .flatten (), P .flatten (), C1 .flatten (), C2 .flatten ()])
8585
8686 derivs = jacobian (lqr_direct )(rho_value )
87- # split into four blocks and reshape
88- sizes = [4 * 12 , 12 * 12 , 4 * 4 , 12 * 12 ]
89- parts = np .split (np .array (derivs ), np .cumsum (sizes )[:- 1 ])
90- dK = parts [0 ].reshape (4 , 12 )
91- dP = parts [1 ].reshape (12 , 12 )
92- dC1 = parts [2 ].reshape (4 , 4 )
93- dC2 = parts [3 ].reshape (12 , 12 )
87+ # Dynamically split the derivative vector based on matrix sizes
88+ deriv_array = np .array (derivs )
89+ nu , nx = Bdyn .shape [1 ], Adyn .shape [0 ]
90+ idx = 0
91+ dK = deriv_array [idx :idx + nu * nx ].reshape (nu , nx ); idx += nu * nx
92+ dP = deriv_array [idx :idx + nx * nx ].reshape (nx , nx ); idx += nx * nx
93+ dC1 = deriv_array [idx :idx + nu * nu ].reshape (nu , nu ); idx += nu * nu
94+ dC2 = deriv_array [idx :idx + nx * nx ].reshape (nx , nx )
9495
9596 # Generate code with sensitivity matrices
9697 prob .codegen_with_sensitivity ("out" , dK , dP , dC1 , dC2 , verbose = 1 )
0 commit comments