Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
cd6f059
toIslScheduleTree: swap the order of non-core filters below extension
ftynse Mar 16, 2018
8bc7735
fixThreadsBelowFilter: fix domain used for threadIdxxScheduleDepth
ftynse Mar 16, 2018
ff4c871
add inequality comparison operator for isl::id
ftynse Mar 15, 2018
a1194c8
extract bandsContainingScheduleDepth and bandsSplitAfterDepth
ftynse Mar 8, 2018
d8437a8
TensorReferenceGroup: change printing format
ftynse Mar 9, 2018
1ec2911
do not include filter constraints in original accesses
ftynse Mar 13, 2018
d2e2460
memory promotion: ignore tensors with no rectangular approximation
ftynse Mar 13, 2018
05e56c5
fullSchedule: handle innermost zero-dimensional bands
ftynse Mar 13, 2018
778f274
bump isl for isl_multi_union_pw_aff_intersect_domain
ftynse Mar 16, 2018
0b3379d
basic heuristic for register promotion
ftynse Mar 9, 2018
c133cb5
Call register promotion from MappedScop
ftynse Mar 13, 2018
feed0a4
add primitive test for register promotion not crashing
ftynse Mar 9, 2018
822494f
introduce a Kind of a promoted declaration
ftynse Mar 13, 2018
720d077
emit declarations for register promotion
ftynse Mar 13, 2018
fbbfcb0
promoteGroup: disallow double promotion
ftynse Mar 13, 2018
bd2d0cd
promote TensorReferenceGroups to registers
ftynse Mar 13, 2018
dfe48ca
use domain sets instead of statement ids in memory promotion
ftynse Mar 14, 2018
49200ae
private memory: don't use statement ids
ftynse Mar 14, 2018
f736ddb
drop activeStatements
ftynse Mar 14, 2018
94dc353
move getParamValIfFixed to islpp.h
ftynse Mar 16, 2018
3259e76
syntactic tests for register promotion
ftynse Mar 16, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions include/tc/core/polyhedral/codegen_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,18 @@ struct CodegenStatementContext : CodegenContext {
isl::id statementId() const {
return this->iteratorMaps.at(astNodeId).get_tuple_id(isl::dim_type::out);
}
isl::set domain() const {
return isl::map::from(this->iteratorMaps.at(astNodeId)).range();
}
std::vector<Scop::PromotionInfo> activePromotions() const {
auto stmtId = statementId();
const auto& promotions = this->scop().activePromotions();
if (promotions.count(stmtId) == 0) {
return {};
std::vector<Scop::PromotionInfo> result;
auto dom = isl::union_set(this->domain());
for (const auto& kvp : this->scop().activePromotions()) {
if (!kvp.first.intersect(dom).is_empty()) {
result.emplace_back(kvp.second);
}
}
return promotions.at(stmtId);
return result;
}

isl::id astNodeId;
Expand Down
4 changes: 2 additions & 2 deletions include/tc/core/polyhedral/memory_promotion.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,9 @@ inline std::ostream& operator<<(std::ostream& os, const TensorReference& tr) {
inline std::ostream& operator<<(
std::ostream& os,
const TensorReferenceGroup& tg) {
os << " with footprint BB: " << tg.approximation << " ";
os << "Reference with footprint: " << tg.approximation << "\n";
for (const auto& tr : tg.references) {
os << *tr << " ";
os << *tr << "\n";
}
return os;
}
Expand Down
6 changes: 6 additions & 0 deletions include/tc/core/polyhedral/memory_promotion_heuristic.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ using ThreadIdxxScheduleDepthState =
std::vector<std::pair<isl::union_set, size_t>>;

class MappedScop;
class Scop;

// In the given mapped scop "mscop",
// promote to shared memory at "depth" until "sharedMemorySize" is used.
Expand All @@ -40,5 +41,10 @@ void promoteGreedilyAtDepth(
std::size_t depth,
std::size_t sharedMemorySize,
bool unrollCopies);

void promoteToRegistersBelowThreads(
Scop& scop,
const ThreadIdxxScheduleDepthState& threadIdxxScheduleDepthState,
std::size_t nRegisters);
} // namespace polyhedral
} // namespace tc
7 changes: 0 additions & 7 deletions include/tc/core/polyhedral/schedule_transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -284,13 +284,6 @@ isl::union_set activeDomainPoints(
const detail::ScheduleTree* root,
const detail::ScheduleTree* node);

// Get the set of statement identifiers whose domains have at least one active
// point at the given node, i.e. the statements that were not filtered away on
// the path from root to node.
std::unordered_set<isl::id, isl::IslIdIslHash> activeStatements(
const detail::ScheduleTree* root,
const detail::ScheduleTree* node);

////////////////////////////////////////////////////////////////////////////////
// Experimental
////////////////////////////////////////////////////////////////////////////////
Expand Down
22 changes: 13 additions & 9 deletions include/tc/core/polyhedral/scop.h
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,11 @@ struct Scop {
void promoteEverythingAt(std::vector<size_t> pos);

struct PromotedDecl {
enum class Kind { SharedMem, Register };

isl::id tensorId;
std::vector<size_t> sizes;
Kind kind;
};

struct PromotionInfo {
Expand All @@ -321,9 +324,8 @@ struct Scop {
return promotedDecls_;
}

const std::
unordered_map<isl::id, std::vector<PromotionInfo>, isl::IslIdIslHash>&
activePromotions() const {
const std::vector<std::pair<isl::union_set, PromotionInfo>>&
activePromotions() const {
return activePromotions_;
}

Expand Down Expand Up @@ -356,7 +358,8 @@ struct Scop {
// Assumes such argument exists.
const Halide::OutputImageParam& findArgument(isl::id id) const;

// Promote a tensor reference group to shared memory, inserting the copy
// Promote a tensor reference group to a storage of a given "kind",
// inserting the copy
// statements below the given node. Inserts an Extension node below the give
// node, unless there is already another Extension node which introduces
// copies. The Extension node has a unique Sequence child, whose children
Expand All @@ -368,11 +371,11 @@ struct Scop {
// If "forceLastExtentOdd" is set, the last extent in the declaration is
// incremented if it is even. This serves as a simple heuristic to reduce
// shared memory bank conflicts.
void promoteGroupToShared(
void promoteGroup(
PromotedDecl::Kind kind,
isl::id tensorId,
std::unique_ptr<TensorReferenceGroup>&& gr,
detail::ScheduleTree* tree,
const std::unordered_set<isl::id, isl::IslIdIslHash>& activeStmts,
isl::union_map schedule,
bool forceLastExtentOdd = false);

Expand Down Expand Up @@ -463,9 +466,10 @@ struct Scop {
std::unordered_map<isl::id, size_t, isl::IslIdIslHash> groupCounts_;
// groupId -> (tensorId, groupSizes)
std::unordered_map<isl::id, PromotedDecl, isl::IslIdIslHash> promotedDecls_;
// stmtId -> (group, partial schedule, groupId)
std::unordered_map<isl::id, std::vector<PromotionInfo>, isl::IslIdIslHash>
activePromotions_;
// (domain, group, partial schedule, groupId)
// Note that domain is a non-unique key, i.e. multiple groups can be listed
// for the same domain, or for partially intersecting domains.
std::vector<std::pair<isl::union_set, PromotionInfo>> activePromotions_;
};

std::ostream& operator<<(std::ostream& os, const Scop&);
Expand Down
19 changes: 19 additions & 0 deletions include/tc/external/detail/islpp.h
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,10 @@ inline bool operator==(const isl::id& id1, const isl::id& id2) {
return id1.get() == id2.get();
}

inline bool operator!=(const isl::id& id1, const isl::id& id2) {
return id1.get() != id2.get();
}

///////////////////////////////////////////////////////////////////////////////
// Helper functions
///////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -399,6 +403,21 @@ auto end(L& list) -> ListIter<decltype(list.get(0)), L> {
using detail::begin;
using detail::end;

template <typename T>
isl::val getParamValIfFixed(T t, int pos) {
auto val = isl::val::nan(t.get_ctx());
for (auto set : isl::UnionAsVector<T>(t)) {
auto currentVal = set.plain_get_val_if_fixed(isl::dim_type::param, pos);
if (currentVal.is_nan()) {
return currentVal;
}
if (!val.is_nan() && val != currentVal) {
return isl::val::nan(t.get_ctx());
}
val = currentVal;
}
return val;
}
} // namespace isl

namespace isl {
Expand Down
22 changes: 7 additions & 15 deletions src/core/polyhedral/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -503,17 +503,6 @@ isl::space findDomainSpaceById(const CodegenStatementContext& context) {
return isl::space();
}

isl::map findScheduleByStmtId(isl::union_map schedule, isl::id stmtId) {
for (auto s : isl::UnionAsVector<isl::union_map>(schedule)) {
if (s.get_tuple_id(isl::dim_type::in) == stmtId) {
return s;
}
}
CHECK(false) << "could not find schedule for " << stmtId << " in "
<< schedule;
return isl::map();
}

isl::multi_aff makeMultiAffAccess(
isl::id tensorId,
const std::vector<Halide::Expr>& subscripts,
Expand Down Expand Up @@ -633,9 +622,9 @@ void emitMappedTensorAccess(
auto promotion = promotionInfo.group->promotion(); // MA :: [S -> O] -> P
promotion = promotion.set_tuple_id(isl::dim_type::out, promotionInfo.groupId);
auto iteratorMap = context.iteratorMap(); // PMA :: A -> D
auto schedule = findScheduleByStmtId(
promotionInfo.outerSchedule,
context.statementId()); // map :: D -> S
auto schedule =
isl::map::from_union_map(promotionInfo.outerSchedule.intersect_domain(
context.domain())); // map :: D -> S

CHECK(schedule.is_single_valued())
<< "expected single-valued schedule, got " << schedule;
Expand Down Expand Up @@ -707,7 +696,10 @@ void emitPromotedArrayViewsHalide(stringstream& ss, const Scop& scop) {
t = i.type();
}
}
ss << "__shared__ " << t << " " << viewName;
if (p.second.kind == Scop::PromotedDecl::Kind::SharedMem) {
ss << "__shared__ ";
}
ss << t << " " << viewName;
for (auto s : p.second.sizes) {
ss << "[" << s << "]";
}
Expand Down
15 changes: 12 additions & 3 deletions src/core/polyhedral/mapped_scop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,11 @@ void fixThreadsBelowFilter(

for (size_t i = begin; i < end; ++i) {
if (mapping::ThreadId::makeId(i) == mapping::ThreadId::x()) {
// Mapping happend below filterTree, so we need points active for its
// children. After insertion, filterTree is guaranteed to have at least
// one child.
mscop.threadIdxxScheduleDepthState.emplace_back(std::make_pair(
activeDomainPoints(mscop.schedule(), filterTree),
activeDomainPoints(mscop.schedule(), filterTree->child({0})),
filterTree->scheduleDepth(mscop.schedule())));
}
}
Expand Down Expand Up @@ -686,10 +689,16 @@ std::unique_ptr<MappedScop> MappedScop::makeWithOuterBlockInnerThreadStrategy(
}
}

// 8. Insert mapping context
// 8. Promote to registers below the loops mapped to threads.
if (options.proto.use_private_memory()) {
promoteToRegistersBelowThreads(
mappedScop->scop(), mappedScop->threadIdxxScheduleDepthState, -1ull);
}

// 9. Insert mapping context
mappedScop->insertMappingContext();

// 9. Optionally insert reduction synchronizations
// 10. Optionally insert reduction synchronizations
for (auto bandUpdate : mappedScop->reductionBandUpdates_) {
for (auto updateId : bandUpdate.second.ids) {
scop->insertReductionSync1D(
Expand Down
15 changes: 13 additions & 2 deletions src/core/polyhedral/memory_promotion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -321,10 +321,21 @@ void addSingletonReferenceGroups(
// access relations have a shape :: [D -> ref] -> O
// use currying to isolate the D part before intersecting with the domain
// Compute initial groups with single reference per group.
accesses = accesses.curry().intersect_domain(domain).uncurry();
std::unordered_set<isl::id, isl::IslIdIslHash> unapproximatable;
for (auto a : isl::UnionAsVector<isl::union_map>(accesses)) {
if (isl::union_map(a.curry()).intersect_domain(domain).is_empty()) {
continue;
}

auto tensorId = a.get_tuple_id(isl::dim_type::out);
addSingletonReferenceGroup(tensorGroups, tensorId, schedule, a, type);
if (unapproximatable.count(tensorId) != 0) {
continue;
}
try {
addSingletonReferenceGroup(tensorGroups, tensorId, schedule, a, type);
} catch (const promotion::GroupingError& err) {
unapproximatable.insert(tensorId);
}
}
}
} // namespace
Expand Down
Loading