Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Bump forward and refactor inline global slots to no longer track via
symlinks. This appears to make the tests past until we manage to remove
torchscript work.
  • Loading branch information
rsuderman authored Sep 10, 2024
1 parent b35675a commit 6934ab8
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 75 deletions.
2 changes: 1 addition & 1 deletion externals/llvm-project
Submodule llvm-project updated 2353 files
117 changes: 52 additions & 65 deletions lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,15 @@ using namespace mlir::torch::Torch;
/// a single module. If we had to support complex nested symbol references, we
/// would probably want to go through the effort to indirect through the symbol
/// tables to make things clearer.
class FlatSymbolRefProgramPoint
: public GenericProgramPointBase<FlatSymbolRefProgramPoint,
FlatSymbolRefAttr> {
class FlatSymbolRefLatticeAnchor
: public GenericLatticeAnchorBase<FlatSymbolRefLatticeAnchor, Operation *> {
public:
using Base::Base;
void print(raw_ostream &os) const override {
os << "FlatSymbolRefProgramPoint(" << getValue() << ")";
os << "FlatSymbolRefLatticeAnchor(" << getValue() << ")";
}
Location getLoc() const override {
return UnknownLoc::get(getValue().getContext());
return UnknownLoc::get(getValue()->getContext());
}
};

Expand All @@ -84,7 +83,7 @@ static bool isUseTreatedWithValueSemantics(OpOperand &use) {
/// State tracking if an IR construct is "safe".
///
/// This state is tracked on Value's and also on global slots (via a
/// FlatSymbolRefProgramPoint).
/// FlatSymbolRefLatticeAnchor).
///
/// In this context, "safe" means that the object is safe to inline.
/// This covers a few concepts
Expand All @@ -93,7 +92,7 @@ static bool isUseTreatedWithValueSemantics(OpOperand &use) {
/// unsafe
class InlineGlobalSlotsAnalysisState : public AnalysisState {
public:
InlineGlobalSlotsAnalysisState(ProgramPoint point) : AnalysisState(point) {
InlineGlobalSlotsAnalysisState(LatticeAnchor point) : AnalysisState(point) {
(void)setSafe();
}

Expand Down Expand Up @@ -147,33 +146,33 @@ class InlineGlobalSlotsAnalysis : public DataFlowAnalysis {

InlineGlobalSlotsAnalysis::InlineGlobalSlotsAnalysis(DataFlowSolver &solver)
: DataFlowAnalysis(solver) {
registerPointKind<FlatSymbolRefProgramPoint>();
registerAnchorKind<FlatSymbolRefLatticeAnchor>();
}

LogicalResult InlineGlobalSlotsAnalysis::initialize(Operation *top) {
auto walkResult = top->walk([this](Operation *op) {
if (auto globalSlot = dyn_cast<Torch::GlobalSlotOp>(op)) {
auto *state = getOrCreate<InlineGlobalSlotsAnalysisState>(
getProgramPoint<FlatSymbolRefProgramPoint>(
FlatSymbolRefAttr::get(globalSlot.getSymNameAttr())));
getLatticeAnchor<FlatSymbolRefLatticeAnchor>(globalSlot));
propagateIfChanged(state,
state->setSafe(globalSlot.getVisibility() !=
SymbolTable::Visibility::Public));
}
if (auto globalSlotSet = dyn_cast<Torch::GlobalSlotSetOp>(op)) {
auto globalSlot = SymbolTable::lookupNearestSymbolFrom<GlobalSlotOp>(
globalSlotSet, globalSlotSet.getSlotAttr());

auto *state = getOrCreate<InlineGlobalSlotsAnalysisState>(
getProgramPoint<FlatSymbolRefProgramPoint>(
globalSlotSet.getSlotAttr()));
getLatticeAnchor<FlatSymbolRefLatticeAnchor>(globalSlot));
propagateIfChanged(state, state->setSafe(false));
}
// Save the InitializeGlobalSlotsOp for later referencee
if (auto initialize = dyn_cast<Torch::InitializeGlobalSlotsOp>(op)) {
initializeGlobalSlotsOp = initialize;
}
for (Value result : op->getResults()) {
if (failed(visit(result)))
return WalkResult::interrupt();
}
if (failed(visit(op)))
return WalkResult::interrupt();

return WalkResult::advance();
});
if (walkResult.wasInterrupted())
Expand All @@ -182,50 +181,32 @@ LogicalResult InlineGlobalSlotsAnalysis::initialize(Operation *top) {
}

LogicalResult InlineGlobalSlotsAnalysis::visit(ProgramPoint point) {
if (Value value = dyn_cast<Value>(point)) {
bool isSafe = isValueSafeTransferFunction(value);
auto *state = getOrCreate<InlineGlobalSlotsAnalysisState>(value);
propagateIfChanged(state, state->setSafe(isSafe));

// Handle GlobalSlotGetOp's.
if (auto opResult = dyn_cast<OpResult>(value)) {
if (auto globalSlotGet =
dyn_cast<Torch::GlobalSlotGetOp>(opResult.getOwner())) {
auto *flatSymbolRefPoint = getProgramPoint<FlatSymbolRefProgramPoint>(
globalSlotGet.getSlotAttr());
auto *valueState = getOrCreateFor<InlineGlobalSlotsAnalysisState>(
flatSymbolRefPoint, globalSlotGet.getResult());
auto *globalState =
getOrCreate<InlineGlobalSlotsAnalysisState>(flatSymbolRefPoint);
propagateIfChanged(globalState,
globalState->incorporateSafetyOfUse(valueState));
}
}

return success();
}
if (auto *genericProgramPoint = dyn_cast<GenericProgramPoint *>(point)) {
if (auto *flatSymbolRefPoint =
dyn_cast<FlatSymbolRefProgramPoint>(genericProgramPoint)) {
if (initializeGlobalSlotsOp) {
auto it =
llvm::find(initializeGlobalSlotsOp.getSlotSymNames(),
static_cast<Attribute>(flatSymbolRefPoint->getValue()));
Value value = initializeGlobalSlotsOp->getOperand(std::distance(
initializeGlobalSlotsOp.getSlotSymNames().begin(), it));
auto *flatSymbolRefState =
getOrCreateFor<InlineGlobalSlotsAnalysisState>(value,
flatSymbolRefPoint);
auto *valueState = getOrCreate<InlineGlobalSlotsAnalysisState>(value);
propagateIfChanged(valueState,
valueState->setSafe(flatSymbolRefState->isSafe));
if (auto op = dyn_cast<Operation *>(point)) {
for (auto value : op->getResults()) {
bool isSafe = isValueSafeTransferFunction(value);
auto *state = getOrCreate<InlineGlobalSlotsAnalysisState>(value);
propagateIfChanged(state, state->setSafe(isSafe));

// Handle GlobalSlotGetOp's.
if (auto opResult = dyn_cast<OpResult>(value)) {
if (auto globalSlotGet =
dyn_cast<Torch::GlobalSlotGetOp>(opResult.getOwner())) {
auto globalSlot = SymbolTable::lookupNearestSymbolFrom<GlobalSlotOp>(
globalSlotGet, globalSlotGet.getSlotAttr());
auto *flatSymbolRefPoint =
getLatticeAnchor<FlatSymbolRefLatticeAnchor>(globalSlot);
auto *valueState = getOrCreateFor<InlineGlobalSlotsAnalysisState>(
globalSlot, globalSlotGet.getResult());
auto *globalState =
getOrCreate<InlineGlobalSlotsAnalysisState>(flatSymbolRefPoint);
propagateIfChanged(globalState,
globalState->incorporateSafetyOfUse(valueState));
}
}
return success();
}
}
LLVM_DEBUG(
{ llvm::dbgs() << "visit failing because of: " << point << "\n"; });
return failure();

return success();
}

// This is only a member function to access protected get* functions.
Expand All @@ -241,16 +222,20 @@ bool InlineGlobalSlotsAnalysis::isValueSafeTransferFunction(Value value) {
// safe. This covers, for example, view-like ops that create aliases.
if ((op->hasTrait<Torch::OpTrait::ReadOnly>() || isMemoryEffectFree(op)) &&
llvm::all_of(op->getResults(), [&](Value result) {
auto *state =
getOrCreateFor<InlineGlobalSlotsAnalysisState>(value, result);
auto *state = getOrCreateFor<InlineGlobalSlotsAnalysisState>(
value.getDefiningOp(), result);
return state->isSafe;
}))
continue;
if (auto initialize = dyn_cast<Torch::InitializeGlobalSlotsOp>(op)) {
auto symName = cast<FlatSymbolRefAttr>(
initialize.getSlotSymNames()[use.getOperandNumber()]);
auto globalSlot =
SymbolTable::lookupNearestSymbolFrom<GlobalSlotOp>(op, symName);

auto *state = getOrCreateFor<InlineGlobalSlotsAnalysisState>(
value, getProgramPoint<FlatSymbolRefProgramPoint>(symName));
value.getDefiningOp(),
getLatticeAnchor<FlatSymbolRefLatticeAnchor>(globalSlot));
if (state->isSafe)
continue;
}
Expand Down Expand Up @@ -299,8 +284,7 @@ class InlineGlobalSlotsPass
module->walk([&](Operation *op) {
if (auto globalSlot = dyn_cast<Torch::GlobalSlotOp>(op)) {
auto *state = solver.lookupState<InlineGlobalSlotsAnalysisState>(
solver.getProgramPoint<FlatSymbolRefProgramPoint>(
FlatSymbolRefAttr::get(globalSlot.getSymNameAttr())));
solver.getLatticeAnchor<FlatSymbolRefLatticeAnchor>(globalSlot));
state->print(llvm::dbgs());
llvm::dbgs() << ": "
<< FlatSymbolRefAttr::get(globalSlot.getSymNameAttr())
Expand Down Expand Up @@ -334,13 +318,16 @@ class InlineGlobalSlotsPass
auto slotSymName =
cast<FlatSymbolRefAttr>(initialize.getSlotSymNames()[i]);
Value operand = initialize.getOperand(i);
auto symbolRefPoint = solver.getProgramPoint<FlatSymbolRefProgramPoint>(
cast<FlatSymbolRefAttr>(initialize.getSlotSymNames()[i]));
auto globalSlot = SymbolTable::lookupNearestSymbolFrom<GlobalSlotOp>(
initialize, slotSymName);

auto symbolRefPoint =
solver.getLatticeAnchor<FlatSymbolRefLatticeAnchor>(globalSlot);
auto *state =
solver.lookupState<InlineGlobalSlotsAnalysisState>(symbolRefPoint);
// We roll the analysis of whether a slot is set or public into the
// main dataflow analysis, so we need to check the slot's
// FlatSymbolRefProgramPoint itself to see if it is safe to inline.
// FlatSymbolRefLatticeAnchor itself to see if it is safe to inline.
// For example, a public !torch.int is not safe to inline, even though
// it is a value-semantic type and so the actual initializer value
// itself is conceptually safe to inline.
Expand Down
9 changes: 0 additions & 9 deletions test/Conversion/TorchToStablehlo/linear.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,6 @@ func.func @torch.aten.mm$proj(%arg0: !torch.vtensor<[?,256],f32>) -> !torch.vten
// CHECK: %[[T_5:.*]] = torch.constant.int 1
// CHECK: %[[T_6:.*]] = torch.constant.int 4
// CHECK: %[[T_7:.*]] = torch.constant.int 3
// CHECK: %[[T_8:.*]] = arith.constant 3 : i64
// CHECK: %[[T_9:.*]] = torch.prim.ListConstruct %[[T_4]], %[[T_5]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_10:.*]] = torch.prim.ListConstruct %[[T_6]], %[[T_4]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_11:.*]] = torch.prim.ListConstruct %[[T_7]], %[[T_5]] : (!torch.int, !torch.int) -> !torch.list<int>
Expand Down Expand Up @@ -295,7 +294,6 @@ func.func @torch.aten.convolution(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !
// CHECK: %int2 = torch.constant.int 2
// CHECK: %int1 = torch.constant.int 1
// CHECK: %int4 = torch.constant.int 4
// CHECK: %[[T_3:.*]] = arith.constant 3 : i64
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int4, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_6:.*]] = torch.prim.ListConstruct %int3, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
Expand Down Expand Up @@ -336,7 +334,6 @@ func.func @torch.aten.convolution$bias(%arg0: !torch.vtensor<[?,?,?,?],f32>, %ar
// CHECK: %none = torch.constant.none
// CHECK: %int0 = torch.constant.int 0
// CHECK: %int1 = torch.constant.int 1
// CHECK: %[[T_2:.*]] = arith.constant 1 : i64
// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_5:.*]] = stablehlo.transpose %[[T_1]], dims = [2, 3, 1, 0] : (tensor<2x4x3x3xf32>) -> tensor<3x3x4x2xf32>
Expand Down Expand Up @@ -367,7 +364,6 @@ func.func @torch.aten.convolution$transposed_basic(%arg0: !torch.vtensor<[1,2,7,
// CHECK: %none = torch.constant.none
// CHECK: %int0 = torch.constant.int 0
// CHECK: %int1 = torch.constant.int 1
// CHECK: %[[T_2:.*]] = arith.constant 1 : i64
// CHECK: %int2 = torch.constant.int 2
// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
Expand Down Expand Up @@ -402,7 +398,6 @@ func.func @torch.aten.convolution$transposed_stride(%arg0: !torch.vtensor<[1,2,7
// CHECK: %none = torch.constant.none
// CHECK: %int0 = torch.constant.int 0
// CHECK: %int1 = torch.constant.int 1
// CHECK: %[[T_2:.*]] = arith.constant 1 : i64
// CHECK: %int2 = torch.constant.int 2
// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
Expand Down Expand Up @@ -438,10 +433,6 @@ func.func @torch.aten.convolution$transposed_outputpadding(%arg0: !torch.vtensor
// CHECK: %int0 = torch.constant.int 0
// CHECK: %int1 = torch.constant.int 1
// CHECK: %int2 = torch.constant.int 2
// CHECK: %[[T_2:.*]] = arith.constant 2 : i64
// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_6:.*]] = stablehlo.transpose %[[T_1]], dims = [2, 3, 1, 0] : (tensor<2x2x3x3xf32>) -> tensor<3x3x2x2xf32>
// CHECK: %[[T_7:.*]] = stablehlo.reverse %[[T_6]], dims = [0, 1] : tensor<3x3x2x2xf32>
// CHECK: %c0 = arith.constant 0 : index
Expand Down

0 comments on commit 6934ab8

Please sign in to comment.