Skip to content

Dropout support #875

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 73 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
d672b84
add jit-friendly dropout w rate in call
priyakasimbeg Apr 17, 2025
aa25e20
remove nan_to_num convertion
priyakasimbeg May 5, 2025
85a3578
update models with custom dropout layer
priyakasimbeg May 5, 2025
9354079
add functional dropout for criteo, fastmri, and vit
priyakasimbeg May 5, 2025
feb9cc5
add functional dropout for ogbg
priyakasimbeg May 5, 2025
9bba078
modify wmt model for dropout passing
priyakasimbeg May 15, 2025
31f6019
modify wmt model for dropout passing
priyakasimbeg May 15, 2025
e36d294
reformatting and dropout fixes to fastmri and vit
priyakasimbeg May 29, 2025
363da8a
dropout fix for criteo1tb jax
priyakasimbeg May 29, 2025
341bf89
dropout fix for criteo1tb jax
priyakasimbeg May 29, 2025
f0c385b
remove aux dropout option from conformer and from init_model_fn signa…
priyakasimbeg May 29, 2025
7af5c94
add dropout piping for conformer and deepspeech
priyakasimbeg May 31, 2025
cbd065b
pipe dropout through model_fn
priyakasimbeg May 31, 2025
31babfd
fix syntax
priyakasimbeg Jun 4, 2025
95d67db
dropout changes wmt jax
priyakasimbeg Jun 4, 2025
2c96b88
modify dockerfile
priyakasimbeg Jun 4, 2025
54786a6
modify docker build script
priyakasimbeg Jun 4, 2025
246d68e
fsmall fixes
priyakasimbeg Jun 5, 2025
0c8dd14
change docker base image to 12.1.1
priyakasimbeg Jun 6, 2025
a78fa66
update base image
priyakasimbeg Jun 6, 2025
f0019ac
small fix
priyakasimbeg Jun 7, 2025
3cb012e
remove aux_dropout from submission_runner.py
priyakasimbeg Jun 7, 2025
b306076
dropout fix criteo, fastmri, vit, conf
Niccolo-Ajroldi Jun 10, 2025
3e7a396
dropout fix deepspeech, ogbg
Niccolo-Ajroldi Jun 11, 2025
e80add4
remove attention_dropout_rate from wmt
Niccolo-Ajroldi Jun 11, 2025
84b1bd1
dropout fix on wmt
Niccolo-Ajroldi Jun 11, 2025
af08bb9
fix dropout, ALL tested
Niccolo-Ajroldi Jun 11, 2025
7a6651a
add dropout equivalence tests
Niccolo-Ajroldi Jun 11, 2025
a7ff3d1
moved custom dropout to pytorch_utils
Niccolo-Ajroldi Jun 11, 2025
f26ab02
remove aux_dropout from pytorch workloads
Niccolo-Ajroldi Jun 11, 2025
e0a0e62
criteo rm dropout from init
Niccolo-Ajroldi Jun 12, 2025
1e2f379
criteo rm dropout from init
Niccolo-Ajroldi Jun 12, 2025
f10e3dc
criteo rm dropout from init
Niccolo-Ajroldi Jun 12, 2025
027b053
criteo rm dropout from init
Niccolo-Ajroldi Jun 12, 2025
74c43aa
fastmri rm dropout from init
Niccolo-Ajroldi Jun 12, 2025
64276ef
vit rm dropout at init
Niccolo-Ajroldi Jun 12, 2025
44029d2
vit rm dropout at init
Niccolo-Ajroldi Jun 12, 2025
44ffec1
add default dropout test
Niccolo-Ajroldi Jun 12, 2025
9d12fa6
add default dropout test
Niccolo-Ajroldi Jun 12, 2025
ac45a9f
conformer: rm dropout_rate from init
Niccolo-Ajroldi Jun 12, 2025
31d64f6
rm dropout_rate at init from all workloads
Niccolo-Ajroldi Jun 12, 2025
5e192dd
remove dropout_rate from init_model_fn for all jax workloads
priyakasimbeg Jun 12, 2025
23828cd
remove dropout from model initialization call in submission_runner.py
priyakasimbeg Jun 12, 2025
86b8624
remove dropout check for None and use default instead if not passed
priyakasimbeg Jun 12, 2025
0128c9f
pipe dropout to model_fn, set default in workload
Niccolo-Ajroldi Jun 13, 2025
a7cba1a
remove aux_dropout from pytorch workloads
Niccolo-Ajroldi Jun 13, 2025
05bff91
fix to model_fn default dropout value
priyakasimbeg Jun 13, 2025
d8e39b0
fix to model_fn default dropout_rate
Niccolo-Ajroldi Jun 15, 2025
7a00158
rm models_dropout torch files
Niccolo-Ajroldi Jun 15, 2025
f7d99a6
fixes
priyakasimbeg Jun 17, 2025
4f9a4b3
Merge branch 'dev' into dropout_jax
priyakasimbeg Jun 17, 2025
3a41559
fix reference_algorithm_tests.py
priyakasimbeg Jun 18, 2025
6b6f2a6
Merge pull request #873 from Niccolo-Ajroldi/dropout_pytorch
priyakasimbeg Jun 18, 2025
7c43022
fixes to ogbg and fastmri
priyakasimbeg Jun 18, 2025
894f4fb
fixes to fastmri and deepspeech
priyakasimbeg Jun 18, 2025
0bcf484
fixes to conformer vit
priyakasimbeg Jun 18, 2025
73c2276
conformer and vit fix for dropout refactor
priyakasimbeg Jun 18, 2025
5ff94d2
wmt fixes
priyakasimbeg Jun 18, 2025
9090e43
fix linting
priyakasimbeg Jun 18, 2025
4e69255
formatting
priyakasimbeg Jun 18, 2025
3ac97ae
fix formatting
priyakasimbeg Jun 18, 2025
badf124
fix test
priyakasimbeg Jun 18, 2025
eff3ea1
fix lint errors
priyakasimbeg Jun 18, 2025
f7fd6c7
formatting
priyakasimbeg Jun 18, 2025
8fc4cc5
fix spacing issues
priyakasimbeg Jun 18, 2025
99c3111
formatting
priyakasimbeg Jun 18, 2025
c2f4ed0
formatting
priyakasimbeg Jun 18, 2025
ae8ca68
Merge pull request #864 from mlcommons/dropout_jax
priyakasimbeg Jun 18, 2025
b20f49d
formatting
priyakasimbeg Jun 19, 2025
0ea37ee
fix
priyakasimbeg Jun 19, 2025
594f285
pylint fixes
priyakasimbeg Jun 19, 2025
f14ff8f
isort fixes
priyakasimbeg Jun 19, 2025
2a8586a
pylint fixes
priyakasimbeg Jun 19, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 130 additions & 0 deletions algoperf/jax_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from collections.abc import Sequence

import flax.linen as nn
from flax.linen.module import compact
from flax.linen.module import merge_param
from flax.linen.module import Module
from flax.typing import PRNGKey
import jax
from jax import lax
from jax import random
import jax.numpy as jnp


# Custom Layers
class Dropout(Module):
# pylint: disable=line-too-long
"""Create a dropout layer.
Forked from
https://flax-linen.readthedocs.io/en/latest/_modules/flax/linen/stochastic.html#Dropout.
The reference dropout implementation is modified support changes
to dropout rate during training by:
1) adding rate argument to the __call__ method.
2) removing the if-else condition to check for edge cases, which
will trigger a recompile for jitted code.

.. note::
When using :meth:`Module.apply() <flax.linen.Module.apply>`, make sure
to include an RNG seed named ``'dropout'``. Dropout isn't necessary for
variable initialization.

Example usage::

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> class MLP(nn.Module):
... @nn.compact
... def __call__(self, x, train):
... x = nn.Dense(4)(x)
... x = nn.Dropout(0.5, deterministic=not train)(x)
... return x

>>> model = MLP()
>>> x = jnp.ones((1, 3))
>>> variables = model.init(jax.random.key(0), x, train=False) # don't use dropout
>>> model.apply(variables, x, train=False) # don't use dropout
Array([[-0.17875527, 1.6255447 , -1.2431065 , -0.02554005]], dtype=float32)
>>> model.apply(variables, x, train=True, rngs={'dropout': jax.random.key(1)}) # use dropout
Array([[-0.35751054, 3.2510893 , 0. , 0. ]], dtype=float32)

Attributes:
rate: the dropout probability. (_not_ the keep rate!)
broadcast_dims: dimensions that will share the same dropout mask
deterministic: if false the inputs are scaled by ``1 / (1 - rate)``
and masked, whereas if true, no mask is applied and the inputs are
returned as is.
rng_collection: the rng collection name to use when requesting an rng
key.
"""

rate: float | None = None
broadcast_dims: Sequence[int] = ()
deterministic: bool | None = None
rng_collection: str = "dropout"
legacy: bool = True

@compact
def __call__(
self,
inputs,
deterministic: bool | None = None,
rate: float | None = None,
rng: PRNGKey | None = None,
):
"""Applies a random dropout mask to the input.

Args:
inputs: the inputs that should be randomly masked.
deterministic: if false the inputs are scaled by ``1 / (1 - rate)``
and masked, whereas if true, no mask is applied and the inputs are
returned as is.
rate: the dropout probability. (_not_ the keep rate!)
rng: an optional PRNGKey used as the random key, if not specified,
one will be generated using ``make_rng`` with the
``rng_collection`` name.

Returns:
The masked inputs reweighted to preserve mean.
"""
deterministic = merge_param("deterministic",
self.deterministic,
deterministic)

# Override self.rate if rate is passed to __call__
if rate is None:
rate = self.rate

if self.legacy:
if rate == 0.0:
return inputs

# Prevent gradient NaNs in 1.0 edge-case.
if rate == 1.0:
return jnp.zeros_like(inputs)

if deterministic:
return inputs

keep_prob = 1.0 - rate
if rng is None:
rng = self.make_rng(self.rng_collection)
broadcast_shape = list(inputs.shape)
for dim in self.broadcast_dims:
broadcast_shape[dim] = 1
mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape)
mask = jnp.broadcast_to(mask, inputs.shape)
return lax.select(mask, inputs, jnp.zeros_like(inputs))


# Utilities for debugging
def print_jax_model_summary(model, fake_inputs):
"""Prints a summary of the jax module."""
tabulate_fn = nn.tabulate(
model,
jax.random.PRNGKey(0),
console_kwargs={
"force_terminal": False, "force_jupyter": False, "width": 240
},
)
print(tabulate_fn(fake_inputs, train=False))
41 changes: 41 additions & 0 deletions algoperf/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
import jax
import tensorflow as tf
import torch
from torch import nn
from torch import Tensor
import torch.distributed as dist
import torch.nn.functional as F

from algoperf import spec
from algoperf.profiler import Profiler
Expand Down Expand Up @@ -77,3 +80,41 @@ def update_batch_norm_fn(module: spec.ParameterContainer,
module.momentum = 0.0
elif hasattr(module, 'momentum_backup'):
module.momentum = module.momentum_backup


class CustomDropout(nn.Module):
"""A module around torch.nn.functional.dropout."""

def __init__(self):
super().__init__()
self._supports_custom_dropout = True

def forward(self, x: Tensor, p: float) -> Tensor:
return F.dropout(x, p, training=self.training)


class CustomDropout2d(nn.Module):
"""A module around torch.nn.functional.dropout2d."""

def __init__(self):
super().__init__()
self._supports_custom_dropout = True

def forward(self, x: Tensor, p: float) -> Tensor:
return F.dropout2d(x, p, training=self.training)


class SequentialWithDropout(nn.Sequential):
"""Sequential of modules with dropout."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._supports_custom_dropout = True

def forward(self, x: Tensor, p: float) -> Tensor:
for module in self:
if getattr(module, '_supports_custom_dropout', False):
x = module(x, p)
else:
x = module(x)
return x
6 changes: 4 additions & 2 deletions algoperf/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,8 @@ def init_model_fn(self,
# ModelAuxiliaryState,
# ForwardPassMode,
# RandomState,
# bool],
# bool,
# float],
# Tensor]
@abc.abstractmethod
def model_fn(self,
Expand All @@ -256,7 +257,8 @@ def model_fn(self,
model_state: ModelAuxiliaryState,
mode: ForwardPassMode,
rng: RandomState,
update_batch_norm: bool) -> Tuple[Tensor, ModelAuxiliaryState]:
update_batch_norm: bool,
dropout_rate: float) -> Tuple[Tensor, ModelAuxiliaryState]:
"""Return logits_batch"""
# Possible side effect of updating BN.

Expand Down
8 changes: 1 addition & 7 deletions algoperf/workloads/cifar/cifar_jax/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,8 @@ def sync_batch_stats(
new_model_state['batch_stats'] = avg_fn(model_state['batch_stats'])
return new_model_state

def init_model_fn(
self,
rng: spec.RandomState,
dropout_rate: Optional[float] = None,
aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
"""Dropout is unused."""
del dropout_rate
del aux_dropout_rate
model_cls = getattr(models, 'ResNet18')
model = model_cls(num_classes=self._num_classes, dtype=jnp.float32)
self._model = model
Expand Down
25 changes: 15 additions & 10 deletions algoperf/workloads/criteo1tb/criteo1tb_jax/models.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
"""A JAX implementation of DLRM-Small."""

from typing import Sequence

import flax.linen as nn
from jax import nn as jnn
import jax.numpy as jnp

from algoperf.jax_utils import Dropout

DROPOUT_RATE = 0.0


class DLRMResNet(nn.Module):
"""Define a DLRMResNet model.
Expand All @@ -23,12 +26,13 @@ class DLRMResNet(nn.Module):
mlp_bottom_dims: Sequence[int] = (256, 256, 256)
mlp_top_dims: Sequence[int] = (256, 256, 256, 256, 1)
embed_dim: int = 128
dropout_rate: float = 0.0
dropout_rate: float = DROPOUT_RATE
use_layer_norm: bool = False # Unused.
embedding_init_multiplier: float = None # Unused

@nn.compact
def __call__(self, x, train):
def __call__(self, x, train, dropout_rate=DROPOUT_RATE):

bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1)
cat_features = jnp.asarray(cat_features, dtype=jnp.int32)

Expand Down Expand Up @@ -88,8 +92,8 @@ def scaled_init(key, shape, dtype=jnp.float_):
stddev=jnp.sqrt(1.0 / mlp_top_dims[layer_idx])))(
top_mlp_input)
x = nn.relu(x)
if self.dropout_rate and layer_idx == num_layers_top - 2:
x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x)
if dropout_rate and layer_idx == num_layers_top - 2:
x = Dropout(dropout_rate, deterministic=not train)(x, rate=dropout_rate)
top_mlp_input += x
# In the DLRM model the last layer width is always 1. We can hardcode that
# below.
Expand Down Expand Up @@ -151,7 +155,8 @@ class DlrmSmall(nn.Module):
embedding_init_multiplier: float = None

@nn.compact
def __call__(self, x, train):
def __call__(self, x, train, dropout_rate=DROPOUT_RATE):

bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1)
cat_features = jnp.asarray(cat_features, dtype=jnp.int32)

Expand Down Expand Up @@ -210,10 +215,10 @@ def scaled_init(key, shape, dtype=jnp.float_):
top_mlp_input = nn.relu(top_mlp_input)
if self.use_layer_norm:
top_mlp_input = nn.LayerNorm()(top_mlp_input)
if (self.dropout_rate is not None and self.dropout_rate > 0.0 and
if (dropout_rate is not None and dropout_rate > 0.0 and
layer_idx == num_layers_top - 2):
top_mlp_input = nn.Dropout(
rate=self.dropout_rate, deterministic=not train)(
top_mlp_input)
top_mlp_input = Dropout(
dropout_rate, deterministic=not train)(
top_mlp_input, rate=dropout_rate)
logits = top_mlp_input
return logits
19 changes: 10 additions & 9 deletions algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,36 +72,34 @@ def loss_fn(
def init_model_fn(
self,
rng: spec.RandomState,
dropout_rate: Optional[float] = None,
aux_dropout_rate: Optional[float] = None,
tabulate: Optional[bool] = False,
) -> spec.ModelInitState:
"""Only dropout is used."""
del aux_dropout_rate
if self.use_resnet:
model_class = models.DLRMResNet
else:
model_class = models.DlrmSmall

self._model = model_class(
vocab_size=self.vocab_size,
num_dense_features=self.num_dense_features,
mlp_bottom_dims=self.mlp_bottom_dims,
mlp_top_dims=self.mlp_top_dims,
embed_dim=self.embed_dim,
dropout_rate=dropout_rate,
use_layer_norm=self.use_layer_norm,
embedding_init_multiplier=self.embedding_init_multiplier)

params_rng, dropout_rng = jax.random.split(rng)
params_rng, _ = jax.random.split(rng)
init_fake_batch_size = 2
num_categorical_features = 26
num_dense_features = 13
input_size = num_dense_features + num_categorical_features
input_shape = (init_fake_batch_size, input_size)
init_fn = functools.partial(self._model.init, train=False)
initial_variables = jax.jit(init_fn)(
{'params': params_rng, 'dropout': dropout_rng},
jnp.ones(input_shape, jnp.float32))
initial_variables = jax.jit(init_fn)({
'params': params_rng,
},
jnp.ones(input_shape, jnp.float32))
initial_params = initial_variables['params']
self._param_shapes = param_utils.jax_param_shapes(initial_params)
self._param_types = param_utils.jax_param_types(self._param_shapes)
Expand All @@ -117,14 +115,17 @@ def model_fn(
model_state: spec.ModelAuxiliaryState,
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
update_batch_norm: bool,
dropout_rate: float = models.DROPOUT_RATE
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del update_batch_norm
inputs = augmented_and_preprocessed_input_batch['inputs']
train = mode == spec.ForwardPassMode.TRAIN
apply_kwargs = {'train': train}
if train:
apply_kwargs['rngs'] = {'dropout': rng}
apply_kwargs['dropout_rate'] = dropout_rate
logits_batch = self._model.apply({'params': params}, inputs, **apply_kwargs)
return logits_batch, None

Expand Down
Loading