@@ -385,6 +385,32 @@ def test_idx_lambda_to_hlo():
385385 (b , a ))
386386
387387
388+ def test_stringify ():
389+ x = pt .make_placeholder ("x" , (10 , 4 ), np .int64 )
390+ y = pt .make_placeholder ("y" , (10 , 4 ), np .int64 )
391+
392+ assert (str (3 * x + 4 * y )
393+ == "3*x + 4*y" )
394+ assert (str (pt .roll (x .reshape (2 , 20 ).reshape (- 1 ), 3 ))
395+ == "roll(reshape(reshape(x, (2, 20)), 40), 3)" )
396+ assert (str (pt .roll (x .reshape (2 , 20 ).reshape (- 1 ), 3 ))
397+ == "roll(reshape(reshape(x, (2, 20)), 40), 3)" )
398+ assert (str (y * pt .not_equal (x , 3 ))
399+ == "y*(x != 3)" )
400+ assert (str (3 * y @ pt .sum (x , axis = 0 ))
401+ == "3*y @ sum(x, axis=0)" )
402+ assert (str (x [y [:, 2 :3 ], x [2 , :]])
403+ == "x[y[::, 2:3:], x[2]]" )
404+ assert (str (pt .stack ([x [y [:, 2 :3 ], x [2 , :]].T , y [x [:, 2 :3 ], y [2 , :]].T ]))
405+ == ("stack([transpose(x[y[::, 2:3:], x[2]]),"
406+ " transpose(y[x[::, 2:3:], y[2]])])" ))
407+ assert (str (pt .concatenate ([x [y [:, 2 :3 ], x [2 , :]],
408+ y [x [:, 2 :3 ], y [2 , :]]]))
409+ == "concatenate([x[y[::, 2:3:], x[2]], y[x[::, 2:3:], y[2]]])" )
410+ assert (str (pt .einsum ("ij,i->i" , 2 * x , pt .sum (y , axis = 1 )))
411+ == 'einsum("ij, i -> i", 2*x, sum(y, axis=1))' )
412+
413+
388414if __name__ == "__main__" :
389415 if len (sys .argv ) > 1 :
390416 exec (sys .argv [1 ])
0 commit comments