-
Notifications
You must be signed in to change notification settings - Fork 451
Multi output fix 2 #1103
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multi output fix 2 #1103
Conversation
b715dc2
to
2cdedb6
Compare
…nto multi_output_fix_2
8ce2a28
to
0c2e90e
Compare
Note, this PR includes #1119. |
Shouldn't have yet approved while we discuss the comments.
This looks good to me, and I would approve it, but didn't want to skip the step of looking at the above comments. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks very good so far, I'd just recommend minor cosmetic changes
hls4ml/backends/fpga/passes/clone.py
Outdated
if n_outputs == 1: | ||
continue | ||
if n_outputs > 3: | ||
msg = f'ERROR: Cloning output {output} of {node.__class__.__name__} ({node.name}) more than 3 times not currently supported' # noqa: E501 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will print the node class (a backend-specific instance) which may be confusing to the user. We have a node.class_name
for this purpose.
hls4ml/backends/fpga/passes/clone.py
Outdated
raise ValueError(msg) | ||
|
||
out_var = node.get_output_variable(output) | ||
attrs = {'size': np.prod(out_var.shape)} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The output of np.prod()
will be an ndarray
, which works for now, but may trip up in the future (serialization and attribute validation will fail, since we would expect an integral value). Since we're already changing the code here, we might as well sort it out now.
@@ -11,14 +11,19 @@ class InplaceParallelReshape(OptimizerPass): | |||
""" | |||
|
|||
def match(self, node): | |||
return isinstance(node, Reshape) | |||
if not isinstance(node, Reshape): | |||
return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we rewrite this so it returns booleans or perhaps combine the two checks?
prev_node = node.get_input_node() | ||
assert ( | ||
prev_node.name not in model.outputs | ||
), f"Cannot output node {prev_node.name}: reshape is a no-op in io_parallel. As a result, the previous node {prev_node.name}'s output will be used as the output. However, this node is already an output." # noqa: E501 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this trigger the bug/feature of black with formatting asserts, or is E501 just out of laziness? 😄
And a nitpick: We prefer single quotes wherever possible
return False | ||
io_type = node.model.config.get_config_value('IOType') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not a one-liner? 😉
hls4ml/model/graph.py
Outdated
rewire (bool, optional): If `True`, connects the outputs of the previous node | ||
to the inputs of the next node | ||
node (Layer): The node to remove rewire (bool, optional): | ||
Deprecated, no effect |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would marking this as deprecated according to the docs work? Would be nice to see
test/pytest/test_multiout_network.py
Outdated
@pytest.mark.parametrize('backend', ['Vivado', 'Quartus', 'Vitis', 'Catapult', 'OneAPI']) | ||
@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) | ||
@pytest.mark.parametrize('strategy', ['latency', 'resource']) | ||
def test_multi_output_nn_2(model2, data2, backend: str, io_type: str, strategy: str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we be more descriptive instead of appending 2
to model, data and test?
e63225a
to
d2d907b
Compare
d2d907b
to
e5023c8
Compare
b1e7349
to
45f96f6
Compare
Description
Fix a couple of corner cases where model generation/compile fails. Now one can let a model output at (almost) all layers (except flatten/reshape simultaneously with its input in
io_parallel
).set(outputs)-set(inputs)
are the output nodes, which is not true if some intermediate nodes are being outputted, or if any inputs is unused.remove_node
is updated to correctly rewire in all cases, including leaf/intermediate+leaf nodes. Deprecate argrewire
.Type of change
Tests
test/pytest/test_multiout_network.py/test_multi_output_nn_2
Checklist
pre-commit
on the files I edited or added.