Skip to content

Commit dfae14c

Browse files
author
Xida Ren
committed
remove redundant comments
1 parent cb97048 commit dfae14c

File tree

1 file changed

+50
-44
lines changed

1 file changed

+50
-44
lines changed

lib/Conversion/TorchOnnxToTorch/OnnxLstmExpander.cpp

Lines changed: 50 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ LstmCellState lstm_cell(ImplicitLocOpBuilder &b, Value Xt, Value H_prev,
6262
LstmActivations activations) {
6363

6464
auto intType = b.getType<IntType>();
65-
ValueTensorType hTy = H_prev.getType().cast<ValueTensorType>();
65+
auto hTy = H_prev.getType().cast<ValueTensorType>();
6666

6767
Value cstOne = b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(1));
6868

@@ -123,14 +123,14 @@ LstmLayerOutput lstm_layer(ImplicitLocOpBuilder &b, Value X, Value initial_h,
123123
Location loc = b.getLoc();
124124

125125
auto xTy = X.getType().cast<ValueTensorType>();
126-
ValueTensorType hTy = initial_h.getType().cast<ValueTensorType>();
126+
auto hTy = initial_h.getType().cast<ValueTensorType>();
127127
// these names are snake_case for consistency with onnx.LSTM documentation
128128
int64_t seq_len = xTy.getSizes()[0];
129129
int64_t batch_size = xTy.getSizes()[1];
130130
int64_t input_size = xTy.getSizes()[2];
131131
int64_t hidden_size = hTy.getSizes()[1];
132132

133-
ValueTensorType cTy = hTy;
133+
auto cTy = hTy;
134134

135135
auto intType = b.getType<IntType>();
136136

@@ -161,35 +161,32 @@ LstmLayerOutput lstm_layer(ImplicitLocOpBuilder &b, Value X, Value initial_h,
161161
// Create a for-like PrimLoopOp.
162162
Value maxTripCount =
163163
b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(seq_len));
164-
Value cTrue = b.create<ConstantBoolOp>(true);
164+
Value loopConditionTrue = b.create<ConstantBoolOp>(true);
165165

166166
Type loopIndexType = intType;
167167
auto loop = b.create<PrimLoopOp>(
168-
/*results=*/TypeRange({yTy, hTy, cTy}), maxTripCount,
169-
/*initialCondition=*/cTrue,
170-
/*iterArgsInit=*/ValueRange({Y_initial, initial_h, initial_c}));
168+
TypeRange({yTy, hTy, cTy}), maxTripCount, loopConditionTrue,
169+
ValueRange({Y_initial, initial_h, initial_c}));
171170
{
172171
OpBuilder::InsertionGuard guard(b);
173-
Block *loopBody = b.createBlock(
174-
/*parentRegion=*/&loop.getRegion(),
175-
/*insertionPoint=*/loop.getRegion().begin(),
176-
/*argumentTypes=*/
177-
TypeRange({
178-
loopIndexType,
179-
yTy,
180-
hTy,
181-
cTy,
182-
}),
183-
{loc, loc, loc, loc} // locs for the loop body arguments
184-
);
172+
Block *loopBody =
173+
b.createBlock(&loop.getRegion(), loop.getRegion().begin(),
174+
TypeRange({
175+
loopIndexType,
176+
yTy,
177+
hTy,
178+
cTy,
179+
}),
180+
{loc, loc, loc, loc} // locs for the loop body arguments
181+
);
185182

186183
Value loopIndex = loopBody->getArgument(0);
187184
Value Y_prev = loopBody->getArgument(1);
188185
Value H_prev = loopBody->getArgument(2);
189186
Value C_prev = loopBody->getArgument(3);
190187

191-
ValueTensorType xTy = X.getType().cast<ValueTensorType>();
192-
ValueTensorType XtType = b.getType<ValueTensorType>(
188+
auto xTy = X.getType().cast<ValueTensorType>();
189+
auto XtType = b.getType<ValueTensorType>(
193190
llvm::SmallVector<int64_t>{batch_size, input_size}, xTy.getDtype());
194191

195192
Value Xt = b.create<AtenSelectIntOp>(XtType, X, cstZero, loopIndex);
@@ -207,9 +204,8 @@ LstmLayerOutput lstm_layer(ImplicitLocOpBuilder &b, Value X, Value initial_h,
207204
b.create<AtenSliceScatterOp>(yTy, Y_prev, H_new_unsqueezed, cstZero,
208205
loopIndex, loopIndexPlusOne, cstOne);
209206

210-
b.create<PrimLoopConditionOp>(
211-
/*shouldContinue=*/cTrue,
212-
/*iterArgs=*/ValueRange({Y_new, H_new, C_new}));
207+
b.create<PrimLoopConditionOp>(loopConditionTrue,
208+
ValueRange({Y_new, H_new, C_new}));
213209
}
214210
LstmLayerOutput output;
215211
output.Y = loop.getResult(0);
@@ -289,8 +285,8 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
289285
return rewriter.notifyMatchFailure(
290286
binder.op, "Missing required attribute hidden_size");
291287

292-
ValueTensorType xTy = X.getType().cast<ValueTensorType>();
293-
ValueTensorType wTy = W.getType().cast<ValueTensorType>();
288+
auto xTy = X.getType().cast<ValueTensorType>();
289+
auto wTy = W.getType().cast<ValueTensorType>();
294290
Value B;
295291
if (binder.tensorOperandAtIndex(B, 3)) {
296292
B = b.create<AtenZerosOp>(W.getType(), W);
@@ -332,14 +328,24 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
332328
int64_t batch_size = XShape[1];
333329
int64_t input_size = XShape[2];
334330
if (num_directions != wTy.getSizes()[0])
335-
return rewriter.notifyMatchFailure("num_directions (" + std::to_string(num_directions) + ") does not match the first dimension of wTy (" + std::to_string(wTy.getSizes()[0]) + ")");
331+
return rewriter.notifyMatchFailure(
332+
binder.op, "num_directions (" + std::to_string(num_directions) +
333+
") does not match the first dimension of wTy (" +
334+
std::to_string(wTy.getSizes()[0]) + ")");
336335
if (num_directions != 1)
337-
return rewriter.notifyMatchFailure("num_directions (" + std::to_string(num_directions) + ") is not equal to 1");
336+
return rewriter.notifyMatchFailure(
337+
binder.op, "num_directions (" + std::to_string(num_directions) +
338+
") is not equal to 1");
338339
if (4 * hidden_size != wTy.getSizes()[1])
339-
return rewriter.notifyMatchFailure("4 times hidden_size (" + std::to_string(4 * hidden_size) + ") does not match the second dimension of wTy (" + std::to_string(wTy.getSizes()[1]) + ")");
340+
return rewriter.notifyMatchFailure(
341+
binder.op, "4 times hidden_size (" + std::to_string(4 * hidden_size) +
342+
") does not match the second dimension of wTy (" +
343+
std::to_string(wTy.getSizes()[1]) + ")");
340344
if (wTy.getSizes()[2] != input_size)
341-
return rewriter.notifyMatchFailure("The third dimension of wTy (" + std::to_string(wTy.getSizes()[2]) + ") does not match input_size (" + std::to_string(input_size) + ")");
342-
345+
return rewriter.notifyMatchFailure(
346+
binder.op,
347+
"The third dimension of wTy (" + std::to_string(wTy.getSizes()[2]) +
348+
") does not match input_size (" + std::to_string(input_size) + ")");
343349

344350
/**
345351
* @brief Splits the input tensor based on the provided direction.
@@ -353,15 +359,15 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
353359
* @return The split tensor for the specified direction.
354360
*/
355361
auto getDirection = [&](int64_t direction, Value input) {
356-
ValueTensorType inputType = input.getType().cast<ValueTensorType>();
362+
auto inputType = input.getType().cast<ValueTensorType>();
357363

358364
// drop 0th dimension
359-
ValueTensorType outputType =
360-
inputType
361-
.getWithSizesAndDtype(
362-
llvm::SmallVector<int64_t>{inputType.getSizes().drop_front()},
363-
inputType.getDtype())
364-
.cast<ValueTensorType>();
365+
auto outputType =
366+
inputType
367+
.getWithSizesAndDtype(
368+
llvm::SmallVector<int64_t>{inputType.getSizes().drop_front()},
369+
inputType.getDtype())
370+
.cast<ValueTensorType>();
365371

366372
auto intType = b.getType<IntType>();
367373
Value selectDim = b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(0));
@@ -374,7 +380,7 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
374380
Value R_forward = getDirection(0, R);
375381
Value B_forward = getDirection(0, B);
376382

377-
ValueTensorType hTy = b.getType<ValueTensorType>(
383+
auto hTy = b.getType<ValueTensorType>(
378384
llvm::SmallVector<int64_t>{num_directions, batch_size, hidden_size},
379385
xTy.getDtype());
380386

@@ -430,7 +436,7 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
430436
Value inputWeightsEndIdx = intConst(4 * hidden_size);
431437
Value recurrentWeightsStartIdx = inputWeightsEndIdx;
432438
Value recurrentWeightsEndIdx = intConst(8 * hidden_size);
433-
ValueTensorType biasType = b.getType<ValueTensorType>(
439+
auto biasType = b.getType<ValueTensorType>(
434440
llvm::SmallVector<int64_t>{hidden_size * 4}, wTy.getDtype());
435441
Value Wb = b.create<AtenSliceTensorOp>(biasType,
436442
/*input=*/B_forward,
@@ -446,13 +452,13 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
446452
/*step=*/cstOne);
447453

448454
// gate splitting
449-
ValueTensorType gateBiasType = b.getType<ValueTensorType>(
455+
auto gateBiasType = b.getType<ValueTensorType>(
450456
llvm::SmallVector<int64_t>{hidden_size},
451457
Wb.getType().cast<ValueTensorType>().getDtype());
452-
ValueTensorType gateWeightsTypeIH = b.getType<ValueTensorType>(
458+
auto gateWeightsTypeIH = b.getType<ValueTensorType>(
453459
llvm::SmallVector<int64_t>{hidden_size, input_size},
454460
W_forward.getType().cast<ValueTensorType>().getDtype());
455-
ValueTensorType gateWeightsTypeHH = b.getType<ValueTensorType>(
461+
auto gateWeightsTypeHH = b.getType<ValueTensorType>(
456462
llvm::SmallVector<int64_t>{hidden_size, hidden_size},
457463
R_forward.getType().cast<ValueTensorType>().getDtype());
458464

@@ -500,7 +506,7 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
500506
LstmLayerOutput lstmLayerOutput = lstm_layer(
501507
b, X, initial_h_forward, initial_c_forward, weights, activations);
502508

503-
ValueTensorType Y_h_Y_c_unsqueezed_type = b.getType<ValueTensorType>(
509+
auto Y_h_Y_c_unsqueezed_type = b.getType<ValueTensorType>(
504510
llvm::SmallVector<int64_t>{num_directions, batch_size, hidden_size},
505511
lstmLayerOutput.Y_h.getType().cast<ValueTensorType>().getDtype());
506512
Value Y_h_unsqueezed = b.create<AtenUnsqueezeOp>(

0 commit comments

Comments
 (0)