Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatic tensor parallelism v2 #2670

Merged
merged 22 commits into from
Jan 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions deepspeed/inference/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from ..moe.utils import has_moe_layers
from ..module_inject import LinearAllreduce, LinearLayer, Normalize, ReplaceWithTensorSlicing
from ..module_inject.policy import TransformerPolicy
from ..module_inject.auto_tp import AutoTP

DS_INFERENCE_ENABLED = False
from torch import nn
Expand Down Expand Up @@ -124,6 +125,15 @@ def __init__(self, model, config):
self._apply_injection_policy(config, client_module)
elif config.replace_method == 'auto':
self._apply_injection_policy(config)
else:
# Automatic Tensor Parallelism
parser_dict = AutoTP.tp_parser(model)
for client_module, injection_policy in parser_dict:
if isinstance(injection_policy, str):
config.injection_policy_tuple = (injection_policy, )
else:
config.injection_policy_tuple = injection_policy
self._apply_injection_policy(config, client_module)

device = torch.cuda.current_device()
self.module.to(device)
Expand Down
76 changes: 76 additions & 0 deletions deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Automatic Tensor Parallelism
import re

from torch import nn


class AutoTP():
def in_module_list(module, module_list):
for item in module_list:
if type(item).__name__ == type(module).__name__:
return True
return False

def get_module_list(model):
mlist = []
for child in model.children():
if isinstance(child, nn.ModuleList):
for module in child.children():
if not mlist:
mlist = [module]
elif not AutoTP.in_module_list(module, mlist):
mlist = mlist + [module]
else:
mlist = mlist + AutoTP.get_module_list(child)
return mlist

def supported(model):
unsupported = ['bloom', 'codegen', 'flaubert', 'xlm']
model = str(model)
key = re.search(r": (.*?)Model", model)
if key is None:
key = re.search(r": (.*?)Stack", model)
if key is None:
key = re.match(r"(.*?)Model", model)
if key.group(1).lower() in unsupported:
return False
return True

def get_layers(parent, module):
layer_list = []
for key, submodule in module._modules.items():
if isinstance(submodule, nn.Linear):
layer_list = layer_list + [parent + "." + key]
elif isinstance(submodule, nn.LayerNorm) or key == 'LayerNorm':
layer_list = layer_list + ["ln"]
else:
layer_list = layer_list + AutoTP.get_layers(key, submodule)
return layer_list

def tp_parser(model):
policy_list = []
module_list = []
layer_list = []
gem_list = []

assert AutoTP.supported(model), "Automatic policy not supported for model. Please provide policy."

module_list = AutoTP.get_module_list(model)
for module in module_list:
for key, submodule in module._modules.items():
if isinstance(submodule, nn.Linear):
layer_list = layer_list + ["." + key]
elif isinstance(submodule, nn.LayerNorm) or key == 'LayerNorm':
layer_list = layer_list + ["ln"]
else:
layer_list = layer_list + AutoTP.get_layers(key, submodule)
for i, layer in enumerate(layer_list):
if layer == 'ln':
if layer_list[i - 1] != 'ln':
gem_list = gem_list + [layer_list[i - 1]]
elif 'out_proj' in layer:
gem_list = gem_list + [layer]
if gem_list != []:
policy_list.append(tuple([type(module), gem_list]))
gem_list = []
return policy_list
79 changes: 78 additions & 1 deletion tests/unit/inference/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@
"token-classification",
"text-generation",
"text2text-generation",
"summarization",
"translation"
]
pytest.all_models = {
task: [m.modelId for m in _all_models if m.pipeline_tag == task]
Expand Down Expand Up @@ -139,8 +141,19 @@ def invalid_model_task_config(model_w_task, dtype, enable_cuda_graph):
@pytest.fixture
def query(model_w_task):
model, task = model_w_task
angle_bracket_mask_models = [
"roberta",
"camembert",
"esm",
"ibert",
"luke",
"mpnet",
"yoso",
"mpnet"
]

if task == "fill-mask":
if "roberta" in model:
if any(map(lambda x: x in model, angle_bracket_mask_models)):
return "Hello I'm a <mask> model."
else:
return "Hell I'm a [MASK] model."
Expand All @@ -157,6 +170,8 @@ def query(model_w_task):
return "DeepSpeed is the greatest"
elif task == "text2text-generation":
return "Is this review positive or negative? Review: this is the best cast iron skillet you will ever buy"
elif task == "translation" or task == "summarization":
return "Hello, my dog is cute"
else:
NotImplementedError(f'query for task "{task}" is not implemented')

Expand Down Expand Up @@ -199,6 +214,15 @@ def text2text_generation_assert(x, y):
for res in y)


def translation_assert(x, y):
return set(res["translation_text"] for res in x) == set(res["translation_text"]
for res in y)


def summarization_assert(x, y):
return set(res["summary_text"] for res in x) == set(res["summary_text"] for res in y)


@pytest.fixture
def assert_fn(model_w_task):
model, task = model_w_task
Expand All @@ -209,6 +233,8 @@ def assert_fn(model_w_task):
"token-classification": token_classification_assert,
"text-generation": text_generation_assert,
"text2text-generation": text2text_generation_assert,
"translation": translation_assert,
"summarization": summarization_assert
mrwyattii marked this conversation as resolved.
Show resolved Hide resolved
}
assert_fn = assert_fn_dict.get(task, None)
if assert_fn is None:
Expand Down Expand Up @@ -415,6 +441,57 @@ def test(
assert assert_fn(bs_output, ds_output)


@pytest.mark.seq_inference
@pytest.mark.parametrize(
"model_w_task",
[
("Helsinki-NLP/opus-mt-en-de",
"translation"),
],
ids=[
"marian",
],
)
@pytest.mark.parametrize("dtype", [torch.float16], ids=["fp16"])
@pytest.mark.parametrize("enable_cuda_graph", [False], ids=["noCG"])
class TestAutoTensorParallelism(DistributedTest):
world_size = [2]

def test(
self,
model_w_task,
query,
inf_kwargs,
assert_fn,
invalid_model_task_config,
dtype,
enable_cuda_graph,
):
if invalid_model_task_config:
pytest.skip(invalid_model_task_config)

model, task = model_w_task
local_rank = int(os.getenv("LOCAL_RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "2"))

# We have to load these large models on CPU with pipeline because not
# enough GPU memory
pipe = pipeline(task, model=model, device=-1, framework="pt")
bs_output = pipe(query, **inf_kwargs)

pipe.model = deepspeed.init_inference(pipe.model,
mp_size=world_size,
dtype=dtype,
replace_method="")
# Switch device to GPU so that input tensors are not on CPU
pipe.device = torch.device(f"cuda:{local_rank}")
ds_output = pipe(query, **inf_kwargs)

print(local_rank, "baseline", bs_output)
print(local_rank, "deepspeed", ds_output)
assert assert_fn(bs_output, ds_output)


@pytest.mark.nightly
@pytest.mark.parametrize(
"model_family, model_name",
Expand Down