@@ -35,12 +35,12 @@ class Einsum(OpFromGraph):
35
35
Wrapper Op for Einsum graphs
36
36
"""
37
37
38
- __props__ = ("subscripts" , "path" , "optimized " )
38
+ __props__ = ("subscripts" , "path" , "optimize " )
39
39
40
- def __init__ (self , * args , subscripts : str , path : str , optimized : bool , ** kwargs ):
40
+ def __init__ (self , * args , subscripts : str , path : str , optimize : bool , ** kwargs ):
41
41
self .subscripts = subscripts
42
42
self .path = path
43
- self .optimized = optimized
43
+ self .optimize = optimize
44
44
super ().__init__ (* args , ** kwargs , strict = True )
45
45
46
46
@@ -223,7 +223,7 @@ def einsum(subscripts: str, *operands: "TensorLike") -> TensorVariable:
223
223
shapes = [operand .type .shape for operand in operands ]
224
224
225
225
if None in itertools .chain .from_iterable (shapes ):
226
- # We mark optimized = False, even in cases where there is no ordering optimization to be done
226
+ # We mark optimize = False, even in cases where there is no ordering optimization to be done
227
227
# because the inner graph may have to accommodate dynamic shapes.
228
228
# If those shapes become known later we will likely want to rebuild the Op (unless we inline it)
229
229
if len (operands ) == 1 :
@@ -234,7 +234,7 @@ def einsum(subscripts: str, *operands: "TensorLike") -> TensorVariable:
234
234
# We use (1,0) and not (0,1) because that's what opt_einsum tends to prefer, and so the Op signatures will match more often
235
235
path = [(1 , 0 ) for i in range (len (operands ) - 1 )]
236
236
contraction_list = contraction_list_from_path (subscripts , operands , path )
237
- optimized = (
237
+ optimize = (
238
238
len (operands ) <= 2
239
239
) # If there are only 1 or 2 operands, there is no optimization to be done?
240
240
else :
@@ -247,7 +247,7 @@ def einsum(subscripts: str, *operands: "TensorLike") -> TensorVariable:
247
247
shapes = True ,
248
248
)
249
249
path = [contraction [0 ] for contraction in contraction_list ]
250
- optimized = True
250
+ optimize = True
251
251
252
252
def sum_uniques (
253
253
operand : TensorVariable , names : str , uniques : list [str ]
@@ -412,6 +412,6 @@ def sum_repeats(
412
412
inputs = list (operands ),
413
413
outputs = [einsum_result ],
414
414
path = tuple (path ),
415
- optimized = optimized ,
415
+ optimize = optimize ,
416
416
)(* operands )
417
417
return cast (TensorVariable , out )
0 commit comments