From b4aa1c729b554b0f7a9714ed7c40136f2424275a Mon Sep 17 00:00:00 2001 From: sdtblck Date: Sat, 29 Jan 2022 16:51:30 +0000 Subject: [PATCH] port fix from nvidia megatron --- .../fused_kernels/scaled_masked_softmax.cpp | 34 +- .../fused_kernels/scaled_masked_softmax.h | 27 +- .../scaled_masked_softmax_cuda.cu | 19 +- .../scaled_upper_triang_masked_softmax.h | 50 ++- .../fused_kernels/tests/test_fused_kernels.py | 296 ++++++++++++++++++ megatron/model/fused_softmax.py | 112 ++++--- megatron/model/gmlp.py | 10 +- megatron/model/transformer.py | 5 +- megatron/model/utils.py | 10 + 9 files changed, 468 insertions(+), 95 deletions(-) create mode 100644 megatron/fused_kernels/tests/test_fused_kernels.py diff --git a/megatron/fused_kernels/scaled_masked_softmax.cpp b/megatron/fused_kernels/scaled_masked_softmax.cpp index 93b078afb..1852aee6f 100644 --- a/megatron/fused_kernels/scaled_masked_softmax.cpp +++ b/megatron/fused_kernels/scaled_masked_softmax.cpp @@ -32,13 +32,19 @@ torch::Tensor bwd_cuda( torch::Tensor const& softmax_results, float scale_factor); +int get_batch_per_block_cuda( + int query_seq_len, + int key_seq_len, + int batches, + int attn_heads); + torch::Tensor fwd( torch::Tensor const& input, torch::Tensor const& mask, float scale_factor) { AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || - (input.scalar_type() == at::ScalarType::BFloat16), + (input.scalar_type() == at::ScalarType::BFloat16), "Only fp16 and bf16 are supported"); AT_ASSERTM(mask.dim() == 4, "expected 4D tensor"); @@ -46,7 +52,7 @@ torch::Tensor fwd( } torch::Tensor bwd( - torch::Tensor const& output_grads, + torch::Tensor const& output_grads, torch::Tensor const& softmax_results, float scale_factor) { @@ -54,24 +60,38 @@ torch::Tensor bwd( AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor"); AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || - (output_grads.scalar_type() == at::ScalarType::BFloat16), + (output_grads.scalar_type() == at::ScalarType::BFloat16), "Only fp16 and bf16 are supported"); AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || - (softmax_results.scalar_type() == at::ScalarType::BFloat16), + (softmax_results.scalar_type() == at::ScalarType::BFloat16), "Only fp16 and bf16 are supported"); return bwd_cuda(output_grads, softmax_results, scale_factor); } +int get_batch_per_block( + int query_seq_len, + int key_seq_len, + int batches, + int attn_heads) { + return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads); +} + } // end namespace scaled_masked_softmax } // end namespace fused_softmax } // end namespace multihead_attn PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", - &multihead_attn::fused_softmax::scaled_masked_softmax::fwd, + m.def("forward", + &multihead_attn::fused_softmax::scaled_masked_softmax::fwd, "Self Multihead Attention scaled, time masked softmax -- Forward."); + m.def("backward", &multihead_attn::fused_softmax::scaled_masked_softmax::bwd, "Self Multihead Attention scaled, time masked softmax -- Backward."); -} \ No newline at end of file + + m.def("get_batch_per_block", + &multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block, + "Return Batch per block size." + ); +} diff --git a/megatron/fused_kernels/scaled_masked_softmax.h b/megatron/fused_kernels/scaled_masked_softmax.h index 5ad81f8ec..1f98291ca 100644 --- a/megatron/fused_kernels/scaled_masked_softmax.h +++ b/megatron/fused_kernels/scaled_masked_softmax.h @@ -1,4 +1,3 @@ - /* coding=utf-8 * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. * @@ -17,6 +16,7 @@ #pragma once +#include #include #include #include @@ -112,7 +112,7 @@ __global__ void scaled_masked_softmax_warp_forward( constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = 4; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) // gridDim/blockIdx = (seq_len, attn_heads, batches) @@ -231,7 +231,7 @@ __global__ void scaled_masked_softmax_warp_backward( constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = 4; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) // gridDim/blockIdx = (seq_len, attn_heads, batches) @@ -311,9 +311,23 @@ __global__ void scaled_masked_softmax_warp_backward( } } } - } // end of anonymous namespace +int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads){ + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + + int batch_count = batches * attn_heads * query_seq_len; + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + constexpr int threads_per_block = 128; + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + + return batches_per_block; +} + template void dispatch_scaled_masked_softmax_forward( output_t *dst, @@ -326,7 +340,6 @@ void dispatch_scaled_masked_softmax_forward( int attn_heads, int pad_batches) { - TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048 ); if (key_seq_len == 0) { return; } else { @@ -345,7 +358,6 @@ void dispatch_scaled_masked_softmax_forward( int warps_per_block = (threads_per_block / warp_size); int batches_per_block = warps_per_block * batches_per_warp; - TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0); dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches); dim3 threads(warp_size, warps_per_block, 1); // Launch code would be more elegant if C++ supported FOR CONSTEXPR @@ -415,7 +427,6 @@ void dispatch_scaled_masked_softmax_backward( int batches, int attn_heads) { - TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 2048 ); if (key_seq_len == 0) { return; } else { @@ -490,4 +501,4 @@ void dispatch_scaled_masked_softmax_backward( break; } } -} \ No newline at end of file +} diff --git a/megatron/fused_kernels/scaled_masked_softmax_cuda.cu b/megatron/fused_kernels/scaled_masked_softmax_cuda.cu index 6ebc95882..902d36dd0 100644 --- a/megatron/fused_kernels/scaled_masked_softmax_cuda.cu +++ b/megatron/fused_kernels/scaled_masked_softmax_cuda.cu @@ -28,6 +28,11 @@ namespace multihead_attn { namespace fused_softmax { namespace scaled_masked_softmax { +int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads){ + return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads); +} + + torch::Tensor fwd_cuda( torch::Tensor const& input, torch::Tensor const& mask, @@ -46,9 +51,9 @@ torch::Tensor fwd_cuda( TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len); TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len); - // Output + // Output auto act_options = input.options().requires_grad(false); - torch::Tensor softmax_results = + torch::Tensor softmax_results = torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); // Softmax Intermediate Result Ptr @@ -74,10 +79,10 @@ torch::Tensor fwd_cuda( } torch::Tensor bwd_cuda( - torch::Tensor const& output_grads_, - torch::Tensor const& softmax_results_, + torch::Tensor const& output_grads_, + torch::Tensor const& softmax_results_, float scale_factor) { - + auto output_grads = output_grads_.contiguous(); auto softmax_results = softmax_results_.contiguous(); @@ -94,8 +99,8 @@ torch::Tensor bwd_cuda( output_grads_.scalar_type(), "dispatch_scaled_masked_softmax_backward", dispatch_scaled_masked_softmax_backward( - reinterpret_cast(output_grads_ptr), - reinterpret_cast(output_grads_ptr), + reinterpret_cast(output_grads_ptr), + reinterpret_cast(output_grads_ptr), reinterpret_cast(softmax_results.data_ptr()), scale_factor, query_seq_len, diff --git a/megatron/fused_kernels/scaled_upper_triang_masked_softmax.h b/megatron/fused_kernels/scaled_upper_triang_masked_softmax.h index c79df72f1..bffc29a0f 100644 --- a/megatron/fused_kernels/scaled_upper_triang_masked_softmax.h +++ b/megatron/fused_kernels/scaled_upper_triang_masked_softmax.h @@ -33,7 +33,7 @@ __device__ __inline__ void copy_vector(c10::BFloat16 *dst, con template <> __device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); } - + template <> __device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *dst = *src; } @@ -112,23 +112,23 @@ __device__ __forceinline__ void warp_reduce(acc_t* sum) { */ template __global__ void scaled_upper_triang_masked_softmax_warp_forward( - output_t *dst, - const input_t *src, - const acc_t scale, - int micro_batch_size, - int stride, - int element_count) + output_t *dst, + const input_t *src, + const acc_t scale, + int micro_batch_size, + int stride, + int element_count) { - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and // warp_size of method warp_softmax_forward_kernel. constexpr int next_power_of_two = 1 << log2_elements; constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = 4; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; - int local_seq = blockIdx.x + 1; + int local_seq = blockIdx.x + 1; int warp_iteration_limit = (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1)/ WARP_SIZE; // micro_batch_size might not be a multiple of WARP_BATCH. Check how @@ -194,7 +194,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward( if (it < warp_iteration_limit) { elements[i][it] = std::exp((elements[i][it] - max_value[i])); sum[i] += elements[i][it]; - } + } } } warp_reduce(sum); @@ -211,7 +211,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward( if (element_index < local_seq) { - #pragma unroll + #pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { if (element_index + element < local_seq) { out[element] = elements[i][it + element] / sum[i]; @@ -224,32 +224,32 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward( copy_zero_vector(dst + i * element_count * stride + it * WARP_SIZE); } else { break; - } + } } } } template __global__ void scaled_upper_triang_masked_softmax_warp_backward( - output_t *gradInput, - input_t *grad, + output_t *gradInput, + input_t *grad, const input_t *output, - acc_t scale, - int micro_batch_size, - int stride, + acc_t scale, + int micro_batch_size, + int stride, int element_count) { - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and // warp_size of method warp_softmax_backward_kernel. constexpr int next_power_of_two = 1 << log2_elements; constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = 4; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; - int local_seq = blockIdx.x + 1; - + int local_seq = blockIdx.x + 1; + // micro_batch_size might not be a multiple of WARP_BATCH. Check how // many batches have to computed within this WARP. int local_batches = micro_batch_size - first_batch; @@ -296,7 +296,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward( } } } - + acc_t sum[WARP_BATCH]; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { @@ -340,7 +340,6 @@ void dispatch_scaled_upper_triang_masked_softmax_forward( int softmax_elements_stride, int attn_batches) { - TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048 ); if (softmax_elements == 0) { return; } else { @@ -360,7 +359,6 @@ void dispatch_scaled_upper_triang_masked_softmax_forward( int warps_per_block = (threads_per_block / warp_size); int batches_per_block = warps_per_block * batches_per_warp; - TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); int blocks_per_seq = attn_batches / batches_per_block; dim3 blocks(seq_len, blocks_per_seq, 1); dim3 threads(warp_size, warps_per_block, 1); @@ -430,7 +428,6 @@ void dispatch_scaled_upper_triang_masked_softmax_backward( int softmax_elements_stride, int attn_batches) { - TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 ); if (softmax_elements == 0) { return; } else { @@ -450,7 +447,6 @@ void dispatch_scaled_upper_triang_masked_softmax_backward( int warps_per_block = (threads_per_block / warp_size); int batches_per_block = warps_per_block * batches_per_warp; - TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); int blocks_per_seq = attn_batches / batches_per_block; dim3 blocks(seq_len, blocks_per_seq, 1); dim3 threads(warp_size, warps_per_block, 1); diff --git a/megatron/fused_kernels/tests/test_fused_kernels.py b/megatron/fused_kernels/tests/test_fused_kernels.py new file mode 100644 index 000000000..b85618da8 --- /dev/null +++ b/megatron/fused_kernels/tests/test_fused_kernels.py @@ -0,0 +1,296 @@ +import math + +import torch +from torch.nn import LayerNorm + +from megatron.model.fused_softmax import FusedScaleMaskSoftmax +from megatron.model.gpt2_model import gpt2_attention_mask_func + + +def test_load_fused_kernels(): + try: + import scaled_masked_softmax_cuda + import scaled_upper_triang_masked_softmax_cuda + import torch + + print("[Success] load_fused_kernels") + except ImportError as e: + print("[Fail] load_fused_kernels") + raise e + + +def test_fused_softmax(): + bert = BertModel.from_pretrained("bert-base-cased").cuda().half() + tokenizer = BertTokenizer.from_pretrained("bert-base-cased") + test_text = ( + "Hello. How are you? I am fine thank you and you? yes Good. " + "hi hi hi hi hi hi hi hi hi hi hi hi hi" # 32 + ) + + tokens = tokenizer( + [test_text] * 4, + return_tensors="pt", + ) + + embedding_output = bert.embeddings( + input_ids=tokens["input_ids"].cuda(), + position_ids=None, + token_type_ids=tokens["token_type_ids"].cuda(), + inputs_embeds=None, + past_key_values_length=0, + ) + + # (bsz, 1, 1, seq_len) + mask = bert.get_extended_attention_mask( + attention_mask=tokens["attention_mask"].cuda(), + input_shape=tokens["input_ids"].shape, + device=bert.device, + ) + # (bsz, 1, seq_len, seq_len) + mask = mask.repeat(1, 1, mask.size()[-1], 1) + + attention = bert.encoder.layer[0].attention.self + key_layer = attention.transpose_for_scores(attention.key(embedding_output)) + query_layer = attention.transpose_for_scores(attention.query(embedding_output)) + + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores /= math.sqrt(key_layer.size()[-1]) + + fused_softmax = ( + FusedScaleMaskSoftmax( + input_in_fp16=True, + input_in_bf16=False, + mask_func=attention_mask_func, + scale=None, + softmax_in_fp32=False, + attn_mask_type=AttnMaskType.padding, + scaled_masked_softmax_fusion=True, + ) + .cuda() + .half() + ) + + fused_softmax_output = fused_softmax( + attention_scores, + (mask != 0), + ) + + torch_softmax = ( + FusedScaleMaskSoftmax( + input_in_fp16=True, + input_in_bf16=False, + mask_func=attention_mask_func, + scale=None, + softmax_in_fp32=False, + attn_mask_type=AttnMaskType.padding, + scaled_masked_softmax_fusion=False, + ) + .cuda() + .half() + ) + + torch_softmax_output = torch_softmax( + attention_scores, + (mask != 0), + ) + + test_result = (fused_softmax_output - torch_softmax_output).abs() + + while test_result.dim() != 1: + test_result = test_result.mean(dim=-1) + + diff = test_result.mean(dim=-1) + + if diff <= 1e-3: + print( + f"\n[Success] test_fused_softmax" + f"\n > mean_difference={diff}" + f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}" + f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}" + ) + else: + print( + f"\n[Fail] test_fused_softmax" + f"\n > mean_difference={diff}, " + f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}, " + f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}" + ) + + +def test_fused_upper_triangle_mask_softmax(): + gpt = GPT2Model.from_pretrained("gpt2").cuda().half() + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + test_text = ( + "Hello. How are you? I am fine thank you and you? yes Good. " + "hi hi hi hi hi hi hi" # 24 + ) + + tokens = tokenizer( + [test_text] * 4, + return_tensors="pt", + ) + + attention_mask = tokens["attention_mask"].cuda() + attention_mask = attention_mask.view(attention_mask.size(0), -1) + attention_mask = attention_mask[:, None, None, :] + attention_mask = (1.0 - attention_mask) * -10000.0 + attention_mask = attention_mask.repeat(1, 1, attention_mask.size()[-1], 1) + attn = gpt.h[0] + + hidden_states = gpt.wte(tokens["input_ids"].cuda()) + q, k, v = attn.attn.c_attn(hidden_states).split(768, dim=-1) + q = attn.attn._split_heads(q, attn.attn.num_heads, attn.attn.head_dim) + k = attn.attn._split_heads(k, attn.attn.num_heads, attn.attn.head_dim) + attn_weights = torch.matmul(q, k.transpose(-1, -2)) + + sq, sk = q.size(-2), k.size(-2) + causal_mask = attn.attn.bias[:, :, sk - sq : sk, :sk].bool() + total_mask = ~(causal_mask & (attention_mask == 0)) + """ + tensor([[[[False, True, True, ..., True, True, True], + [False, False, True, ..., True, True, True], + [False, False, False, ..., True, True, True], + ..., + [False, False, False, ..., False, True, True], + [False, False, False, ..., False, False, True], + [False, False, False, ..., False, False, False]]] + """ + + fused_softmax = ( + FusedScaleMaskSoftmax( + input_in_fp16=True, + input_in_bf16=False, + mask_func=attention_mask_func, + scale=None, + softmax_in_fp32=False, + attn_mask_type=AttnMaskType.causal, + scaled_masked_softmax_fusion=True, + ) + .cuda() + .half() + ) + + fused_softmax_output = fused_softmax( + attn_weights, + total_mask, + ) + + torch_softmax = ( + FusedScaleMaskSoftmax( + input_in_fp16=True, + input_in_bf16=False, + mask_func=attention_mask_func, + scale=None, + softmax_in_fp32=False, + attn_mask_type=AttnMaskType.causal, + scaled_masked_softmax_fusion=False, + ) + .cuda() + .half() + ) + + torch_softmax_output = torch_softmax( + attn_weights, + total_mask, + ) + + test_result = (fused_softmax_output - torch_softmax_output).abs() + + while test_result.dim() != 1: + test_result = test_result.mean(dim=-1) + + diff = test_result.mean(dim=-1) + + if diff <= 1e-3: + print( + f"\n[Success] test_fused_upper_triangle_mask_softmax" + f"\n > mean_difference={diff}" + f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}" + f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}" + ) + else: + print( + f"\n[Fail] test_fused_upper_triangle_mask_softmax" + f"\n > mean_difference={diff}, " + f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}, " + f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}" + ) + + +def test_layer_norm(): + bert = BertModel.from_pretrained("bert-base-cased").cuda().half() + tokenizer = BertTokenizer.from_pretrained("bert-base-cased") + test_text = ( + "Hello. How are you? I am fine thank you and you? yes Good. " + "hi hi hi hi hi hi hi hi hi hi hi hi hi" # 32 + ) + + tokens = tokenizer( + [test_text] * 4, + return_tensors="pt", + ) + + # [bsz, seq_len, d_model] + embedding_output = ( + bert.embeddings( + input_ids=tokens["input_ids"].cuda(), + position_ids=None, + token_type_ids=tokens["token_type_ids"].cuda(), + inputs_embeds=None, + past_key_values_length=0, + ) + .cuda() + .half() + ) + + fused_layernorm_layer = ( + MixedFusedLayerNorm(normalized_shape=embedding_output.size(-1)).cuda().half() + ) + + torch_layernorm_layer = ( + LayerNorm(normalized_shape=embedding_output.size(-1)).cuda().half() + ) + + fused_output = fused_layernorm_layer(embedding_output) + torch_output = torch_layernorm_layer(embedding_output) + test_result = (fused_output - torch_output).abs() + + while test_result.dim() != 1: + test_result = test_result.mean(dim=-1) + + diff = test_result.mean(dim=-1) + + if diff <= 1e-3: + print( + f"\n[Success] test_layer_norm" + f"\n > mean_difference={diff}" + f"\n > fused_values={fused_output[-1][-1][:5].tolist()}" + f"\n > torch_values={torch_output[-1][-1][:5].tolist()}" + ) + else: + print( + f"\n[Fail] test_layer_norm" + f"\n > mean_difference={diff}, " + f"\n > fused_values={fused_output[-1][-1][:5].tolist()}, " + f"\n > torch_values={torch_output[-1][-1][:5].tolist()}" + ) + + +if __name__ == "__main__": + try: + from transformers import BertTokenizer, GPT2Tokenizer + from transformers.models.bert.modeling_bert import BertModel + from transformers.models.gpt2.modeling_gpt2 import GPT2Model + import transformers + + transformers.logging.set_verbosity( + transformers.logging.FATAL, + ) + + except: + print("\n[Fail] Please install `transformers` package to test fused kernels\n") + exit(-1) + + test_load_fused_kernels() + test_fused_softmax() + test_fused_upper_triangle_mask_softmax() diff --git a/megatron/model/fused_softmax.py b/megatron/model/fused_softmax.py index ea121f27d..04175ff27 100644 --- a/megatron/model/fused_softmax.py +++ b/megatron/model/fused_softmax.py @@ -14,7 +14,8 @@ # limitations under the License. import torch - +import torch.nn as nn +import enum class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): """ @@ -75,28 +76,34 @@ def backward(ctx, output_grads): scale_t[0]) return input_grads, None, None +class SoftmaxFusionTypes(enum.Enum): + upper_triang = 1 # causal mask + general = 2 # general mask + none = 3 # no fusion -class FusedScaleMaskSoftmax(torch.nn.Module): +class FusedScaleMaskSoftmax(nn.Module): """ fused operation: scaling + mask + softmax Arguments: input_in_fp16: flag to indicate if input in fp16 data format. - upper_triang_mask: if true, apply upper triangular masking. - (used in gpt family networks) + input_in_bf16: flag to indicate if input in bf16 data format. + fusion_type: type of fusion to perform, should be either upper_triang, general or none. None will perform a regular torch softmax. mask_func: mask function to be applied. softmax_in_fp32: if true, softmax in performed at fp32 precision. scale: scaling factor used in input tensor scaling. """ - def __init__(self, input_in_fp16, input_in_bf16, upper_triang_mask_fusion, - general_mask_fusion, mask_func, softmax_in_fp32, scale): - super(FusedScaleMaskSoftmax, self).__init__() + def __init__(self, input_in_fp16, input_in_bf16, fusion_type, mask_func, softmax_in_fp32, scale): + super().__init__() self.input_in_fp16 = input_in_fp16 self.input_in_bf16 = input_in_bf16 self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 - self.upper_triang_mask_fusion = upper_triang_mask_fusion - self.general_mask_fusion = general_mask_fusion + + assert fusion_type in [SoftmaxFusionTypes.upper_triang, SoftmaxFusionTypes.general, SoftmaxFusionTypes.none], f"Invalid fusion type {fusion_type}" + self.upper_triang_mask_fusion = fusion_type == SoftmaxFusionTypes.upper_triang + self.general_mask_fusion = fusion_type == SoftmaxFusionTypes.general + self.fusion = fusion_type != SoftmaxFusionTypes.none self.mask_func = mask_func self.softmax_in_fp32 = softmax_in_fp32 self.scale = scale @@ -107,38 +114,65 @@ def __init__(self, input_in_fp16, input_in_bf16, upper_triang_mask_fusion, def forward(self, input, mask): # [b, np, sq, sk] assert input.dim() == 4 - data_size = input.size() - query_seq_len = data_size[-2] - key_seq_len = data_size[-1] - attn_batch_size = data_size[0] * data_size[1] - - # constraints on various tensor dimensions to enable warp based - # optimization and upper triangular optimization (for causal mask) - - custom_kernel_constraint = 16 < key_seq_len <= 2048 and query_seq_len % 4 == 0 and attn_batch_size % 4 == 0 - - # invoke custom kernel - if self.input_in_float16 and data_size[-1] <= 2048 and mask is not None and custom_kernel_constraint and (self.upper_triang_mask_fusion or self.general_mask_fusion) and query_seq_len == key_seq_len: - scale = self.scale if self.scale is not None else 1.0 - if self.upper_triang_mask_fusion: - input = input.view(-1, query_seq_len, key_seq_len) - probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale) - probs = probs.view(*data_size) - else: - probs = ScaledMaskedSoftmax.apply(input, mask, scale) + if self.is_kernel_available(mask, *input.size()): + return self.forward_fused_softmax(input, mask) + else: + return self.forward_torch_softmax(input, mask) + + def is_kernel_available(self, mask, b, np, sq, sk): + attn_batches = b * np + + if ( + self.fusion # user wants to fuse + and self.input_in_float16 # input must be fp16 + and mask is not None # mask tensor must not be None + and 16 < sq <= 2048 # sq must be 16 ~ 2048 + and sk % 4 == 0 # sk must be divisor of 4 + and attn_batches % 4 == 0 # np * b must be divisor of 4 + ): + if 0 <= sk <= 2048: + batch_per_block = self.get_batch_per_block(sq, sk, b, np) + + if self.upper_triang_mask_fusion: + if attn_batches % batch_per_block == 0: + return True + else: + if sq % batch_per_block == 0: + return True + return False + + def forward_fused_softmax(self, input, mask): + b, np, sq, sk = input.size() + scale = self.scale if self.scale is not None else 1.0 + if self.upper_triang_mask_fusion: + assert sq == sk, "causal mask is only for self attention" + + # input is 3D tensor (attn_batches, sq, sk) + input = input.view(-1, sq, sk) + probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale) + return probs.view(b, np, sq, sk) else: - if self.input_in_float16 and self.softmax_in_fp32: - input = input.float() + # input is 4D tensor (b, np, sq, sk) + return ScaledMaskedSoftmax.apply(input, mask, scale) - if self.scale is not None: - input = input * self.scale - mask_output = self.mask_func(input, mask) - probs = torch.nn.Softmax(dim=-1)(mask_output) + def forward_torch_softmax(self, input, mask): + if self.input_in_float16 and self.softmax_in_fp32: + input = input.float() - if self.input_in_float16 and self.softmax_in_fp32: - if self.input_in_fp16: - probs = probs.half() - else: - probs = probs.bfloat16() + if self.scale is not None: + input = input * self.scale + mask_output = self.mask_func(input, mask) if mask is not None else input + probs = torch.nn.Softmax(dim=-1)(mask_output) + if self.input_in_float16 and self.softmax_in_fp32: + if self.input_in_fp16: + probs = probs.half() + else: + probs = probs.bfloat16() + return probs + + @staticmethod + def get_batch_per_block(b, np, sq, sk): + import scaled_masked_softmax_cuda + return scaled_masked_softmax_cuda.get_batch_per_block(sq, sk, b, np) \ No newline at end of file diff --git a/megatron/model/gmlp.py b/megatron/model/gmlp.py index 8a6dd1f68..c48726459 100644 --- a/megatron/model/gmlp.py +++ b/megatron/model/gmlp.py @@ -5,8 +5,9 @@ from megatron.model.fused_softmax import FusedScaleMaskSoftmax from megatron.model.activations import get_activation from megatron.model.norms import get_norm -from megatron import mpu +from megatron.model.utils import get_fusion_type +from megatron import mpu class TinyAttention(nn.Module): def __init__(self, neox_args, d_attn, d_ff, mask_fn): @@ -16,11 +17,12 @@ def __init__(self, neox_args, d_attn, d_ff, mask_fn): self.proj_ffn = nn.Linear(d_attn, d_ff) self.softmax = FusedScaleMaskSoftmax( input_in_fp16=neox_args.precision == "fp16", - upper_triang_mask_fusion=neox_args.scaled_upper_triang_masked_softmax_fusion, - general_mask_fusion=neox_args.scaled_masked_softmax_fusion, + input_in_bf16=neox_args.precision == "bfloat16", + fusion_type=get_fusion_type(neox_args), mask_func=mask_fn, softmax_in_fp32=neox_args.attention_softmax_in_fp32, - scale=None) + scale=None + ) def forward(self, x, attention_mask): q, k, v = torch.chunk(self.proj_qkv(x), 3, dim=-1) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index c52bbe2e6..03a32d994 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -27,7 +27,7 @@ from megatron import mpu from megatron.model.fused_softmax import FusedScaleMaskSoftmax from megatron.model.activations import get_activation -from megatron.model.utils import exists +from megatron.model.utils import exists, get_fusion_type from megatron.model.positional_embeddings import ( RotaryEmbedding, apply_rotary_pos_emb, @@ -273,8 +273,7 @@ def __init__( self.scale_mask_softmax = FusedScaleMaskSoftmax( input_in_fp16=self.fp16, input_in_bf16=self.bf16, - upper_triang_mask_fusion=neox_args.scaled_upper_triang_masked_softmax_fusion, - general_mask_fusion=neox_args.scaled_masked_softmax_fusion, + fusion_type=get_fusion_type(neox_args), mask_func=self.attention_mask_func, softmax_in_fp32=self.attention_softmax_in_fp32, scale=coeff, diff --git a/megatron/model/utils.py b/megatron/model/utils.py index ccaf808f5..06c25de86 100644 --- a/megatron/model/utils.py +++ b/megatron/model/utils.py @@ -20,6 +20,7 @@ import torch from megatron.model.norms import LayerNorm, RMSNorm, ScaleNorm +from megatron.model.fused_softmax import SoftmaxFusionTypes from types import GeneratorType def get_params_for_weight_decay_optimization(module, neox_args): @@ -280,3 +281,12 @@ def configure_sparse_attention(neox_args, attention_type, num_attention_heads, m attn_mask_mode="add", mpu=mpu, ) + + +def get_fusion_type(neox_args): + fusion_type = SoftmaxFusionTypes.none + if neox_args.scaled_upper_triang_masked_softmax_fusion: + fusion_type = SoftmaxFusionTypes.upper_triang + elif neox_args.scaled_masked_softmax_fusion: + fusion_type = SoftmaxFusionTypes.general + return fusion_type