Skip to content

Add half-precision (bfloat16, float16) support to train & validate scripts #2397

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions timm/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,18 +77,18 @@ class PrefetchLoader:

def __init__(
self,
loader,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
channels=3,
device=torch.device('cuda'),
img_dtype=torch.float32,
fp16=False,
re_prob=0.,
re_mode='const',
re_count=1,
re_num_splits=0):

loader: torch.utils.data.DataLoader,
mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
channels: int = 3,
device: torch.device = torch.device('cuda'),
img_dtype: Optional[torch.dtype] = None,
fp16: bool = False,
re_prob: float = 0.,
re_mode: str = 'const',
re_count: int = 1,
re_num_splits: int = 0,
):
mean = adapt_to_chs(mean, channels)
std = adapt_to_chs(std, channels)
normalization_shape = (1, channels, 1, 1)
Expand All @@ -98,7 +98,7 @@ def __init__(
if fp16:
# fp16 arg is deprecated, but will override dtype arg if set for bwd compat
img_dtype = torch.float16
self.img_dtype = img_dtype
self.img_dtype = img_dtype or torch.float32
self.mean = torch.tensor(
[x * 255 for x in mean], device=device, dtype=img_dtype).view(normalization_shape)
self.std = torch.tensor(
Expand Down
36 changes: 28 additions & 8 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@
help='lower precision AMP dtype (default: float16)')
group.add_argument('--amp-impl', default='native', type=str,
help='AMP impl to use, "native" or "apex" (default: native)')
group.add_argument('--model-dtype', default=None, type=str,
help='Model dtype override (non-AMP) (default: float32)')
group.add_argument('--no-ddp-bb', action='store_true', default=False,
help='Force broadcast buffers for native DDP to off.')
group.add_argument('--synchronize-step', action='store_true', default=False,
Expand Down Expand Up @@ -434,10 +436,18 @@ def main():
_logger.info(f'Training with a single process on 1 device ({args.device}).')
assert args.rank >= 0

model_dtype = None
if args.model_dtype:
assert args.model_dtype in ('float32', 'float16', 'bfloat16')
model_dtype = getattr(torch, args.model_dtype)
if model_dtype == torch.float16:
_logger.warning('float16 is not recommended for training, for half precision bfloat16 is recommended.')

# resolve AMP arguments based on PyTorch / Apex availability
use_amp = None
amp_dtype = torch.float16
if args.amp:
assert model_dtype is None or model_dtype == torch.float32, 'float32 model dtype must be used with AMP'
if args.amp_impl == 'apex':
assert has_apex, 'AMP impl specified as APEX but APEX is not installed.'
use_amp = 'apex'
Expand Down Expand Up @@ -517,7 +527,7 @@ def main():
model = convert_splitbn_model(model, max(num_aug_splits, 2))

# move model to GPU, enable channels last layout if set
model.to(device=device)
model.to(device=device, dtype=model_dtype) # FIXME move model device & dtype into create_model
if args.channels_last:
model.to(memory_format=torch.channels_last)

Expand Down Expand Up @@ -587,7 +597,7 @@ def main():
_logger.info('Using native Torch AMP. Training in mixed precision.')
else:
if utils.is_primary(args):
_logger.info('AMP not enabled. Training in float32.')
_logger.info(f'AMP not enabled. Training in {model_dtype or torch.float32}.')

# optionally resume from a checkpoint
resume_epoch = None
Expand Down Expand Up @@ -732,6 +742,7 @@ def main():
distributed=args.distributed,
collate_fn=collate_fn,
pin_memory=args.pin_mem,
img_dtype=model_dtype,
device=device,
use_prefetcher=args.prefetcher,
use_multi_epochs_loader=args.use_multi_epochs_loader,
Expand All @@ -756,6 +767,7 @@ def main():
distributed=args.distributed,
crop_pct=data_config['crop_pct'],
pin_memory=args.pin_mem,
img_dtype=model_dtype,
device=device,
use_prefetcher=args.prefetcher,
)
Expand Down Expand Up @@ -823,9 +835,13 @@ def main():
if utils.is_primary(args) and args.log_wandb:
if has_wandb:
assert not args.wandb_resume_id or args.resume
wandb.init(project=args.experiment, config=args, tags=args.wandb_tags,
resume='must' if args.wandb_resume_id else None,
id=args.wandb_resume_id if args.wandb_resume_id else None)
wandb.init(
project=args.experiment,
config=args,
tags=args.wandb_tags,
resume='must' if args.wandb_resume_id else None,
id=args.wandb_resume_id if args.wandb_resume_id else None,
)
else:
_logger.warning(
"You've requested to log metrics to wandb but package not found. "
Expand Down Expand Up @@ -879,6 +895,7 @@ def main():
output_dir=output_dir,
amp_autocast=amp_autocast,
loss_scaler=loss_scaler,
model_dtype=model_dtype,
model_ema=model_ema,
mixup_fn=mixup_fn,
num_updates_total=num_epochs * updates_per_epoch,
Expand All @@ -897,6 +914,7 @@ def main():
args,
device=device,
amp_autocast=amp_autocast,
model_dtype=model_dtype,
)

if model_ema is not None and not args.model_ema_force_cpu:
Expand Down Expand Up @@ -979,6 +997,7 @@ def train_one_epoch(
output_dir=None,
amp_autocast=suppress,
loss_scaler=None,
model_dtype=None,
model_ema=None,
mixup_fn=None,
num_updates_total=None,
Expand Down Expand Up @@ -1015,7 +1034,7 @@ def train_one_epoch(
accum_steps = last_accum_steps

if not args.prefetcher:
input, target = input.to(device), target.to(device)
input, target = input.to(device=device, dtype=model_dtype), target.to(device=device)
if mixup_fn is not None:
input, target = mixup_fn(input, target)
if args.channels_last:
Expand Down Expand Up @@ -1142,6 +1161,7 @@ def validate(
args,
device=torch.device('cuda'),
amp_autocast=suppress,
model_dtype=None,
log_suffix=''
):
batch_time_m = utils.AverageMeter()
Expand All @@ -1157,8 +1177,8 @@ def validate(
for batch_idx, (input, target) in enumerate(loader):
last_batch = batch_idx == last_idx
if not args.prefetcher:
input = input.to(device)
target = target.to(device)
input = input.to(device=device, dtype=model_dtype)
target = target.to(device=device)
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)

Expand Down
19 changes: 14 additions & 5 deletions validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@
help='lower precision AMP dtype (default: float16)')
parser.add_argument('--amp-impl', default='native', type=str,
help='AMP impl to use, "native" or "apex" (default: native)')
parser.add_argument('--model-dtype', default=None, type=str,
help='Model dtype override (non-AMP) (default: float32)')
parser.add_argument('--tf-preprocessing', action='store_true', default=False,
help='Use Tensorflow preprocessing pipeline (require CPU TF installed')
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
Expand Down Expand Up @@ -168,10 +170,16 @@ def validate(args):

device = torch.device(args.device)

model_dtype = None
if args.model_dtype:
assert args.model_dtype in ('float32', 'float16', 'bfloat16')
model_dtype = getattr(torch, args.model_dtype)

# resolve AMP arguments based on PyTorch / Apex availability
use_amp = None
amp_autocast = suppress
if args.amp:
assert model_dtype is None or model_dtype == torch.float32, 'float32 model dtype must be used with AMP'
if args.amp_impl == 'apex':
assert has_apex, 'AMP impl specified as APEX but APEX is not installed.'
assert args.amp_dtype == 'float16'
Expand All @@ -184,7 +192,7 @@ def validate(args):
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
_logger.info('Validating in mixed precision with native PyTorch AMP.')
else:
_logger.info('Validating in float32. AMP not enabled.')
_logger.info(f'Validating in {model_dtype or torch.float32}. AMP not enabled.')

if args.fuser:
set_jit_fuser(args.fuser)
Expand Down Expand Up @@ -231,7 +239,7 @@ def validate(args):
if args.test_pool:
model, test_time_pool = apply_test_time_pool(model, data_config)

model = model.to(device)
model = model.to(device=device, dtype=model_dtype) # FIXME move model device & dtype into create_model
if args.channels_last:
model = model.to(memory_format=torch.channels_last)

Expand Down Expand Up @@ -299,6 +307,7 @@ def validate(args):
crop_border_pixels=args.crop_border_pixels,
pin_memory=args.pin_mem,
device=device,
img_dtype=model_dtype,
tf_preprocessing=args.tf_preprocessing,
)

Expand All @@ -310,7 +319,7 @@ def validate(args):
model.eval()
with torch.no_grad():
# warmup, reduce variability of first batch time, especially for comparing torchscript vs non
input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).to(device)
input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).to(device=device, dtype=model_dtype)
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
with amp_autocast():
Expand All @@ -319,8 +328,8 @@ def validate(args):
end = time.time()
for batch_idx, (input, target) in enumerate(loader):
if args.no_prefetcher:
target = target.to(device)
input = input.to(device)
target = target.to(device=device)
input = input.to(device=device, dtype=model_dtype)
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)

Expand Down
Loading