@@ -62,7 +62,7 @@ LstmCellState lstm_cell(ImplicitLocOpBuilder &b, Value Xt, Value H_prev,
6262 LstmActivations activations) {
6363
6464 auto intType = b.getType <IntType>();
65- ValueTensorType hTy = H_prev.getType ().cast <ValueTensorType>();
65+ auto hTy = H_prev.getType ().cast <ValueTensorType>();
6666
6767 Value cstOne = b.create <ConstantIntOp>(intType, b.getI64IntegerAttr (1 ));
6868
@@ -123,14 +123,14 @@ LstmLayerOutput lstm_layer(ImplicitLocOpBuilder &b, Value X, Value initial_h,
123123 Location loc = b.getLoc ();
124124
125125 auto xTy = X.getType ().cast <ValueTensorType>();
126- ValueTensorType hTy = initial_h.getType ().cast <ValueTensorType>();
126+ auto hTy = initial_h.getType ().cast <ValueTensorType>();
127127 // these names are snake_case for consistency with onnx.LSTM documentation
128128 int64_t seq_len = xTy.getSizes ()[0 ];
129129 int64_t batch_size = xTy.getSizes ()[1 ];
130130 int64_t input_size = xTy.getSizes ()[2 ];
131131 int64_t hidden_size = hTy.getSizes ()[1 ];
132132
133- ValueTensorType cTy = hTy;
133+ auto cTy = hTy;
134134
135135 auto intType = b.getType <IntType>();
136136
@@ -161,35 +161,32 @@ LstmLayerOutput lstm_layer(ImplicitLocOpBuilder &b, Value X, Value initial_h,
161161 // Create a for-like PrimLoopOp.
162162 Value maxTripCount =
163163 b.create <ConstantIntOp>(intType, b.getI64IntegerAttr (seq_len));
164- Value cTrue = b.create <ConstantBoolOp>(true );
164+ Value loopConditionTrue = b.create <ConstantBoolOp>(true );
165165
166166 Type loopIndexType = intType;
167167 auto loop = b.create <PrimLoopOp>(
168- /* results=*/ TypeRange ({yTy, hTy, cTy}), maxTripCount,
169- /* initialCondition=*/ cTrue,
170- /* iterArgsInit=*/ ValueRange ({Y_initial, initial_h, initial_c}));
168+ TypeRange ({yTy, hTy, cTy}), maxTripCount, loopConditionTrue,
169+ ValueRange ({Y_initial, initial_h, initial_c}));
171170 {
172171 OpBuilder::InsertionGuard guard (b);
173- Block *loopBody = b.createBlock (
174- /* parentRegion=*/ &loop.getRegion (),
175- /* insertionPoint=*/ loop.getRegion ().begin (),
176- /* argumentTypes=*/
177- TypeRange ({
178- loopIndexType,
179- yTy,
180- hTy,
181- cTy,
182- }),
183- {loc, loc, loc, loc} // locs for the loop body arguments
184- );
172+ Block *loopBody =
173+ b.createBlock (&loop.getRegion (), loop.getRegion ().begin (),
174+ TypeRange ({
175+ loopIndexType,
176+ yTy,
177+ hTy,
178+ cTy,
179+ }),
180+ {loc, loc, loc, loc} // locs for the loop body arguments
181+ );
185182
186183 Value loopIndex = loopBody->getArgument (0 );
187184 Value Y_prev = loopBody->getArgument (1 );
188185 Value H_prev = loopBody->getArgument (2 );
189186 Value C_prev = loopBody->getArgument (3 );
190187
191- ValueTensorType xTy = X.getType ().cast <ValueTensorType>();
192- ValueTensorType XtType = b.getType <ValueTensorType>(
188+ auto xTy = X.getType ().cast <ValueTensorType>();
189+ auto XtType = b.getType <ValueTensorType>(
193190 llvm::SmallVector<int64_t >{batch_size, input_size}, xTy.getDtype ());
194191
195192 Value Xt = b.create <AtenSelectIntOp>(XtType, X, cstZero, loopIndex);
@@ -207,9 +204,8 @@ LstmLayerOutput lstm_layer(ImplicitLocOpBuilder &b, Value X, Value initial_h,
207204 b.create <AtenSliceScatterOp>(yTy, Y_prev, H_new_unsqueezed, cstZero,
208205 loopIndex, loopIndexPlusOne, cstOne);
209206
210- b.create <PrimLoopConditionOp>(
211- /* shouldContinue=*/ cTrue,
212- /* iterArgs=*/ ValueRange ({Y_new, H_new, C_new}));
207+ b.create <PrimLoopConditionOp>(loopConditionTrue,
208+ ValueRange ({Y_new, H_new, C_new}));
213209 }
214210 LstmLayerOutput output;
215211 output.Y = loop.getResult (0 );
@@ -289,8 +285,8 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
289285 return rewriter.notifyMatchFailure (
290286 binder.op , " Missing required attribute hidden_size" );
291287
292- ValueTensorType xTy = X.getType ().cast <ValueTensorType>();
293- ValueTensorType wTy = W.getType ().cast <ValueTensorType>();
288+ auto xTy = X.getType ().cast <ValueTensorType>();
289+ auto wTy = W.getType ().cast <ValueTensorType>();
294290 Value B;
295291 if (binder.tensorOperandAtIndex (B, 3 )) {
296292 B = b.create <AtenZerosOp>(W.getType (), W);
@@ -332,14 +328,24 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
332328 int64_t batch_size = XShape[1 ];
333329 int64_t input_size = XShape[2 ];
334330 if (num_directions != wTy.getSizes ()[0 ])
335- return rewriter.notifyMatchFailure (" num_directions (" + std::to_string (num_directions) + " ) does not match the first dimension of wTy (" + std::to_string (wTy.getSizes ()[0 ]) + " )" );
331+ return rewriter.notifyMatchFailure (
332+ binder.op , " num_directions (" + std::to_string (num_directions) +
333+ " ) does not match the first dimension of wTy (" +
334+ std::to_string (wTy.getSizes ()[0 ]) + " )" );
336335 if (num_directions != 1 )
337- return rewriter.notifyMatchFailure (" num_directions (" + std::to_string (num_directions) + " ) is not equal to 1" );
336+ return rewriter.notifyMatchFailure (
337+ binder.op , " num_directions (" + std::to_string (num_directions) +
338+ " ) is not equal to 1" );
338339 if (4 * hidden_size != wTy.getSizes ()[1 ])
339- return rewriter.notifyMatchFailure (" 4 times hidden_size (" + std::to_string (4 * hidden_size) + " ) does not match the second dimension of wTy (" + std::to_string (wTy.getSizes ()[1 ]) + " )" );
340+ return rewriter.notifyMatchFailure (
341+ binder.op , " 4 times hidden_size (" + std::to_string (4 * hidden_size) +
342+ " ) does not match the second dimension of wTy (" +
343+ std::to_string (wTy.getSizes ()[1 ]) + " )" );
340344 if (wTy.getSizes ()[2 ] != input_size)
341- return rewriter.notifyMatchFailure (" The third dimension of wTy (" + std::to_string (wTy.getSizes ()[2 ]) + " ) does not match input_size (" + std::to_string (input_size) + " )" );
342-
345+ return rewriter.notifyMatchFailure (
346+ binder.op ,
347+ " The third dimension of wTy (" + std::to_string (wTy.getSizes ()[2 ]) +
348+ " ) does not match input_size (" + std::to_string (input_size) + " )" );
343349
344350 /* *
345351 * @brief Splits the input tensor based on the provided direction.
@@ -353,15 +359,15 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
353359 * @return The split tensor for the specified direction.
354360 */
355361 auto getDirection = [&](int64_t direction, Value input) {
356- ValueTensorType inputType = input.getType ().cast <ValueTensorType>();
362+ auto inputType = input.getType ().cast <ValueTensorType>();
357363
358364 // drop 0th dimension
359- ValueTensorType outputType =
360- inputType
361- .getWithSizesAndDtype (
362- llvm::SmallVector<int64_t >{inputType.getSizes ().drop_front ()},
363- inputType.getDtype ())
364- .cast <ValueTensorType>();
365+ auto outputType =
366+ inputType
367+ .getWithSizesAndDtype (
368+ llvm::SmallVector<int64_t >{inputType.getSizes ().drop_front ()},
369+ inputType.getDtype ())
370+ .cast <ValueTensorType>();
365371
366372 auto intType = b.getType <IntType>();
367373 Value selectDim = b.create <ConstantIntOp>(intType, b.getI64IntegerAttr (0 ));
@@ -374,7 +380,7 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
374380 Value R_forward = getDirection (0 , R);
375381 Value B_forward = getDirection (0 , B);
376382
377- ValueTensorType hTy = b.getType <ValueTensorType>(
383+ auto hTy = b.getType <ValueTensorType>(
378384 llvm::SmallVector<int64_t >{num_directions, batch_size, hidden_size},
379385 xTy.getDtype ());
380386
@@ -430,7 +436,7 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
430436 Value inputWeightsEndIdx = intConst (4 * hidden_size);
431437 Value recurrentWeightsStartIdx = inputWeightsEndIdx;
432438 Value recurrentWeightsEndIdx = intConst (8 * hidden_size);
433- ValueTensorType biasType = b.getType <ValueTensorType>(
439+ auto biasType = b.getType <ValueTensorType>(
434440 llvm::SmallVector<int64_t >{hidden_size * 4 }, wTy.getDtype ());
435441 Value Wb = b.create <AtenSliceTensorOp>(biasType,
436442 /* input=*/ B_forward,
@@ -446,13 +452,13 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
446452 /* step=*/ cstOne);
447453
448454 // gate splitting
449- ValueTensorType gateBiasType = b.getType <ValueTensorType>(
455+ auto gateBiasType = b.getType <ValueTensorType>(
450456 llvm::SmallVector<int64_t >{hidden_size},
451457 Wb.getType ().cast <ValueTensorType>().getDtype ());
452- ValueTensorType gateWeightsTypeIH = b.getType <ValueTensorType>(
458+ auto gateWeightsTypeIH = b.getType <ValueTensorType>(
453459 llvm::SmallVector<int64_t >{hidden_size, input_size},
454460 W_forward.getType ().cast <ValueTensorType>().getDtype ());
455- ValueTensorType gateWeightsTypeHH = b.getType <ValueTensorType>(
461+ auto gateWeightsTypeHH = b.getType <ValueTensorType>(
456462 llvm::SmallVector<int64_t >{hidden_size, hidden_size},
457463 R_forward.getType ().cast <ValueTensorType>().getDtype ());
458464
@@ -500,7 +506,7 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
500506 LstmLayerOutput lstmLayerOutput = lstm_layer (
501507 b, X, initial_h_forward, initial_c_forward, weights, activations);
502508
503- ValueTensorType Y_h_Y_c_unsqueezed_type = b.getType <ValueTensorType>(
509+ auto Y_h_Y_c_unsqueezed_type = b.getType <ValueTensorType>(
504510 llvm::SmallVector<int64_t >{num_directions, batch_size, hidden_size},
505511 lstmLayerOutput.Y_h .getType ().cast <ValueTensorType>().getDtype ());
506512 Value Y_h_unsqueezed = b.create <AtenUnsqueezeOp>(
0 commit comments