Skip to content

Commit

Permalink
update precision interface
Browse files Browse the repository at this point in the history
  • Loading branch information
sangkeun00 committed May 25, 2023
1 parent 334bb18 commit e2a7651
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 48 deletions.
8 changes: 5 additions & 3 deletions betty/configs/problem_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ class Config:
# gradient clipping
gradient_clipping: float = 0.0

# fp16 training
fp16: bool = False
# precision
precision: str = "fp32"
initial_dynamic_scale: float = 4096.0
scale_factor: float = 2.0

Expand All @@ -33,7 +33,9 @@ class Config:

# darts
darts_alpha: float = 0.01
darts_adam_alpha: float = 1.0

# sama
sama_adam_alpha: float = 1.0

# neumann
neumann_iterations: int = 1
Expand Down
2 changes: 1 addition & 1 deletion betty/hypergradient/sama.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def sama(vector, curr, prev, sync):
:rtype: Sequence of Tensor
"""
config = curr.config
R = config.darts_adam_alpha
R = config.sama_adam_alpha
vector = precondition(vector, curr)
eps = R / to_vec(vector).norm().add_(1e-12).item()

Expand Down
9 changes: 3 additions & 6 deletions betty/problems/implicit_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,16 @@ def __init__(

def optimizer_step(self, *args, **kwargs):
if self.is_implemented("custom_optimizer_step"):
assert (
not self._is_default_fp16()
), "[!] FP16 training is not supported for custom optimizer step."
if self.gradient_clipping > 0.0:
self.clip_grad()
self.custom_optimizer_step(*args, **kwargs)
else:
if self._is_default_fp16():
if self.scaler is not None:
self.scaler.unscale_(self.optimizer)
if self.gradient_clipping > 0.0:
self.clip_grad()
self.scaler.step(self.optimizer)
if self.config.type == "darts_adam":
if self.config.type == "sama":
for param in self.trainable_parameters():
state = self.get_opt_state_for_param(param)
if param.grad is not None and len(state) != 0:
Expand All @@ -59,7 +56,7 @@ def optimizer_step(self, *args, **kwargs):
if self.gradient_clipping > 0.0:
self.clip_grad()
self.optimizer.step()
if self.config.type == "darts_adam":
if self.config.type == "sama":
for param in self.trainable_parameters():
state = self.get_opt_state_for_param(param)
if param.grad is not None and len(state) != 0:
Expand Down
36 changes: 13 additions & 23 deletions betty/problems/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from betty.patch.scheduler import patch_scheduler
from betty.configs import Config
from betty.hypergradient import get_grads
from betty.utils import convert_tensor, log_from_loss_dict
from betty.utils import convert_tensor, log_from_loss_dict, get_dtype


class Problem:
Expand Down Expand Up @@ -74,10 +74,11 @@ def __init__(
# environment
self.env = None

# fp16 scaler
self._fp16 = config.fp16
# precision
self.precision = config.precision
self.dtype = get_dtype(self.precision)
self.scaler = None
if self._fp16:
if self.precision == "fp16":
self.initial_dynamic_scale = config.initial_dynamic_scale
self.scale_factor = config.scale_factor

Expand Down Expand Up @@ -161,7 +162,7 @@ def initialize(self):
self.scheduler = self.configure_scheduler()

# set up fp16 training
if self._is_default_fp16():
if self.precision == "fp16" and self._strategy != "accelerate":
assert torch.cuda.is_available()
scaler_cls = torch.cuda.amp.GradScaler
if self._strategy == "fsdp":
Expand Down Expand Up @@ -311,8 +312,8 @@ def training_step(self, batch):
raise NotImplementedError

def training_step_exec(self, batch):
if self._is_default_fp16():
with torch.cuda.amp.autocast():
if self.precision in ["fp16", "bf16"] and self._strategy != "accelerate":
with torch.cuda.amp.autocast(dtype=self.dtype):
return self.training_step(batch)
else:
return self.training_step(batch)
Expand Down Expand Up @@ -472,13 +473,10 @@ def get_batch_single_loader(self, idx):
self.train_data_iterator[idx] = iter(train_data_loader)
batch = next(self.train_data_iterator[idx])
if not isinstance(batch, dict):
batch = tuple(
convert_tensor(value, self.device, self._is_default_fp16())
for value in batch
)
batch = tuple(convert_tensor(value, self.device) for value in batch)
else:
for key, value in batch.items():
batch[key] = convert_tensor(value, self.device, self._is_default_fp16())
batch[key] = convert_tensor(value, self.device)

return batch

Expand All @@ -494,7 +492,7 @@ def get_loss(self, batch):
is_dict = isinstance(maybe_loss_dict, dict)
loss = maybe_loss_dict["loss"] if is_dict else maybe_loss_dict
loss_no_scale = loss.item()
if self._is_default_fp16():
if self.scaler is not None:
loss = self.scaler.scale(loss)
loss = loss / self.gas

Expand Down Expand Up @@ -632,7 +630,7 @@ def state_dict(self):
state_dict["optimizer"] = self.optimizer.state_dict()
if self.scheduler is not None:
state_dict["scheduler"] = self.scheduler.state_dict()
if self._is_default_fp16():
if self.scaler is not None:
state_dict["scaler"] = self.scaler.state_dict()

return state_dict
Expand All @@ -647,7 +645,7 @@ def load_state_dict(self, state_dict):
self.optimizer.load_state_dict(state_dict["optimizer"])
if self.scheduler is not None and "scheduler" in state_dict:
self.scheduler.load_state_dict(state_dict["scheduler"])
if self._is_default_fp16() and "scaler" in state_dict:
if self.scaler is not None and "scaler" in state_dict:
self.scaler.load_state_dict(state_dict["scaler"])

def configure_distributed_training(self, dictionary):
Expand Down Expand Up @@ -732,14 +730,6 @@ def gradient_accumulation_boundary(self):
"""
return bool(self._count % self.gas == 0)

def _is_default_fp16(self):
"""
Check whether to use PyTorch native fp16 (mixed-precision) feature
"""
if not self._fp16 or self._strategy in ["accelerate"]:
return False
return True

def is_implemented(self, fn_name):
"""
Check if ``fn_name`` method is implemented in the class
Expand Down
11 changes: 10 additions & 1 deletion betty/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
import torch


def convert_tensor(item, device=None, fp16=False):
def convert_tensor(item, device=None):
if not isinstance(item, torch.Tensor):
return item
return item.to(device)


def get_dtype(precision):
if precision == "fp16":
return torch.float16
elif precision == "bf16":
return torch.bfloat16
else:
return torch.float32


def get_grad_norm(parameters):
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
Expand Down
7 changes: 3 additions & 4 deletions examples/bert_data_reweighting/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

parser = argparse.ArgumentParser(description="Meta_Weight_Net")
parser.add_argument("--baseline", action="store_true")
parser.add_argument("--fp16", action="store_true")
parser.add_argument("--precision", type=str, default="fp32")
parser.add_argument("--strategy", type=str, default="default")
parser.add_argument("--local_rank", type=int, default=0)
parser.add_argument("--rollback", action="store_true")
Expand Down Expand Up @@ -165,14 +165,13 @@ def validation(self):
)
finetune_config = Config(
type="darts",
fp16=args.fp16,
precision=args.precision,
retain_graph=True,
gradient_clipping=5.0,
log_step=args.valid_step,
unroll_steps=5,
darts_preconditioned=False,
)
reweight_config = Config(type="darts", fp16=args.fp16)
reweight_config = Config(type="darts", precision=args.precision)

finetune = Finetune(
name="finetune",
Expand Down
19 changes: 9 additions & 10 deletions examples/bert_data_reweighting/run.sh
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
# Baseline
python main.py --batch_size 16 --baseline --fp16 --seed 0
python main.py --batch_size 16 --baseline --fp16 --seed 1
python main.py --batch_size 16 --baseline --fp16 --seed 2

python main.py --batch_size 16 --baseline --precision fp16 --seed 0
python main.py --batch_size 16 --baseline --precision fp16 --seed 1
python main.py --batch_size 16 --baseline --precision fp16 --seed 2
# Meta-Weight-Net (single GPU fp 16)
python main.py --batch_size 16 --fp16 --seed 0
python main.py --batch_size 16 --fp16 --seed 1
python main.py --batch_size 16 --fp16 --seed 2
python main.py --batch_size 16 --precision fp16 --seed 0
python main.py --batch_size 16 --precision fp16 --seed 1
python main.py --batch_size 16 --precision fp16 --seed 2

# Meta-Weight-Net (Multi GPU + ZeRO)
torchrun --standalone --nnodes=1 --nproc_per_node=2 main.py --batch_size 16 --fp16 --strategy zero --seed 0
torchrun --standalone --nnodes=1 --nproc_per_node=2 main.py --batch_size 16 --fp16 --strategy zero --seed 1
torchrun --standalone --nnodes=1 --nproc_per_node=2 main.py --batch_size 16 --fp16 --strategy zero --seed 2
torchrun --standalone --nnodes=1 --nproc_per_node=2 main.py --batch_size 8 --precision fp16 --strategy zero --seed 0
torchrun --standalone --nnodes=1 --nproc_per_node=2 main.py --batch_size 8 --precision fp16 --strategy zero --seed 1
torchrun --standalone --nnodes=1 --nproc_per_node=2 main.py --batch_size 8 --precision fp16 --strategy zero --seed 2

0 comments on commit e2a7651

Please sign in to comment.