Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DeepSpeed MoE #1310

Merged
merged 22 commits into from
Aug 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 135 additions & 72 deletions csrc/adam/cpu_adam.cpp

Large diffs are not rendered by default.

21 changes: 21 additions & 0 deletions csrc/adam/custom_cuda_kernel.cu
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,24 @@ void launch_param_update(const float* input, __half* output, int size, cudaStrea

param_update_kernel<<<grid_dim, block_dim, 0, stream>>>(input, output, size);
}

__global__ void param_update_kernel_half(const float* input, __half* output, int size)
{
int id = blockIdx.x * blockDim.x + threadIdx.x;
__half2* output_cast = reinterpret_cast<__half2*>(output);
if (id < size) {
float input_f = input[id];
__half2* input_h = reinterpret_cast<__half2*>(&input_f);
output_cast[id] = *input_h;
}
}

void launch_param_update_half(const float* input, __half* output, int size, cudaStream_t stream)
{
int threads = 1024;
size /= 2;
dim3 grid_dim((size - 1) / threads + 1);
dim3 block_dim(threads);

param_update_kernel_half<<<grid_dim, block_dim, 0, stream>>>(input, output, size);
}
29 changes: 26 additions & 3 deletions csrc/includes/cpu_adam.h
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@
#define SIMD_SQRT(x) _mm512_sqrt_ps(x)
#define SIMD_DIV(x, y) _mm512_div_ps(x, y)
#define SIMD_WIDTH 16

#define SIMD_LOAD2(x, h) \
((h) ? _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)x)) : _mm512_loadu_ps(x))
#define SIMD_STORE2(x, d, h) \
((h) ? _mm256_store_ps(x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) \
: _mm512_storeu_ps(x, d))

#define INTV __m256i
#else
#if defined(__AVX256__)
#define SIMD_STORE(a, d) _mm256_storeu_ps(a, d)
Expand All @@ -44,6 +52,15 @@
#define SIMD_SQRT(x) _mm256_sqrt_ps(x)
#define SIMD_DIV(x, y) _mm256_div_ps(x, y)
#define SIMD_WIDTH 8
#define SIMD_LOAD2(x, h) \
((h) ? _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*)x)) : _mm256_loadu_ps(x))

#define SIMD_STORE2(x, d, h) \
((h) ? _mm_store_ps(x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) \
: _mm256_storeu_ps(x, d))

#define INTV __m128i

#endif
#endif

Expand Down Expand Up @@ -82,19 +99,25 @@ class Adam_Optimizer {
float* _exp_avg,
float* _exp_avg_sq,
size_t param_size,
__half* dev_param = nullptr);
__half* dev_param = nullptr,
bool half_precision = false);

void Step_4(float* _params,
float* grads,
float* _exp_avg,
float* _exp_avg_sa,
size_t param_size,
__half* dev_param = nullptr);
__half* dev_param = nullptr,
bool half_precision = false);

void Step_8(float* _params,
float* grads,
float* _exp_avg,
float* _exp_avg_sq,
size_t _param_size,
__half* dev_params = nullptr);
__half* dev_params = nullptr,
bool half_precision = false);

inline void SynchronizeStreams()
{
for (int i = 0; i < 2; i++) cudaStreamSynchronize(_streams[i]);
Expand Down
1 change: 1 addition & 0 deletions csrc/includes/custom_cuda_layers.h
Original file line number Diff line number Diff line change
Expand Up @@ -281,3 +281,4 @@ void launch_fuse_transpose_bias_kernel(const T* inp,
cudaStream_t stream);

void launch_param_update(const float* input, __half* output, int size, cudaStream_t stream);
void launch_param_update_half(const float* input, __half* output, int size, cudaStream_t stream);
Empty file added deepspeed/moe/__init__.py
Empty file.
33 changes: 33 additions & 0 deletions deepspeed/moe/experts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
'''
Copyright 2020 The Microsoft DeepSpeed Team
'''

import torch
import copy


class Experts(torch.nn.Module):
def __init__(self, expert, num_local_experts=1):
super(Experts, self).__init__()

self.deepspeed_experts = torch.nn.ModuleList(
[copy.deepcopy(expert) for i in range(num_local_experts)])
self.num_local_experts = num_local_experts

# TODO: revisit allreduce for moe.gate...
for expert in self.deepspeed_experts:
# TODO: Create param groups to handle expert + data case (e.g. param.group = moe_group)
for name, param in expert.named_parameters():
param.allreduce = False

def forward(self, inputs):
chunks = inputs.chunk(self.num_local_experts, dim=1)
expert_outputs = []
for chunk, expert in zip(chunks, self.deepspeed_experts):
out = expert(chunk)
if type(out) is tuple:
out = out[0] # Ignore the bias term for now
expert_outputs += [out]

expert_output = torch.cat(expert_outputs, dim=1)
return expert_output
96 changes: 96 additions & 0 deletions deepspeed/moe/layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
'''
Copyright 2020 The Microsoft DeepSpeed Team
'''

import torch.nn.init as init
import torch
import torch.distributed as dist

from deepspeed.utils import logger, log_dist

import deepspeed.utils.groups as groups
from .sharded_moe import MOELayer, TopKGate
from .experts import Experts
import copy
import typing


class MoE(torch.nn.Module):
'''
DeepSpeed MOE API: This defines a simple API that can be used from client-side code.
'''
def __init__(self,
hidden_size,
expert,
num_experts=1,
k=1,
output_dropout_prob=0.0,
capacity_factor=1.,
eval_capacity_factor=1.,
min_capacity=4,
noisy_gate_policy: typing.Optional[str] = None):
"""Initialize an MoE layer.
TODO: add details about input/output dimension assumptions

Arguments:
hidden_size (int): the hidden dimension of the model.

expert (torch.nn.Module): the torch module that defines the expert (e.g., MLP, torch.linear).

num_experts (int, optional): default=1, the total number of experts per layer.

k (int, optional): default=1, top-k gating value, only supports k=1 or k=2.

output_dropout_prob (float, optional): default=0.0, output dropout probability.

capacity_factor (float, optional): default=1.0, the capacity of the expert at training time.

eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time.

min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor.

noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample' or 'None'.
"""

super(MoE, self).__init__()

assert groups.is_initialized(), \
'Please call deepspeed.utils.groups.initialize() before using MoE layers'
assert noisy_gate_policy is None or noisy_gate_policy in ['None', 'Jitter', 'RSample'], \
'Unsupported noisy_gate_policy: ' + noisy_gate_policy

num_local_experts = num_experts // groups.get_expert_parallel_world_size()

log_dist(
f'num_experts: {num_experts} | num_local_experts: {num_local_experts} | expert_parallel_size: {groups.get_expert_parallel_world_size()}',
[0])

self.num_experts = num_experts
experts = Experts(expert, num_local_experts)
self.deepspeed_moe = MOELayer(TopKGate(hidden_size,
num_experts,
k,
capacity_factor,
eval_capacity_factor,
min_capacity,
noisy_gate_policy),
experts,
num_local_experts,
group=groups.get_expert_parallel_group())

self.dropout = torch.nn.Dropout(output_dropout_prob)

def forward(self, hidden_states, used_token=None):
"""
Arguments:
hidden_states (Tensor): input to the layer
used_token (Tensor, optional): default: None, mask only used tokens

Returns:
output (Tensor): output of the model
l_aux (Tensor): gate loss value
exp_counts (int): expert count
"""
output = self.deepspeed_moe(hidden_states, used_token)
output = self.dropout(output)
return output, self.deepspeed_moe.l_aux, self.deepspeed_moe.exp_counts
Loading