Skip to content
This repository was archived by the owner on Mar 31, 2025. It is now read-only.
This repository was archived by the owner on Mar 31, 2025. It is now read-only.

Verify connectome regularization gradients #166

@qsimeon

Description

@qsimeon

To determine if the gradients of ols_weights are being used to update your model parameters effectively, especially in the context of your detailed loss function, there are several key points to check:

1. Gradient Flow Verification

Your loss function combines reconstruction loss with L1 regularization on all parameters and a connectome-specific regularization that directly involves ols_weights. The correct gradient computation and application hinge on ensuring that each part of your loss contributes correctly to the gradient calculations:

  • Reconstruction Loss: Gradients from this part are standard and should be affecting the parameters involved in producing output from your model.
  • L1 Regularization: This affects all parameters (self.parameters()), including ols_weights if it is properly registered as a parameter. The L1 loss should contribute to gradients for all parameters uniformly.
  • Connectome Regularization: Directly involves ols_weights and should be the primary concern for the gradients of ols_weights.

2. Checking Gradients for ols_weights

After running a backward pass (loss.backward()), check the .grad attribute of ols_weights:

print("Gradient for ols_weights: ", model.ols_weights.grad)

If you see non-zero gradients for ols_weights post-backward pass, it confirms that gradients are indeed being calculated for this parameter, which should subsequently be used in the optimization step (optimizer.step()).

3. Ensure ols_weights Is a Trainable Parameter

Since ols_weights is registered as a parameter using self.register_parameter(), ensure it's done correctly:

  • It should have requires_grad set to True.
  • It should be listed in model.parameters(), which is used by your optimizer.

You can confirm this by:

print("Is 'ols_weights' a parameter?: ", any(p is model.ols_weights for p in model.parameters()))
print("Does 'ols_weights' require gradients?: ", model.ols_weights.requires_grad)

4. Optimizer and Parameter Update Check

Ensure that the optimizer is set up to include ols_weights. Since you're using model.parameters() in the optimizer setup, it should naturally include ols_weights if it’s registered correctly:

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)  # Example setup

5. Debugging Tips

If gradients are not behaving as expected:

  • Gradient Checking: Use torch.autograd.gradcheck to manually verify the gradients for a small input set.
  • Isolate the Issue: Temporarily modify your loss function to only include the connectome loss part and check if ols_weights updates as expected.
  • Logging and Monitoring: Log the values of ols_weights before and after an update step to confirm changes are being made.

6. Run a Simple Test

A practical test can be to initialize ols_weights with a specific value, run one training iteration with a known input, and check how ols_weights changes. This will help you understand if and how the updates align with your expectations based on the computed gradients.

By following these steps, you should be able to conclusively determine if ols_weights is correctly receiving and using gradients derived from your custom loss function during training.


Some real results from debugging code:

DEBUG train_model:

(after backward) Y_pred: (<ViewBackward0 object at 0x14572fe15b70>, False, True)

(after backward) model.ols_weights.grad: tensor([[-1.3481e-02, -3.2692e-03, -3.7134e-05,  ..., -4.4823e-05,
          3.5942e-05, -9.2983e-05],
        [-3.6478e-05,  9.1195e-05, -3.6240e-05,  ..., -3.7193e-05,
          3.5942e-05,  3.9697e-05],
        [-2.5868e-05,  3.6359e-05, -1.1606e-03,  ...,  3.6120e-05,
          3.5942e-05, -3.7730e-05],
        ...,
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00]], device='cuda:0', dtype=torch.float16)

model.ols_weights: (None, True, True)

(before backward) Y_pred: (<ViewBackward0 object at 0x14572fe15b70>, False, True)

(before backward) model.ols_weights.grad: None

DEBUG train_model:

(after backward) Y_pred: (<ViewBackward0 object at 0x14572fe15cf0>, False, True)

(after backward) model.ols_weights.grad: tensor([[-9.1019e-03, -4.1466e-03, -4.2081e-05,  ..., -1.1563e-04,
          6.5207e-05, -2.0838e-04],
        [-9.0301e-05,  2.9993e-04, -3.6240e-05,  ...,  4.4346e-05,
         -3.5942e-05, -3.6061e-05],
        [-3.4618e-04,  4.2140e-05, -9.2363e-04,  ..., -5.0187e-05,
          5.0962e-05, -4.1246e-05],
        ...,
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00]], device='cuda:0', dtype=torch.float16)

model.ols_weights: (None, True, True)
  File "/net/vast-storage/scratch/vast/yanglab/qsimeon/worm-graph/train/_main.py", line 286, in <module>
    model, metric = train_model(
                    ^^^^^^^^^^^^
  File "/net/vast-storage/scratch/vast/yanglab/qsimeon/worm-graph/train/_main.py", line 145, in train_model
    train_baseline = criterion(output=Y_base, target=Y_train, mask=mask_train)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/net/vast-storage/scratch/vast/yanglab/qsimeon/worm-graph/models/_utils.py", line 1043, in loss
    ols_grad = gradcheck(torch.norm, self.ols_weights**2 - (self.chem_weights + self.elec_weights))
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/qsimeon/om2/miniconda3/envs/worm-graph/lib/python3.11/site-packages/torch/autograd/gradcheck.py", line 1476, in gradcheck
    return _gradcheck_helper(**args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/qsimeon/om2/miniconda3/envs/worm-graph/lib/python3.11/site-packages/torch/autograd/gradcheck.py", line 1490, in _gradcheck_helper
    _gradcheck_real_imag(gradcheck_fn, func, func_out, tupled_inputs, outputs, eps,
  File "/home/qsimeon/om2/miniconda3/envs/worm-graph/lib/python3.11/site-packages/torch/autograd/gradcheck.py", line 1113, in _gradcheck_real_imag
    gradcheck_fn(func, func_out, tupled_inputs, outputs, eps,
  File "/home/qsimeon/om2/miniconda3/envs/worm-graph/lib/python3.11/site-packages/torch/autograd/gradcheck.py", line 1170, in _slow_gradcheck
    raise GradcheckError(_get_notallclose_msg(a, n, i, j, complex_indices, test_imag))
torch.autograd.gradcheck.GradcheckError: Jacobian mismatch for output 0 with respect to input 0,
numerical:tensor([[0.9537],
        [0.0000],
        [0.0000],
        ...,
        [0.0000],
        [0.0000],
        [0.0000]], device='cuda:0')
analytical:tensor([[ 0.0417],
        [-0.0149],
        [ 0.0000],
        ...,
        [ 0.0000],
        [-0.0178],
        [ 0.0446]], device='cuda:0')

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions