Skip to content

Commit 110370d

Browse files
authored
Merge pull request #13 from zz990099/develop_texture_free_feature
Make `FoundationPose` support texture-less rendering
2 parents 90f274a + 6a27646 commit 110370d

File tree

7 files changed

+212
-12
lines changed

7 files changed

+212
-12
lines changed

detection_6d_foundationpose/src/foundationpose_render.cpp

Lines changed: 83 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,8 @@ FoundationPoseRenderer::PrepareBuffer()
270270
// nvdiffrast render 用到的缓存以及渲染器
271271
size_t pose_clip_size = num_vertices_ * (kVertexPoints + 1) * input_poses_num_ * sizeof(float);
272272
size_t pts_cam_size = num_vertices_ * kVertexPoints * input_poses_num_ * sizeof(float);
273+
size_t diffuse_intensity_size = num_vertices_ * input_poses_num_ * sizeof(float);
274+
size_t diffuse_intensity_map_size = input_poses_num_ * crop_window_H_ * crop_window_W_ * sizeof(float);
273275
size_t rast_out_size = input_poses_num_ * crop_window_H_ * crop_window_W_ * (kVertexPoints + 1) * sizeof(float);
274276
size_t color_size = input_poses_num_ * crop_window_H_ * crop_window_W_ * kNumChannels * sizeof(float);
275277
size_t xyz_map_size = input_poses_num_ * crop_window_H_ * crop_window_W_ * kNumChannels * sizeof(float);
@@ -279,6 +281,8 @@ FoundationPoseRenderer::PrepareBuffer()
279281
float* _pose_clip_device;
280282
float* _rast_out_device;
281283
float* _pts_cam_device;
284+
float* _diffuse_intensity_device;
285+
float* _diffuse_intensity_map_device;
282286
float* _texcoords_out_device;
283287
float* _color_device;
284288
float* _xyz_map_device;
@@ -307,6 +311,14 @@ FoundationPoseRenderer::PrepareBuffer()
307311
"[FoundationPoseRenderer] cudaMalloc `_pts_cam_device` FAILED!!!");
308312
pts_cam_device_ = DeviceBufferUniquePtrType<float>(_pts_cam_device, CudaMemoryDeleter<float>());
309313

314+
CHECK_CUDA(cudaMalloc(&_diffuse_intensity_device, diffuse_intensity_size),
315+
"[FoundationPoseRenderer] cudaMalloc `_diffuse_intensity_device` FAILED!!!");
316+
diffuse_intensity_device_ = DeviceBufferUniquePtrType<float>(_diffuse_intensity_device, CudaMemoryDeleter<float>());
317+
318+
CHECK_CUDA(cudaMalloc(&_diffuse_intensity_map_device, diffuse_intensity_map_size),
319+
"[FoundationPoseRenderer] cudaMalloc `_diffuse_intensity_map_device` FAILED!!!");
320+
diffuse_intensity_map_device_ = DeviceBufferUniquePtrType<float>(_diffuse_intensity_map_device, CudaMemoryDeleter<float>());
321+
310322
CHECK_CUDA(cudaMalloc(&_texcoords_out_device, texcoords_out_size),
311323
"[FoundationPoseRenderer] cudaMalloc `_texcoords_out_device` FAILED!!!");
312324
texcoords_out_device_ = DeviceBufferUniquePtrType<float>(_texcoords_out_device, CudaMemoryDeleter<float>());
@@ -361,18 +373,24 @@ FoundationPoseRenderer::LoadTexturedMesh()
361373
{
362374
const auto& mesh_model_center = mesh_loader_->GetMeshModelCenter();
363375
const auto& mesh_vertices = mesh_loader_->GetMeshVertices();
376+
const auto& mesh_vertex_normals = mesh_loader_->GetMeshVertexNormals();
364377
const auto& mesh_texcoords = mesh_loader_->GetMeshTextureCoords();
365378
const auto& mesh_faces = mesh_loader_->GetMeshFaces();
366379
const auto& rgb_texture_map = mesh_loader_->GetTextureMap();
367380
mesh_diameter_ = mesh_loader_->GetMeshDiameter();
368381

382+
std::vector<float> vertex_normals;
369383

370384
// Walk through each of the mesh's vertices
371385
for (unsigned int v = 0; v < mesh_vertices.size(); v++) {
372386
vertices_.push_back(mesh_vertices[v].x - mesh_model_center[0]);
373387
vertices_.push_back(mesh_vertices[v].y - mesh_model_center[1]);
374388
vertices_.push_back(mesh_vertices[v].z - mesh_model_center[2]);
375389

390+
vertex_normals.push_back(mesh_vertex_normals[v].x);
391+
vertex_normals.push_back(mesh_vertex_normals[v].y);
392+
vertex_normals.push_back(mesh_vertex_normals[v].z);
393+
376394
// Check if the mesh has texture coordinates
377395
if (mesh_texcoords.size() >= 1) {
378396
texcoords_.push_back(mesh_texcoords[0][v].x);
@@ -422,6 +440,7 @@ FoundationPoseRenderer::LoadTexturedMesh()
422440
size_t texcoords_size = texcoords_.size() * sizeof(float);
423441

424442
float* _vertices_device;
443+
float* _vertex_normals_device;
425444
float* _texcoords_device;
426445
int32_t* _mesh_faces_device;
427446
uint8_t* _texture_map_device;
@@ -430,6 +449,10 @@ FoundationPoseRenderer::LoadTexturedMesh()
430449
"[FoundationposeRender] cudaMalloc `mesh_faces_device` FAILED!!!");
431450
vertices_device_ = DeviceBufferUniquePtrType<float>(_vertices_device, CudaMemoryDeleter<float>());
432451

452+
CHECK_CUDA(cudaMalloc(&_vertex_normals_device, vertices_size),
453+
"[FoundationposeRender] cudaMalloc `vertex_normals_device` FAILED!!!");
454+
vertex_normals_device_ = DeviceBufferUniquePtrType<float>(_vertex_normals_device, CudaMemoryDeleter<float>());
455+
433456
CHECK_CUDA(cudaMalloc(&_mesh_faces_device, faces_size),
434457
"[FoundationposeRender] cudaMalloc `mesh_faces_device` FAILED!!!");
435458
mesh_faces_device_ = DeviceBufferUniquePtrType<int32_t>(_mesh_faces_device, CudaMemoryDeleter<int32_t>());
@@ -442,9 +465,14 @@ FoundationPoseRenderer::LoadTexturedMesh()
442465
"[FoundationposeRender] cudaMalloc `texture_map_device_` FAILED!!!");
443466
texture_map_device_ = DeviceBufferUniquePtrType<uint8_t>(_texture_map_device, CudaMemoryDeleter<uint8_t>());
444467

445-
CHECK_CUDA(cudaMemcpy(vertices_device_.get(),
446-
vertices_.data(),
447-
vertices_size,
468+
CHECK_CUDA(cudaMemcpy(vertices_device_.get(),
469+
vertices_.data(),
470+
vertices_size,
471+
cudaMemcpyHostToDevice),
472+
"[FoundationposeRender] cudaMemcpy mesh_faces_host -> mesh_faces_device FAILED!!!");
473+
CHECK_CUDA(cudaMemcpy(vertex_normals_device_.get(),
474+
vertex_normals.data(),
475+
vertices_size,
448476
cudaMemcpyHostToDevice),
449477
"[FoundationposeRender] cudaMemcpy mesh_faces_host -> mesh_faces_device FAILED!!!");
450478
CHECK_CUDA(cudaMemcpy(mesh_faces_device_.get(),
@@ -514,6 +542,36 @@ bool FoundationPoseRenderer::TransformVerticesOnCUDA(cudaStream_t stream,
514542
return true;
515543
}
516544

545+
bool FoundationPoseRenderer::TransformVertexNormalsOnCUDA(cudaStream_t stream,
546+
const std::vector<Eigen::MatrixXf>& tfs,
547+
float* output_buffer)
548+
{
549+
// Get the dimensions of the inputs
550+
int tfs_size = tfs.size();
551+
CHECK_STATE(tfs_size != 0,
552+
"[FoundationposeRender] The transfomation matrix is empty! ");
553+
554+
CHECK_STATE(tfs[0].cols() == tfs[0].rows(),
555+
"[FoundationposeRender] The transfomation matrix has different rows and cols! ");
556+
557+
const int total_elements = tfs[0].cols() * tfs[0].rows();
558+
559+
float* transform_device_buffer_ = nullptr;
560+
cudaMallocAsync(&transform_device_buffer_, tfs_size * total_elements * sizeof(float), stream);
561+
562+
for (int i = 0 ; i < tfs_size ; ++ i) {
563+
cudaMemcpyAsync(transform_device_buffer_ + i * total_elements,
564+
tfs[i].data(),
565+
total_elements * sizeof(float),
566+
cudaMemcpyHostToDevice,
567+
stream);
568+
}
569+
570+
foundationpose_render::transform_normals(stream, transform_device_buffer_, tfs_size, vertex_normals_device_.get(), num_vertices_, output_buffer);
571+
572+
cudaFreeAsync(transform_device_buffer_, stream);
573+
return true;
574+
}
517575

518576
bool FoundationPoseRenderer::GeneratePoseClipOnCUDA(cudaStream_t stream,
519577
float* output_buffer,
@@ -595,15 +653,15 @@ FoundationPoseRenderer::NvdiffrastRender(cudaStream_t cuda_stream,
595653
foundationpose_render::interpolate(
596654
cuda_stream,
597655
pts_cam_device_.get(), rast_out_device_.get(), mesh_faces_device_.get(), xyz_map_device_.get(),
598-
num_vertices_, num_faces_, kVertexPoints,
656+
num_vertices_, num_faces_, 3, kVertexPoints,
599657
H, W, N);
600658
CHECK_CUDA(cudaGetLastError(),
601659
"[FoundationPoseRenderer] interpolate failed!!!");
602660

603661
foundationpose_render::interpolate(
604662
cuda_stream,
605663
texcoords_device_.get(), rast_out_device_.get(), mesh_faces_device_.get(), texcoords_out_device_.get(),
606-
num_vertices_, num_faces_, kTexcoordsDim,
664+
num_vertices_, num_faces_, 2, kTexcoordsDim,
607665
H, W, N);
608666
CHECK_CUDA(cudaGetLastError(),
609667
"[FoundationPoseRenderer] interpolate failed!!!");
@@ -619,6 +677,26 @@ FoundationPoseRenderer::NvdiffrastRender(cudaStream_t cuda_stream,
619677
CHECK_CUDA(cudaGetLastError(),
620678
"[FoundationPoseRenderer] texture failed!!!");
621679

680+
CHECK_STATE(TransformVertexNormalsOnCUDA(cuda_stream, poses, diffuse_intensity_device_.get()),
681+
"[FoundationPoseRenderer] Transform vertex normals failed!!!");
682+
683+
foundationpose_render::interpolate(cuda_stream,
684+
diffuse_intensity_device_.get(),
685+
rast_out_device_.get(),
686+
mesh_faces_device_.get(),
687+
diffuse_intensity_map_device_.get(),
688+
num_vertices_, num_faces_, 3, 1, H, W, N);
689+
CHECK_CUDA(cudaGetLastError(),
690+
"[FoundationPoseRenderer] interpolate failed!!!");
691+
692+
foundationpose_render::refine_color(cuda_stream, color_device_.get(),
693+
diffuse_intensity_map_device_.get(),
694+
rast_out_device_.get(),
695+
color_device_.get(),
696+
poses.size(), 0.8, 0.5, H, W);
697+
CHECK_CUDA(cudaGetLastError(),
698+
"[FoundationPoseRenderer] refine_color failed!!!");
699+
622700
float min_value = 0.0;
623701
float max_value = 1.0;
624702
foundationpose_render::clamp(cuda_stream, color_device_.get(), min_value, max_value, N * H * W * kNumChannels);

detection_6d_foundationpose/src/foundationpose_render.cu

Lines changed: 86 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,8 @@ void rasterize(
196196

197197
void interpolate(
198198
cudaStream_t stream, float* attr_ptr, float* rast_ptr, int32_t* tri_ptr, float* out, int num_vertices,
199-
int num_triangles, int attr_dim, int H, int W, int C) {
200-
int instance_mode = attr_dim > 2 ? 1 : 0;
199+
int num_triangles, int attr_shape_dim, int attr_dim, int H, int W, int C) {
200+
int instance_mode = attr_shape_dim > 2 ? 1 : 0;
201201

202202
InterpolateKernelParams p = {}; // Initialize all fields to zero.
203203
p.instance_mode = instance_mode;
@@ -339,4 +339,88 @@ void generate_pose_clip(cudaStream_t stream, const float* transform_matrixs, con
339339
transform_matrixs, bbox2d_matrix, M, points_vectors, N, transformed_points_vectors, rgb_H, rgb_W);
340340
}
341341

342+
343+
__global__ void transform_normals_kernel(
344+
const float* transform_matrixs, int M, const float* normals_vectors,
345+
int N, float* transformed_normal_vectors)
346+
{
347+
int row_idx = threadIdx.y + blockIdx.y * blockDim.y;
348+
int col_idx = threadIdx.x + blockIdx.x * blockDim.x;
349+
if (row_idx >= M || col_idx >= N) return;
350+
351+
const float* matrix = transform_matrixs + row_idx * 16; // 指向当前 4x4 变换矩阵
352+
const float* normal = normals_vectors + col_idx * 3; // 指向当前 normal 向量
353+
float* transformed_normal = transformed_normal_vectors + (row_idx * N + col_idx);
354+
355+
float x = normal[0], y = normal[1], z = normal[2];
356+
// **Column-Major 访问方式**
357+
float tx = matrix[0] * x + matrix[4] * y + matrix[8] * z;
358+
float ty = matrix[1] * x + matrix[5] * y + matrix[9] * z;
359+
float tz = matrix[2] * x + matrix[6] * y + matrix[10] * z;
360+
// 只保留z方向的分量,取反
361+
float l2 = sqrt(tx*tx + ty*ty + tz*tz);
362+
float value = l2 == 0 ? 0 : - tz / l2;
363+
value = clamp_func(value, 0, 1);
364+
transformed_normal[0] = value;
365+
}
366+
367+
void transform_normals(cudaStream_t stream, const float* transform_matrixs, int M, const float* normals_vectors,
368+
int N, float* transformed_normal_vectors)
369+
{
370+
dim3 blockSize = {32, 32};
371+
dim3 gridSize = {ceil_div(N, 32), ceil_div(M, 32)};
372+
373+
transform_normals_kernel<<<gridSize, blockSize, 0, stream>>>(
374+
transform_matrixs, M, normals_vectors, N, transformed_normal_vectors);
375+
}
376+
377+
378+
__global__ void renfine_color_kernel(
379+
const float* color, const float* diffuse_intensity_map, const float* rast_out, float* output, int poses_num, float w_ambient,
380+
float w_diffuse, int rgb_H, int rgb_W)
381+
{
382+
int row_idx = threadIdx.y + blockIdx.y * blockDim.y;
383+
int col_idx = threadIdx.x + blockIdx.x * blockDim.x;
384+
if (row_idx >= rgb_H || col_idx >= rgb_W * poses_num) return;
385+
386+
const int color_idx = col_idx / rgb_W;
387+
const int color_row_idx = row_idx;
388+
const int color_col_idx = col_idx - color_idx * rgb_W;
389+
390+
const size_t pixel_idx = color_row_idx * rgb_W + color_col_idx;
391+
const size_t pixel_offset = color_idx * rgb_H * rgb_W + pixel_idx;
392+
393+
const float* rgb = color + pixel_offset * 3;
394+
const float* diffuse = diffuse_intensity_map + pixel_offset;
395+
const float* rast = rast_out + pixel_offset * 4;
396+
float* out = output + pixel_offset * 3;
397+
398+
float diff = diffuse[0];
399+
400+
float is_foreground = clamp_func(rast[3], 0, 1);
401+
402+
float r = rgb[0] * (w_ambient + diff*w_diffuse) * is_foreground;
403+
float g = rgb[1] * (w_ambient + diff*w_diffuse) * is_foreground;
404+
float b = rgb[2] * (w_ambient + diff*w_diffuse) * is_foreground;
405+
406+
r = clamp_func(r, 0, 1);
407+
g = clamp_func(g, 0, 1);
408+
b = clamp_func(b, 0, 1);
409+
410+
out[0] = r;
411+
out[1] = g;
412+
out[2] = b;
413+
}
414+
415+
void refine_color(cudaStream_t stream, const float* color, const float* diffuse_intensity_map, const float* rast_out, float* output,
416+
int poses_num, float w_ambient, float w_diffuse, int rgb_H, int rgb_W)
417+
{
418+
dim3 blockSize = {32, 32};
419+
dim3 gridSize = {ceil_div(rgb_W * poses_num, 32), ceil_div(rgb_H, 32)};
420+
421+
renfine_color_kernel<<<gridSize, blockSize, 0, stream>>>(
422+
color, diffuse_intensity_map, rast_out, output, poses_num, w_ambient, w_diffuse, rgb_H, rgb_W
423+
);
424+
}
425+
342426
} // namespace foundationpose_render

detection_6d_foundationpose/src/foundationpose_render.cu.hpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ void rasterize(
4848

4949
void interpolate(
5050
cudaStream_t stream, float* attr_ptr, float* rast_ptr, int32_t* tri_ptr, float* out, int num_vertices,
51-
int num_triangles, int attr_dim, int H, int W, int C);
51+
int num_triangles, int attr_shape_dim, int attr_dim, int H, int W, int C);
5252

5353
void texture(
5454
cudaStream_t stream, float* tex_ptr, float* uv_ptr, float* out, int tex_height, int tex_width, int tex_channel,
@@ -71,6 +71,17 @@ void transform_points(cudaStream_t stream, const float* transform_matrixs, int t
7171
void generate_pose_clip(cudaStream_t stream, const float* transform_matrixs, const float* bbox2d_matrix, int transform_num, const float* points_vectors,
7272
int points_num, float* transformed_points_vectors, int rgb_H, int rgb_W);
7373

74+
/**
75+
* @param transform_matrixs 应当是`Col-Major`的transform_num个4x4矩阵
76+
* @param normals_vectors 应当是`normals_num`个3x1向量
77+
* @param transformed_normal_vectors 这里直接输出归一化后的z方向分量,供 `transform_num x normals_num`个,即 [hyp-pose, H, W, 1]
78+
*/
79+
void transform_normals(cudaStream_t stream, const float* transform_matrixs, int transform_num, const float* normals_vectors,
80+
int normals_num, float* transformed_normal_vectors);
81+
82+
void refine_color(cudaStream_t stream, const float* color, const float* diffuse_intensity_map, const float* rast, float* output,
83+
int poses_num, float w_ambient, float w_diffuse, int rgb_H, int rgb_W);
84+
7485
} // namespace foundationpose_render
7586

7687
#endif // NVIDIA_ISAAC_ROS_EXTENSIONS_FOUNDATIONPOSE_RENDER_CUDA_HPP_

detection_6d_foundationpose/src/foundationpose_render.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ class FoundationPoseRenderer {
6464
const std::vector<Eigen::MatrixXf>& tfs,
6565
float* output_buffer) ;
6666

67+
bool TransformVertexNormalsOnCUDA(cudaStream_t stream,
68+
const std::vector<Eigen::MatrixXf>& tfs,
69+
float* output_buffer);
70+
6771
bool GeneratePoseClipOnCUDA(cudaStream_t stream,
6872
float* output_buffer,
6973
const std::vector<Eigen::MatrixXf>& poses,
@@ -121,13 +125,16 @@ class FoundationPoseRenderer {
121125
using DeviceBufferUniquePtrType = std::unique_ptr<T, std::function<void(T*)>>;
122126

123127
DeviceBufferUniquePtrType<float> vertices_device_ {nullptr};
128+
DeviceBufferUniquePtrType<float> vertex_normals_device_ {nullptr};
124129
DeviceBufferUniquePtrType<float> texcoords_device_ {nullptr};
125130
DeviceBufferUniquePtrType<int32_t> mesh_faces_device_ {nullptr};
126131
DeviceBufferUniquePtrType<uint8_t> texture_map_device_ {nullptr};
127132
// nvdiffrast render时相关缓存
128133
DeviceBufferUniquePtrType<float> pose_clip_device_ {nullptr};
129134
DeviceBufferUniquePtrType<float> rast_out_device_ {nullptr};
130135
DeviceBufferUniquePtrType<float> pts_cam_device_ {nullptr};
136+
DeviceBufferUniquePtrType<float> diffuse_intensity_device_ {nullptr};
137+
DeviceBufferUniquePtrType<float> diffuse_intensity_map_device_ {nullptr};
131138
DeviceBufferUniquePtrType<float> texcoords_out_device_ {nullptr};
132139
DeviceBufferUniquePtrType<float> color_device_ {nullptr};
133140
DeviceBufferUniquePtrType<float> xyz_map_device_ {nullptr};

detection_6d_foundationpose/src/foundationpose_sampling.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ std::vector<Eigen::Matrix4f> MakeRotationGrid(unsigned int n_views = 40, int inp
200200
auto R_inplane = Eigen::Affine3f::Identity();
201201
R_inplane.rotate(Eigen::AngleAxisf(0, Eigen::Vector3f::UnitX()))
202202
.rotate(Eigen::AngleAxisf(0, Eigen::Vector3f::UnitY()))
203-
.rotate(Eigen::AngleAxisf(inplane_rot, Eigen::Vector3f::UnitZ()));
203+
.rotate(Eigen::AngleAxisf(inplane_rot * M_PI / 180.0f, Eigen::Vector3f::UnitZ()));
204204

205205
cam_in_ob = cam_in_ob * R_inplane.matrix();
206206
Eigen::Matrix4f ob_in_cam = cam_in_ob.inverse();

detection_6d_foundationpose/src/foundationpose_utils.cpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ TexturedMeshLoader::TexturedMeshLoader(const std::string& mesh_file_path,
148148
// Walk through each of the mesh's vertices
149149
for (unsigned int v = 0; v < mesh->mNumVertices; v++) {
150150
vertices_.push_back(mesh->mVertices[v]);
151+
vertex_normals_.push_back(mesh->mNormals[v]);
151152
}
152153
for (unsigned int i = 0 ; i < AI_MAX_NUMBER_OF_TEXTURECOORDS ; ++ i) {
153154
if (mesh->mTextureCoords[i] != nullptr) {
@@ -167,12 +168,12 @@ TexturedMeshLoader::TexturedMeshLoader(const std::string& mesh_file_path,
167168
LOG(INFO) << "Loading textured map file: " << textured_file_path;
168169
texture_map_ = cv::imread(textured_file_path);
169170
if (texture_map_.empty()) {
170-
throw std::runtime_error("[TexturedMeshLoader] Failed to read textured image: "
171-
+ textured_file_path);
171+
// throw std::runtime_error("[TexturedMeshLoader] Failed to read textured image: "
172+
// + textured_file_path);
173+
texture_map_ = cv::Mat(2, 2, CV_8UC3, {100, 100, 100});
172174
}
173175
cv::cvtColor(texture_map_, texture_map_, cv::COLOR_BGR2RGB);
174176

175-
176177
LOG(INFO) << "Successfully Loaded textured mesh file!!!";
177178
LOG(INFO) << "Mesh has vertices_num: " << vertices_.size()
178179
<< ", diameter: " << mesh_diamter_
@@ -217,6 +218,17 @@ TexturedMeshLoader::GetMeshVertices() const noexcept
217218
return vertices_;
218219
}
219220

221+
/**
222+
* @brief 获取mesh模型顶点的法向量
223+
*
224+
* @return const std::vector<aiVector3D> &
225+
*/
226+
const std::vector<aiVector3D> &
227+
TexturedMeshLoader::GetMeshVertexNormals() const noexcept
228+
{
229+
return vertex_normals_;
230+
}
231+
220232
/**
221233
* @brief 获取mesh模型的外观坐标系
222234
*

0 commit comments

Comments
 (0)