|
| 1 | +import collections |
| 2 | +from collections.abc import Sequence |
| 3 | +from functools import reduce |
| 4 | +from itertools import pairwise |
| 5 | + |
| 6 | +from numpy.core.numeric import normalize_axis_index # type: ignore |
| 7 | + |
| 8 | +from pytensor.compile.builders import OpFromGraph |
| 9 | +from pytensor.tensor.basic import ( |
| 10 | + arange, |
| 11 | + expand_dims, |
| 12 | + get_vector_length, |
| 13 | + stack, |
| 14 | + transpose, |
| 15 | + where, |
| 16 | +) |
| 17 | +from pytensor.tensor.extra_ops import broadcast_to |
| 18 | +from pytensor.tensor.math import and_, eq, tensordot |
| 19 | +from pytensor.tensor.shape import shape_padright |
| 20 | +from pytensor.tensor.variable import TensorVariable |
| 21 | + |
| 22 | + |
| 23 | +class Einsum(OpFromGraph): |
| 24 | + """ |
| 25 | + Wrapper Op for Einsum graphs |
| 26 | + """ |
| 27 | + |
| 28 | + __props__ = ("subscripts", "optimize") |
| 29 | + |
| 30 | + def __init__( |
| 31 | + self, *args, subscripts: str, optimize: str | None = "optimal", **kwargs |
| 32 | + ): |
| 33 | + self.subscripts = subscripts |
| 34 | + self.optimize = optimize |
| 35 | + super().__init__(*args, **kwargs) |
| 36 | + |
| 37 | + |
| 38 | +def _iota(shape: TensorVariable, axis: int) -> TensorVariable: |
| 39 | + axis = normalize_axis_index(axis, get_vector_length(shape)) |
| 40 | + values = arange(shape[axis]) |
| 41 | + return broadcast_to(shape_padright(values, axis), shape) |
| 42 | + |
| 43 | + |
| 44 | +def _delta(shape, axes: Sequence[int]) -> TensorVariable: |
| 45 | + """This utility function exists for creating Kronecker delta arrays.""" |
| 46 | + base_shape = stack([shape[axis] for axis in axes]) |
| 47 | + iotas = [_iota(base_shape, i) for i in range(len(axes))] |
| 48 | + eyes = [eq(i1, i2) for i1, i2 in pairwise(iotas)] |
| 49 | + result = reduce(and_, eyes) |
| 50 | + return broadcast_to(expand_dims(result, tuple(axes)), shape) |
| 51 | + |
| 52 | + |
| 53 | +def _removechars(s, chars): |
| 54 | + return s.translate(str.maketrans(dict.fromkeys(chars))) |
| 55 | + |
| 56 | + |
| 57 | +def einsum(subscripts: str, *operands): |
| 58 | + """ |
| 59 | + Multiplication and summation of tensors using the Einstein summation convention. |
| 60 | +
|
| 61 | + # TODO: Write docs |
| 62 | +
|
| 63 | + Parameters |
| 64 | + ---------- |
| 65 | + subscripts: str |
| 66 | +
|
| 67 | + operands: sequence of TensorVariable |
| 68 | + Tensors to be multiplied and summed. |
| 69 | +
|
| 70 | + Returns |
| 71 | + ------- |
| 72 | + TensorVariable |
| 73 | + The result of the einsum operation. |
| 74 | + """ |
| 75 | + # TODO: Is this doing something clever about unknown shapes? |
| 76 | + # contract_path = _poly_einsum_handlers.get(ty, _default_poly_einsum_handler) |
| 77 | + # using einsum_call=True here is an internal api for opt_einsum... sorry |
| 78 | + |
| 79 | + # TODO: Handle None static shapes |
| 80 | + # TODO: Do we need this as dependency? |
| 81 | + from opt_einsum import contract_path |
| 82 | + |
| 83 | + shapes = [operand.type.shape for operand in operands] |
| 84 | + |
| 85 | + # TODE: Do fast path at creation time, and optimize only in fast_run |
| 86 | + _, contraction_list = contract_path( |
| 87 | + subscripts, |
| 88 | + *shapes, |
| 89 | + einsum_call=True, |
| 90 | + use_blas=True, |
| 91 | + optimize="optimal", |
| 92 | + shapes=True, |
| 93 | + ) |
| 94 | + |
| 95 | + def sum_uniques( |
| 96 | + operand: TensorVariable, names: str, uniques: list[str] |
| 97 | + ) -> tuple[TensorVariable, str]: |
| 98 | + if uniques: |
| 99 | + axes = [names.index(name) for name in uniques] |
| 100 | + operand = operand.sum(axes) |
| 101 | + names = _removechars(names, uniques) |
| 102 | + return operand, names |
| 103 | + |
| 104 | + def sum_repeats( |
| 105 | + operand: TensorVariable, |
| 106 | + names: str, |
| 107 | + counts: collections.Counter, |
| 108 | + keep_names: str, |
| 109 | + ) -> tuple[TensorVariable, str]: |
| 110 | + for name, count in counts.items(): |
| 111 | + if count > 1: |
| 112 | + axes = [i for i, n in enumerate(names) if n == name] |
| 113 | + eye = _delta(operand.shape, axes) |
| 114 | + operand = where(eye, operand, operand.zeros_like()) |
| 115 | + if name not in keep_names: |
| 116 | + operand = operand.sum(axes) |
| 117 | + names = names.replace(name, "") |
| 118 | + else: |
| 119 | + operand = operand.sum(axes[:-1]) |
| 120 | + names = names.replace(name, "", count - 1) |
| 121 | + return operand, names |
| 122 | + |
| 123 | + # def filter_singleton_dims(operand, names, other_shape, other_names): |
| 124 | + # eq = core.definitely_equal |
| 125 | + # keep = [ |
| 126 | + # not eq(operand.shape[i], 1) or j == -1 or eq(other_shape[j], 1) |
| 127 | + # for i, j in enumerate(map(other_names.find, names)) |
| 128 | + # ] |
| 129 | + # sqez_axes, keep_axes = partition_list(keep, list(range(operand.ndim))) |
| 130 | + # return lax.squeeze(operand, sqez_axes), "".join(names[i] for i in keep_axes) |
| 131 | + |
| 132 | + einsum_operands = list(operands) # So we can pop |
| 133 | + for operand_indices, contracted_names, einstr, _, _ in contraction_list: |
| 134 | + contracted_names = sorted(contracted_names) |
| 135 | + assert len(contracted_names) == len( |
| 136 | + set(contracted_names) |
| 137 | + ), "The set was needed!" |
| 138 | + |
| 139 | + input_str, result_names = einstr.split("->") |
| 140 | + input_names = input_str.split(",") |
| 141 | + |
| 142 | + # switch on the number of operands to be processed in this loop iteration. |
| 143 | + # every case here sets 'operand' and 'names'. |
| 144 | + if len(operand_indices) == 1: |
| 145 | + operand = einsum_operands.pop(operand_indices[0]) |
| 146 | + (names,) = input_names |
| 147 | + counts = collections.Counter(names) |
| 148 | + |
| 149 | + # sum out unique contracted indices with a single reduce-sum |
| 150 | + uniques = [name for name in contracted_names if counts[name] == 1] |
| 151 | + operand, names = sum_uniques(operand, names, uniques) |
| 152 | + |
| 153 | + # for every repeated index, do a contraction against an identity matrix |
| 154 | + operand, names = sum_repeats(operand, names, counts, result_names) |
| 155 | + |
| 156 | + elif len(operand_indices) == 2: |
| 157 | + lhs, rhs = map(einsum_operands.pop, operand_indices) |
| 158 | + lhs_names, rhs_names = input_names |
| 159 | + |
| 160 | + # handle cases where one side of a contracting or batch dimension is 1 |
| 161 | + # but its counterpart is not. |
| 162 | + # lhs, lhs_names = filter_singleton_dims(lhs, lhs_names, shape(rhs), |
| 163 | + # rhs_names) |
| 164 | + # rhs, rhs_names = filter_singleton_dims(rhs, rhs_names, shape(lhs), |
| 165 | + # lhs_names) |
| 166 | + |
| 167 | + lhs_counts = collections.Counter(lhs_names) |
| 168 | + rhs_counts = collections.Counter(rhs_names) |
| 169 | + |
| 170 | + # sum out unique contracted indices in lhs and rhs |
| 171 | + lhs_uniques = [ |
| 172 | + name |
| 173 | + for name in contracted_names |
| 174 | + if lhs_counts[name] == 1 and rhs_counts[name] == 0 |
| 175 | + ] |
| 176 | + lhs, lhs_names = sum_uniques(lhs, lhs_names, lhs_uniques) |
| 177 | + |
| 178 | + rhs_uniques = [ |
| 179 | + name |
| 180 | + for name in contracted_names |
| 181 | + if rhs_counts[name] == 1 and lhs_counts[name] == 0 |
| 182 | + ] |
| 183 | + rhs, rhs_names = sum_uniques(rhs, rhs_names, rhs_uniques) |
| 184 | + |
| 185 | + # for every repeated index, contract against an identity matrix |
| 186 | + lhs, lhs_names = sum_repeats( |
| 187 | + lhs, lhs_names, lhs_counts, result_names + rhs_names |
| 188 | + ) |
| 189 | + rhs, rhs_names = sum_repeats( |
| 190 | + rhs, rhs_names, rhs_counts, result_names + lhs_names |
| 191 | + ) |
| 192 | + |
| 193 | + lhs_or_rhs_names = set(lhs_names) | set(rhs_names) |
| 194 | + contracted_names = [x for x in contracted_names if x in lhs_or_rhs_names] |
| 195 | + lhs_and_rhs_names = set(lhs_names) & set(rhs_names) |
| 196 | + batch_names = [x for x in result_names if x in lhs_and_rhs_names] |
| 197 | + |
| 198 | + if batch_names: |
| 199 | + lhs_batch, rhs_batch = tuple( |
| 200 | + zip(*[(lhs_names.find(n), rhs_names.find(n)) for n in batch_names]) |
| 201 | + ) |
| 202 | + if lhs_batch or rhs_batch: |
| 203 | + raise NotImplementedError("Batch dimensions are not yet supported") |
| 204 | + else: |
| 205 | + lhs_batch = rhs_batch = () |
| 206 | + |
| 207 | + # contract using dot_general |
| 208 | + batch_names_str = "".join(batch_names) |
| 209 | + if contracted_names: |
| 210 | + lhs_cont, rhs_cont = tuple( |
| 211 | + zip( |
| 212 | + *[ |
| 213 | + (lhs_names.index(n), rhs_names.index(n)) |
| 214 | + for n in contracted_names |
| 215 | + ] |
| 216 | + ) |
| 217 | + ) |
| 218 | + else: |
| 219 | + lhs_cont = rhs_cont = () |
| 220 | + deleted_names = batch_names_str + "".join(contracted_names) |
| 221 | + remaining_lhs_names = _removechars(lhs_names, deleted_names) |
| 222 | + remaining_rhs_names = _removechars(rhs_names, deleted_names) |
| 223 | + # Try both orders of lhs and rhs, in the hope that one of them means we |
| 224 | + # don't need an explicit transpose. opt_einsum likes to contract from |
| 225 | + # right to left, so we expect (rhs,lhs) to have the best chance of not |
| 226 | + # needing a transpose. |
| 227 | + names = batch_names_str + remaining_rhs_names + remaining_lhs_names |
| 228 | + if names == result_names: |
| 229 | + operand = tensordot(rhs, lhs, (rhs_cont, lhs_cont)) |
| 230 | + else: |
| 231 | + names = batch_names_str + remaining_lhs_names + remaining_rhs_names |
| 232 | + operand = tensordot(lhs, rhs, axes=(lhs_cont, rhs_cont)) |
| 233 | + |
| 234 | + # the resulting 'operand' with axis labels 'names' should be a permutation of the desired result |
| 235 | + assert len(names) == len(result_names) == len(set(names)) |
| 236 | + assert set(names) == set(result_names) |
| 237 | + if names != result_names: |
| 238 | + perm = tuple(names.index(name) for name in result_names) |
| 239 | + operand = transpose(operand, perm) |
| 240 | + einsum_operands.append(operand) # used in next iteration |
| 241 | + |
| 242 | + [einsum_result] = einsum_operands |
| 243 | + |
| 244 | + return Einsum( |
| 245 | + subscripts=subscripts, |
| 246 | + inputs=list(operands), |
| 247 | + outputs=[einsum_result], |
| 248 | + )(*operands) |
0 commit comments