Skip to content

Commit

Permalink
Fix some minor issues related to NHWC on CUDA with more than 4 dims.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Oct 25, 2024
1 parent c6ce5b1 commit 34d496a
Showing 1 changed file with 18 additions and 2 deletions.
20 changes: 18 additions & 2 deletions lib/nnc/gpu/ccv_nnc_compat.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1129,7 +1129,15 @@ ccv_nnc_cudnn_tensor_view_descriptor_t ccv_nnc_cudnn_get_tensor_view_descriptor(
stride[0] = stride[2] * inc[1];
break;
default:
assert(0);
dim[0] = tensor->info.dim[0];
dim[1] = tensor->info.dim[axis_count - 1];
stride[1] = 1;
for (i = axis_count - 3; i >= 0; i--)
{
dim[i + 2] = tensor->info.dim[i + 1];
stride[i + 2] = (i == axis_count - 3) ? inc[i + 2] : stride[i + 3] * inc[i + 2];
}
stride[0] = stride[2] * inc[1];
}
} else if (tensor->info.format == CCV_TENSOR_FORMAT_CHWN) {
switch (axis_count)
Expand Down Expand Up @@ -1263,7 +1271,15 @@ ccv_nnc_cudnn_tensor_view_descriptor_t ccv_nnc_cudnn_get_tensor_view_descriptor(
stride[0] = tensor_stride[0];
break;
default:
assert(0);
dim[0] = tensor->info.dim[0];
dim[1] = tensor->info.dim[axis_count - 1];
stride[1] = tensor_stride[axis_count - 1];
for (i = axis_count - 3; i >= 0; i--)
{
dim[i + 2] = tensor->info.dim[i + 1];
stride[i + 2] = tensor_stride[i + 1];
}
stride[0] = tensor_stride[0];
}
} else if (tensor->info.format == CCV_TENSOR_FORMAT_CHWN) {
switch (axis_count)
Expand Down

0 comments on commit 34d496a

Please sign in to comment.