Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions paddle/fluid/operators/tril_triu_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ class TrilTriuGradOpMaker : public framework::SingleGradOpMaker<T> {
} // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OPERATOR(tril_triu, ops::TrilTriuOp, ops::TrilTriuOpMaker,
ops::TrilTriuGradOpMaker<paddle::framework::OpDesc>,
ops::TrilTriuGradOpMaker<paddle::imperative::OpBase>);
Expand All @@ -107,10 +108,13 @@ REGISTER_OP_CPU_KERNEL(
tril_triu, ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, int>,
ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, int64_t>);
ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, plat::float16>);
REGISTER_OP_CPU_KERNEL(
tril_triu_grad,
ops::TrilTriuGradOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::TrilTriuGradOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::TrilTriuGradOpKernel<paddle::platform::CPUDeviceContext, int>,
ops::TrilTriuGradOpKernel<paddle::platform::CPUDeviceContext, int64_t>);
ops::TrilTriuGradOpKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::TrilTriuGradOpKernel<paddle::platform::CPUDeviceContext,
plat::float16>);
8 changes: 6 additions & 2 deletions paddle/fluid/operators/tril_triu_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,20 @@ limitations under the License. */
#include "paddle/fluid/operators/tril_triu_op.h"

namespace ops = paddle::operators;
namespace plat = paddle::platform;

REGISTER_OP_CUDA_KERNEL(
tril_triu,
ops::TrilTriuOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::TrilTriuOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::TrilTriuOpKernel<paddle::platform::CUDADeviceContext, int>,
ops::TrilTriuOpKernel<paddle::platform::CUDADeviceContext, int64_t>);
ops::TrilTriuOpKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::TrilTriuOpKernel<paddle::platform::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(
tril_triu_grad,
ops::TrilTriuGradOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::TrilTriuGradOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::TrilTriuGradOpKernel<paddle::platform::CUDADeviceContext, int>,
ops::TrilTriuGradOpKernel<paddle::platform::CUDADeviceContext, int64_t>);
ops::TrilTriuGradOpKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::TrilTriuGradOpKernel<paddle::platform::CUDADeviceContext,
plat::float16>);
1 change: 1 addition & 0 deletions paddle/fluid/operators/tril_triu_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License. */
#pragma once

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/for_range.h"

namespace paddle {
Expand Down
83 changes: 54 additions & 29 deletions python/paddle/fluid/tests/unittests/test_tril_triu_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
import unittest
import numpy as np
from op_test import OpTest
import paddle
import paddle.fluid as fluid
import paddle.tensor as tensor
from paddle.fluid.framework import Program, program_guard


class TrilTriuOpDefaultTest(OpTest):
Expand Down Expand Up @@ -68,13 +70,17 @@ def case_generator(op_type, Xshape, diagonal, expected):

class FailureCase(unittest.TestCase):
def test_failure(self):
paddle.enable_static()

data = fluid.data(shape=Xshape, dtype='float64', name=cls_name)
with self.assertRaisesRegexp(
eval(expected.split(':')[-1]), errmsg[expected]):
getattr(tensor, op_type)(x=data, diagonal=diagonal)

class SuccessCase(TrilTriuOpDefaultTest):
def initTestCase(self):
paddle.enable_static()

self.real_op_type = op_type
self.diagonal = diagonal
self.X = np.random.random(Xshape).astype("float64")
Expand Down Expand Up @@ -120,39 +126,58 @@ class TestTrilTriuOpAPI(unittest.TestCase):
"""

def test_api(self):
data = np.random.random([1, 9, 9, 4]).astype('float32')
x = fluid.data(shape=[1, 9, -1, 4], dtype='float32', name='x')
tril_out, triu_out = tensor.tril(x), tensor.triu(x)

place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
tril_out, triu_out = exe.run(
fluid.default_main_program(),
feed={"x": data},
fetch_list=[tril_out, triu_out], )
self.assertTrue(np.allclose(tril_out, np.tril(data)))
self.assertTrue(np.allclose(triu_out, np.triu(data)))
paddle.enable_static()

dtypes = ['float16', 'float32']
for dtype in dtypes:
prog = Program()
startup_prog = Program()
with program_guard(prog, startup_prog):
data = np.random.random([1, 9, 9, 4]).astype(dtype)
x = fluid.data(shape=[1, 9, -1, 4], dtype=dtype, name='x')
tril_out, triu_out = tensor.tril(x), tensor.triu(x)

place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
tril_out, triu_out = exe.run(
fluid.default_main_program(),
feed={"x": data},
fetch_list=[tril_out, triu_out], )
self.assertTrue(np.allclose(tril_out, np.tril(data)))
self.assertTrue(np.allclose(triu_out, np.triu(data)))

def test_api_with_dygraph(self):
with fluid.dygraph.guard():
data = np.random.random([1, 9, 9, 4]).astype('float32')
x = fluid.dygraph.to_variable(data)
tril_out, triu_out = tensor.tril(x).numpy(), tensor.triu(x).numpy()
self.assertTrue(np.allclose(tril_out, np.tril(data)))
self.assertTrue(np.allclose(triu_out, np.triu(data)))
paddle.disable_static()

dtypes = ['float16', 'float32']
for dtype in dtypes:
with fluid.dygraph.guard():
data = np.random.random([1, 9, 9, 4]).astype(dtype)
x = fluid.dygraph.to_variable(data)
tril_out, triu_out = tensor.tril(x).numpy(), tensor.triu(
x).numpy()
self.assertTrue(np.allclose(tril_out, np.tril(data)))
self.assertTrue(np.allclose(triu_out, np.triu(data)))

def test_fluid_api(self):
data = np.random.random([1, 9, 9, 4]).astype('float32')
x = fluid.data(shape=[1, 9, -1, 4], dtype='float32', name='x')
triu_out = fluid.layers.triu(x)

place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
triu_out = exe.run(fluid.default_main_program(),
feed={"x": data},
fetch_list=[triu_out])
paddle.enable_static()

dtypes = ['float16', 'float32']
for dtype in dtypes:
prog = Program()
startup_prog = Program()
with program_guard(prog, startup_prog):
data = np.random.random([1, 9, 9, 4]).astype(dtype)
x = fluid.data(shape=[1, 9, -1, 4], dtype=dtype, name='x')
triu_out = fluid.layers.triu(x)

place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
triu_out = exe.run(fluid.default_main_program(),
feed={"x": data},
fetch_list=[triu_out])


if __name__ == '__main__':
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/tensor/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,8 +555,8 @@ def _tril_triu_op(helper):
x = helper.kwargs.get('x', None)

assert x is not None, 'x cannot be None in {}'.format(op_type)
check_variable_and_dtype(x, 'x', ['float32', 'float64', 'int32', 'int64'],
op_type)
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'], op_type)
if len(x.shape) < 2:
raise ValueError("x shape in {} must be at least 2-D".format(op_type))
diagonal = helper.kwargs.get('diagonal', 0)
Expand Down