Skip to content

Commit

Permalink
[PYTHON] Added atomic_add (triton-lang#94)
Browse files Browse the repository at this point in the history
  • Loading branch information
ptillet committed Jul 27, 2021
1 parent d7f8792 commit 2b75158
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 2 deletions.
1 change: 1 addition & 0 deletions include/triton/ir/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ struct dispatch{
static ir::value *store(ir::value* ptr, ir::value *value, ir::value *mask, ir::builder *builder);
static ir::value *atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder);
static ir::value *atomic_xchg(ir::value* ptr, ir::value *val, ir::builder *builder);
static ir::value *atomic_add(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);

// linear algebra
static ir::value *dot(ir::value *lhs, ir::value *rhs, ir::builder *builder);
Expand Down
9 changes: 9 additions & 0 deletions lib/ir/dispatch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,15 @@ ir::value *dispatch::atomic_xchg(ir::value* ptr, ir::value *val, ir::builder *bu
return builder->create_atomic_exch(ptr, val);
}

ir::value *dispatch::atomic_add(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
if(!mask){
mask = builder->get_int1(true);
if(ptr->get_type()->is_block_ty())
mask = builder->create_splat(mask, ptr->get_type()->get_block_shapes());
}
return builder->create_atomic_add(ptr, val, mask);
}

//===----------------------------------------------------------------------===//
// Linear Algebra
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions python/src/triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ void init_triton_frontend(py::module &&m) {
m.def("store", &ir::dispatch::store, ret::reference);
m.def("atomic_cas", &ir::dispatch::atomic_cas, ret::reference);
m.def("atomic_xchg", &ir::dispatch::atomic_xchg, ret::reference);
m.def("atomic_add", &ir::dispatch::atomic_add, ret::reference);
// linear algebra
m.def("dot", &ir::dispatch::dot, ret::reference);
// indexing
Expand Down
8 changes: 6 additions & 2 deletions python/triton/code_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,13 @@ def visit_Compare(self, node):
ast.Is: '__eq__',
ast.IsNot: '__ne__',
}[type(node.ops[0])]
if self.is_triton_object(lhs) or self.is_triton_object(rhs):
if self.is_triton_object(lhs):
return getattr(lhs, fn)(rhs, builder=self.builder)
return getattr(lhs, fn)(rhs)
elif self.is_triton_object(rhs):
fn = fn[:2] + 'r' + fn[2:]
return getattr(rhs, fn)(lhs, builder=self.builder)
else:
return getattr(lhs, fn)(rhs)

def visit_UnaryOp(self, node):
op = self.visit(node.operand)
Expand Down
48 changes: 48 additions & 0 deletions python/triton/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ def __radd__(self, other, builder=None):
def __sub__(self, other, builder=None):
return frontend.sub(self, other, builder)

def __rsub__(self, other, builder=None):
return frontend.sub(other, self, builder)

@builtin
def __mul__(self, other, builder=None):
return frontend.mul(self, other, builder)
Expand Down Expand Up @@ -183,22 +186,42 @@ def __rshift__(self, other, builder=None):

# comparison operators

# >
@builtin
def __gt__(self, other, builder=None):
return frontend.greater_than(self, other, builder)

@builtin
def __rgt__(self, other, builder=None):
return frontend.greater_than(other, self, builder)

# >=
@builtin
def __ge__(self, other, builder=None):
return frontend.greater_equal(self, other, builder)

def __rge__(self, other, builder=None):
return frontend.greater_equal(other, self, builder)

# <
@builtin
def __lt__(self, other, builder=None):
return frontend.less_than(self, other, builder)

@builtin
def __rlt__(self, other, builder=None):
return frontend.less_than(other, self, builder)

# <=
@builtin
def __le__(self, other, builder=None):
return frontend.less_equal(self, other, builder)

@builtin
def __rle__(self, other, builder=None):
return frontend.less_equal(other, self, builder)

# ==
@builtin
def __eq__(self, other, builder=None):
return frontend.equal(self, other, builder)
Expand Down Expand Up @@ -421,6 +444,20 @@ def atomic_xchg(pointer, val, builder=None):
return frontend.atomic_xchg(pointer, val, builder)


@builtin
def atomic_add(pointer, val, mask=None, builder=None):
"""
Performs an atomic add and the memory locations specified by :code:`pointer`.
:param pointer: The memory locations which contain the old values
:type pointer: Block of dtype=triton.PointerDType
:param val: The values to add
:type val: Block of dtype=`pointer.dtype.element_ty`
:param mask: If mask[idx] is false, :code:`pointer[idx]` is unaffected.
:type mask: Block of triton.int1, optional
"""
return frontend.atomic_add(pointer, val, mask, builder)


# -----------------------
# Conditioning
# -----------------------
Expand Down Expand Up @@ -475,6 +512,17 @@ def log(x, builder=None):
return frontend.log(x, builder)


@builtin
def sqrt(x, builder=None):
"""
Computes the element-wise square root of :code:`x`
:param x: the input values
:type x: Block
"""
return frontend.sqrt(x, builder)


# -----------------------
# Reductions
# -----------------------
Expand Down

0 comments on commit 2b75158

Please sign in to comment.