Skip to content

Commit

Permalink
Fix test_functional_regressions.py::test_vecmat
Browse files Browse the repository at this point in the history
Signed-off-by: Whitney Tsang <whitney.tsang@intel.com>
  • Loading branch information
whitneywhtsang committed Jan 19, 2025
1 parent 083eafa commit fa341e9
Showing 1 changed file with 5 additions and 28 deletions.
33 changes: 5 additions & 28 deletions third_party/intel/lib/TritonIntelGPUToLLVM/ViewOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,36 +299,13 @@ struct TransOpConversion : public ConvertOpToLLVMPattern<TransOp> {
LogicalResult
matchAndRewrite(TransOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto resultTy = cast<TensorOrMemDesc>(op.getType());
if (auto enc = dyn_cast<SharedEncodingAttr>(resultTy.getEncoding())) {
auto llvmElemTy =
getTypeConverter()->convertType(resultTy.getElementType());
auto srcSmemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(),
llvmElemTy, rewriter);
auto dstSmemObj = SharedMemoryObject(
srcSmemObj.getBase(), srcSmemObj.getBaseElemType(),
/*offsets=*/applyPermutation(srcSmemObj.getOffsets(), op.getOrder()));
auto retVal = getStructFromSharedMemoryObject(loc, dstSmemObj, rewriter);
rewriter.replaceOp(op, retVal);
return success();
} else if (auto enc = mlir::dyn_cast<BlockedEncodingAttr>(
resultTy.getEncoding())) {
// If the dst encoding is blocked, then TransOp::inferReturnTypes
// ensures that:
// - the src encoding is also blocked, and
// - the translation from src to dst is just a "renaming" of the
// registers, i.e. each thread has exactly the same values.
// Thus the transpose op simply returns the same values it got.
auto vals = unpackLLElements(loc, adaptor.getSrc(), rewriter);
Value ret = packLLElements(loc, this->getTypeConverter(), vals, rewriter,
resultTy);
rewriter.replaceOp(op, ret);
return success();
}
return emitOptionalError(loc, "unsupported encoding for TransOp");
// By construction, TransOp::inferReturnTypes ensures that the src encoding
// is the same as the dst encoding so that this op is a no-op.
rewriter.replaceOp(op, adaptor.getSrc());
return success();
}
};

struct BroadcastOpConversion
: public ConvertOpToLLVMPattern<triton::BroadcastOp> {
using ConvertOpToLLVMPattern<triton::BroadcastOp>::ConvertOpToLLVMPattern;
Expand Down

0 comments on commit fa341e9

Please sign in to comment.