@@ -53,18 +53,18 @@ def g(*args3):
53
53
54
54
def broadcasting_map (func , * args ):
55
55
"""Like tree_map, but scalar arguments are broadcast to all leaves."""
56
- static_argnums = [i for i , x in enumerate (args ) if not isinstance (x , Vector )]
56
+ static_argnums = [i for i , x in enumerate (args ) if not isinstance (x , VectorMixin )]
57
57
func2 , vector_args = _argnums_partial (func , args , static_argnums )
58
58
for arg in args :
59
- if not isinstance (arg , Vector ):
59
+ if not isinstance (arg , VectorMixin ):
60
60
shape = jnp .shape (arg )
61
61
if shape :
62
62
raise TypeError (
63
- f"non-tree_math.Vector argument is not a scalar: { arg !r} "
63
+ f"non-tree_math.VectorMixin argument is not a scalar: { arg !r} "
64
64
)
65
65
if not vector_args :
66
66
return func2 () # result is a scalar
67
- _flatten_together (* [arg . tree for arg in vector_args ]) # check shapes
67
+ _flatten_together (* [arg for arg in vector_args ]) # check shapes
68
68
return tree_util .tree_map (func2 , * vector_args )
69
69
70
70
@@ -112,43 +112,22 @@ def dot(left, right, *, precision="highest"):
112
112
Returns:
113
113
Resulting dot product (scalar).
114
114
"""
115
- if not isinstance (left , Vector ) or not isinstance (right , Vector ):
116
- raise TypeError ("matmul arguments must both be tree_math.Vector objects" )
115
+ if not isinstance (left , VectorMixin ) or not isinstance (right , VectorMixin ):
116
+ raise TypeError ("matmul arguments must both be tree_math.VectorMixin objects" )
117
117
118
118
def _vector_dot (a , b ):
119
119
return jnp .dot (jnp .ravel (a ), jnp .ravel (b ), precision = precision )
120
120
121
- (left_values , right_values ), _ = _flatten_together (left . tree , right . tree )
121
+ (left_values , right_values ), _ = _flatten_together (left , right )
122
122
parts = map (_vector_dot , left_values , right_values )
123
123
return functools .reduce (operator .add , parts )
124
124
125
-
126
- @tree_util .register_pytree_node_class
127
- class Vector :
128
- """A wrapper for treating an arbitrary pytree as a 1D vector."""
129
-
130
- def __init__ (self , tree ):
131
- self ._tree = tree
132
-
133
- @property
134
- def tree (self ):
135
- return self ._tree
136
-
137
- # TODO(shoyer): consider casting to a common dtype?
138
-
139
- def __repr__ (self ):
140
- return f"tree_math.Vector({ self ._tree !r} )"
141
-
142
- def tree_flatten (self ):
143
- return (self .tree ,), None
144
-
145
- @classmethod
146
- def tree_unflatten (cls , _ , args ):
147
- return cls (* args )
125
+ class VectorMixin :
126
+ """A mixin class that adds a 1D vector-like behaviour to any custom pytree class."""
148
127
149
128
@property
150
129
def size (self ):
151
- values = tree_util .tree_leaves (self . tree )
130
+ values = tree_util .tree_leaves (self )
152
131
return sum (jnp .size (value ) for value in values )
153
132
154
133
def __len__ (self ):
@@ -164,7 +143,7 @@ def ndim(self):
164
143
165
144
@property
166
145
def dtype (self ):
167
- values = tree_util .tree_leaves (self . tree )
146
+ values = tree_util .tree_leaves (self )
168
147
return jnp .result_type (* values )
169
148
170
149
# comparison
@@ -225,3 +204,28 @@ def min(self):
225
204
def max (self ):
226
205
parts = map (jnp .max , tree_util .tree_leaves (self ))
227
206
return jnp .asarray (list (parts )).max ()
207
+
208
+ @tree_util .register_pytree_node_class
209
+ class Vector (VectorMixin ):
210
+ """A wrapper for treating an arbitrary pytree as a 1D vector."""
211
+
212
+ def __init__ (self , tree ):
213
+ self ._tree = tree
214
+
215
+ @property
216
+ def tree (self ):
217
+ return self ._tree
218
+
219
+ # TODO(shoyer): consider casting to a common dtype?
220
+
221
+ def __repr__ (self ):
222
+ return f"tree_math.Vector({ self ._tree !r} )"
223
+
224
+ def tree_flatten (self ):
225
+ return (self ._tree ,), None
226
+
227
+ @classmethod
228
+ def tree_unflatten (cls , _ , args ):
229
+ return cls (* args )
230
+
231
+
0 commit comments