Skip to content
This repository was archived by the owner on Nov 7, 2024. It is now read-only.

Fixes for contract_between(). #421

Merged
merged 2 commits into from
Jan 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 59 additions & 50 deletions tensornetwork/network_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -1753,6 +1753,10 @@ def contract_between(
) -> BaseNode:
"""Contract all of the edges between the two given nodes.

If `output_edge_order` is not set, the output axes will be ordered as:
[...free axes of `node1`..., ...free axes of `node2`...]. Within the axes
of each node, the input order is preserved.

Args:
node1: The first node.
node2: The second node.
Expand All @@ -1764,7 +1768,8 @@ def contract_between(
contain all edges belonging to, but not shared by `node1` and `node2`.
The axes of the new node will be permuted (if necessary) to match this
ordering of Edges.
axis_names: An optional list of names for the axis of the new node
axis_names: An optional list of names for the axis of the new node in order
of the output axes.
Returns:
The new node created.

Expand All @@ -1784,64 +1789,68 @@ def contract_between(
node2.backend.name))

backend = node1.backend
shared_edges = get_shared_edges(node1, node2)
# Trace edges cannot be contracted using tensordot.
if node1 is node2:
flat_edge = flatten_edges_between(node1, node2)
if not flat_edge:
raise ValueError("No trace edges found on contraction of edges between "
"node '{}' and itself.".format(node1))
return contract(flat_edge, name)

shared_edges = get_shared_edges(node1, node2)
if not shared_edges:
if allow_outer_product:
return outer_product(node1, node2, name=name, axis_names=axis_names)
raise ValueError("No edges found between nodes '{}' and '{}' "
"and allow_outer_product=False.".format(node1, node2))

# Collect the axis of each node corresponding to each edge, in order.
# This specifies the contraction for tensordot.
# NOTE: The ordering of node references in each contraction edge is ignored.
axes1 = []
axes2 = []
for edge in shared_edges:
if edge.node1 is node1:
axes1.append(edge.axis1)
axes2.append(edge.axis2)
else:
axes1.append(edge.axis2)
axes2.append(edge.axis1)

if output_edge_order:
# Determine heuristically if output transposition can be minimized by
# flipping the arguments to tensordot.
node1_output_axes = []
node2_output_axes = []
for (i, edge) in enumerate(output_edge_order):
if edge in shared_edges:
raise ValueError(
"Edge '{}' in output_edge_order is shared by the nodes to be "
"contracted: '{}' and '{}'.".format(edge, node1, node2))
edge_nodes = set(edge.get_nodes())
if node1 in edge_nodes:
node1_output_axes.append(i)
elif node2 in edge_nodes:
node2_output_axes.append(i)
new_node = contract(flat_edge, name)
elif not shared_edges:
if not allow_outer_product:
raise ValueError("No edges found between nodes '{}' and '{}' "
"and allow_outer_product=False.".format(node1, node2))
new_node = outer_product(node1, node2, name=name)
else:
# Collect the axis of each node corresponding to each edge, in order.
# This specifies the contraction for tensordot.
# NOTE: The ordering of node references in each contraction edge is ignored.
axes1 = []
axes2 = []
for edge in shared_edges:
if edge.node1 is node1:
axes1.append(edge.axis1)
axes2.append(edge.axis2)
else:
raise ValueError(
"Edge '{}' in output_edge_order is not connected to node '{}' or "
"node '{}'".format(edge, node1, node2))
if np.mean(node1_output_axes) > np.mean(node2_output_axes):
node1, node2 = node2, node1
axes1, axes2 = axes2, axes1

new_tensor = backend.tensordot(node1.tensor, node2.tensor, [axes1, axes2])
new_node = Node(
tensor=new_tensor, name=name, axis_names=axis_names, backend=backend)
# node1 and node2 get new edges in _remove_edges
_remove_edges(shared_edges, node1, node2, new_node)
axes1.append(edge.axis2)
axes2.append(edge.axis1)

if output_edge_order:
# Determine heuristically if output transposition can be minimized by
# flipping the arguments to tensordot.
node1_output_axes = []
node2_output_axes = []
for (i, edge) in enumerate(output_edge_order):
if edge in shared_edges:
raise ValueError(
"Edge '{}' in output_edge_order is shared by the nodes to be "
"contracted: '{}' and '{}'.".format(edge, node1, node2))
edge_nodes = set(edge.get_nodes())
if node1 in edge_nodes:
node1_output_axes.append(i)
elif node2 in edge_nodes:
node2_output_axes.append(i)
else:
raise ValueError(
"Edge '{}' in output_edge_order is not connected to node '{}' or "
"node '{}'".format(edge, node1, node2))
if node1_output_axes and node2_output_axes and (
np.mean(node1_output_axes) > np.mean(node2_output_axes)):
node1, node2 = node2, node1
axes1, axes2 = axes2, axes1

new_tensor = backend.tensordot(node1.tensor, node2.tensor, [axes1, axes2])
new_node = Node(
tensor=new_tensor, name=name, backend=backend)
# node1 and node2 get new edges in _remove_edges
_remove_edges(shared_edges, node1, node2, new_node)

if output_edge_order:
new_node = new_node.reorder_edges(list(output_edge_order))
if axis_names:
new_node.add_axis_names(axis_names)

return new_node


Expand Down
72 changes: 64 additions & 8 deletions tensornetwork/tests/network_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,26 +483,45 @@ def test_flatten_all_edges(backend):


def test_contract_between(backend):
a_val = np.ones((2, 3, 4, 5))
b_val = np.ones((3, 5, 4, 2))
a_val = np.random.rand(2, 3, 4, 5)
b_val = np.random.rand(3, 5, 6, 2)
a = tn.Node(a_val, backend=backend)
b = tn.Node(b_val, backend=backend)
tn.connect(a[0], b[3])
tn.connect(b[1], a[3])
tn.connect(a[1], b[0])
edge_a = a[2]
edge_b = b[2]
c = tn.contract_between(a, b, name="New Node")
c.reorder_edges([edge_a, edge_b])
output_axis_names = ["a2", "b2"]
c = tn.contract_between(a, b, name="New Node", axis_names=output_axis_names)
tn.check_correct({c})
# Check expected values.
a_flat = np.reshape(np.transpose(a_val, (2, 1, 0, 3)), (4, 30))
b_flat = np.reshape(np.transpose(b_val, (2, 0, 3, 1)), (4, 30))
b_flat = np.reshape(np.transpose(b_val, (2, 0, 3, 1)), (6, 30))
final_val = np.matmul(a_flat, b_flat.T)
assert c.name == "New Node"
assert c.axis_names == output_axis_names
np.testing.assert_allclose(c.tensor, final_val)


def test_contract_between_output_edge_order(backend):
a_val = np.random.rand(2, 3, 4, 5)
b_val = np.random.rand(3, 5, 6, 2)
a = tn.Node(a_val, backend=backend)
b = tn.Node(b_val, backend=backend)
tn.connect(a[0], b[3])
tn.connect(b[1], a[3])
tn.connect(a[1], b[0])
output_axis_names = ["b2", "a2"]
c = tn.contract_between(a, b, name="New Node", axis_names=output_axis_names,
output_edge_order=[b[2], a[2]])
# Check expected values.
a_flat = np.reshape(np.transpose(a_val, (2, 1, 0, 3)), (4, 30))
b_flat = np.reshape(np.transpose(b_val, (2, 0, 3, 1)), (6, 30))
final_val = np.matmul(a_flat, b_flat.T)
assert c.name == "New Node"
assert c.axis_names == output_axis_names
np.testing.assert_allclose(c.tensor, final_val.T)


def test_contract_between_no_outer_product_value_error(backend):
a_val = np.ones((2, 3, 4))
b_val = np.ones((5, 6, 7))
Expand All @@ -517,8 +536,45 @@ def test_contract_between_outer_product_no_value_error(backend):
b_val = np.ones((5, 6, 7))
a = tn.Node(a_val, backend=backend)
b = tn.Node(b_val, backend=backend)
c = tn.contract_between(a, b, allow_outer_product=True)
output_axis_names = ["a0", "a1", "a2", "b0", "b1", "b2"]
c = tn.contract_between(a, b, allow_outer_product=True,
axis_names=output_axis_names)
assert c.shape == (2, 3, 4, 5, 6, 7)
assert c.axis_names == output_axis_names


def test_contract_between_outer_product_output_edge_order(backend):
a_val = np.ones((2, 3, 4))
b_val = np.ones((5, 6, 7))
a = tn.Node(a_val, backend=backend)
b = tn.Node(b_val, backend=backend)
output_axis_names = ["b0", "b1", "a0", "b2", "a1", "a2"]
c = tn.contract_between(
a, b,
allow_outer_product=True,
output_edge_order=[b[0], b[1], a[0], b[2], a[1], a[2]],
axis_names=output_axis_names)
assert c.shape == (5, 6, 2, 7, 3, 4)
assert c.axis_names == output_axis_names


def test_contract_between_trace(backend):
a_val = np.ones((2, 3, 2, 4))
a = tn.Node(a_val, backend=backend)
tn.connect(a[0], a[2])
c = tn.contract_between(a, a, axis_names=["1", "3"])
assert c.shape == (3, 4)
assert c.axis_names == ["1", "3"]


def test_contract_between_trace_output_edge_order(backend):
a_val = np.ones((2, 3, 2, 4))
a = tn.Node(a_val, backend=backend)
tn.connect(a[0], a[2])
c = tn.contract_between(a, a, output_edge_order=[a[3], a[1]],
axis_names=["3", "1"])
assert c.shape == (4, 3)
assert c.axis_names == ["3", "1"]


def test_contract_parallel(backend):
Expand Down