Description
🐛 Bug
Hi, thanks so much for quick response to previous issues! I've recently been receiving a small bug for Expanded Weights where my model parameters can't handle some built-in operator functions, namely __repr__
and __len__
. My Opacus is updated to the most recent version including the new Expanded Weights Goodness update.
In this code example, I'm replicating the error with the LinearWithExtraParam()
class that I found in the test directory, and I'm attempting to print out or find the length of the self.extra_param
, a nn.Parameter()
instance defined in the LinearWithExtraParam()
class:
import torch
import torch.nn as nn
import torch.nn.functional as F
from opacus import PrivacyEngine
from opacus.grad_sample.gsm_exp_weights import API_CUTOFF_VERSION
from torch.utils.data import DataLoader
from opacus import PrivacyEngine
class LinearWithExtraParam(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.fc = nn.Linear(in_features, out_features)
self.extra_param = nn.Parameter(torch.randn(out_features, 2))
def forward(self, x):
x = self.fc(x)
x = x.matmul(self.extra_param)
return x
def _init(module, size, batch_size=10):
optim = torch.optim.SGD(module.parameters(), lr=0.1)
dl = DataLoader(
dataset=[torch.randn(*size) for _ in range(100)],
batch_size=batch_size,
)
return module, optim, dl
privacy_engine = PrivacyEngine()
module, optim, dl = _init(LinearWithExtraParam(5, 8), size=(16, 5))
module, optim, dl = privacy_engine.make_private(
module=module,
optimizer=optim,
data_loader=dl,
noise_multiplier=1.0,
max_grad_norm=1.0,
grad_sample_mode="ew",
)
for x in dl:
module(x)
If I add len(self.extra_param)
to the LinearWithExtraParam.forward()
method before x.matmul(self.extra_param)
, I receive the following error:
RuntimeError: Expanded Weights encountered but cannot handle function __repr__
If I add print(self.extra_param)
to the LinearWithExtraParam.forward()
method before x.matmul(self.extra_param)
, I receive the following error:
RuntimeError: Expanded Weights encountered but cannot handle function __len__
To Reproduce
Steps to reproduce the behavior:
- Create a module with a
nn.Parameter
. Set 'grad_sample_mode="ew"` - Insert either
len(self.extra_param)
orprint(self.extra_param)
in a method, such as theforward
method - Run the method that calls either the len or print function
- Receive an error where EW cannot handle the len or print functions
Expected behavior
For an object x, I expect len(nn.Parameter(x))
to output length of the object x
.
For an object x, I expect print(nn.Parameter(x))
to print x
(i.e. use the __repr__
function for x
).
Environment
- PyTorch Version (e.g., 1.0): 1.13
- OS (e.g., Linux): Linux
- How you installed PyTorch (
conda
,pip
, source): conda - Build command you used (if compiling from source):
- Python version:
- CUDA/cuDNN version:
- GPU models and configuration:
- Any other relevant information: