Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed GRU quality issues exposed by e2e tests #3753

Merged
merged 1 commit into from
Oct 2, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 42 additions & 42 deletions lib/Conversion/TorchOnnxToTorch/OnnxRecurrentLayerOpExpanders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1072,11 +1072,10 @@ LogicalResult OnnxGruExpander(OpBinder binder,
Value cstNone = b.create<ConstantNoneOp>();
Value cstZero = b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(0));
Value cstOne = b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(1));
Value cstTwo = b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(2));

// Binding arguments
ValueTensorType yTy, Y_hType;
if (binder.tensorResultTypeAtIndex(yTy, 0) ||
if (binder.tensorResultTypeAtIndex(yTy, 0) &&
binder.tensorResultTypeAtIndex(Y_hType, 1)) {
return rewriter.notifyMatchFailure(binder.op,
"At least one output must be present");
Expand Down Expand Up @@ -1132,6 +1131,7 @@ LogicalResult OnnxGruExpander(OpBinder binder,
// Validations
auto XShape = xTy.getSizes();
int64_t batch_size = (layout == 0) ? XShape[1] : XShape[0];
int64_t seq_len = (layout == 0) ? XShape[0] : XShape[1];
int64_t input_size = XShape[2];

std::ostringstream oss;
Expand Down Expand Up @@ -1173,6 +1173,10 @@ LogicalResult OnnxGruExpander(OpBinder binder,
Value cstDtype = getDtypeIntValueForType(rewriter, loc, xTy.getDtype());
initial_h =
b.create<AtenZerosOp>(hTy, hShape, cstDtype, cstNone, cstNone, cstNone);
} else {
if (layout == 1) {
initial_h = StaticTranspose(b, initial_h, 0, 1);
}
}

if (binder.tensorOperandAtIndex(sequence_lens, 4))
Expand All @@ -1192,10 +1196,10 @@ LogicalResult OnnxGruExpander(OpBinder binder,
// fill in B
Value cstXDtype = getDtypeIntValueForType(rewriter, loc, xTy.getDtype());
if (B == nullptr) {
SmallVector<int64_t> BShape = {num_directions, 2 * hidden_size};
SmallVector<int64_t> BShape = {num_directions, 6 * hidden_size};
SmallVector<Value> BShapeListContents = {
b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(num_directions)),
b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(2 * hidden_size))};
b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(6 * hidden_size))};
Value BShapeList = b.create<PrimListConstructOp>(
b.getType<ListType>(intType), BShapeListContents);
auto BType = b.getType<ValueTensorType>(BShape, wTy.getDtype());
Expand Down Expand Up @@ -1256,51 +1260,47 @@ LogicalResult OnnxGruExpander(OpBinder binder,
B_slices[4], B_slices[5]);

// Process inputs based on layout
Value X_processed, initial_h_processed;
ValueTensorType yTy_processed, Y_hType_processed;

if (layout == 0) {
X_processed = X;
initial_h_processed = initial_h_forward;
yTy_processed = yTy;
Y_hType_processed = Y_hType;
} else {
X_processed = b.create<AtenTransposeIntOp>(X.getType(), X, cstZero, cstOne);
initial_h_processed = b.create<AtenTransposeIntOp>(
initial_h.getType(), initial_h_forward, cstZero, cstOne);

auto yTySizes = yTy.getSizes();
auto Y_hTypeSizes = Y_hType.getSizes();

yTy_processed = b.getType<ValueTensorType>(
llvm::SmallVector<int64_t>{yTySizes[1], yTySizes[0], yTySizes[2],
yTySizes[3]},
yTy.getDtype());

Y_hType_processed = b.getType<ValueTensorType>(
llvm::SmallVector<int64_t>{Y_hTypeSizes[1], Y_hTypeSizes[0],
Y_hTypeSizes[2]},
Y_hType.getDtype());
if (layout == 1) {
X = StaticTranspose(b, X, 0, 1);
}

// Weights and biases ready. Calling GRU layer to insert the actual ops.
GruLayerOutput gruLayerOutput =
gru_layer(b, X_processed, initial_h_processed, weights, activations,
linear_before_reset);
GruLayerOutput gruLayerOutput = gru_layer(b, X, initial_h_forward, weights,
activations, linear_before_reset);

// Process outputs based on layout
Value Y_final, Y_h_final;
if (layout == 0) {
Y_final = b.create<AtenUnsqueezeOp>(yTy, gruLayerOutput.Y, cstOne);
Y_h_final = b.create<AtenUnsqueezeOp>(Y_hType, gruLayerOutput.Y_h, cstZero);
Value Y_final;
if (binder.tensorResultTypeAtIndex(yTy, 0)) {
Y_final = cstNone;
} else {
auto Y_transposed = b.create<AtenTransposeIntOp>(
gruLayerOutput.Y.getType(), gruLayerOutput.Y, cstZero, cstOne);
Y_final = b.create<AtenUnsqueezeOp>(yTy, Y_transposed, cstTwo);
if (layout == 0) {
Y_final = b.create<AtenUnsqueezeOp>(yTy, gruLayerOutput.Y, cstOne);
} else {
Type yTy_original = b.getType<ValueTensorType>(
llvm::SmallVector<int64_t>{seq_len, 1, batch_size, hidden_size},
yTy.getDtype());
Y_final =
b.create<AtenUnsqueezeOp>(yTy_original, gruLayerOutput.Y, cstOne);
Y_final = StaticTranspose(b, Y_final, 1, 2);
Y_final = StaticTranspose(b, Y_final, 0, 1);
}
}

auto Y_h_transposed = b.create<AtenTransposeIntOp>(
gruLayerOutput.Y_h.getType(), gruLayerOutput.Y_h, cstZero, cstOne);
Y_h_final = b.create<AtenUnsqueezeOp>(Y_hType, Y_h_transposed, cstZero);
Value Y_h_final;
if (binder.tensorResultTypeAtIndex(Y_hType, 1)) {
Y_h_final = cstNone;
} else {
if (layout == 0) {
Y_h_final =
b.create<AtenUnsqueezeOp>(Y_hType, gruLayerOutput.Y_h, cstZero);
} else {
Type y_hTy_original = b.getType<ValueTensorType>(
llvm::SmallVector<int64_t>{1, batch_size, hidden_size},
Y_hType.getDtype());
Y_h_final = b.create<AtenUnsqueezeOp>(y_hTy_original, gruLayerOutput.Y_h,
cstZero);
Y_h_final = StaticTranspose(b, Y_h_final, 0, 1);
}
}

rewriter.replaceOp(binder.op, mlir::ValueRange{Y_final, Y_h_final});
Expand Down
Loading