-
Notifications
You must be signed in to change notification settings - Fork 1
Verify connectome regularization gradients #166
Description
To determine if the gradients of ols_weights are being used to update your model parameters effectively, especially in the context of your detailed loss function, there are several key points to check:
1. Gradient Flow Verification
Your loss function combines reconstruction loss with L1 regularization on all parameters and a connectome-specific regularization that directly involves ols_weights. The correct gradient computation and application hinge on ensuring that each part of your loss contributes correctly to the gradient calculations:
- Reconstruction Loss: Gradients from this part are standard and should be affecting the parameters involved in producing
outputfrom your model. - L1 Regularization: This affects all parameters (
self.parameters()), includingols_weightsif it is properly registered as a parameter. The L1 loss should contribute to gradients for all parameters uniformly. - Connectome Regularization: Directly involves
ols_weightsand should be the primary concern for the gradients ofols_weights.
2. Checking Gradients for ols_weights
After running a backward pass (loss.backward()), check the .grad attribute of ols_weights:
print("Gradient for ols_weights: ", model.ols_weights.grad)If you see non-zero gradients for ols_weights post-backward pass, it confirms that gradients are indeed being calculated for this parameter, which should subsequently be used in the optimization step (optimizer.step()).
3. Ensure ols_weights Is a Trainable Parameter
Since ols_weights is registered as a parameter using self.register_parameter(), ensure it's done correctly:
- It should have
requires_gradset toTrue. - It should be listed in
model.parameters(), which is used by your optimizer.
You can confirm this by:
print("Is 'ols_weights' a parameter?: ", any(p is model.ols_weights for p in model.parameters()))
print("Does 'ols_weights' require gradients?: ", model.ols_weights.requires_grad)4. Optimizer and Parameter Update Check
Ensure that the optimizer is set up to include ols_weights. Since you're using model.parameters() in the optimizer setup, it should naturally include ols_weights if it’s registered correctly:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # Example setup5. Debugging Tips
If gradients are not behaving as expected:
- Gradient Checking: Use
torch.autograd.gradcheckto manually verify the gradients for a small input set. - Isolate the Issue: Temporarily modify your loss function to only include the connectome loss part and check if
ols_weightsupdates as expected. - Logging and Monitoring: Log the values of
ols_weightsbefore and after an update step to confirm changes are being made.
6. Run a Simple Test
A practical test can be to initialize ols_weights with a specific value, run one training iteration with a known input, and check how ols_weights changes. This will help you understand if and how the updates align with your expectations based on the computed gradients.
By following these steps, you should be able to conclusively determine if ols_weights is correctly receiving and using gradients derived from your custom loss function during training.
Some real results from debugging code:
DEBUG train_model:
(after backward) Y_pred: (<ViewBackward0 object at 0x14572fe15b70>, False, True)
(after backward) model.ols_weights.grad: tensor([[-1.3481e-02, -3.2692e-03, -3.7134e-05, ..., -4.4823e-05,
3.5942e-05, -9.2983e-05],
[-3.6478e-05, 9.1195e-05, -3.6240e-05, ..., -3.7193e-05,
3.5942e-05, 3.9697e-05],
[-2.5868e-05, 3.6359e-05, -1.1606e-03, ..., 3.6120e-05,
3.5942e-05, -3.7730e-05],
...,
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]], device='cuda:0', dtype=torch.float16)
model.ols_weights: (None, True, True)
(before backward) Y_pred: (<ViewBackward0 object at 0x14572fe15b70>, False, True)
(before backward) model.ols_weights.grad: None
DEBUG train_model:
(after backward) Y_pred: (<ViewBackward0 object at 0x14572fe15cf0>, False, True)
(after backward) model.ols_weights.grad: tensor([[-9.1019e-03, -4.1466e-03, -4.2081e-05, ..., -1.1563e-04,
6.5207e-05, -2.0838e-04],
[-9.0301e-05, 2.9993e-04, -3.6240e-05, ..., 4.4346e-05,
-3.5942e-05, -3.6061e-05],
[-3.4618e-04, 4.2140e-05, -9.2363e-04, ..., -5.0187e-05,
5.0962e-05, -4.1246e-05],
...,
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]], device='cuda:0', dtype=torch.float16)
model.ols_weights: (None, True, True)
File "/net/vast-storage/scratch/vast/yanglab/qsimeon/worm-graph/train/_main.py", line 286, in <module>
model, metric = train_model(
^^^^^^^^^^^^
File "/net/vast-storage/scratch/vast/yanglab/qsimeon/worm-graph/train/_main.py", line 145, in train_model
train_baseline = criterion(output=Y_base, target=Y_train, mask=mask_train)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/net/vast-storage/scratch/vast/yanglab/qsimeon/worm-graph/models/_utils.py", line 1043, in loss
ols_grad = gradcheck(torch.norm, self.ols_weights**2 - (self.chem_weights + self.elec_weights))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/qsimeon/om2/miniconda3/envs/worm-graph/lib/python3.11/site-packages/torch/autograd/gradcheck.py", line 1476, in gradcheck
return _gradcheck_helper(**args)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/qsimeon/om2/miniconda3/envs/worm-graph/lib/python3.11/site-packages/torch/autograd/gradcheck.py", line 1490, in _gradcheck_helper
_gradcheck_real_imag(gradcheck_fn, func, func_out, tupled_inputs, outputs, eps,
File "/home/qsimeon/om2/miniconda3/envs/worm-graph/lib/python3.11/site-packages/torch/autograd/gradcheck.py", line 1113, in _gradcheck_real_imag
gradcheck_fn(func, func_out, tupled_inputs, outputs, eps,
File "/home/qsimeon/om2/miniconda3/envs/worm-graph/lib/python3.11/site-packages/torch/autograd/gradcheck.py", line 1170, in _slow_gradcheck
raise GradcheckError(_get_notallclose_msg(a, n, i, j, complex_indices, test_imag))
torch.autograd.gradcheck.GradcheckError: Jacobian mismatch for output 0 with respect to input 0,
numerical:tensor([[0.9537],
[0.0000],
[0.0000],
...,
[0.0000],
[0.0000],
[0.0000]], device='cuda:0')
analytical:tensor([[ 0.0417],
[-0.0149],
[ 0.0000],
...,
[ 0.0000],
[-0.0178],
[ 0.0446]], device='cuda:0')