Skip to content

Commit ffe332e

Browse files
author
Xida Ren
committed
use non-member cast instead
1 parent dfae14c commit ffe332e

File tree

1 file changed

+22
-25
lines changed

1 file changed

+22
-25
lines changed

lib/Conversion/TorchOnnxToTorch/OnnxLstmExpander.cpp

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ LstmCellState lstm_cell(ImplicitLocOpBuilder &b, Value Xt, Value H_prev,
6262
LstmActivations activations) {
6363

6464
auto intType = b.getType<IntType>();
65-
auto hTy = H_prev.getType().cast<ValueTensorType>();
65+
auto hTy = cast<ValueTensorType>(H_prev.getType());
6666

6767
Value cstOne = b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(1));
6868

@@ -122,8 +122,8 @@ LstmLayerOutput lstm_layer(ImplicitLocOpBuilder &b, Value X, Value initial_h,
122122

123123
Location loc = b.getLoc();
124124

125-
auto xTy = X.getType().cast<ValueTensorType>();
126-
auto hTy = initial_h.getType().cast<ValueTensorType>();
125+
auto xTy = cast<ValueTensorType>(X.getType());
126+
auto hTy = cast<ValueTensorType>(initial_h.getType());
127127
// these names are snake_case for consistency with onnx.LSTM documentation
128128
int64_t seq_len = xTy.getSizes()[0];
129129
int64_t batch_size = xTy.getSizes()[1];
@@ -185,7 +185,7 @@ LstmLayerOutput lstm_layer(ImplicitLocOpBuilder &b, Value X, Value initial_h,
185185
Value H_prev = loopBody->getArgument(2);
186186
Value C_prev = loopBody->getArgument(3);
187187

188-
auto xTy = X.getType().cast<ValueTensorType>();
188+
auto xTy = cast<ValueTensorType>(X.getType());
189189
auto XtType = b.getType<ValueTensorType>(
190190
llvm::SmallVector<int64_t>{batch_size, input_size}, xTy.getDtype());
191191

@@ -285,8 +285,8 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
285285
return rewriter.notifyMatchFailure(
286286
binder.op, "Missing required attribute hidden_size");
287287

288-
auto xTy = X.getType().cast<ValueTensorType>();
289-
auto wTy = W.getType().cast<ValueTensorType>();
288+
auto xTy = cast<ValueTensorType>(X.getType());
289+
auto wTy = cast<ValueTensorType>(W.getType());
290290
Value B;
291291
if (binder.tensorOperandAtIndex(B, 3)) {
292292
B = b.create<AtenZerosOp>(W.getType(), W);
@@ -322,8 +322,6 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
322322
direction + "' is provided.");
323323
int64_t num_directions = 1 + (direction == "bidirectional");
324324

325-
326-
327325
auto XShape = xTy.getSizes();
328326
int64_t batch_size = XShape[1];
329327
int64_t input_size = XShape[2];
@@ -349,31 +347,30 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
349347

350348
/**
351349
* @brief Splits the input tensor based on the provided direction.
352-
*
353-
* This function is used to split the LSTM parameters (W, R, B) into forward and backward directions.
354-
* The input tensor is expected to have the forward and backward parameters concatenated along the 0th dimension.
355-
* The function returns a tensor that contains the parameters for the specified direction.
350+
*
351+
* This function is used to split the LSTM parameters (W, R, B) into forward
352+
* and backward directions. The input tensor is expected to have the forward
353+
* and backward parameters concatenated along the 0th dimension. The function
354+
* returns a tensor that contains the parameters for the specified direction.
356355
*
357356
* @param direction The direction to split out. 0 for forward, 1 for backward.
358357
* @param input The input tensor to split.
359358
* @return The split tensor for the specified direction.
360359
*/
361360
auto getDirection = [&](int64_t direction, Value input) {
362-
auto inputType = input.getType().cast<ValueTensorType>();
361+
auto inputType = cast<ValueTensorType>(input.getType());
363362

364363
// drop 0th dimension
365-
auto outputType =
366-
inputType
367-
.getWithSizesAndDtype(
368-
llvm::SmallVector<int64_t>{inputType.getSizes().drop_front()},
369-
inputType.getDtype())
370-
.cast<ValueTensorType>();
364+
auto outputType = cast<ValueTensorType>(inputType.getWithSizesAndDtype(
365+
llvm::SmallVector<int64_t>{inputType.getSizes().drop_front()},
366+
inputType.getDtype()));
371367

372368
auto intType = b.getType<IntType>();
373369
Value selectDim = b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(0));
374370
Value cstDirection =
375-
b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(direction));
376-
return b.create<AtenSelectIntOp>(outputType, input, selectDim, cstDirection);
371+
b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(direction));
372+
return b.create<AtenSelectIntOp>(outputType, input, selectDim,
373+
cstDirection);
377374
};
378375

379376
Value W_forward = getDirection(0, W);
@@ -454,13 +451,13 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
454451
// gate splitting
455452
auto gateBiasType = b.getType<ValueTensorType>(
456453
llvm::SmallVector<int64_t>{hidden_size},
457-
Wb.getType().cast<ValueTensorType>().getDtype());
454+
cast<ValueTensorType>(Wb.getType()).getDtype());
458455
auto gateWeightsTypeIH = b.getType<ValueTensorType>(
459456
llvm::SmallVector<int64_t>{hidden_size, input_size},
460-
W_forward.getType().cast<ValueTensorType>().getDtype());
457+
cast<ValueTensorType>(W_forward.getType()).getDtype());
461458
auto gateWeightsTypeHH = b.getType<ValueTensorType>(
462459
llvm::SmallVector<int64_t>{hidden_size, hidden_size},
463-
R_forward.getType().cast<ValueTensorType>().getDtype());
460+
cast<ValueTensorType>(R_forward.getType()).getDtype());
464461

465462
Value inputGateWeightsEndIdx = intConst(hidden_size);
466463
Value outputGateWeightsEndIdx = intConst(2 * hidden_size);
@@ -508,7 +505,7 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
508505

509506
auto Y_h_Y_c_unsqueezed_type = b.getType<ValueTensorType>(
510507
llvm::SmallVector<int64_t>{num_directions, batch_size, hidden_size},
511-
lstmLayerOutput.Y_h.getType().cast<ValueTensorType>().getDtype());
508+
cast<ValueTensorType>(lstmLayerOutput.Y_h.getType()).getDtype());
512509
Value Y_h_unsqueezed = b.create<AtenUnsqueezeOp>(
513510
Y_h_Y_c_unsqueezed_type, lstmLayerOutput.Y_h, cstZero);
514511
Value Y_c_unsqueezed = b.create<AtenUnsqueezeOp>(

0 commit comments

Comments
 (0)