Skip to content

Commit

Permalink
ns for fx: remove restriction on nodes with no args and only kwargs
Browse files Browse the repository at this point in the history
Summary:

Removes the restriction from NS for FX on handling nodes which have
no positional arguments, such as `F.linear(input=x, weight=w, bias=b).

In order to achieve this, we delete all places in the code which
were doing things like

```
node.args[0]
```

And replace them with

```
_get_normalized_nth_input(node, gm, 0)
```

The `_get_normalized_nth_input` function is a best effort way to
get the n'th normalized input.

This is needed because some FX tools output nodes normalized to
be kwargs only, and we need to be able to handle this in NS.

Test plan:

```
python test/test_quantization.py -k test_linear_kwargs_shadow
```

Pull Request resolved: pytorch#78181

Approved by: https://github.com/z-a-f, https://github.com/hx89
  • Loading branch information
vkuzo authored and pytorchmergebot committed May 25, 2022
1 parent 7050826 commit 53e05ad
Show file tree
Hide file tree
Showing 2 changed files with 171 additions and 69 deletions.
28 changes: 28 additions & 0 deletions test/quantization/fx/test_numeric_suite_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2034,6 +2034,34 @@ def forward(self, x):
m, (torch.randn(1, 1, 4, 4),),
results_len=0)

def test_linear_kwargs_shadow(self):

class M(nn.Module):
def __init__(self):
super().__init__()
self.w1 = nn.Parameter(torch.empty(4, 4))
self.b1 = nn.Parameter(torch.zeros(4))
torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5))

def forward(self, x):
x = F.linear(input=x, weight=self.w1, bias=self.b1)
return x

# note: FX graph mode quantization does not have good support
# for kwargs-only right now, so we pass in two unquantized
# models
m = M().eval()
mt = torch.fx.symbolic_trace(m)
mt_copy = copy.deepcopy(mt)

mt_shadows_mt_copy = add_shadow_loggers(
'a', mt, 'b', mt_copy, OutputLogger)

mt_shadows_mt_copy(torch.randn(4, 4))
act_compare_dict = extract_shadow_logger_info(
mt_shadows_mt_copy, OutputLogger, 'b')
self.assertTrue(len(act_compare_dict) == 1)


class TestFXNumericSuiteCoreAPIsModels(FXNumericSuiteQuantizationTestCase):
"""
Expand Down
Loading

0 comments on commit 53e05ad

Please sign in to comment.