Skip to content

Commit

Permalink
skip cpu support unimplemented error and update cpu inference workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
Yejing-Lai committed Jun 6, 2023
1 parent d755b9d commit 7ea1c8b
Show file tree
Hide file tree
Showing 20 changed files with 80 additions and 12 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/cpu-inference.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,5 +79,4 @@ jobs:
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
cd tests
TRANSFORMERS_CACHE=~/tmp/transformers_cache/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest -m 'inference' unit/inference/test_inference_config.py
TRANSFORMERS_CACHE=~/tmp/transformers_cache/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest -k TestDistAllReduce unit/comm/test_dist.py
TRANSFORMERS_CACHE=~/tmp/transformers_cache/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest -m 'seq_inference' -m 'inference_ops' -m 'inference' unit/
4 changes: 4 additions & 0 deletions accelerator/abstract_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ def is_bf16_supported(self):
def is_fp16_supported(self):
...

@abc.abstractmethod
def supported_dtypes(self):
...

# Misc
@abc.abstractmethod
def amp(self):
Expand Down
5 changes: 4 additions & 1 deletion accelerator/cpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,10 @@ def is_bf16_supported(self):
return True

def is_fp16_supported(self):
return True
return False

def supported_dtypes(self):
return [torch.float, torch.bfloat16]

# Tensor operations

Expand Down
3 changes: 3 additions & 0 deletions accelerator/cuda_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ def is_fp16_supported(self):
else:
return False

def supported_dtypes(self):
return [torch.float, torch.half]

# Misc
def amp(self):
if hasattr(torch.cuda, 'amp'):
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def op_enabled(op_name):
for op_name, builder in ALL_OPS.items():
op_compatible = builder.is_compatible()
compatible_ops[op_name] = op_compatible
compatible_ops["deepspeed_not_implemented"] = False

# If op is requested but not available, throw an error.
if op_enabled(op_name) and not op_compatible:
Expand Down
5 changes: 5 additions & 0 deletions tests/unit/checkpoint/test_latest_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,15 @@

import deepspeed

import pytest
from unit.common import DistributedTest
from unit.simple_model import *

from unit.checkpoint.common import checkpoint_correctness_verification
from deepspeed.ops.op_builder import FusedAdamBuilder

if not deepspeed.ops.__compatible_ops__[FusedAdamBuilder.NAME]:
pytest.skip("This op had not been implemented on this system.", allow_module_level=True)


class TestLatestCheckpoint(DistributedTest):
Expand Down
4 changes: 4 additions & 0 deletions tests/unit/comm/test_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
from deepspeed.accelerator import get_accelerator

import pytest
from deepspeed.ops.op_builder import FusedAdamBuilder

if not deepspeed.ops.__compatible_ops__[FusedAdamBuilder.NAME]:
pytest.skip("This op had not been implemented on this system.", allow_module_level=True)


class TestInit(DistributedTest):
Expand Down
4 changes: 4 additions & 0 deletions tests/unit/elasticity/test_elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
from deepspeed.git_version_info import version as ds_version
import os
from unit.simple_model import SimpleModel
from deepspeed.ops.op_builder import FusedAdamBuilder

if not deepspeed.ops.__compatible_ops__[FusedAdamBuilder.NAME]:
pytest.skip("This op had not been implemented on this system.", allow_module_level=True)


@pytest.fixture
Expand Down
8 changes: 6 additions & 2 deletions tests/unit/hybrid_engine/test_he_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,13 @@
import deepspeed
from deepspeed.ops.op_builder import OpBuilder
from unit.common import DistributedTest
from deepspeed.accelerator import get_accelerator
from deepspeed.ops.op_builder import UtilsBuilder

from transformers import (AutoConfig, AutoTokenizer, AutoModelForCausalLM)

if not deepspeed.ops.__compatible_ops__[UtilsBuilder.NAME]:
pytest.skip("This op had not been implemented on this system.", allow_module_level=True)
rocm_version = OpBuilder.installed_rocm_version()
if rocm_version != (0, 0):
pytest.skip("skip inference tests on rocm for now", allow_module_level=True)
Expand All @@ -28,7 +32,7 @@ def _generate(self, model, tokenizer, prompt):
tokens = tokenizer.batch_encode_plus(prompt, return_tensors="pt", padding=True)
for t in tokens:
if torch.is_tensor(tokens[t]):
tokens[t] = tokens[t].to(f'cuda:{local_rank}')
tokens[t] = tokens[t].to(f'{get_accelerator().device_name()}:{local_rank}')
output = model.generate(**tokens, do_sample=False, max_length=100)
outputs = tokenizer.batch_decode(output, skip_special_tokens=True)
return outputs
Expand All @@ -39,7 +43,7 @@ def get_model(self, model_name):
model_config.dropout = 0.0
model = AutoModelForCausalLM.from_pretrained(model_name, config=model_config)
model = model.half()
model = model.to(f'cuda:{local_rank}')
model = model.to(f'{get_accelerator().device_name()}:{local_rank}')
return model

def get_tokenizer(self, model_name):
Expand Down
8 changes: 6 additions & 2 deletions tests/unit/hybrid_engine/test_he_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,13 @@
import deepspeed
from deepspeed.ops.op_builder import OpBuilder
from unit.common import DistributedTest
from deepspeed.accelerator import get_accelerator
from deepspeed.ops.op_builder import UtilsBuilder

from transformers import (AutoConfig, AutoTokenizer, AutoModelForCausalLM)

if not deepspeed.ops.__compatible_ops__[UtilsBuilder.NAME]:
pytest.skip("This op had not been implemented on this system.", allow_module_level=True)
rocm_version = OpBuilder.installed_rocm_version()
if rocm_version != (0, 0):
pytest.skip("skip inference tests on rocm for now", allow_module_level=True)
Expand All @@ -28,7 +32,7 @@ def _generate(self, model, tokenizer, prompt):
tokens = tokenizer.batch_encode_plus(prompt, return_tensors="pt", padding=True)
for t in tokens:
if torch.is_tensor(tokens[t]):
tokens[t] = tokens[t].to(f'cuda:{local_rank}')
tokens[t] = tokens[t].to(f'{get_accelerator().device_name()}:{local_rank}')
#output = model.generate(**tokens, do_sample=False, max_length=100)
output = model.generate(tokens.input_ids, do_sample=False, max_length=100)
outputs = tokenizer.batch_decode(output, skip_special_tokens=True)
Expand All @@ -42,7 +46,7 @@ def get_model(self, model_name):
# Make the model smaller so we can run it on a single GPU in CI
_ = [model.model.layers.pop(-1) for _ in range(8)]
model = model.half()
model = model.to(f'cuda:{local_rank}')
model = model.to(f'{get_accelerator().device_name()}:{local_rank}')
return model

def get_tokenizer(self, model_name):
Expand Down
4 changes: 4 additions & 0 deletions tests/unit/inference/test_checkpoint_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
import deepspeed.comm as dist
from huggingface_hub import snapshot_download
from transformers.utils import is_offline_mode
from deepspeed.ops.op_builder import InferenceBuilder

if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
pytest.skip("This op had not been implemented on this system.", allow_module_level=True)


def check_dtype(model, expected_dtype):
Expand Down
4 changes: 4 additions & 0 deletions tests/unit/inference/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
from deepspeed.model_implementations import DeepSpeedTransformerInference
from torch import nn
from deepspeed.accelerator import get_accelerator
from deepspeed.ops.op_builder import InferenceBuilder

if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
pytest.skip("This op had not been implemented on this system.", allow_module_level=True)

rocm_version = OpBuilder.installed_rocm_version()
if rocm_version != (0, 0):
Expand Down
4 changes: 4 additions & 0 deletions tests/unit/inference/test_model_profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
from transformers import pipeline
from unit.common import DistributedTest
from deepspeed.accelerator import get_accelerator
from deepspeed.ops.op_builder import InferenceBuilder

if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
pytest.skip("This op had not been implemented on this system.", allow_module_level=True)


@pytest.fixture
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/ops/accelerators/test_accelerator_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#pytest.skip(
# "transformer kernels are temporarily disabled because of unexplained failures",
# allow_module_level=True)
if torch.half not in get_accelerator().supported_dtypes():
pytest.skip(f"fp16 not supported, valid dtype: {get_accelerator().supported_dtypes()}", allow_module_level=True)


def check_equal(first, second, atol=1e-2, verbose=False):
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/ops/accelerators/test_accelerator_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from deepspeed.accelerator import get_accelerator
from unit.common import DistributedTest

if torch.half not in get_accelerator().supported_dtypes():
pytest.skip(f"fp16 not supported, valid dtype: {get_accelerator().supported_dtypes()}", allow_module_level=True)


def check_equal(first, second, atol=1e-2, verbose=False):
if verbose:
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/ops/adam/test_adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
from deepspeed.ops.adam import DeepSpeedCPUAdam
from unit.common import DistributedTest
from unit.simple_model import SimpleModel
from deepspeed.accelerator import get_accelerator

if torch.half not in get_accelerator().supported_dtypes():
pytest.skip(f"fp16 not supported, valid dtype: {get_accelerator().supported_dtypes()}", allow_module_level=True)
# yapf: disable
#'optimizer, zero_offload, torch_adam, adam_w_mode, resulting_optimizer
adam_configs = [["AdamW", False, False, False, (FusedAdam, True)],
Expand Down
8 changes: 6 additions & 2 deletions tests/unit/ops/quantizer/test_fake_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@

import torch
import pytest
import deepspeed
from deepspeed.accelerator import get_accelerator
from deepspeed.ops import op_builder
from deepspeed.ops.op_builder import QuantizerBuilder

if not deepspeed.ops.__compatible_ops__[QuantizerBuilder.NAME]:
pytest.skip("Inference ops are not available on this system", allow_module_level=True)

quantizer_cuda_module = None

Expand Down Expand Up @@ -36,7 +40,7 @@ def run_quant_dequant(inputs, groups, bits):
global quantizer_cuda_module

if quantizer_cuda_module is None:
quantizer_cuda_module = op_builder.QuantizerBuilder().load()
quantizer_cuda_module = QuantizerBuilder().load()
return quantizer_cuda_module.ds_quantize_fp16(inputs, groups, bits)


Expand Down
10 changes: 7 additions & 3 deletions tests/unit/ops/quantizer/test_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,20 @@

import pytest
import torch
from deepspeed.ops import op_builder
import deepspeed
from deepspeed.ops.op_builder import QuantizerBuilder
from deepspeed.accelerator import get_accelerator

if not deepspeed.ops.__compatible_ops__[QuantizerBuilder.NAME]:
pytest.skip("Inference ops are not available on this system", allow_module_level=True)

inference_module = None


def run_quantize_ds(activations, num_groups, q_bits, is_symmetric_quant):
global inference_module
if inference_module is None:
inference_module = op_builder.QuantizerBuilder().load()
inference_module = QuantizerBuilder().load()

return inference_module.quantize(activations, num_groups, q_bits,
inference_module.Symmetric if is_symmetric_quant else inference_module.Asymmetric)
Expand All @@ -23,7 +27,7 @@ def run_quantize_ds(activations, num_groups, q_bits, is_symmetric_quant):
def run_dequantize_ds(activations, params, num_groups, q_bits, is_symmetric_quant):
global inference_module
if inference_module is None:
inference_module = op_builder.QuantizerBuilder().load()
inference_module = QuantizerBuilder().load()
return inference_module.dequantize(
activations,
params,
Expand Down
5 changes: 5 additions & 0 deletions tests/unit/ops/spatial/test_nhwc_bias_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,14 @@

import pytest
import torch
import deepspeed
from deepspeed.ops.op_builder import SpatialInferenceBuilder
from deepspeed.ops.transformer.inference.bias_add import nhwc_bias_add
from deepspeed.accelerator import get_accelerator

if not deepspeed.ops.__compatible_ops__[SpatialInferenceBuilder.NAME]:
pytest.skip("Inference ops are not available on this system", allow_module_level=True)


def allclose(x, y):
assert x.dtype == y.dtype
Expand Down
4 changes: 4 additions & 0 deletions tests/unit/profiling/flops_profiler/test_flops_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
from unit.simple_model import SimpleModel, random_dataloader
from unit.common import DistributedTest
from unit.util import required_minimum_torch_version
from deepspeed.accelerator import get_accelerator

if torch.half not in get_accelerator().supported_dtypes():
pytest.skip(f"fp16 not supported, valid dtype: {get_accelerator().supported_dtypes()}", allow_module_level=True)

pytestmark = pytest.mark.skipif(not required_minimum_torch_version(major_version=1, minor_version=3),
reason='requires Pytorch version 1.3 or above')
Expand Down

0 comments on commit 7ea1c8b

Please sign in to comment.