-
Notifications
You must be signed in to change notification settings - Fork 7
Fix: NIRGraph error by disabling type checks #36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Nice addition, thank you for starting the PR. Your code eliminates type checking, and I was wondering whether type checking should be the default thing to do. But in this case, we're only reading the graph and don't really care if it type checks. That said, there may be useful information in the fact that it doesn't type check. What about doing something like this to help inform the user that there may be a problem, if the conversion fails? recon_graph = nir.NIRGraph(nodes, sanitized_edges, type_check=False)
is_well_typed = False
error = None
try:
is_well_typed = recon_graph._check_types()
except Exception as e:
error = e
finally:
if not is_well_typed:
warnings.warn("The graph failed to type check. If you run into problems, consider validating the graph", error)
... That would give the user a chance to react to the fact that something in the graph flow may be wrong. Of course, we can wrap that in an if statement if people want to fully ignore the error. For instance, by adding an argument to the |
I have two thoughts on this: First, I believe this might not be the right place to raise an exception, since we expect the type check to fail—after all, the nodes are torch.nn.Modules, not NIRNodes. So in that sense, the failure is somewhat expected. Second, if we look at this from a framework design perspective, it's critical that the NIRGraph is constructed correctly. Users of the framework shouldn't have to second-guess whether the graph is consistent with its intended design. This was something I struggled with when I first started working with Norse and NIR—it wasn't always clear whether things had been set up properly, and the lack of guarantees made debugging harder. So my opinion is that for this case, where the NIRGraph holds torch.nn.modules, it suffices to just deactivate the type check. That said, I'm not entirely sure I understand the broader use cases you had in mind for this. Could you clarify in which other scenarios you’re planning to use this logic? That might help me better understand whether relaxing the check is appropriate more generally. |
Agreed. But the code doesn't raise an exception, it just states an error.
I fully agree that it's a monumental headache to have incorrect/inconsistent graphs. The problem is that some of those problems can be hidden, and my core motivation is basically to ensure that any such problems propagate to the surface. My main scenario involves different versions of NIR graphs. The early specification didn't include Input/Output nodes, which meant we had no way to (type) check the graphs. Which, in turn, made it almost impossible to validate a graph without manual inspection. My thinking is that (1) future changes will cause things to break and that it should (2) be very visible to the user. Particularly when working with NIRTorch which is kind of a high-level framework where graph errors could easily disappear. That said, if you still don't think we should inform the user, we can try to implement this as-is and see how it goes. Maybe it won't be a problem at all because the tracing catches some of the issues? |
Ah yes, my fault. Your comment really made me think—thank you for that! It made me consider aspects I hadn't before. That said, I believe this goes a bit beyond the scope of this pull request. In this specific case, the issue is that even if we assume NIRTorch receives a correctly defined NIRGraph, it still breaks when the graph contains torch.nn.Module instances as nodes. This PR addresses that problem in a straightforward way. Regarding your point about the user potentially not being informed: I thought that was exactly the reason the type check was introduced in the first place? I’m not entirely sure I understand what you mean—sorry if I’m missing your point. Maybe we can go ahead and merge this, and open a follow-up issue in the NIR repository to discuss your broader concerns? |
Agreed! This is much better than status quo. I agree that we should merge it, but can I ask you to take a look at the tests? They seem to be failing. (Note that I bumped the recent NIR version to include the type checking argument in the NIRGraph constructor).
Yes, the reason the type check was informed was to prevent the user from doing something weird with a graph that wasn't well-formed. In particular, where there isn't enough information to make the parsing meaningful (e. g. in the case of a missing input node where the "beginning" becomes ambiguous). I think that information is valuable in itself. And it's strictly independent of much of the parsing in NIRTorch, which is PyTorch-related. My only point was that there may be reasons to include that type check by default. But I also agree it's a separate issue, which I've raised in #37 Finally, I wanted to let you know that I really appreciate your efforts here. Both in discussing this with us (despite the sometimes long history) and seeing the code changes through. I've invited you as a contributor on both the NIR and NIRTorch repos, so you can make the changes in local branches instead, if you prefer. That also makes it easier for us to contribute/assist. |
Yes, let me take a look at the tests. :) Awesome, thanks for adding me as a collaborator! 🙌 |
Okay, so the tests are working now. As expected, the type checking of the NIRGraph was the reason for the failing tests. |
I think that's a great compromise for now.
Yes, I'd probably leave it around for a while. It's deprecated, but we're issuing loud warnings.
Hmm, if I recall correctly, these are used as placeholder nodes just before the PyTorch graph is traced. If that's the case, I don't think the shape matters because PyTorch doesn't care about NIR types and the NIR nodes themselves won't be type-checked in the torch.fx construction. |
I think @benkroehs talks about this code, after the torch model has been traced: NIRTorch/nirtorch/torch_tracer.py Line 220 in c5b1015
The shapes of the Input/Output nodes matter because they are validated in _check_types() that is called by the NIRGraph constructor in the code above.
We could either disable the type checking like we do in other parts of the pull request, but at least for me, inferring the types of the Input and Output nodes from the follower/predecessor nodes seems to work. Having a type check on the imported graph is certainly a good idea. Something like that, directly before creating the graph seems to work: # Infer input types from follower nodes
for node_name, node in nodes.items():
if isinstance(node, nir.Input):
follower_edges = [edge for edge in edges if edge[0] == node_name]
follower_nodes = [nodes[edge[1]] for edge in follower_edges]
# Check that all follower nodes have the same input_type
first_input_type = None
for follower_node in follower_nodes:
if first_input_type is None:
first_input_type = follower_node.input_type
else:
# Verify they match (code taken from to NIRGraph._check_types)
if len(first_input_type) != len(follower_node.input_type):
raise ValueError(f"Input type length mismatch for followers of {node_name}")
if len(first_input_type.keys()) == 1:
first_type = list(first_input_type.values())[0]
follower_type = list(follower_node.input_type.values())[0]
if not np.array_equal(first_type, follower_type):
raise ValueError(f"Input type mismatch for followers of {node_name}: {first_type} vs {follower_type}")
else:
raise NotImplementedError("Multiple input/output types not supported yet")
# Update the input node's input_type if we found a valid type
if first_input_type is not None:
nodes[node_name] = nir.Input(first_input_type)
# Similar logic for Output nodes - infer from predecessor nodes
for node_name, node in nodes.items():
if isinstance(node, nir.Output):
# Find all edges where this output node is the target
predecessor_edges = [edge for edge in edges if edge[1] == node_name]
predecessor_nodes = [nodes[edge[0]] for edge in predecessor_edges]
# Check that all predecessor nodes have the same output_type
first_output_type = None
for predecessor_node in predecessor_nodes:
if first_output_type is None:
first_output_type = predecessor_node.output_type
else:
# Verify they match (code taken from to NIRGraph._check_types)
if len(first_output_type) != len(predecessor_node.output_type):
raise ValueError(f"Output type length mismatch for predecessors of {node_name}")
if len(first_output_type.keys()) == 1:
first_type = list(first_output_type.values())[0]
predecessor_type = list(predecessor_node.output_type.values())[0]
if not np.array_equal(first_type, predecessor_type):
raise ValueError(f"Output type mismatch for predecessors of {node_name}: {first_type} vs {predecessor_type}")
else:
raise NotImplementedError("Multiple input/output types not supported yet")
# Update the output node's output_type if we found a valid type
if first_output_type is not None:
nodes[node_name] = nir.Output(first_output_type) I am not sure if we actually support multiple sucessor nodes from one input/multiple output, but it's handled anyway. |
Follow up from above after some more investigation: I believe the correct approach would be:
Edit: Bad idea, the whole point of the forward_type_inference is that the input type of the graph is correct and we use that to update the rest of the graph. My current approach:
|
Thank you @fabio-innatera for addressing the follow-up issues concerning the |
This fixes the issue #35, originally adressed in a NIR issue