Skip to content

Conversation

benkroehs
Copy link
Collaborator

@benkroehs benkroehs commented May 21, 2025

This fixes the issue #35, originally adressed in a NIR issue

@Jegp
Copy link
Collaborator

Jegp commented May 26, 2025

Nice addition, thank you for starting the PR.

Your code eliminates type checking, and I was wondering whether type checking should be the default thing to do. But in this case, we're only reading the graph and don't really care if it type checks. That said, there may be useful information in the fact that it doesn't type check.

What about doing something like this to help inform the user that there may be a problem, if the conversion fails?

recon_graph = nir.NIRGraph(nodes, sanitized_edges, type_check=False)
is_well_typed = False
error = None
try:
    is_well_typed = recon_graph._check_types()
except Exception as e:
    error = e
finally:
    if not is_well_typed:
        warnings.warn("The graph failed to type check. If you run into problems, consider validating the graph", error)
...

That would give the user a chance to react to the fact that something in the graph flow may be wrong.

Of course, we can wrap that in an if statement if people want to fully ignore the error. For instance, by adding an argument to the load (e. g. warn_on_type_error: bool = True) that propagates to _map_graph_to_torch.

@benkroehs
Copy link
Collaborator Author

I have two thoughts on this:

First, I believe this might not be the right place to raise an exception, since we expect the type check to fail—after all, the nodes are torch.nn.Modules, not NIRNodes. So in that sense, the failure is somewhat expected.

Second, if we look at this from a framework design perspective, it's critical that the NIRGraph is constructed correctly. Users of the framework shouldn't have to second-guess whether the graph is consistent with its intended design. This was something I struggled with when I first started working with Norse and NIR—it wasn't always clear whether things had been set up properly, and the lack of guarantees made debugging harder.

So my opinion is that for this case, where the NIRGraph holds torch.nn.modules, it suffices to just deactivate the type check. That said, I'm not entirely sure I understand the broader use cases you had in mind for this. Could you clarify in which other scenarios you’re planning to use this logic? That might help me better understand whether relaxing the check is appropriate more generally.

@Jegp
Copy link
Collaborator

Jegp commented May 26, 2025

I have two thoughts on this:

First, I believe this might not be the right place to raise an exception, since we expect the type check to fail—after all, the nodes are torch.nn.Modules, not NIRNodes. So in that sense, the failure is somewhat expected.

Agreed. But the code doesn't raise an exception, it just states an error.

Second, if we look at this from a framework design perspective, it's critical that the NIRGraph is constructed correctly. Users of the framework shouldn't have to second-guess whether the graph is consistent with its intended design. This was something I struggled with when I first started working with Norse and NIR—it wasn't always clear whether things had been set up properly, and the lack of guarantees made debugging harder.

So my opinion is that for this case, where the NIRGraph holds torch.nn.modules, it suffices to just deactivate the type check. That said, I'm not entirely sure I understand the broader use cases you had in mind for this. Could you clarify in which other scenarios you’re planning to use this logic? That might help me better understand whether relaxing the check is appropriate more generally.

I fully agree that it's a monumental headache to have incorrect/inconsistent graphs. The problem is that some of those problems can be hidden, and my core motivation is basically to ensure that any such problems propagate to the surface. My main scenario involves different versions of NIR graphs. The early specification didn't include Input/Output nodes, which meant we had no way to (type) check the graphs. Which, in turn, made it almost impossible to validate a graph without manual inspection. My thinking is that (1) future changes will cause things to break and that it should (2) be very visible to the user. Particularly when working with NIRTorch which is kind of a high-level framework where graph errors could easily disappear.

That said, if you still don't think we should inform the user, we can try to implement this as-is and see how it goes. Maybe it won't be a problem at all because the tracing catches some of the issues?

@benkroehs
Copy link
Collaborator Author

Agreed. But the code doesn't raise an exception, it just states an error.

Ah yes, my fault.

Your comment really made me think—thank you for that! It made me consider aspects I hadn't before. That said, I believe this goes a bit beyond the scope of this pull request.

In this specific case, the issue is that even if we assume NIRTorch receives a correctly defined NIRGraph, it still breaks when the graph contains torch.nn.Module instances as nodes. This PR addresses that problem in a straightforward way.

Regarding your point about the user potentially not being informed: I thought that was exactly the reason the type check was introduced in the first place? I’m not entirely sure I understand what you mean—sorry if I’m missing your point.

Maybe we can go ahead and merge this, and open a follow-up issue in the NIR repository to discuss your broader concerns?

@Jegp
Copy link
Collaborator

Jegp commented May 27, 2025

In this specific case, the issue is that even if we assume NIRTorch receives a correctly defined NIRGraph, it still breaks when the graph contains torch.nn.Module instances as nodes. This PR addresses that problem in a straightforward way.
Maybe we can go ahead and merge this, and open a follow-up issue in the NIR repository to discuss your broader concerns?

Agreed! This is much better than status quo. I agree that we should merge it, but can I ask you to take a look at the tests? They seem to be failing. (Note that I bumped the recent NIR version to include the type checking argument in the NIRGraph constructor).

Regarding your point about the user potentially not being informed: I thought that was exactly the reason the type check was introduced in the first place? I’m not entirely sure I understand what you mean—sorry if I’m missing your point.

Yes, the reason the type check was informed was to prevent the user from doing something weird with a graph that wasn't well-formed. In particular, where there isn't enough information to make the parsing meaningful (e. g. in the case of a missing input node where the "beginning" becomes ambiguous). I think that information is valuable in itself. And it's strictly independent of much of the parsing in NIRTorch, which is PyTorch-related. My only point was that there may be reasons to include that type check by default. But I also agree it's a separate issue, which I've raised in #37

Finally, I wanted to let you know that I really appreciate your efforts here. Both in discussing this with us (despite the sometimes long history) and seeing the code changes through. I've invited you as a contributor on both the NIR and NIRTorch repos, so you can make the changes in local branches instead, if you prefer. That also makes it easier for us to contribute/assist.

@benkroehs
Copy link
Collaborator Author

Yes, let me take a look at the tests. :)

Awesome, thanks for adding me as a collaborator! 🙌

@benkroehs
Copy link
Collaborator Author

Okay, so the tests are working now. As expected, the type checking of the NIRGraph was the reason for the failing tests.
I had to deactivate type checking in general for the extract_nir_graph, since it doesn't return nir.Input or nir.Output nodes.
Although this function is deprecated, we might want to keep those tests for now since e.g. Norse also still uses this. The new torch_to_nir function kind of fixes this by adding a nir.Input and nir.Output node but both with hard coded shape, so that doesn't fully solve the problem either. I’m not sure yet how to fix this properly Any thoughts, @Jegp?
Should we merge this and open a follow-up PR or keep this open and fix here later? I'm fine with either option. :)

@Jegp
Copy link
Collaborator

Jegp commented Jun 2, 2025

Okay, so the tests are working now. As expected, the type checking of the NIRGraph was the reason for the failing tests. I had to deactivate type checking in general for the extract_nir_graph, since it doesn't return nir.Input or nir.Output nodes.

I think that's a great compromise for now.

Although this function is deprecated, we might want to keep those tests for now since e.g. Norse also still uses this.

Yes, I'd probably leave it around for a while. It's deprecated, but we're issuing loud warnings.

The new torch_to_nir function kind of fixes this by adding a nir.Input and nir.Output node but both with hard coded shape, so that doesn't fully solve the problem either. I’m not sure yet how to fix this properly Any thoughts, @Jegp? Should we merge this and open a follow-up PR or keep this open and fix here later? I'm fine with either option. :)

Hmm, if I recall correctly, these are used as placeholder nodes just before the PyTorch graph is traced. If that's the case, I don't think the shape matters because PyTorch doesn't care about NIR types and the NIR nodes themselves won't be type-checked in the torch.fx construction.

@fabio-innatera
Copy link
Collaborator

fabio-innatera commented Jun 5, 2025

The new torch_to_nir function kind of fixes this by adding a nir.Input and nir.Output node but both with hard coded shape, so that doesn't fully solve the problem either. I’m not sure yet how to fix this properly Any thoughts, @Jegp? Should we merge this and open a follow-up PR or keep this open and fix here later? I'm fine with either option. :)

Hmm, if I recall correctly, these are used as placeholder nodes just before the PyTorch graph is traced. If that's the case, I don't think the shape matters because PyTorch doesn't care about NIR types and the NIR nodes themselves won't be type-checked in the torch.fx construction.

I think @benkroehs talks about this code, after the torch model has been traced:

graph = nir.NIRGraph(nodes=nodes, edges=edges)

The shapes of the Input/Output nodes matter because they are validated in _check_types() that is called by the NIRGraph constructor in the code above.

We could either disable the type checking like we do in other parts of the pull request, but at least for me, inferring the types of the Input and Output nodes from the follower/predecessor nodes seems to work. Having a type check on the imported graph is certainly a good idea. Something like that, directly before creating the graph seems to work:

    # Infer input types from follower nodes
    for node_name, node in nodes.items():
        if isinstance(node, nir.Input):
            follower_edges = [edge for edge in edges if edge[0] == node_name]
            follower_nodes = [nodes[edge[1]] for edge in follower_edges]

            # Check that all follower nodes have the same input_type
            first_input_type = None
            for follower_node in follower_nodes:
                if first_input_type is None:
                    first_input_type = follower_node.input_type
                else:
                    # Verify they match (code taken from to NIRGraph._check_types)
                    if len(first_input_type) != len(follower_node.input_type):
                        raise ValueError(f"Input type length mismatch for followers of {node_name}")

                    if len(first_input_type.keys()) == 1:
                        first_type = list(first_input_type.values())[0]
                        follower_type = list(follower_node.input_type.values())[0]
                        if not np.array_equal(first_type, follower_type):
                            raise ValueError(f"Input type mismatch for followers of {node_name}: {first_type} vs {follower_type}")
                    else:
                        raise NotImplementedError("Multiple input/output types not supported yet")

            # Update the input node's input_type if we found a valid type
            if first_input_type is not None:
                nodes[node_name] = nir.Input(first_input_type)

    # Similar logic for Output nodes - infer from predecessor nodes
    for node_name, node in nodes.items():
        if isinstance(node, nir.Output):
            # Find all edges where this output node is the target
            predecessor_edges = [edge for edge in edges if edge[1] == node_name]
            predecessor_nodes = [nodes[edge[0]] for edge in predecessor_edges]

            # Check that all predecessor nodes have the same output_type
            first_output_type = None
            for predecessor_node in predecessor_nodes:
                if first_output_type is None:
                    first_output_type = predecessor_node.output_type
                else:
                    # Verify they match (code taken from to NIRGraph._check_types)
                    if len(first_output_type) != len(predecessor_node.output_type):
                        raise ValueError(f"Output type length mismatch for predecessors of {node_name}")

                    if len(first_output_type.keys()) == 1:
                        first_type = list(first_output_type.values())[0]
                        predecessor_type = list(predecessor_node.output_type.values())[0]
                        if not np.array_equal(first_type, predecessor_type):
                            raise ValueError(f"Output type mismatch for predecessors of {node_name}: {first_type} vs {predecessor_type}")
                    else:
                        raise NotImplementedError("Multiple input/output types not supported yet")

            # Update the output node's output_type if we found a valid type
            if first_output_type is not None:
                nodes[node_name] = nir.Output(first_output_type)

I am not sure if we actually support multiple sucessor nodes from one input/multiple output, but it's handled anyway.
I can open a separate PR for this, or @benkroehs you can add this code to this one if you want.

@fabio-innatera
Copy link
Collaborator

fabio-innatera commented Jun 5, 2025

Follow up from above after some more investigation:
We do a graph.infer_types() right after constructing the graph.
infer_types does the wrong thing and overrides the input_type of the first node after input with np.array([1]), i.e. the output_type of the Input node.
The same algorithm does the right thing for the Output node.

I believe the correct approach would be:

  1. Run my code snippet above to infer input/output types
  2. Construct the graph with type_check=False
  3. Run infer_types on the graph
  4. Optional: Run _check_types() on the graph. We should already have a valid graph after infer_types, so this is probably not necessary.

Alternatively, we could special case infer_types() for Inputs only, in which case it would get the type from the follower node. Maybe this would be the better approach, this would avoid using my code snippet above. This depends a bit on where else infer_types is used in the NIR ecosystem, because we would change behavior.

Edit: Bad idea, the whole point of the forward_type_inference is that the input type of the graph is correct and we use that to update the rest of the graph.

My current approach:

@benkroehs
Copy link
Collaborator Author

benkroehs commented Jun 10, 2025

Thank you @fabio-innatera for addressing the follow-up issues concerning the torch_to_nir and well as the nir_to_torch conversions.
Since this is handled elsewhere now, this fix here for the deprecated conversion functions and the corresponding tests can be merged now.

@benkroehs benkroehs merged commit ceaca02 into neuromorphs:main Jun 12, 2025
15 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants