Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 33 additions & 26 deletions src/op/atomic_add.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
* Define elment-wise operators.
*/

#include "atomic_add.h"

#include "./atomic_add.h"
#include "./region.h"
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
Expand Down Expand Up @@ -34,25 +34,35 @@ static int GetArchInt(Target target) {
return arch_int;
}

AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) : args_(args) {
AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<AtomicAddNode> node = make_object<AtomicAddNode>();
Array<Range> rgs[2];
Buffer bf[2];
for (int i = 0; i < 2; i++) {
auto expr = args[i];
auto call = expr.as<CallNode>();
ICHECK(call);
auto region = RegionOp(call->args, vmap);
rgs[i] = region.GetRanges();
bf[i] = region.GetBuffer();
rgs[i] = region->GetRanges();
bf[i] = region->GetBuffer();
}
std::tie(this->src, this->dst) = std::tie(bf[0], bf[1]);
std::tie(this->src_range, this->dst_range) = std::tie(rgs[0], rgs[1]);
std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]);
std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]);
if (args.size() >= 3) {
coalesced_width = Downcast<IntImm>(args[2]);
node->coalesced_width = Downcast<IntImm>(args[2]);
}
data_ = std::move(node);
}

Array<IterVar> AtomicAdd::MakeIterVars() const {
TileOperator AtomicAddNode::Clone() const {
auto op = make_object<AtomicAddNode>(*this);
if (par_op_.defined()) {
op->par_op_ = Downcast<ParallelOp>(par_op_->Clone());
}
return AtomicAdd(op);
}

Array<IterVar> AtomicAddNode::MakeIterVars() const {
Array<IterVar> loop_vars;
size_t idx = 0;
for (size_t i = 0; i < src_range.size(); i++) {
Expand All @@ -68,8 +78,8 @@ Array<IterVar> AtomicAdd::MakeIterVars() const {

// ivs: itervars returned by MakeIterVars()
// src_dst: 0 for src_indices, 1 for dst_indices
Array<PrimExpr> AtomicAdd::MakeIndices(const Array<IterVar> &ivs,
int src_dst) const {
Array<PrimExpr> AtomicAddNode::MakeIndices(const Array<IterVar> &ivs,
int src_dst) const {
Array<PrimExpr> indices;
Array<Range> ranges = src_dst == 0 ? src_range : dst_range;
size_t idx = 0;
Expand All @@ -87,9 +97,10 @@ Array<PrimExpr> AtomicAdd::MakeIndices(const Array<IterVar> &ivs,
return indices;
}

PrimExpr AtomicAdd::MakePredicate(arith::Analyzer *analyzer,
const Array<IterVar> &ivs,
Array<PrimExpr> extents, int src_dst) const {
PrimExpr AtomicAddNode::MakePredicate(arith::Analyzer *analyzer,
const Array<IterVar> &ivs,
Array<PrimExpr> extents,
int src_dst) const {
Array<Range> ranges = src_dst == 0 ? src_range : dst_range;
Array<PrimExpr> cond_list;
ICHECK(extents.size() == ranges.size()) << extents << " " << ranges;
Expand Down Expand Up @@ -117,7 +128,7 @@ PrimExpr AtomicAdd::MakePredicate(arith::Analyzer *analyzer,
}
}

For AtomicAdd::MakeSIMTLoop(arith::Analyzer *analyzer) const {
For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
Array<IterVar> loop_vars = MakeIterVars();
bool is_scalar = loop_vars.size() == 0;
if (is_scalar) {
Expand Down Expand Up @@ -180,16 +191,16 @@ For AtomicAdd::MakeSIMTLoop(arith::Analyzer *analyzer) const {
return Downcast<For>(body);
}

Stmt AtomicAdd::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
Target target = T.target;
auto simt_loop = MakeSIMTLoop(analyzer);
auto fused_loop = Downcast<For>(ParallelLoopFuser::Fuse(simt_loop));
auto par_op = std::make_unique<ParallelOp>(fused_loop);
auto par_op = ParallelOp(fused_loop);

std::vector<InferLevel> levels = {InferLevel::kCommon, InferLevel::kStrict,
InferLevel::kFree};
for (auto level : levels) {
par_op->InferLayout(
(par_op)->InferLayout(
{T.target, T.thread_bounds, T.layout_map, T.buffer_remap}, level);
}
auto loop_layout = par_op->GetLoopLayout();
Expand All @@ -210,10 +221,11 @@ Stmt AtomicAdd::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
return vectorized_thread_loop;
}

LayoutMap AtomicAdd::InferLayout(const LayoutInferArgs &T, InferLevel level) {
if (par_op_ == nullptr) {
LayoutMap AtomicAddNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const {
if (!par_op_.defined()) {
arith::Analyzer analyzer;
par_op_ = std::make_unique<ParallelOp>(MakeSIMTLoop(&analyzer));
par_op_ = ParallelOp(MakeSIMTLoop(&analyzer));
}
if (T.layout_map.count(src) && T.layout_map.count(dst)) {
if (src.scope() == "local.fragment" && dst.scope() == "local.fragment") {
Expand All @@ -236,10 +248,5 @@ TIR_REGISTER_TL_OP(AtomicAdd, atomicadd)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

// TVM_REGISTER_OP("tl.atomicadd")
// .set_num_inputs(2)
// .add_argument("ref", "Buffer", "The destination buffer")
// .add_argument("val", "Expr", "The value to be added atomically");

} // namespace tl
} // namespace tvm
46 changes: 21 additions & 25 deletions src/op/atomic_add.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,34 +7,31 @@
#ifndef TVM_TL_OP_ATOMIC_ADD_H_
#define TVM_TL_OP_ATOMIC_ADD_H_

#include "op.h"
#include "operator.h"
#include "parallel.h"

namespace tvm {
namespace tl {

using namespace tir;

class AtomicAdd : public Operator {
class AtomicAddNode : public TileOperatorNode {
public:
AtomicAdd(Array<PrimExpr> args, BufferMap vmap);
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final;
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final;
Array<PrimExpr> args_;

static const Op &Get();
Buffer src, dst;
Array<Range> src_range, dst_range;
IntImm coalesced_width;

AtomicAdd(const AtomicAdd &other)
: args_(other.args_), src(other.src), dst(other.dst),
src_range(other.src_range), dst_range(other.dst_range),
coalesced_width(other.coalesced_width) {
// No clone nullptr
if (other.par_op_)
par_op_ = std::unique_ptr<ParallelOp>(
static_cast<ParallelOp *>(other.par_op_->Clone().release()));
}
std::unique_ptr<Operator> Clone() const final {
return std::make_unique<AtomicAdd>(*this);
}
mutable ParallelOp par_op_;
static constexpr const char *_type_key = "tl.AtomicAdd";
TVM_DECLARE_FINAL_OBJECT_INFO(AtomicAddNode, TileOperatorNode);

Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const;
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const;

static const Op &Get();
TileOperator Clone() const;

protected:
For MakeSIMTLoop(arith::Analyzer *analyzer) const;
Expand All @@ -46,14 +43,13 @@ class AtomicAdd : public Operator {

PrimExpr MakePredicate(arith::Analyzer *analyzer, const Array<IterVar> &ivs,
Array<PrimExpr> extents, int src_dst) const;
};

Array<PrimExpr> args_;

Buffer src, dst;
Array<Range> src_range, dst_range;
IntImm coalesced_width;

std::unique_ptr<ParallelOp> par_op_;
class AtomicAdd : public TileOperator {
public:
TVM_DEFINE_OBJECT_REF_METHODS(AtomicAdd, TileOperator, AtomicAddNode);
TVM_DLL AtomicAdd(Array<PrimExpr> args, BufferMap vmap);
static const Op &Get();
};

} // namespace tl
Expand Down
2 changes: 1 addition & 1 deletion src/op/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#ifndef TVM_TL_OP_BUILTIN_H_
#define TVM_TL_OP_BUILTIN_H_

#include "op.h"
#include "operator.h"
#include <tvm/ir/transform.h>

namespace tvm {
Expand Down
Loading