diff --git a/python/popxl/python_files/ops/repeat.py b/python/popxl/python_files/ops/repeat.py index ad9dab9a3..ef8e48aaa 100644 --- a/python/popxl/python_files/ops/repeat.py +++ b/python/popxl/python_files/ops/repeat.py @@ -54,8 +54,6 @@ def build(self, x): # repeat 8 times y0, w0 = ops.repeat(add_weight_graph0, 8, x0, inputs_dict={add_weight0.w: w0}) - See also `PyTorch Tensor.repeat `__, `NumPy repeat `__. - Args: graph (Graph): User defined graph to repeat `repeat_count` times. repeat_count (int): Number of times to repeat calling the graph. @@ -74,7 +72,10 @@ def build(self, x): Tuple[Tensor, ...]: Tuple of the output tensors of the call in the parent graph. """ - loop_info = repeat_with_info(graph, repeat_count, *inputs, inputs_dict=inputs_dict) + loop_info = repeat_with_info(graph, + repeat_count, + *inputs, + inputs_dict=inputs_dict) out_tensors = loop_info.outputs return out_tensors @@ -193,8 +194,7 @@ def build(self, x): if total_inputs < total_outputs: raise ValueError( f"To repeat the subgraph ({graph.id}) the number of inputs must be greater than or equal to the number of outputs." - f" {total_inputs} < {total_outputs}" - ) + f" {total_inputs} < {total_outputs}") # For clarity, we rename our graphs: # - Bottom: The user provided bottom level graph. We call this with a call op. This has gone @@ -215,16 +215,14 @@ def build(self, x): # Create the middle graph, call and loop ops pb_middle_graph, pb_callop, pb_loop_op = _setup_call_and_repeat( - pb_ir, pb_top_graph, pb_bottom_graph - ) + pb_ir, pb_top_graph, pb_bottom_graph) # set the number of times to loop pb_loop_op.setTripCountValue(repeat_count) # Prep and validate inputs - inputs_all = _prep_and_validate_inputs( - check_inputs, top_graph, graph, "called", inputs, inputs_dict - ) + inputs_all = _prep_and_validate_inputs(check_inputs, top_graph, graph, + "called", inputs, inputs_dict) # 1, 2. Connect inputs. _setup_inputs( @@ -236,9 +234,8 @@ def build(self, x): ) # 3. Connect outputs. - _ = _setup_outputs( - pb_top_graph, pb_bottom_graph, pb_middle_graph, pb_callop, pb_loop_op - ) + _ = _setup_outputs(pb_top_graph, pb_bottom_graph, pb_middle_graph, + pb_callop, pb_loop_op) pb_callop.setup() pb_loop_op.setup() @@ -250,13 +247,14 @@ def build(self, x): loop_carried_inputs = pb_loop_op.getNumExplicitInputs() for bottom_t in bottom_graph._by_ref_inputs: middle_t = c_info.graph_to_parent(bottom_t) - loop_carried = pb_middle_graph.getInputIndex(middle_t.id) < loop_carried_inputs + loop_carried = pb_middle_graph.getInputIndex( + middle_t.id) < loop_carried_inputs # If a tensor was set as a by_ref_input, we should also do the same for the looped subgraph. c_info.set_parent_input_modified( - middle_t, infer_modified_regions=not loop_carried - ) + middle_t, infer_modified_regions=not loop_carried) top_t = r_info.graph_to_parent(middle_t) - r_info.set_parent_input_modified(top_t, infer_modified_regions=not loop_carried) + r_info.set_parent_input_modified( + top_t, infer_modified_regions=not loop_carried) r_info.called_graph._by_ref_inputs.add(middle_t) return r_info @@ -280,34 +278,35 @@ def _setup_call_and_repeat( # This is the graph we will repeat. pb_middle_graph = pb_ir.createGraph( _ir.GraphId( - pb_ir.createUniqueSubgraphId(f"{pb_bottom_graph.id.str()}__loop_wrapper") - ) - ) + pb_ir.createUniqueSubgraphId( + f"{pb_bottom_graph.id.str()}__loop_wrapper"))) - opid = _ir.OperatorIdentifier("ai.graphcore", "Call", 1, _ir.NumInputs(), 0) + opid = _ir.OperatorIdentifier("ai.graphcore", "Call", 1, _ir.NumInputs(), + 0) op_name = pb_middle_graph.id.str() + "__call__" + pb_bottom_graph.id.str() ctx = get_current_context() # Call the bottom_graph - pb_callop = pb_middle_graph.createOp_CallOp( - opid, pb_bottom_graph, ctx._get_op_settings(op_name) - ) + pb_callop = pb_middle_graph.createOp_CallOp(opid, pb_bottom_graph, + ctx._get_op_settings(op_name)) opid = _ir.OperatorIdentifier("ai.onnx", "Loop", 11, _ir.NumInputs(), 0) op_name = pb_top_graph.id.str() + "__loop__" + pb_middle_graph.id.str() # Loop the middle_graph - pb_loop_op = pb_top_graph.createOp_LoopOp( - opid, ctx._get_op_settings(op_name), pb_middle_graph - ) + pb_loop_op = pb_top_graph.createOp_LoopOp(opid, + ctx._get_op_settings(op_name), + pb_middle_graph) # Add mandatory loop iterator tensor to graph (is not an output) repeatIterId = _ir.addScope(pb_middle_graph, "Iterator___") - pb_middle_graph.addInput(repeatIterId, _ir.TensorInfo(_ir.DataType.INT32, ())) + pb_middle_graph.addInput(repeatIterId, + _ir.TensorInfo(_ir.DataType.INT32, ())) # Add mandatory loop condition tensor to graph (is also an output) repeatCondId = _ir.addScope(pb_middle_graph, "LoopCond___") - pb_middle_graph.addInput(repeatCondId, _ir.TensorInfo(_ir.DataType.BOOL, ())) + pb_middle_graph.addInput(repeatCondId, + _ir.TensorInfo(_ir.DataType.BOOL, ())) pb_middle_graph.markAsOutput(repeatCondId) return pb_middle_graph, pb_callop, pb_loop_op @@ -354,8 +353,7 @@ def _setup_inputs( False, ) pb_callop.connectInTensor( - call_in_idx, _ir.addScope(pb_middle_graph, parent_tensor.name) - ) + call_in_idx, _ir.addScope(pb_middle_graph, parent_tensor.name)) def _setup_outputs( @@ -385,21 +383,19 @@ def _setup_outputs( for pb_subgraph_out_id in pb_bottom_graph.getOutputIds(): top_tensor_id = _ir.addScope( - pb_top_graph, _ir.removeScope(pb_bottom_graph, pb_subgraph_out_id) - ) + pb_top_graph, _ir.removeScope(pb_bottom_graph, pb_subgraph_out_id)) # Already has scope added middle_tensor_id = _ir.removeScope(pb_bottom_graph, pb_subgraph_out_id) bottom_tensor_id = _ir.addScope( - pb_bottom_graph, _ir.removeScope(pb_bottom_graph, pb_subgraph_out_id) - ) + pb_bottom_graph, + _ir.removeScope(pb_bottom_graph, pb_subgraph_out_id)) sgOutIdx = pb_bottom_graph.getOutputIndex(bottom_tensor_id) callOutIdx = pb_callop.subgraphOutToOpOutIndex(sgOutIdx) # Avoid tensor name collisions middle_tensor_id = pb_middle_graph.getIr().createIntermediateTensorId( - middle_tensor_id - ) + middle_tensor_id) pb_callop.createAndConnectOutTensor(callOutIdx, middle_tensor_id) pb_middle_graph.markAsOutput(middle_tensor_id) @@ -407,11 +403,11 @@ def _setup_outputs( repeatOutIdx = pb_loop_op.subgraphOutToOpOutIndex(sgOutIdx) # Avoid tensor name collisions top_tensor_id = pb_middle_graph.getIr().createIntermediateTensorId( - top_tensor_id - ) + top_tensor_id) # We overwrite here as we added the middle_tensor_id as an output above, but we want to make # sure the loop op is setup correctly. - pb_loop_op.addLoopOutput(repeatOutIdx, top_tensor_id, middle_tensor_id, True) + pb_loop_op.addLoopOutput(repeatOutIdx, top_tensor_id, middle_tensor_id, + True) outnames.append(top_tensor_id) return outnames