@@ -62,7 +62,7 @@ LstmCellState lstm_cell(ImplicitLocOpBuilder &b, Value Xt, Value H_prev,
6262                        LstmActivations activations) {
6363
6464  auto  intType = b.getType <IntType>();
65-   auto  hTy = H_prev. getType (). cast <ValueTensorType>();
65+   auto  hTy = cast<ValueTensorType>(H_prev. getType () );
6666
6767  Value cstOne = b.create <ConstantIntOp>(intType, b.getI64IntegerAttr (1 ));
6868
@@ -122,8 +122,8 @@ LstmLayerOutput lstm_layer(ImplicitLocOpBuilder &b, Value X, Value initial_h,
122122
123123  Location loc = b.getLoc ();
124124
125-   auto  xTy = X. getType (). cast <ValueTensorType>();
126-   auto  hTy = initial_h. getType (). cast <ValueTensorType>();
125+   auto  xTy = cast<ValueTensorType>(X. getType () );
126+   auto  hTy = cast<ValueTensorType>(initial_h. getType () );
127127  //  these names are snake_case for consistency with onnx.LSTM documentation
128128  int64_t  seq_len = xTy.getSizes ()[0 ];
129129  int64_t  batch_size = xTy.getSizes ()[1 ];
@@ -185,7 +185,7 @@ LstmLayerOutput lstm_layer(ImplicitLocOpBuilder &b, Value X, Value initial_h,
185185    Value H_prev = loopBody->getArgument (2 );
186186    Value C_prev = loopBody->getArgument (3 );
187187
188-     auto  xTy = X. getType (). cast <ValueTensorType>();
188+     auto  xTy = cast<ValueTensorType>(X. getType () );
189189    auto  XtType = b.getType <ValueTensorType>(
190190        llvm::SmallVector<int64_t >{batch_size, input_size}, xTy.getDtype ());
191191
@@ -285,8 +285,8 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
285285    return  rewriter.notifyMatchFailure (
286286        binder.op , " Missing required attribute hidden_size"  );
287287
288-   auto  xTy = X. getType (). cast <ValueTensorType>();
289-   auto  wTy = W. getType (). cast <ValueTensorType>();
288+   auto  xTy = cast<ValueTensorType>(X. getType () );
289+   auto  wTy = cast<ValueTensorType>(W. getType () );
290290  Value B;
291291  if  (binder.tensorOperandAtIndex (B, 3 )) {
292292    B = b.create <AtenZerosOp>(W.getType (), W);
@@ -322,8 +322,6 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
322322                                           direction + " ' is provided."  );
323323  int64_t  num_directions = 1  + (direction == " bidirectional"  );
324324
325- 
326- 
327325  auto  XShape = xTy.getSizes ();
328326  int64_t  batch_size = XShape[1 ];
329327  int64_t  input_size = XShape[2 ];
@@ -349,31 +347,30 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
349347
350348  /* *
351349   * @brief Splits the input tensor based on the provided direction. 
352-    *  
353-    * This function is used to split the LSTM parameters (W, R, B) into forward and backward directions. 
354-    * The input tensor is expected to have the forward and backward parameters concatenated along the 0th dimension. 
355-    * The function returns a tensor that contains the parameters for the specified direction. 
350+    * 
351+    * This function is used to split the LSTM parameters (W, R, B) into forward 
352+    * and backward directions. The input tensor is expected to have the forward 
353+    * and backward parameters concatenated along the 0th dimension. The function 
354+    * returns a tensor that contains the parameters for the specified direction. 
356355   * 
357356   * @param direction The direction to split out. 0 for forward, 1 for backward. 
358357   * @param input The input tensor to split. 
359358   * @return The split tensor for the specified direction. 
360359   */  
361360  auto  getDirection = [&](int64_t  direction, Value input) {
362-     auto  inputType = input. getType (). cast <ValueTensorType>();
361+     auto  inputType = cast<ValueTensorType>(input. getType () );
363362
364363    //  drop 0th dimension
365-     auto  outputType =
366-         inputType
367-             .getWithSizesAndDtype (
368-                 llvm::SmallVector<int64_t >{inputType.getSizes ().drop_front ()},
369-                 inputType.getDtype ())
370-             .cast <ValueTensorType>();
364+     auto  outputType = cast<ValueTensorType>(inputType.getWithSizesAndDtype (
365+         llvm::SmallVector<int64_t >{inputType.getSizes ().drop_front ()},
366+         inputType.getDtype ()));
371367
372368    auto  intType = b.getType <IntType>();
373369    Value selectDim = b.create <ConstantIntOp>(intType, b.getI64IntegerAttr (0 ));
374370    Value cstDirection =
375-       b.create <ConstantIntOp>(intType, b.getI64IntegerAttr (direction));
376-     return  b.create <AtenSelectIntOp>(outputType, input, selectDim, cstDirection);
371+         b.create <ConstantIntOp>(intType, b.getI64IntegerAttr (direction));
372+     return  b.create <AtenSelectIntOp>(outputType, input, selectDim,
373+                                      cstDirection);
377374  };
378375
379376  Value W_forward = getDirection (0 , W);
@@ -454,13 +451,13 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
454451  //  gate splitting
455452  auto  gateBiasType = b.getType <ValueTensorType>(
456453      llvm::SmallVector<int64_t >{hidden_size},
457-       Wb. getType (). cast <ValueTensorType>().getDtype ());
454+       cast<ValueTensorType>(Wb. getType () ).getDtype ());
458455  auto  gateWeightsTypeIH = b.getType <ValueTensorType>(
459456      llvm::SmallVector<int64_t >{hidden_size, input_size},
460-       W_forward. getType (). cast <ValueTensorType>().getDtype ());
457+       cast<ValueTensorType>(W_forward. getType () ).getDtype ());
461458  auto  gateWeightsTypeHH = b.getType <ValueTensorType>(
462459      llvm::SmallVector<int64_t >{hidden_size, hidden_size},
463-       R_forward. getType (). cast <ValueTensorType>().getDtype ());
460+       cast<ValueTensorType>(R_forward. getType () ).getDtype ());
464461
465462  Value inputGateWeightsEndIdx = intConst (hidden_size);
466463  Value outputGateWeightsEndIdx = intConst (2  * hidden_size);
@@ -508,7 +505,7 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
508505
509506  auto  Y_h_Y_c_unsqueezed_type = b.getType <ValueTensorType>(
510507      llvm::SmallVector<int64_t >{num_directions, batch_size, hidden_size},
511-       lstmLayerOutput.Y_h .getType (). cast <ValueTensorType>( ).getDtype ());
508+       cast<ValueTensorType>( lstmLayerOutput.Y_h .getType ()).getDtype ());
512509  Value Y_h_unsqueezed = b.create <AtenUnsqueezeOp>(
513510      Y_h_Y_c_unsqueezed_type, lstmLayerOutput.Y_h , cstZero);
514511  Value Y_c_unsqueezed = b.create <AtenUnsqueezeOp>(
0 commit comments