Skip to content

Commit

Permalink
Fixing the transformer APIs to return tuple as the output (if needed) (
Browse files Browse the repository at this point in the history
  • Loading branch information
RezaYazdaniAminabadi authored Oct 29, 2021
1 parent a4fff53 commit ee6a92c
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 65 deletions.
13 changes: 6 additions & 7 deletions csrc/transformer/inference/csrc/pt_binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ std::array<int, 3> gemm_algos = std::array<int, 3>({99, 99, 99});

template <typename T>
at::Tensor ds_softmax(at::Tensor& attn_scores,
T* attn_mask_ptr,
at::Tensor& attn_mask,
bool triangular,
bool recompute,
bool local_attention,
Expand All @@ -22,9 +22,8 @@ at::Tensor ds_softmax(at::Tensor& attn_scores,
int seq_len = attn_scores_c.size(2);
int soft_len = attn_scores_c.size(3);
int heads = attn_scores_c.size(1);

launch_attn_softmax_v2((T*)attn_scores_c.data_ptr(),
attn_mask_ptr,
(attn_mask.sizes().size() > 1 ? (T*)attn_mask.data_ptr() : nullptr),
triangular,
recompute,
local_attention,
Expand All @@ -42,7 +41,7 @@ at::Tensor ds_softmax(at::Tensor& attn_scores,
template <typename T>
void attention_unfused(at::Tensor& prev_key_cont,
at::Tensor& query_cont,
T* attn_mask_ptr,
at::Tensor& attn_mask,
at::Tensor& prev_value_cont,
at::Tensor& output,
int& bsz,
Expand Down Expand Up @@ -81,8 +80,8 @@ void attention_unfused(at::Tensor& prev_key_cont,
seq_len * soft_len,
bsz * heads,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
attn_score = ds_softmax<T>(
attn_score, attn_mask_ptr, triangular, recompute, local_attention, window_size);
attn_score =
ds_softmax<T>(attn_score, attn_mask, triangular, recompute, local_attention, window_size);
alpha = 1.0;
cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(),
k,
Expand Down Expand Up @@ -139,7 +138,7 @@ std::vector<at::Tensor> ds_softmax_context(at::Tensor& query,
at::empty({prev_value.size(0), heads, seq_len, prev_value.size(2) / heads}, options);
attention_unfused<T>(prev_key_cont,
query_cont,
(no_masking ? nullptr : (T*)attn_mask.data_ptr()),
attn_mask, //(no_masking ? nullptr : (T*)attn_mask.data_ptr()),
prev_value_cont,
output,
bsz,
Expand Down
43 changes: 21 additions & 22 deletions csrc/transformer/inference/csrc/softmax.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#define ATTN_THREADS 1024
#define MAX_REG_SIZE 8

#define minus_infinity (-1 * std::numeric_limits<float>::infinity())
#define minus_infinity -10000.0

void CheckCudaErrorAux(const char* file, unsigned line)
{
Expand Down Expand Up @@ -94,10 +94,10 @@ __global__ void attn_softmax_v2(__half* vals,
(data_id + 3) > window_stride)
? __half2float(vals[data_id + 3])
: minus_infinity;
if (mask && !triangular && recompute) {
if (mask && recompute) {
low_data[i].x += __half2float(mask[data_id + mask_offset]);
low_data[i].y += __half2float(mask[data_id + mask_offset + 1]);
high_data[i].y += __half2float(mask[data_id + mask_offset + 2]);
high_data[i].x += __half2float(mask[data_id + mask_offset + 2]);
high_data[i].y += __half2float(mask[data_id + mask_offset + 3]);
}
} else {
Expand All @@ -114,15 +114,15 @@ __global__ void attn_softmax_v2(__half* vals,
? __half2float(vals[data_id + 2])
: minus_infinity;
high_data[i].y = minus_infinity;
if (mask && !triangular && recompute) {
if (mask && recompute) {
low_data[i].x += __half2float(mask[data_id + mask_offset]);
if ((data_id + 1) < sequence_length)
low_data[i].y += __half2float(mask[data_id + mask_offset + 1]);
if ((data_id + 2) < sequence_length)
high_data[i].x += __half2float(mask[data_id + mask_offset + 2]);
// high_data[i].y += __half2float(mask[data_id + mask_offset + 3]);
}
}
// if(lane == 0) printf("%f , %d, %d \n", low_data[i].x, data_id, seq_id);
max_val = (low_data[i].x > max_val ? low_data[i].x : max_val);
max_val = (low_data[i].y > max_val ? low_data[i].y : max_val);
max_val = (high_data[i].x > max_val ? high_data[i].x : max_val);
Expand Down Expand Up @@ -155,7 +155,6 @@ __global__ void attn_softmax_v2(__half* vals,

max_val = g.shfl(max_val, threadIdx.x / WARP_SIZE);
}

float sum = 0;
for (int i = 0; i < iterations; i++) {
low_data[i].x = __expf(low_data[i].x - max_val);
Expand All @@ -181,7 +180,6 @@ __global__ void attn_softmax_v2(__half* vals,
sum = g.shfl(sum, threadIdx.x / WARP_SIZE);
}
sum += 1e-6;

for (int i = 0; i < iterations; i++) {
int data_id = i * (reduceWidth << 2) + (seq_lane << 2);

Expand Down Expand Up @@ -265,7 +263,7 @@ __global__ void attn_softmax_v2(float* vals,
(data_id + 3) > window_stride)
? vals[data_id + 3]
: minus_infinity;
if (attn_mask && !triangular && recompute) {
if (attn_mask && recompute) {
data[i].x += attn_mask[data_id + mask_offset];
data[i].y += attn_mask[data_id + mask_offset + 1];
data[i].z += attn_mask[data_id + mask_offset + 2];
Expand All @@ -282,7 +280,7 @@ __global__ void attn_softmax_v2(float* vals,
? (vals[data_id + 2])
: minus_infinity;
data[i].w = minus_infinity;
if (attn_mask && !triangular && recompute) {
if (attn_mask && recompute) {
data[i].x += attn_mask[data_id + mask_offset];
if ((data_id + 1) < sequence_length)
data[i].y += attn_mask[data_id + mask_offset + 1];
Expand Down Expand Up @@ -390,19 +388,20 @@ void launch_attn_softmax_v2(T* vals,
const int iterations = (sequence_length - 1) / (reduce_width << 2) + 1;

if (sequence_length <= 32768)
attn_softmax_v2<<<grid_dim, block_dim, 0, stream>>>(vals,
mask,
triangular,
recompute,
local_attention,
window_size,
total_count,
heads,
sequence_length,
num_seq,
scale,
iterations,
reduce_width);
attn_softmax_v2<<<grid_dim, block_dim, 0, stream>>>(
vals,
mask,
triangular,
recompute,
local_attention,
window_size,
total_count,
(triangular ? (heads * batch_size) : heads),
sequence_length,
num_seq,
scale,
iterations,
reduce_width);
else
throw std::runtime_error("Unsupport Seq_Length!");
}
Expand Down
2 changes: 2 additions & 0 deletions deepspeed/inference/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(self,
checkpoint=None,
dtype=None,
injection_dict=None,
return_tuple=True,
replace_method='auto',
quantization_setting=None):

Expand Down Expand Up @@ -141,6 +142,7 @@ def _apply_injection_policy(self, client_module=None, injection_policy=None):
config=self.config,
fp16=(self.dtype == torch.half),
training=False,
return_tuple=return_tuple,
quantize=(self.dtype == torch.int8),
quantize_settings=(self.quantization_scales,
self.quantize_merge_count,
Expand Down
11 changes: 6 additions & 5 deletions deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ def replace_transformer_layer(orig_layer_impl,
stochastic_mode=True,
training=True,
quantize=False,
encoder_decoder=False,
quantize_settings=None):
quantize_settings=None,
return_tuple=False):
""" Replace bert-style transformer layers with DeepSpeed's transformer layer
Arguments:
orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for,
Expand All @@ -125,7 +125,8 @@ def replace_transformer_layer(orig_layer_impl,
training (bool): specifying whether kernel-injection is done for training/inference (set to false for inference-mode injection)
quantize_settings (tuple): this setting shows how we can quantize a model for running it through the inference kernels.
It includes (quantization_scales, merge_count, mlp_extra_grouping, quantize_groups).
encoder_decoder (bool): this flag needs to be set for huggingface Bert models.
return_tuple (bool): if set, transformer layer returns a tuple as the output.
Note: this flag needs to be set for huggingface models.
Returns:
Updated nn.module with replaced transformer layers
Expand Down Expand Up @@ -181,7 +182,7 @@ def replace_with_policy(child, policy_cls, inference=False, preln=True, layer_id
pre_layer_norm=preln,
mp_size=mp_size,
q_int8=quantize,
encoder_decoder=(True if policy_cls is HFBertLayerPolicy else False),
return_tuple=(return_tuple or (policy_cls is HFBertLayerPolicy)),
triangular_masking=(policy_cls is not HFBertLayerPolicy),
local_attention=((config.attention_layers[layer_id] == "local")
if hasattr(config,
Expand Down Expand Up @@ -276,7 +277,7 @@ def transpose(data):
seed=seed,
fp16=fp16,
pre_layer_norm=(False if policy_cls is HFBertLayerPolicy else preln),
huggingface=encoder_decoder,
return_tuple=return_tuple,
local_rank=local_rank,
stochastic_mode=stochastic_mode,
normalize_invertible=True,
Expand Down
41 changes: 14 additions & 27 deletions deepspeed/ops/transformer/inference/transformer_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,8 @@ class DeepSpeedInferenceConfig(TransformerConfig):
a high accuracy level. On the other hand, for the downstream tasks, such as fine-tuning, we recommend
to turn it off in order to be able to reproduce the same result through the regular kernel execution.
encoder_decoder: DeepSpeed-Inference currently support the encoder-only architecture! We will add
the required features to support both soon!
scale_attention: If true, both q and k are scaled by 1/sqrt(attention_heads) before attention computation.
return_tuple: if True, returns the transformer output as a tuple, otherwise returns as a tensor
"""
def __init__(self,
hidden_size=-1,
Expand All @@ -75,11 +72,11 @@ def __init__(self,
q_int8=False,
pre_layer_norm=True,
stochastic_mode=False,
encoder_decoder=False,
scale_attention=True,
triangular_masking=True,
local_attention=False,
window_size=256):
window_size=256,
return_tuple=True):
super(DeepSpeedInferenceConfig,
self).__init__(
hidden_size,
Expand All @@ -93,12 +90,12 @@ def __init__(self,
self.epsilon = layer_norm_eps
self.mp_size = mp_size
self.q_int8 = q_int8
self.encoder_decoder = encoder_decoder
self.scale_attention = scale_attention
self.specialized_mode = None
self.triangular_masking = triangular_masking
self.local_attention = local_attention
self.window_size = window_size
self.return_tuple = return_tuple

@classmethod
def from_dict(cls, json_object):
Expand Down Expand Up @@ -160,10 +157,8 @@ def _transpose_for_context(x):
return x.view(*new_x_layer_shape)

def compute_attention(qkv_out, input_mask):
score_context_func = inference_cuda_module.softmax_context_fp32 if (not config.fp16 or not config.triangular_masking) else \
score_context_func = inference_cuda_module.softmax_context_fp32 if (not config.fp16) else \
inference_cuda_module.softmax_context_fp16
if not config.triangular_masking:
qkv_out = qkv_out.float()

if merge_count > 0 and config.q_int8:
split_dim = (qkv_out.dim() - 1)
Expand All @@ -187,7 +182,9 @@ def compute_attention(qkv_out, input_mask):
value_layer) = torch.split(qkv_out,
(qkv_out.shape[-1] // 3),
dim=(qkv_out.dim() - 1))

no_masking = input_mask is None
if no_masking:
input_mask = torch.empty(1)
head_size = (mixed_query.shape[-1] // num_attention_heads_per_partition)

unfused_mode = not config.specialized_mode or \
Expand All @@ -210,17 +207,12 @@ def compute_attention(qkv_out, input_mask):
True) / (norm_factor if config.scale_attention else 1.0)
value_layer1 = _transpose_for_scores(value_layer, False, True)

no_masking = input_mask is None
if no_masking:
input_mask = torch.empty(1)

if layer_past is None:
attn_key_value = score_context_func(
mixed_query,
(key_layer1 if unfused_mode else key_layer),
torch.empty(1),
(input_mask
if config.triangular_masking or no_masking else input_mask.float()),
(input_mask),
(value_layer1 if unfused_mode else value_layer),
torch.empty(1),
num_attention_heads_per_partition,
Expand All @@ -235,8 +227,7 @@ def compute_attention(qkv_out, input_mask):
mixed_query,
(key_layer1 if unfused_mode else past_key.type_as(key_layer)),
(key_layer1 if unfused_mode else key_layer),
(input_mask
if config.triangular_masking or no_masking else input_mask.float()),
(input_mask),
(value_layer1 if unfused_mode else past_value.type_as(value_layer)),
(value_layer1 if unfused_mode else value_layer),
num_attention_heads_per_partition,
Expand All @@ -246,16 +237,13 @@ def compute_attention(qkv_out, input_mask):
config.local_attention,
config.window_size,
no_masking)
#import pdb;pdb.set_trace()
if unfused_mode:
context_layer, _, _ = attn_key_value
else:
context_layer, key_layer, value_layer = attn_key_value

# Transpose Context
context_layer = _transpose_for_context(context_layer)
if (config.fp16 or config.q_int8) and not config.triangular_masking:
context_layer = context_layer.half()

return context_layer, key_layer, value_layer

Expand All @@ -270,6 +258,7 @@ def selfAttention_fp():
else:
qkv_func = inference_cuda_module.qkv_gemm_fp16 if config.fp16 else \
inference_cuda_module.qkv_gemm_fp32
print(input.shape)
qkv_out = qkv_func(input,
attn_qkvw,
(attn_qkvb if attn_qkvb is not None else norm_b),
Expand Down Expand Up @@ -552,14 +541,13 @@ def __init__(self,
self.config = config
self.config.layer_id = DeepSpeedTransformerInference.layer_id
DeepSpeedTransformerInference.layer_id += 1

self.attention = DeepSpeedSelfAttention(config,
self.attention = DeepSpeedSelfAttention(self.config,
mp_group,
quantize_scales,
quantize_groups,
merge_count,
qkv_merging)
self.mlp = DeepSpeedMLP(config,
self.mlp = DeepSpeedMLP(self.config,
mp_group,
quantize_scales,
quantize_groups,
Expand Down Expand Up @@ -599,7 +587,6 @@ def forward(self,
encoder_attention_mask=None,
use_cache=False,
output_attentions=False):

get_present = (get_present or get_key_value or use_cache)
input_mask = input_mask if attention_mask is None else attention_mask

Expand Down Expand Up @@ -644,7 +631,7 @@ def forward(self,
if get_present:
output = (output, presents)

if self.config.encoder_decoder:
if self.config.return_tuple:
return (output, )
else:
return output
8 changes: 4 additions & 4 deletions deepspeed/ops/transformer/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ class DeepSpeedTransformerConfig(TransformerConfig):
a high accuracy level. On the other hand, for the downstream tasks, such as fine-tuning, we recommend
to turn it off in order to be able to reproduce the same result through the regular kernel execution.
huggingface: Enable if using the HuggingFace interface style for sending out the forward results.
return_tuple: Enable if using the return_tuple interface style for sending out the forward results.
training: Enable for training rather than inference.
"""
Expand All @@ -109,7 +109,7 @@ def __init__(self,
adjust_init_range=True,
attn_dropout_checkpoint=False,
stochastic_mode=False,
huggingface=False,
return_tuple=False,
training=True):
super(DeepSpeedTransformerConfig,
self).__init__(
Expand All @@ -134,7 +134,7 @@ def __init__(self,
self.is_grad_enabled = True
self.attn_dropout_checkpoint = attn_dropout_checkpoint
self.stochastic_mode = stochastic_mode
self.huggingface = huggingface
self.return_tuple = return_tuple

@classmethod
def from_dict(cls, json_object):
Expand Down Expand Up @@ -316,7 +316,7 @@ def forward(ctx,
if inp_size[1] % 16 != 0:
output = torch.narrow(output, 1, 0, inp_size[1])

if config.huggingface:
if config.return_tuple:
return (output, ) # outputs -> (output) : outputs[0] = output
else:
return output
Expand Down

0 comments on commit ee6a92c

Please sign in to comment.