Closed
Description
Issue Type
Usability
Modules Involved
SPU runtime
Have you reproduced the bug with SPU HEAD?
Yes
Have you searched existing issues?
Yes
SPU Version
spu 0.9.0b1
OS Platform and Distribution
Linux
Python Version
3.10
Compiler Version
No response
Current Behavior?
Why plaintext calculation is ok while SPU simulation failed? I don't think I'm performing any sophisticated calculation.
Standalone code to reproduce the issue
A = np.array([[-4.16757847e-01, -5.62668272e-02, -2.13619610e+00,
1.64027081e+00, -1.79343559e+00, -8.41747366e-01,
5.02881417e-01, -1.24528809e+00, -1.05795222e+00,
-9.09007615e-01],
[ 5.51454045e-01, 2.29220801e+00, 4.15393930e-02,
-1.11792545e+00, 5.39058321e-01, -5.96159700e-01,
-1.91304965e-02, 1.17500122e+00, -7.47870949e-01,
9.02525097e-03],
[-8.78107893e-01, -1.56434170e-01, 2.56570452e-01,
-9.88779049e-01, -3.38821966e-01, -2.36184031e-01,
-6.37655012e-01, -1.18761229e+00, -1.42121723e+00,
-1.53495196e-01],
[-2.69056960e-01, 2.23136679e+00, -2.43476758e+00,
1.12726505e-01, 3.70444537e-01, 1.35963386e+00,
5.01857207e-01, -8.44213704e-01, 9.76147160e-06,
5.42352572e-01]])
B = np.array([[-0.3135082 , 0.77101174, -1.86809065, 1.73118467, 1.46767801,
-0.33567734, 0.61134078, 0.04797059, -0.82913529, 0.08771022],
[ 1.00036589, -0.38109252, -0.37566942, -0.07447076, 0.43349633,
1.27837923, -0.63467931, 0.50839624, 0.21611601, -1.85861239],
[-0.41931648, -0.1323289 , -0.03957024, 0.32600343, -2.04032305,
0.04625552, -0.67767558, -1.43943903, 0.52429643, 0.73527958],
[-0.65325027, 0.84245628, -0.38151648, 0.06648901, -1.09873895,
1.58448706, -2.65944946, -0.09145262, 0.69511961, -2.03346655]])
C = np.array([[-0.18946926, -0.07721867, 0.82470301, 1.24821292, -0.40389227,
-1.38451867, 1.36723542, 1.21788563, -0.46200535, 0.35088849],
[ 0.38186623, 0.56627544, 0.20420798, 1.40669624, -1.7379595 ,
1.04082395, 0.38047197, -0.21713527, 1.1735315 , -2.34360319],
[ 1.16152149, 0.38607805, -1.13313327, 0.43309255, -0.30408644,
2.58529487, 1.83533272, 0.44068987, -0.71925384, -0.58341459],
[-0.32504963, -0.56023451, -0.90224607, -0.59097228, -0.27617949,
-0.51688389, -0.69858995, -0.92889192, 2.55043824, -1.47317325]])
D = np.array([[-1.02141473, 0.4323957 , -0.32358007, 0.42382471, 0.79918 ,
1.26261366, 0.75196485, -0.99376098, 1.10914328, -1.76491773],
[-0.1144213 , -0.49817419, -1.06079904, 0.59166652, -0.18325657,
1.01985473, -1.48246548, 0.84631189, 0.49794015, 0.12650418],
[-1.41881055, -0.25177412, -1.54667461, -2.08265194, 3.2797454 ,
0.97086132, 1.79259285, -0.42901332, 0.69619798, 0.69741627],
[ 0.60151581, 0.00365949, -0.22824756, -2.06961226, 0.61014409,
0.4234969 , 1.11788673, -0.27424209, 1.74181219, -0.44750088]])
A_1 = np.square(A)
B_1 = B + 0.1
C_1 = np.square(C)
D_1 = D + 0.1
s_0, s_1 = A_1.shape
I = jnp.tile(jnp.arange(s_1), (s_0,1))
Z = jnp.zeros(A_1.shape, dtype=int)
def my_compare(x1, x2):
A = x1[3]*x2[3]*(x1[0]*x2[1]-x2[0]*x1[1])+x1[1]*x2[1]*(x1[2]*x2[3]-x2[2]*x1[3])
B = x1[1]*x1[3]*x2[1]*x2[3]
A_sign = A > 0
B_sign = B > 0
z = jnp.logical_xor(A_sign, B_sign)
zz_0 = jax.lax.select(z, x2[0], x1[0])
zz_1 = jax.lax.select(z, x2[1], x1[1])
zz_2 = jax.lax.select(z, x2[2], x1[2])
zz_3 = jax.lax.select(z, x2[3], x1[3])
zz_4 = jax.lax.select(z, x2[4], x1[4])
zz_5 = jax.lax.select(z, x2[5], x1[5])
return [zz_0,zz_1,zz_2,zz_3,zz_4,zz_5]
fn = lambda a,b,c,d,e,f: jax.lax.reduce([a,b,c,d,e,f], [0.1,10000.0,0.1,10000.0,0,0], my_compare, [0])
### plaintext calculation
res = fn(A_1[0], B_1[0], C_1[0], D_1[0], I[0], Z[0])
res
### SPU simulation
config = spu.RuntimeConfig(
protocol=spu.spu_pb2.ProtocolKind.CHEETAH,
field=spu.spu_pb2.FieldType.FM128,
fxp_fraction_bits=40,
)
simulator = pps.Simulator(2, config)
spu_argmax = pps.sim_jax(simulator, fn)
res = spu_argmax(A_1[0], B_1[0], C_1[0], D_1[0], I[0], Z[0])
res
Relevant log output
### plaintext calculation
[Array(1.5507424, dtype=float32),
Array(0.14797059, dtype=float32),
Array(1.4832454, dtype=float32),
Array(-0.893761, dtype=float32),
Array(7, dtype=int32),
Array(0, dtype=int32)]
### SPU simulation
RuntimeError Traceback (most recent call last)
Cell In[49], line 13
9 simulator = pps.Simulator(2, config)
10 spu_argmax = pps.sim_jax(simulator, fn)
---> 13 z = spu_argmax(A_1[0], B_1[0], C_1[0], D_1[0], I[0], Z[0])
15 z
File ~/.conda/envs/sf310/lib/python3.10/site-packages/spu/utils/simulation.py:168, in sim_jax.<locals>.wrapper(*args, **kwargs)
154 executable, output = spu_fe.compile(
155 spu_fe.Kind.JAX,
156 fun,
(...)
163 copts=copts,
164 )
166 wrapper.pphlo = executable.code.decode("utf-8")
--> 168 out_flat = sim(executable, *args_flat)
170 _, output_tree = jax.tree_util.tree_flatten(output)
172 return jax.tree_util.tree_unflatten(output_tree, out_flat)
File ~/.conda/envs/sf310/lib/python3.10/site-packages/spu/utils/simulation.py:116, in Simulator.__call__(self, executable, *flat_args)
110 jobs = [
111 PropagatingThread(target=wrapper, args=(rank,))
112 for rank in range(self.wsize)
113 ]
115 [job.start() for job in jobs]
--> 116 parties = [job.join() for job in jobs]
118 outputs = zip(*parties)
119 return [self.io.reconstruct(out) for out in outputs]
File ~/.conda/envs/sf310/lib/python3.10/site-packages/spu/utils/simulation.py:116, in <listcomp>(.0)
110 jobs = [
111 PropagatingThread(target=wrapper, args=(rank,))
112 for rank in range(self.wsize)
113 ]
115 [job.start() for job in jobs]
--> 116 parties = [job.join() for job in jobs]
118 outputs = zip(*parties)
119 return [self.io.reconstruct(out) for out in outputs]
File ~/.conda/envs/sf310/lib/python3.10/site-packages/spu/utils/simulation.py:43, in PropagatingThread.join(self)
41 super(PropagatingThread, self).join()
42 if self.exc:
---> 43 raise self.exc
44 return self.ret
File ~/.conda/envs/sf310/lib/python3.10/site-packages/spu/utils/simulation.py:36, in PropagatingThread.run(self)
34 self.exc = None
35 try:
---> 36 self.ret = self._target(*self._args, **self._kwargs)
37 except BaseException as e:
38 self.exc = e
File ~/.conda/envs/sf310/lib/python3.10/site-packages/spu/utils/simulation.py:105, in Simulator.__call__.<locals>.wrapper(rank)
102 rt.set_var(executable.input_names[idx], param[rank])
104 # run
--> 105 rt.run(executable)
107 # do outfeed
108 return [rt.get_var(name) for name in executable.output_names]
File ~/.conda/envs/sf310/lib/python3.10/site-packages/spu/api.py:44, in Runtime.run(self, executable)
37 def run(self, executable: spu_pb2.ExecutableProto) -> None:
38 """Run an SPU executable.
39
40 Args:
41 executable (spu_pb2.ExecutableProto): executable.
42
43 """
---> 44 return self._vm.Run(executable.SerializeToString())
RuntimeError: what:
[Enforce fail at libspu/kernel/hal/polymorphic.cc:195] (x.shape() == y.shape()).
Stacktrace:
#0 spu::kernel::hlo::Greater()+0x7fe3b6c1c37b
#1 spu::device::pphlo::dispatchOp<>()+0x7fe3b6532660
#2 spu::device::pphlo::dispatchOp<>()+0x7fe3b653379a
#3 spu::device::pphlo::dispatchOp<>()+0x7fe3b6536b45
#4 spu::device::pphlo::dispatchOp<>()+0x7fe3b6537496
#5 spu::device::pphlo::dispatchOp<>()+0x7fe3b65392ed
#6 spu::device::pphlo::dispatchOp<>()+0x7fe3b653b8f4
#7 spu::device::runBlock()+0x7fe3b6693c25
#8 spu::device::runRegion()+0x7fe3b6695cb3
#9 std::_Function_handler<>::_M_invoke()+0x7fe3b651d57d
#10 spu::kernel::hlo::TreeReduce()+0x7fe3b6c400a3
#11 spu::kernel::hlo::Reduce()+0x7fe3b6c422dd
Metadata
Metadata
Assignees
Labels
No labels