Skip to content

[XPU] support ignore_index in c_softmax_with_cross_entropy #65149

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,6 @@ template <typename T, typename DeviceContext>
class CSoftmaxWithCrossEntropyOp : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const int64_t ignore_index = ctx.Attr<int64_t>("ignore_index");
if (ignore_index >= 0) {
LOG_FIRST_N(INFO, 1) << "XPU does not support ignore_index in mp.";
}
const int rid = ctx.Attr<int>("ring_id");
auto map = distributed::ProcessGroupMapFromGid::getInstance();
if (map->has(rid)) {
Expand All @@ -57,6 +53,80 @@ class CSoftmaxWithCrossEntropyOp : public framework::OpKernel<T> {
}
};

template <typename T>
void FixLossAccordingToIgnoreIndex(const framework::ExecutionContext& ctx,
const phi::DenseTensor* labels,
const phi::DenseTensor* predicted_logits,
phi::DenseTensor* loss,
const int64_t N,
const int64_t ignore_index) {
auto& dev_ctx = ctx.template device_context<phi::XPUContext>();
using XPUType = typename XPUTypeTrait<T>::Type;
// 先准备一个全0的tensor
phi::DenseTensor zeros_constant =
ctx.AllocateTmpTensor<T, phi::XPUContext>({N, 1}, dev_ctx);
int ret = xpu::constant<XPUType>(
dev_ctx.x_context(),
reinterpret_cast<XPUType*>(zeros_constant.data<T>()),
N,
0.0);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "constant");

// 准备一个bool类型的tensor,用来标记每一个loss要不要刷0
phi::DenseTensor bool_tensor_for_mask_label =
ctx.AllocateTmpTensor<bool, phi::XPUContext>({N, 1}, dev_ctx);
// 准备一个和label同类型的tensor,每个元素都刷成ignore_index
phi::DenseTensor ignore_label_as_tensor;

const auto& label_type = framework::TransToProtoVarType(labels->dtype());
if (label_type == framework::proto::VarType::INT32) {
ignore_label_as_tensor =
ctx.AllocateTmpTensor<int, phi::XPUContext>({N, 1}, dev_ctx);
ret = xpu::constant<int>(dev_ctx.x_context(),
ignore_label_as_tensor.data<int>(),
N,
ignore_index);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "constant");
// 如果label和ignore_index一样,那么把这个bool类型的对应位置刷成1,表示后面要刷成0
// int equal(Context* ctx, const T* x, const T* y, bool* z, int64_t len);
ret = xpu::equal<int>(dev_ctx.x_context(),
ignore_label_as_tensor.data<int>(),
labels->data<int>(),
bool_tensor_for_mask_label.data<bool>(),
N);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "equal");
} else if (label_type == framework::proto::VarType::INT64) {
ignore_label_as_tensor =
ctx.AllocateTmpTensor<int64_t, phi::XPUContext>({N, 1}, dev_ctx);
ret = xpu::constant<int64_t>(dev_ctx.x_context(),
ignore_label_as_tensor.data<int64_t>(),
N,
ignore_index);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "constant");
// 如果label和ignore_index一样,那么把这个bool类型的对应位置刷成1,表示后面要刷成0
// int equal(Context* ctx, const T* x, const T* y, bool* z, int64_t len);
ret = xpu::equal<int64_t>(dev_ctx.x_context(),
ignore_label_as_tensor.data<int64_t>(),
labels->data<int64_t>(),
bool_tensor_for_mask_label.data<bool>(),
N);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "equal");
}
// bool值为1的说明命中了,要刷0,bool为0的要保留
// int select(Context* ctx, const bool* condition, const T* x, const T* y,
// T* z, const std::vector<int64_t>& condition_shape, const
// std::vector<int64_t>& xshape);
ret = xpu::select(
dev_ctx.x_context(),
reinterpret_cast<const bool*>(bool_tensor_for_mask_label.data<bool>()),
reinterpret_cast<const XPUType*>(zeros_constant.data<T>()),
reinterpret_cast<const XPUType*>(loss->data<T>()),
reinterpret_cast<XPUType*>(loss->data<T>()),
common::vectorize(predicted_logits->dims()),
common::vectorize(predicted_logits->dims()));
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "select");
}

template <typename T>
struct CSoftmaxWithCrossEntropyProcessGroupFunctor<phi::XPUContext, T> {
void operator()(const framework::ExecutionContext& ctx) {
Expand All @@ -65,7 +135,7 @@ struct CSoftmaxWithCrossEntropyProcessGroupFunctor<phi::XPUContext, T> {
const phi::DenseTensor* labels = ctx.Input<phi::DenseTensor>("Label");
phi::DenseTensor* softmax = ctx.Output<phi::DenseTensor>("Softmax");
phi::DenseTensor* loss = ctx.Output<phi::DenseTensor>("Loss");

const int64_t ignore_index = ctx.Attr<int64_t>("ignore_index");
const int rid = ctx.Attr<int>("ring_id");
const int nranks = ctx.Attr<int>("nranks");
const int rank = ctx.Attr<int>("rank");
Expand Down Expand Up @@ -165,7 +235,8 @@ struct CSoftmaxWithCrossEntropyProcessGroupFunctor<phi::XPUContext, T> {
end_index,
N,
D,
nranks);
nranks,
ignore_index);
} else if (label_type == framework::proto::VarType::INT64) {
ret = xpu::mask_label_by_index<XPUType, int64_t>(
dev_ctx.x_context(),
Expand All @@ -176,7 +247,8 @@ struct CSoftmaxWithCrossEntropyProcessGroupFunctor<phi::XPUContext, T> {
end_index,
N,
D,
nranks);
nranks,
ignore_index);
}
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "mask_label_by_index");

Expand Down Expand Up @@ -249,6 +321,10 @@ struct CSoftmaxWithCrossEntropyProcessGroupFunctor<phi::XPUContext, T> {
N * 1);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "sub");

// 将label和ignore_index相同的那些loss,置为0
FixLossAccordingToIgnoreIndex<T>(
ctx, labels, &predicted_logits, loss, N, ignore_index);

phi::memory_utils::Copy(ctx.GetPlace(),
softmax->data(),
ctx.GetPlace(),
Expand All @@ -265,6 +341,7 @@ struct CSoftmaxWithCrossEntropyFunctor<phi::XPUContext, T> {
const phi::DenseTensor* labels = ctx.Input<phi::DenseTensor>("Label");
phi::DenseTensor* softmax = ctx.Output<phi::DenseTensor>("Softmax");
phi::DenseTensor* loss = ctx.Output<phi::DenseTensor>("Loss");
const int64_t ignore_index = ctx.Attr<int64_t>("ignore_index");

const int rid = ctx.Attr<int>("ring_id");
const int nranks = ctx.Attr<int>("nranks");
Expand Down Expand Up @@ -407,7 +484,8 @@ struct CSoftmaxWithCrossEntropyFunctor<phi::XPUContext, T> {
end_index,
N,
D,
nranks);
nranks,
ignore_index);
} else if (label_type == framework::proto::VarType::INT64) {
ret = xpu::mask_label_by_index<XPUType, int64_t>(
dev_ctx.x_context(),
Expand All @@ -418,7 +496,8 @@ struct CSoftmaxWithCrossEntropyFunctor<phi::XPUContext, T> {
end_index,
N,
D,
nranks);
nranks,
ignore_index);
}
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "mask_label_by_index");

Expand Down Expand Up @@ -470,7 +549,7 @@ struct CSoftmaxWithCrossEntropyFunctor<phi::XPUContext, T> {
false,
&sum_exp_logits,
f);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "reduce_max");
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "reduce_sum");
}

if (comm_ctx) {
Expand Down Expand Up @@ -514,6 +593,10 @@ struct CSoftmaxWithCrossEntropyFunctor<phi::XPUContext, T> {
N * 1);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "sub");

// 将label和ignore_index相同的那些loss,置为0
FixLossAccordingToIgnoreIndex<T>(
ctx, labels, &predicted_logits, loss, N, ignore_index);

phi::memory_utils::Copy(ctx.GetPlace(),
softmax->data(),
ctx.GetPlace(),
Expand All @@ -535,9 +618,6 @@ class CSoftmaxWithCrossEntropyGrad : public framework::OpKernel<T> {
const phi::DenseTensor* softmax =
context.Input<phi::DenseTensor>("Softmax");
const int64_t ignore_index = context.Attr<int64_t>("ignore_index");
if (ignore_index >= 0) {
LOG_FIRST_N(INFO, 1) << "XPU does not support ignore_index in mp.";
}
const int rank = context.Attr<int>("rank");
auto& dev_ctx = context.template device_context<DeviceContext>();

Expand All @@ -564,7 +644,8 @@ class CSoftmaxWithCrossEntropyGrad : public framework::OpKernel<T> {
start_index,
end_index,
N,
D);
D,
ignore_index);
} else if (label_type == framework::proto::VarType::INT64) {
ret = xpu::mask_label_by_index_grad<XPUType, int64_t>(
dev_ctx.x_context(),
Expand All @@ -574,7 +655,8 @@ class CSoftmaxWithCrossEntropyGrad : public framework::OpKernel<T> {
start_index,
end_index,
N,
D);
D,
ignore_index);
}
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "mask_label_by_index_grad");
}
Expand Down
25 changes: 15 additions & 10 deletions test/xpu/collective_softmax_with_cross_entropy_op_xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import os
import pickle
import sys

import numpy as np
Expand All @@ -23,6 +22,7 @@
from test_collective_base_xpu import (
DataTypeCast,
TestCollectiveRunnerBase,
dump_output,
runtime_main,
)

Expand All @@ -46,7 +46,7 @@ def __init__(self):
self.logits_shape = [self.seq_len, self.local_elements]
self.label_shape = [self.seq_len, 1]

def get_model(self, main_prog, startup_program, rank):
def get_model(self, main_prog, startup_program, rank, ignore_index):
with program_guard(main_prog, startup_program):
logits = data(
name="Logits",
Expand Down Expand Up @@ -86,6 +86,7 @@ def get_model(self, main_prog, startup_program, rank):
'ring_id': self.ring_id,
'rank': rank,
'nranks': self.nranks,
'ignore_index': ignore_index,
},
)
# generate backward op_desc
Expand Down Expand Up @@ -129,13 +130,6 @@ def run_trainer(self, args):
]
self.label_shape = [self.batch_size, self.seq_len, 1]

np_dtype = DataTypeCast(args["dtype"])
loss, softmax = self.get_model(train_prog, startup_prog, rank)
device_id = int(os.getenv("FLAGS_selected_xpus", "0"))
place = paddle.XPUPlace(device_id)
exe = Executor(place)
exe.run(startup_prog)

# NOTE use uid here to assure that two xpus share the same label
np.random.seed(os.getuid())
label = np.random.randint(
Expand All @@ -144,6 +138,17 @@ def run_trainer(self, args):
size=self.label_shape,
dtype='int32',
)
ignore_index = label[0][0]

np_dtype = DataTypeCast(args["dtype"])
loss, softmax = self.get_model(
train_prog, startup_prog, rank, ignore_index
)
device_id = int(os.getenv("FLAGS_selected_xpus", "0"))
place = paddle.XPUPlace(device_id)
exe = Executor(place)
exe.run(startup_prog)

# use FAKE loss_grad here, only to examine the correctness of grad func
loss_grad_fp32 = np.random.uniform(
low=-10.0, high=10.0, size=self.label_shape
Expand All @@ -167,7 +172,7 @@ def run_trainer(self, args):
feed={'Logits': logits, 'Label': label, 'Loss@GRAD': loss_grad},
fetch_list=[loss.name, softmax.name, 'Logits@GRAD'],
)
sys.stdout.buffer.write(pickle.dumps(out))
dump_output(out)


if __name__ == "__main__":
Expand Down
32 changes: 22 additions & 10 deletions test/xpu/test_collective_softmax_with_cross_entropy_xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ def cross_entropy(softmax, label, soft_label, axis, ignore_index=-1):
return result.reshape(label.shape)


def softmax_with_cross_entropy_grad(softmax, label, loss_grad, axis):
def softmax_with_cross_entropy_grad(
softmax, label, loss_grad, axis, ignore_index=-1
):
shape = softmax.shape
axis %= len(shape)
n = int(np.prod(shape[:axis]))
Expand All @@ -69,14 +71,17 @@ def softmax_with_cross_entropy_grad(softmax, label, loss_grad, axis):
for i in range(n * d):
row = int(i / d)
col = i % d
if col == label_2d[row]:
logit_grad_2d[row][col] = (
logit_grad_2d[row][col] - 1.0
) * loss_grad_2d[row]
if label_2d[row] == ignore_index:
logit_grad_2d[row][col] = 0
else:
logit_grad_2d[row][col] = (
logit_grad_2d[row][col] * loss_grad_2d[row]
)
if col == label_2d[row]:
logit_grad_2d[row][col] = (
logit_grad_2d[row][col] - 1.0
) * loss_grad_2d[row]
else:
logit_grad_2d[row][col] = (
logit_grad_2d[row][col] * loss_grad_2d[row]
)
logit_grad = logit_grad_2d.reshape(softmax.shape)
return logit_grad

Expand Down Expand Up @@ -144,6 +149,7 @@ def check_with_place(
label = np.random.randint(
0, self.num_class, size=label_shape, dtype='int32'
)
ignore_index = label[0][0]
loss_grad = np.random.uniform(
low=-10.0, high=10.0, size=label_shape
).astype(np_dtype)
Expand All @@ -165,9 +171,15 @@ def check_with_place(

# calculate analytic result
need_softmax = np.apply_along_axis(stable_softmax, -1, inputs)
need_loss = cross_entropy(need_softmax, label, False, -1)
need_loss = cross_entropy(
need_softmax, label, False, -1, ignore_index
)
need_logits_grad = softmax_with_cross_entropy_grad(
need_softmax, label, loss_grad, axis=-1
need_softmax,
label,
loss_grad,
axis=-1,
ignore_index=ignore_index,
)

# get real result
Expand Down