Skip to content

Multi output fix 2 #1103

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Nov 15, 2024
Merged

Conversation

calad0i
Copy link
Contributor

@calad0i calad0i commented Oct 29, 2024

Description

Fix a couple of corner cases where model generation/compile fails. Now one can let a model output at (almost) all layers (except flatten/reshape simultaneously with its input in io_parallel).

  • Fix output node updating mechanism. The current parser assumes set(outputs)-set(inputs) are the output nodes, which is not true if some intermediate nodes are being outputted, or if any inputs is unused.
  • remove_node is updated to correctly rewire in all cases, including leaf/intermediate+leaf nodes. Deprecate arg rewire.
  • Add early stream flatten optimizer to pass original size to potential repack optimizers
  • Added 1-to-3 repack catapult template

Type of change

  • Bug fix (non-breaking change that fixes an issue)

Tests

test/pytest/test_multiout_network.py/test_multi_output_nn_2

Checklist

  • I have read the guidelines for contributing.
  • I have commented my code, particularly in hard-to-understand areas.
  • [] I have made corresponding changes to the documentation.
  • My changes generate no new warnings.
  • I have installed and run pre-commit on the files I edited or added.
  • I have added tests that prove my fix is effective or that my feature works.

@calad0i calad0i added the please test Trigger testing by creating local PR branch label Oct 29, 2024
@calad0i calad0i force-pushed the multi_output_fix_2 branch from b715dc2 to 2cdedb6 Compare October 30, 2024 00:10
@calad0i calad0i added please test Trigger testing by creating local PR branch and removed please test Trigger testing by creating local PR branch labels Oct 30, 2024
@jmitrevs jmitrevs added this to the v1.0.0 milestone Nov 7, 2024
@jmitrevs jmitrevs added please test Trigger testing by creating local PR branch and removed please test Trigger testing by creating local PR branch labels Nov 7, 2024
@calad0i calad0i force-pushed the multi_output_fix_2 branch from 8ce2a28 to 0c2e90e Compare November 8, 2024 18:32
@calad0i calad0i added please test Trigger testing by creating local PR branch and removed please test Trigger testing by creating local PR branch labels Nov 8, 2024
@jmitrevs
Copy link
Contributor

jmitrevs commented Nov 9, 2024

Note, this PR includes #1119.

jmitrevs
jmitrevs previously approved these changes Nov 9, 2024
@jmitrevs jmitrevs self-requested a review November 9, 2024 03:03
@jmitrevs jmitrevs dismissed their stale review November 9, 2024 03:04

Shouldn't have yet approved while we discuss the comments.

@jmitrevs
Copy link
Contributor

jmitrevs commented Nov 9, 2024

This looks good to me, and I would approve it, but didn't want to skip the step of looking at the above comments.

@calad0i calad0i enabled auto-merge (rebase) November 9, 2024 03:33
@calad0i calad0i removed the please test Trigger testing by creating local PR branch label Nov 9, 2024
@calad0i calad0i added the please test Trigger testing by creating local PR branch label Nov 9, 2024
Copy link
Contributor

@vloncar vloncar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks very good so far, I'd just recommend minor cosmetic changes

if n_outputs == 1:
continue
if n_outputs > 3:
msg = f'ERROR: Cloning output {output} of {node.__class__.__name__} ({node.name}) more than 3 times not currently supported' # noqa: E501
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will print the node class (a backend-specific instance) which may be confusing to the user. We have a node.class_name for this purpose.

raise ValueError(msg)

out_var = node.get_output_variable(output)
attrs = {'size': np.prod(out_var.shape)}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The output of np.prod() will be an ndarray, which works for now, but may trip up in the future (serialization and attribute validation will fail, since we would expect an integral value). Since we're already changing the code here, we might as well sort it out now.

@@ -11,14 +11,19 @@ class InplaceParallelReshape(OptimizerPass):
"""

def match(self, node):
return isinstance(node, Reshape)
if not isinstance(node, Reshape):
return
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we rewrite this so it returns booleans or perhaps combine the two checks?

prev_node = node.get_input_node()
assert (
prev_node.name not in model.outputs
), f"Cannot output node {prev_node.name}: reshape is a no-op in io_parallel. As a result, the previous node {prev_node.name}'s output will be used as the output. However, this node is already an output." # noqa: E501
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this trigger the bug/feature of black with formatting asserts, or is E501 just out of laziness? 😄

And a nitpick: We prefer single quotes wherever possible

return False
io_type = node.model.config.get_config_value('IOType')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not a one-liner? 😉

rewire (bool, optional): If `True`, connects the outputs of the previous node
to the inputs of the next node
node (Layer): The node to remove rewire (bool, optional):
Deprecated, no effect
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would marking this as deprecated according to the docs work? Would be nice to see

@pytest.mark.parametrize('backend', ['Vivado', 'Quartus', 'Vitis', 'Catapult', 'OneAPI'])
@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream'])
@pytest.mark.parametrize('strategy', ['latency', 'resource'])
def test_multi_output_nn_2(model2, data2, backend: str, io_type: str, strategy: str):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we be more descriptive instead of appending 2 to model, data and test?

@calad0i calad0i added please test Trigger testing by creating local PR branch and removed please test Trigger testing by creating local PR branch labels Nov 13, 2024
@calad0i calad0i added please test Trigger testing by creating local PR branch and removed please test Trigger testing by creating local PR branch labels Nov 13, 2024
@JanFSchulte JanFSchulte added please test Trigger testing by creating local PR branch and removed please test Trigger testing by creating local PR branch labels Nov 13, 2024
@calad0i calad0i mentioned this pull request Nov 15, 2024
2 tasks
@calad0i calad0i merged commit ef2e8f4 into fastmachinelearning:main Nov 15, 2024
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
please test Trigger testing by creating local PR branch
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants