Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions include/tc/core/polyhedral/memory_promotion.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,5 +210,15 @@ detail::ScheduleTree* insertCopiesUnder(
const TensorReferenceGroup& group,
isl::id tensorId,
isl::id groupId = isl::id());

detail::ScheduleTree* insertIntraCopiesUnder(
Scop& scop,
detail::ScheduleTree* tree,
const TensorReferenceGroup& group,
const TensorReferenceGroup& outerScopeGroup,
isl::id tensorId,
isl::id groupId,
isl::id outerScopeGroupId);

} // namespace polyhedral
} // namespace tc
24 changes: 24 additions & 0 deletions include/tc/core/polyhedral/scop.h
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,12 @@ struct Scop {
return activePromotions_;
}

std::vector<std::pair<isl::union_set, Scop::PromotionInfo>> activePromotions(
isl::union_set activePoints,
isl::id tensorId) const {
return promotionsAtIndexes(activePromotionsIndexes(activePoints, tensorId));
}

detail::ScheduleTree* scheduleRoot() {
return scheduleTreeUPtr.get();
}
Expand Down Expand Up @@ -379,6 +385,8 @@ struct Scop {
isl::union_map schedule,
bool forceLastExtentOdd = false);

void demoteGroup(isl::id groupId);

// Given a tree node under which the promotion copy statements were
// introduced, insert syncthread statements before and after the copies.
// The tree should match the structure:
Expand Down Expand Up @@ -408,6 +416,22 @@ struct Scop {
isl::schedule_constraints constraints,
const SchedulerOptionsView& schedulerOptions);

// Get the indexes of active promotions in the activePromotions_.
std::vector<size_t> activePromotionsIndexes(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

activePromotionsIndices

isl::union_set domain,
isl::id tensorId) const;
std::vector<std::pair<isl::union_set, Scop::PromotionInfo>>
promotionsAtIndexes(const std::vector<size_t>& indexes) const;

void promoteWithCopyFromGlobal(
isl::union_set activePoints,
PromotedDecl::Kind kind,
isl::id tensorId,
std::unique_ptr<TensorReferenceGroup>&& gr,
detail::ScheduleTree* tree,
isl::union_map schedule,
bool forceLastExtentOdd = false);

public:
// Halide stuff
struct {
Expand Down
87 changes: 73 additions & 14 deletions src/core/polyhedral/memory_promotion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -452,25 +452,20 @@ isl::set tensorElementsSet(const Scop& scop, isl::id tensorId) {
}
} // namespace

ScheduleTree* insertCopiesUnder(
ScheduleTree* insertCopiesUnder_(
Scop& scop,
ScheduleTree* tree,
const TensorReferenceGroup& group,
isl::id tensorId,
isl::id groupId) {
isl::map promotion,
isl::set originalElements,
isl::set readElements,
isl::map exactWrites) {
auto groupId = promotion.get_tuple_id(isl::dim_type::out);
const ScheduleTree* root = scop.scheduleRoot();
auto ctx = root->ctx_;
isl::id readId = isl::id(ctx, std::string(kReadIdName));
isl::id writeId = isl::id(ctx, std::string(kWriteIdName));

// Take the set of all tensor elements.
auto tensorElements = tensorElementsSet(scop, tensorId);

if (groupId.is_null()) {
throw promotion::GroupingError("expected group id");
}
auto promotion =
isl::map(group.promotion()).set_tuple_id(isl::dim_type::out, groupId);
auto promotionSpace = promotion.get_space();

auto identityCopySchedule =
Expand Down Expand Up @@ -500,15 +495,15 @@ ScheduleTree* insertCopiesUnder(
auto approximattedRead =
isl::map(
scheduleUniverse,
group.approximateFootprint().set_tuple_id(arrayId).intersect(
tensorElements))
readElements.set_tuple_id(arrayId).intersect(originalElements))
.wrap();
approximattedRead = isl::map(approximattedRead, promotedFootprint).wrap();
auto readExtension = extension.intersect_range(approximattedRead)
.set_tuple_id(isl::dim_type::out, readId);

auto writtenElements =
isl::map(
group.scopedWrites().intersect_range(tensorElements).wrap(),
exactWrites.intersect_range(originalElements).wrap(),
promotedFootprint)
.wrap();
auto writeExtension = extension.intersect_range(writtenElements)
Expand Down Expand Up @@ -568,5 +563,69 @@ ScheduleTree* insertCopiesUnder(
tree->appendChild(std::move(extensionNode));
return tree;
}

ScheduleTree* insertIntraCopiesUnder(
Scop& scop,
ScheduleTree* tree,
const TensorReferenceGroup& group,
const TensorReferenceGroup& outerScopeGroup,
isl::id tensorId,
isl::id groupId,
isl::id outerScopeGroupId) {
auto innerScopePromotion =
isl::map(group.promotion()).set_tuple_id(isl::dim_type::out, groupId);
auto outerScopePromotion =
isl::map(outerScopeGroup.promotion())
.set_tuple_id(isl::dim_type::out, outerScopeGroupId);

auto outerScopeInDims =
outerScopePromotion.get_space().curry().dim(isl::dim_type::in);
auto innerScopeInDims =
innerScopePromotion.get_space().curry().dim(isl::dim_type::in);
CHECK_GT(innerScopeInDims, outerScopeInDims);
outerScopePromotion =
outerScopePromotion.curry()
.add_dims(isl::dim_type::in, innerScopeInDims - outerScopeInDims)
.uncurry();
auto domainAccessToDomainMap = isl::map(isl::multi_aff::domain_map(
innerScopePromotion.get_space().domain().unwrap()));
outerScopePromotion =
domainAccessToDomainMap.range_product(outerScopePromotion);
innerScopePromotion = innerScopePromotion.apply_domain(outerScopePromotion);

return insertCopiesUnder_(
scop,
tree,
group,
innerScopePromotion,
outerScopeGroup.promotedFootprint().set_tuple_id(outerScopeGroupId),
outerScopeGroup.promotedFootprint().set_tuple_id(outerScopeGroupId),
group.scopedWrites().wrap().apply(outerScopePromotion).unwrap());
}

ScheduleTree* insertCopiesUnder(
Scop& scop,
ScheduleTree* tree,
const TensorReferenceGroup& group,
isl::id tensorId,
isl::id groupId) {
// Take the set of all tensor elements.
auto tensorElements = tensorElementsSet(scop, tensorId);

if (groupId.is_null()) {
throw promotion::GroupingError("expected group id");
}
auto promotion =
isl::map(group.promotion()).set_tuple_id(isl::dim_type::out, groupId);

return insertCopiesUnder_(
scop,
tree,
group,
promotion,
tensorElements,
group.approximateFootprint(),
group.scopedWrites());
}
} // namespace polyhedral
} // namespace tc
13 changes: 10 additions & 3 deletions src/core/polyhedral/memory_promotion_heuristic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,15 @@ void promoteGreedilyAtDepth(
mapCopiesToThreads(mscop, unrollCopies);
}

namespace {
template <typename T>
T projectOutNamedParam(T t, isl::id paramId) {
auto space = t.get_space();
int pos = space.find_dim_by_id(isl::dim_type::param, paramId);
return (pos == -1) ? t : t.project_out(isl::dim_type::param, pos, 1);
}
} // namespace

// Assuming the mapping to threads happens in inverse order, i.e. the innermost
// loop is mapped to thread x, promote below that depth.
void promoteToRegistersBelowThreads(
Expand Down Expand Up @@ -640,9 +649,7 @@ void promoteToRegistersBelowThreads(
if (!hasReuse(*group, fullSched, depth)) {
continue;
}
// TODO: if something is already in shared, but reuse it within one
// thread only, there is no point in keeping it in shared _if_ it
// gets promoted into a register.

scop.promoteGroup(
Scop::PromotedDecl::Kind::Register,
tensorId,
Expand Down
7 changes: 6 additions & 1 deletion src/core/polyhedral/schedule_print.cc
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,12 @@ std::ostream& ScheduleTreeElemDomain::write(std::ostream& os) const {

std::ostream& ScheduleTreeElemExtension::write(std::ostream& os) const {
WS w;
os << w.tab() << "extension(" << extension_ << ")";
os << w.tab() << "extension(";
for (const auto& u : isl::UnionAsVector<isl::union_map>(extension_)) {
WS w2;
os << std::endl << w2.tab() << u;
}
os << ")";
return os;
}

Expand Down
Loading