Skip to content

【Inference】support rl #10567

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: develop
Choose a base branch
from
18 changes: 8 additions & 10 deletions csrc/gpu/cpp_extensions.cu
Original file line number Diff line number Diff line change
Expand Up @@ -236,13 +236,11 @@ std::vector<paddle::Tensor> GetPaddingOffsetV2(const paddle::Tensor& input_ids,
const paddle::optional<paddle::Tensor>& draft_tokens,
const paddle::optional<paddle::Tensor>& seq_lens_encoder);

void SaveOutMmsg(const paddle::Tensor& x,
void SaveOutMmsgStatic(const paddle::Tensor& x,
const paddle::Tensor& not_need_stop, // cpu
const paddle::Tensor& msg_queue_id, // cpu
int64_t rank_id);

void GetOutput(const paddle::Tensor& x,
const paddle::Tensor& msg_queue_id, // cpu
void GetOutputStatic(const paddle::Tensor& x,
int64_t rank_id,
bool wait_flag);

Expand Down Expand Up @@ -301,8 +299,8 @@ PYBIND11_MODULE(paddlenlp_ops, m) {
m.def("f_per_token_group_quant", &PerTokenGroupQuant, "PerTokenGroupQuant");
m.def("f_per_tensor_quant_fp8", &PerTensorQuantFp8, "PerTensorQuantFp8");
m.def("f_get_padding_offset_v2", &GetPaddingOffsetV2, "GetPaddingOffsetV2");
m.def("f_save_output", &SaveOutMmsg, "SaveOutMmsg");
m.def("f_get_output", &GetOutput, "GetOutput");
m.def("f_save_output", &SaveOutMmsgStatic, "SaveOutMmsgStatic");
m.def("f_get_output", &GetOutputStatic, "GetOutputStatic");
m.def("f_step_paddle", &StepPaddle, "StepPaddle");
m.def("f_save_output_dygraph", &SaveOutputDygraph, "SaveOutputDygraph");
// m.def("f_cutlass_fp8_fp8_half_block_gemm_fused", &cutlass_fp8_fp8_half_block_gemm_fused_func, "cutlass_fp8_fp8_half_block_gemm_fused_func");
Expand Down Expand Up @@ -331,8 +329,8 @@ PYBIND11_MODULE(paddlenlp_ops_80, m) {
m.def("f_per_token_group_quant", &PerTokenGroupQuant, "PerTokenGroupQuant");
m.def("f_per_tensor_quant_fp8", &PerTensorQuantFp8, "PerTensorQuantFp8");
m.def("f_get_padding_offset_v2", &GetPaddingOffsetV2, "GetPaddingOffsetV2");
m.def("f_save_output", &SaveOutMmsg, "SaveOutMmsg");
m.def("f_get_output", &GetOutput, "GetOutput");
m.def("f_save_output", &SaveOutMmsgStatic, "SaveOutMmsgStatic");
m.def("f_get_output", &GetOutputStatic, "GetOutputStatic");
m.def("f_step_paddle", &StepPaddle, "StepPaddle");
m.def("f_save_output_dygraph", &SaveOutputDygraph, "SaveOutputDygraph");
}
Expand Down Expand Up @@ -360,8 +358,8 @@ PYBIND11_MODULE(paddlenlp_ops_90, m) {
m.def("f_per_token_group_quant", &PerTokenGroupQuant, "PerTokenGroupQuant");
m.def("f_per_tensor_quant_fp8", &PerTensorQuantFp8, "PerTensorQuantFp8");
m.def("f_get_padding_offset_v2", &GetPaddingOffsetV2, "GetPaddingOffsetV2");
m.def("f_save_output", &SaveOutMmsg, "SaveOutMmsg");
m.def("f_get_output", &GetOutput, "GetOutput");
m.def("f_save_output", &SaveOutMmsgStatic, "SaveOutMmsgStatic");
m.def("f_get_output", &GetOutputStatic, "GetOutputStatic");
m.def("f_step_paddle", &StepPaddle, "StepPaddle");
m.def("f_save_output_dygraph", &SaveOutputDygraph, "SaveOutputDygraph");
}
100 changes: 68 additions & 32 deletions csrc/gpu/get_output.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -19,52 +19,88 @@
#include <sys/types.h>
#include "paddle/extension.h"

#define MAX_BSZ 512

#define MAX_BSZ 256
// #define GET_OUTPUT_DEBUG
struct msgdata {
long mtype;
int mtext[MAX_BSZ + 2]; // stop_flag, bsz, tokens
};

void GetOutput(const paddle::Tensor& x,
const paddle::Tensor& msg_queue_id,
int64_t rank_id,
bool wait_flag) {
if (rank_id > 0) return;
int64_t rank_id,
bool wait_flag,
int msg_queue_id) {
if (rank_id > 0) {
return;
}
static struct msgdata msg_rcv;
if (const char* inference_msg_queue_id_env_p = std::getenv("INFERENCE_MSG_QUEUE_ID")){
std::string inference_msg_queue_id_env_str(inference_msg_queue_id_env_p);
int inference_msg_queue_id_from_env = std::stoi(inference_msg_queue_id_env_str);
#ifdef GET_OUTPUT_DEBUG
std::cout << "Your INFERENCE_MSG_QUEUE_ID is: " << inference_msg_queue_id_from_env << std::endl;
#endif
msg_queue_id = inference_msg_queue_id_from_env;
}
static key_t key = ftok("/dev/shm", msg_queue_id);
static int msgid = msgget(key, IPC_CREAT | 0666);

#ifdef GET_OUTPUT_DEBUG
std::cout<<"get_output_key: "<< key << std::endl;
std::cout << "get_output msgid: " << msgid << std::endl;
#endif

static struct msgdata msg_rcv;
int queue_id_val = msg_queue_id.data<int>()[0];
static key_t key = ftok("./", queue_id_val);
int64_t *out_data = const_cast<int64_t*>(x.data<int64_t>());
int ret = -1;
if (!wait_flag) {
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, IPC_NOWAIT);
} else {
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, 0);
}
if(ret == -1) {
out_data[0] = -2;
out_data[1] = 0;
return;
}
int bsz = msg_rcv.mtext[1];

static int msgid = msgget(key, IPC_CREAT | 0666);
for (int64_t i = 0; i < bsz + 2; i++) {
out_data[i] = (int64_t)msg_rcv.mtext[i];
}
#ifdef GET_OUTPUT_DEBUG
std::cout << "get_output finished: " << msgid << std::endl;
#endif

int64_t *out_data = const_cast<int64_t*>(x.data<int64_t>());
int ret = -1;
if (!wait_flag) {
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, IPC_NOWAIT);
} else {
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, 0);
}
if(ret == -1)
{
// read none
out_data[0] = -2;
out_data[1] = 0;
return;
}
return;
}

int bsz = msg_rcv.mtext[1];
void GetOutputStatic(const paddle::Tensor& x,
int64_t rank_id,
bool wait_flag) {
GetOutput(x, rank_id, wait_flag, 1);
}

for (int64_t i = 0; i < bsz + 2; i++) {
out_data[i] = (int64_t)msg_rcv.mtext[i];
}
return;
void GetOutputDynamic(const paddle::Tensor& x,
int64_t rank_id,
bool wait_flag,
int msg_queue_id) {

GetOutput(x, rank_id, wait_flag, msg_queue_id);
}

PD_BUILD_OP(get_output)
.Inputs({"x", "msg_queue_id"})
.Inputs({"x"})
.Attrs({"rank_id: int64_t",
"wait_flag: bool"})
.Outputs({"x_out"})
.SetInplaceMap({{"x", "x_out"}})
.SetKernelFn(PD_KERNEL(GetOutput));
.SetKernelFn(PD_KERNEL(GetOutputStatic));

PD_BUILD_OP(get_output_dynamic)
.Inputs({"x"})
.Attrs({"rank_id: int64_t",
"wait_flag: bool",
"msg_queue_id: int"})
.Outputs({"x_out"})
.SetInplaceMap({{"x", "x_out"}})
.SetKernelFn(PD_KERNEL(GetOutputDynamic));
94 changes: 80 additions & 14 deletions csrc/gpu/save_with_output_msg.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -19,43 +19,109 @@
#include <sys/types.h>
#include "paddle/extension.h"

#define MAX_BSZ 512
#define MAX_BSZ 256

// #define SAVE_WITH_OUTPUT_DEBUG
struct msgdata {
long mtype;
int mtext[MAX_BSZ + 2]; // stop_flag, bsz, tokens
};

void SaveOutMmsg(const paddle::Tensor& x,
const paddle::Tensor& not_need_stop, // cpu
const paddle::Tensor& msg_queue_id, // cpu
int64_t rank_id) {
const paddle::Tensor& not_need_stop,
int64_t rank_id,
int msg_queue_id) {

if (rank_id > 0) return;
auto x_cpu = x.copy_to(paddle::CPUPlace(), false);
int64_t *x_data = x_cpu.data<int64_t>();
auto not_need_stop_data = not_need_stop.data<bool>()[0];

static struct msgdata msg_sed;
int queue_id_val = msg_queue_id.data<int>()[0];
static key_t key = ftok("./", queue_id_val);
static int msgid = msgget(key, IPC_CREAT | 0666);

if (const char* inference_msg_queue_id_env_p = std::getenv("INFERENCE_MSG_QUEUE_ID")){
std::string inference_msg_queue_id_env_str(inference_msg_queue_id_env_p);
int inference_msg_queue_id_from_env = std::stoi(inference_msg_queue_id_env_str);
msg_queue_id = inference_msg_queue_id_from_env;
#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout << "Your INFERENCE_MSG_QUEUE_ID is: " << inference_msg_queue_id_from_env << std::endl;
#endif
} else {
#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout << "Failed to got INFERENCE_MSG_QUEUE_ID at env, use default." << std::endl;
#endif
}
int inference_msg_id_from_env = 1;
if (const char* inference_msg_id_env_p = std::getenv("INFERENCE_MSG_ID")){
std::string inference_msg_id_env_str(inference_msg_id_env_p);
inference_msg_id_from_env = std::stoi(inference_msg_id_env_str);
if (inference_msg_id_from_env == 2){
// 2 and -2 is perserve for no-output indication.
throw std::runtime_error(" INFERENCE_MSG_ID cannot be 2, please use other number.");
}
if (inference_msg_id_from_env < 0) {
throw std::runtime_error(" INFERENCE_MSG_ID cannot be negative, please use other number.");
}

#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout << "Your INFERENCE_MSG_ID is: " << inference_msg_id_from_env << std::endl;
#endif
} else {
#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout << "Failed to got INFERENCE_MSG_ID at env, use (int)1 as default." << std::endl;
#endif
}
static key_t key = ftok("/dev/shm", msg_queue_id);

static int msgid = msgget(key, IPC_CREAT | 0666);
#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout << "save_output_key: " << key << std::endl;
std::cout << "save msgid: " << msgid << std::endl;
#endif
msg_sed.mtype = 1;
msg_sed.mtext[0] = not_need_stop_data ? 1 : -1;
bool not_need_stop_data = not_need_stop.data<bool>()[0];
// printf("not_need_stop_data %d\n", (int)not_need_stop_data);
msg_sed.mtext[0] = not_need_stop_data ? inference_msg_id_from_env : -inference_msg_id_from_env;
int bsz = x.shape()[0];
msg_sed.mtext[1] = bsz;
for (int i = 2; i < bsz + 2; i++) {
msg_sed.mtext[i] = (int)x_data[i - 2];
}
#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout << "msg data: ";
for (int i = 0; i < bsz; i++) {
std::cout << " " << (int)x_data[i];
}
std::cout << std::endl;
#endif
if ((msgsnd(msgid, &msg_sed, (MAX_BSZ + 2) * 4, 0)) == -1) {
// printf("full msg buffer\n");
printf("full msg buffer\n");
}
return;
}


void SaveOutMmsgStatic(const paddle::Tensor& x,
const paddle::Tensor& not_need_stop,
int64_t rank_id) {
SaveOutMmsg(x, not_need_stop, rank_id, 1);
}

void SaveOutMmsgDynamic(const paddle::Tensor& x,
const paddle::Tensor& not_need_stop,
int64_t rank_id,
int msg_queue_id) {
SaveOutMmsg(x, not_need_stop, rank_id, msg_queue_id);
}

PD_BUILD_OP(save_output)
.Inputs({"x", "not_need_stop", "msg_queue_id"})
.Inputs({"x", "not_need_stop"})
.Attrs({"rank_id: int64_t"})
.Outputs({"x_out"})
.SetInplaceMap({{"x", "x_out"}})
.SetKernelFn(PD_KERNEL(SaveOutMmsg));
.SetKernelFn(PD_KERNEL(SaveOutMmsgStatic));

PD_BUILD_OP(save_output_dynamic)
.Inputs({"x", "not_need_stop"})
.Attrs({"rank_id: int64_t", "msg_queue_id: int"})
.Outputs({"x_out"})
.SetInplaceMap({{"x", "x_out"}})
.SetKernelFn(PD_KERNEL(SaveOutMmsgDynamic));
1 change: 0 additions & 1 deletion paddlenlp/experimental/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,6 @@ def _post_process_(
save_output(
next_tokens,
model_kwargs["not_need_stop"],
model_kwargs["msg_queue_id"],
self.config.tensor_parallel_rank,
)
return next_tokens
Expand Down
Loading