Skip to content

Commit

Permalink
cleanup code comments analytical Jacobian as vjp projection (pytorch#…
Browse files Browse the repository at this point in the history
…117483)

Cleanup code comments for `_compute_analytical_jacobian_rows` to make clear Jacobian is computed by standard basis vector projections using the vector-Jacobian-product operation.

Pull Request resolved: pytorch#117483
Approved by: https://github.com/soulitzer
  • Loading branch information
redwrasse authored and pytorchmergebot committed Jan 19, 2024
1 parent 40dbd56 commit ab216bb
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions torch/autograd/gradcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,7 +754,7 @@ def _check_analytical_jacobian_attributes(
inputs, output, nondet_tol, check_grad_dtypes, fast_mode=False, v=None
) -> Tuple[torch.Tensor, ...]:
# This is used by both fast and slow mode:
# - For slow mode, vjps[i][j] is the jth row the Jacobian wrt the ith
# - For slow mode, vjps[i][j] is the jth row of the Jacobian wrt the ith
# input.
# - For fast mode, vjps[i][0] is a linear combination of the rows
# of the Jacobian wrt the ith input
Expand Down Expand Up @@ -873,19 +873,20 @@ def _get_analytical_jacobian(inputs, outputs, input_idx, output_idx):
def _compute_analytical_jacobian_rows(
vjp_fn, sample_output
) -> List[List[Optional[torch.Tensor]]]:
# Computes Jacobian row-by-row using backward function `vjp_fn` = v^T J
# Computes Jacobian row-by-row by projecting `vjp_fn` = v^T J on standard basis
# vectors: vjp_fn(e) = e^T J is a corresponding row of the Jacobian.
# NB: this function does not assume vjp_fn(v) to return tensors with the same
# number of elements for different v. This is checked when we later combine the
# rows into a single tensor.
grad_out_base = torch.zeros_like(
sample_output, memory_format=torch.legacy_contiguous_format
)
flat_grad_out = grad_out_base.view(-1)
# jacobians_rows[i][j] represents the jth row of the ith input
# jacobians_rows[i][j] is the Jacobian jth row for the ith input
jacobians_rows: List[List[Optional[torch.Tensor]]] = []
for j in range(flat_grad_out.numel()):
flat_grad_out.zero_()
flat_grad_out[j] = 1.0
flat_grad_out[j] = 1.0 # projection for jth row of Jacobian
grad_inputs = vjp_fn(grad_out_base)
for i, d_x in enumerate(grad_inputs):
if j == 0:
Expand Down

0 comments on commit ab216bb

Please sign in to comment.