Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend HE-Lora test with Z3 support + Fix/add guard in HE for Z3 #3883

Merged
merged 5 commits into from
Jul 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions deepspeed/runtime/hybrid_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,8 +431,9 @@ def step(self, lr_kwargs=None):
super().step(lr_kwargs=lr_kwargs)

if len(self._inference_containers) > 0:
for inference_container in self._inference_containers:
inference_container.reset_params()
if not self.Z3_enabled:
for inference_container in self._inference_containers:
inference_container.reset_params()

if self._training_start_time is not None:
self._training_latency += (time.time() - self._training_start_time)
Expand Down
44 changes: 36 additions & 8 deletions tests/unit/hybrid_engine/test_he_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import torch.nn.functional as F
import pytest
import deepspeed
from deepspeed.runtime.zero import GatheredParameters
from deepspeed.ops.op_builder import OpBuilder
from deepspeed.ops.adam import FusedAdam
from deepspeed.utils import safe_get_full_grad
import numpy.testing as npt
from unit.common import DistributedTest
Expand Down Expand Up @@ -109,7 +109,9 @@ def only_optimize_lora_parameters(model):

@pytest.mark.seq_inference
@pytest.mark.parametrize("batch_size", [1], ids=["bsz=1"])
@pytest.mark.parametrize("zero_stage", [2, 3], ids=["zero_stage=2", "zero_stage=3"])
@pytest.mark.parametrize("model_name", ["EleutherAI/gpt-neo-125m", "facebook/opt-350m", "bigscience/bloom-560m"])
@pytest.mark.parametrize("offload_device", ["none", "cpu"])
class TestHybridEngineLoRA(DistributedTest):
world_size = 1

Expand Down Expand Up @@ -139,20 +141,42 @@ def get_train_sentences(self, batch_size):
else:
raise NotImplementedError(f"batch_size {batch_size} not implemented")

def test_lora(self, batch_size, model_name):
def test_lora(self, batch_size, model_name, zero_stage, offload_device):
local_rank = int(os.getenv("LOCAL_RANK", "0"))

model = self.get_model(model_name)
tokenizer = self.get_tokenizer(model_name)
train_sentences = self.get_train_sentences(batch_size)

# Inject LoRA
model = convert_linear_layer_to_lora(model, "", 8)
model = only_optimize_lora_parameters(model)
optim = FusedAdam([p for p in model.parameters() if p.requires_grad], lr=1.0, betas=(0.9, 0.95))
ds_config = {"train_batch_size": batch_size, "bfp16": {"enabled": True}, "hybrid_engine": {"enabled": True}}

model, *_ = deepspeed.initialize(model=model, optimizer=optim, config=ds_config)
ds_config = {
"optimizer": {
"type": "Adam",
"params": {
"lr": 1.0,
"betas": [0.9, 0.95]
}
},
"train_batch_size": batch_size,
"fp16": {
"enabled": True,
"initial_scale_power": 12
},
"hybrid_engine": {
"enabled": True,
"pin_parameters": True
},
"zero_optimization": {
"stage": zero_stage,
"offload_optimizer": {
"device": offload_device
}
}
}

model, *_ = deepspeed.initialize(model=model, config=ds_config)

# Verify gradient norm is larger than 0
before_grad_update_layer0_params = [
Expand Down Expand Up @@ -187,7 +211,9 @@ def test_lora(self, batch_size, model_name):

# Verify fuse will mutate layer_params
model.eval()
model.fuse_lora_weight()
with GatheredParameters(model.parameters()):
model.fuse_lora_weight()

after_grad_update_layer0_params_lora_fused = [
ele.detach().cpu().float().numpy() for ele in model.layer_params[0]
if ele is not None and len(ele.shape) > 1
Expand All @@ -196,4 +222,6 @@ def test_lora(self, batch_size, model_name):
for lhs, rhs in zip(before_grad_update_layer0_params, after_grad_update_layer0_params_lora_fused):
with pytest.raises(AssertionError):
npt.assert_allclose(lhs, rhs, 1E-5, 1E-5)
model.unfuse_lora_weight()

with GatheredParameters(model.parameters()):
model.unfuse_lora_weight()