Skip to content

Commit

Permalink
popxl.ops.repeat - remove 'see also' links (#4)
Browse files Browse the repository at this point in the history
Removing "See also" links in `popxl.ops.repeat`.
  • Loading branch information
jayniep-gc authored Sep 11, 2023
1 parent 1aa2053 commit d2f20d7
Showing 1 changed file with 36 additions and 40 deletions.
76 changes: 36 additions & 40 deletions python/popxl/python_files/ops/repeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,6 @@ def build(self, x):
# repeat 8 times
y0, w0 = ops.repeat(add_weight_graph0, 8, x0, inputs_dict={add_weight0.w: w0})
See also `PyTorch Tensor.repeat <https://pytorch.org/docs/stable/generated/torch.Tensor.repeat.html>`__, `NumPy repeat <https://numpy.org/doc/stable/reference/generated/numpy.repeat.html>`__.
Args:
graph (Graph): User defined graph to repeat `repeat_count` times.
repeat_count (int): Number of times to repeat calling the graph.
Expand All @@ -74,7 +72,10 @@ def build(self, x):
Tuple[Tensor, ...]:
Tuple of the output tensors of the call in the parent graph.
"""
loop_info = repeat_with_info(graph, repeat_count, *inputs, inputs_dict=inputs_dict)
loop_info = repeat_with_info(graph,
repeat_count,
*inputs,
inputs_dict=inputs_dict)

out_tensors = loop_info.outputs
return out_tensors
Expand Down Expand Up @@ -193,8 +194,7 @@ def build(self, x):
if total_inputs < total_outputs:
raise ValueError(
f"To repeat the subgraph ({graph.id}) the number of inputs must be greater than or equal to the number of outputs."
f" {total_inputs} < {total_outputs}"
)
f" {total_inputs} < {total_outputs}")

# For clarity, we rename our graphs:
# - Bottom: The user provided bottom level graph. We call this with a call op. This has gone
Expand All @@ -215,16 +215,14 @@ def build(self, x):

# Create the middle graph, call and loop ops
pb_middle_graph, pb_callop, pb_loop_op = _setup_call_and_repeat(
pb_ir, pb_top_graph, pb_bottom_graph
)
pb_ir, pb_top_graph, pb_bottom_graph)

# set the number of times to loop
pb_loop_op.setTripCountValue(repeat_count)

# Prep and validate inputs
inputs_all = _prep_and_validate_inputs(
check_inputs, top_graph, graph, "called", inputs, inputs_dict
)
inputs_all = _prep_and_validate_inputs(check_inputs, top_graph, graph,
"called", inputs, inputs_dict)

# 1, 2. Connect inputs.
_setup_inputs(
Expand All @@ -236,9 +234,8 @@ def build(self, x):
)

# 3. Connect outputs.
_ = _setup_outputs(
pb_top_graph, pb_bottom_graph, pb_middle_graph, pb_callop, pb_loop_op
)
_ = _setup_outputs(pb_top_graph, pb_bottom_graph, pb_middle_graph,
pb_callop, pb_loop_op)

pb_callop.setup()
pb_loop_op.setup()
Expand All @@ -250,13 +247,14 @@ def build(self, x):
loop_carried_inputs = pb_loop_op.getNumExplicitInputs()
for bottom_t in bottom_graph._by_ref_inputs:
middle_t = c_info.graph_to_parent(bottom_t)
loop_carried = pb_middle_graph.getInputIndex(middle_t.id) < loop_carried_inputs
loop_carried = pb_middle_graph.getInputIndex(
middle_t.id) < loop_carried_inputs
# If a tensor was set as a by_ref_input, we should also do the same for the looped subgraph.
c_info.set_parent_input_modified(
middle_t, infer_modified_regions=not loop_carried
)
middle_t, infer_modified_regions=not loop_carried)
top_t = r_info.graph_to_parent(middle_t)
r_info.set_parent_input_modified(top_t, infer_modified_regions=not loop_carried)
r_info.set_parent_input_modified(
top_t, infer_modified_regions=not loop_carried)
r_info.called_graph._by_ref_inputs.add(middle_t)

return r_info
Expand All @@ -280,34 +278,35 @@ def _setup_call_and_repeat(
# This is the graph we will repeat.
pb_middle_graph = pb_ir.createGraph(
_ir.GraphId(
pb_ir.createUniqueSubgraphId(f"{pb_bottom_graph.id.str()}__loop_wrapper")
)
)
pb_ir.createUniqueSubgraphId(
f"{pb_bottom_graph.id.str()}__loop_wrapper")))

opid = _ir.OperatorIdentifier("ai.graphcore", "Call", 1, _ir.NumInputs(), 0)
opid = _ir.OperatorIdentifier("ai.graphcore", "Call", 1, _ir.NumInputs(),
0)
op_name = pb_middle_graph.id.str() + "__call__" + pb_bottom_graph.id.str()

ctx = get_current_context()
# Call the bottom_graph
pb_callop = pb_middle_graph.createOp_CallOp(
opid, pb_bottom_graph, ctx._get_op_settings(op_name)
)
pb_callop = pb_middle_graph.createOp_CallOp(opid, pb_bottom_graph,
ctx._get_op_settings(op_name))

opid = _ir.OperatorIdentifier("ai.onnx", "Loop", 11, _ir.NumInputs(), 0)
op_name = pb_top_graph.id.str() + "__loop__" + pb_middle_graph.id.str()

# Loop the middle_graph
pb_loop_op = pb_top_graph.createOp_LoopOp(
opid, ctx._get_op_settings(op_name), pb_middle_graph
)
pb_loop_op = pb_top_graph.createOp_LoopOp(opid,
ctx._get_op_settings(op_name),
pb_middle_graph)

# Add mandatory loop iterator tensor to graph (is not an output)
repeatIterId = _ir.addScope(pb_middle_graph, "Iterator___")
pb_middle_graph.addInput(repeatIterId, _ir.TensorInfo(_ir.DataType.INT32, ()))
pb_middle_graph.addInput(repeatIterId,
_ir.TensorInfo(_ir.DataType.INT32, ()))

# Add mandatory loop condition tensor to graph (is also an output)
repeatCondId = _ir.addScope(pb_middle_graph, "LoopCond___")
pb_middle_graph.addInput(repeatCondId, _ir.TensorInfo(_ir.DataType.BOOL, ()))
pb_middle_graph.addInput(repeatCondId,
_ir.TensorInfo(_ir.DataType.BOOL, ()))
pb_middle_graph.markAsOutput(repeatCondId)

return pb_middle_graph, pb_callop, pb_loop_op
Expand Down Expand Up @@ -354,8 +353,7 @@ def _setup_inputs(
False,
)
pb_callop.connectInTensor(
call_in_idx, _ir.addScope(pb_middle_graph, parent_tensor.name)
)
call_in_idx, _ir.addScope(pb_middle_graph, parent_tensor.name))


def _setup_outputs(
Expand Down Expand Up @@ -385,33 +383,31 @@ def _setup_outputs(

for pb_subgraph_out_id in pb_bottom_graph.getOutputIds():
top_tensor_id = _ir.addScope(
pb_top_graph, _ir.removeScope(pb_bottom_graph, pb_subgraph_out_id)
)
pb_top_graph, _ir.removeScope(pb_bottom_graph, pb_subgraph_out_id))
# Already has scope added
middle_tensor_id = _ir.removeScope(pb_bottom_graph, pb_subgraph_out_id)
bottom_tensor_id = _ir.addScope(
pb_bottom_graph, _ir.removeScope(pb_bottom_graph, pb_subgraph_out_id)
)
pb_bottom_graph,
_ir.removeScope(pb_bottom_graph, pb_subgraph_out_id))

sgOutIdx = pb_bottom_graph.getOutputIndex(bottom_tensor_id)
callOutIdx = pb_callop.subgraphOutToOpOutIndex(sgOutIdx)

# Avoid tensor name collisions
middle_tensor_id = pb_middle_graph.getIr().createIntermediateTensorId(
middle_tensor_id
)
middle_tensor_id)
pb_callop.createAndConnectOutTensor(callOutIdx, middle_tensor_id)

pb_middle_graph.markAsOutput(middle_tensor_id)
sgOutIdx = pb_middle_graph.getOutputIndex(middle_tensor_id)
repeatOutIdx = pb_loop_op.subgraphOutToOpOutIndex(sgOutIdx)
# Avoid tensor name collisions
top_tensor_id = pb_middle_graph.getIr().createIntermediateTensorId(
top_tensor_id
)
top_tensor_id)
# We overwrite here as we added the middle_tensor_id as an output above, but we want to make
# sure the loop op is setup correctly.
pb_loop_op.addLoopOutput(repeatOutIdx, top_tensor_id, middle_tensor_id, True)
pb_loop_op.addLoopOutput(repeatOutIdx, top_tensor_id, middle_tensor_id,
True)

outnames.append(top_tensor_id)
return outnames

0 comments on commit d2f20d7

Please sign in to comment.