Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,7 @@ using namespace fbgemm_gpu;


{#-/* Set the weights row accessor */#}
{%- if is_rocm %}
const auto weights_row = rocm::WeightRowAccessorVec2
{%- else %}
const auto weights_row = WeightRowAccessor
{%- endif %}
<
{{ 'cache_t' if from_cache else 'emb_t' }},
cache_t
Expand Down Expand Up @@ -182,11 +178,7 @@ using namespace fbgemm_gpu;
{%- endif %}

{#-/* Set the weights row accessor */#}
{%- if is_rocm %}
const auto weights_row = rocm::WeightRowAccessorVec2
{%- else %}
const auto weights_row = WeightRowAccessor
{%- endif %}
<
{{ 'cache_t' if from_cache else 'emb_t' }},
cache_t
Expand Down Expand Up @@ -319,7 +311,7 @@ using namespace fbgemm_gpu;

{%- if is_rocm %}
{%- if not nobag %}
rocm::Vec2T<cache_t> vals[kManualUnrollLength * kMaxVecsPerThread];
Vec4T<cache_t> vals[kManualUnrollLength * kMaxVecsPerThread];
{%- endif %}
// Iterate over kThreadGroupSize indices
for (auto outer_j = 0; outer_j < kThreadGroupSize && l_start + outer_j < L - L % kManualUnrollLength; outer_j += kManualUnrollLength)
Expand Down Expand Up @@ -633,12 +625,7 @@ batch_index_select_dim0_codegen_forward_kernel(
#endif

// Elements are processed 4 at a time through fbgemm_gpu::Vec4 (CUDA float4, 16 bytes)
// for CUDA devices and 2 at a time for ROCm
{%- if is_rocm %}
constexpr int VEC_WIDTH = 2;
{%- else %}
constexpr int VEC_WIDTH = 4;
{%- endif %}
{%- if is_rocm %}
// Unroll factor for ROCm devices
constexpr int kManualUnrollLength = 4;
Expand Down Expand Up @@ -743,12 +730,8 @@ batch_index_select_dim0_codegen_forward_kernel(
const float inv_L = (mean_pooling && L != 0) ? static_cast<float>(1.0) / L: static_cast<float>(1.0);

// Set up the accumulator buffer
{%- if is_rocm %}
rocm::Vec2T<cache_t> accumulators[kMaxVecsPerThread];
{%- else %}
Vec4T<cache_t> accumulators[kMaxVecsPerThread];
{%- endif %}
{%- endif %}

{%- if dense %}
{{ embedding_pool_or_store("NULL") }}
Expand Down Expand Up @@ -930,7 +913,7 @@ batch_index_select_dim0_codegen_forward_kernel
{%- endmacro %}

{%- macro bulk_template_instantiations(use_cache, kMaxVecsPerThread, kThreadGroupSize) %}
{%- set max_vecs_per_thread = 2 * kMaxVecsPerThread if is_rocm else kMaxVecsPerThread %}
{%- set max_vecs_per_thread = kMaxVecsPerThread %}
{%- for emb_type in (['float', 'at::Half'] + (['at::Float8_e4m3fnuz'] if is_rocm else ['at::Float8_e4m3fn'])) %}
{%- for cache_type in ['float', 'at::Half'] %}
{%- for output_type in ['float', 'at::Half', 'at::BFloat16'] %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -716,12 +716,7 @@ batch_index_select_dim0_codegen_forward_cuda(
// kFixedMaxVecsPerThread instead of kMaxVecsPerThread. But
// kMaxVecsPerThread and kFixedMaxVecsPerThread are the same
// forward
{%- if is_rocm %}
// Account for Vec2 load for ROCm
constexpr auto kMaxVecsPerThread = 2 * kFixedMaxVecsPerThread;
{%- else %}
constexpr auto kMaxVecsPerThread = kFixedMaxVecsPerThread;
{%- endif %}

const auto grid = min(
div_round_up(total_B, kForwardMaxThreads / kThreadGroupSize),
Expand Down