Skip to content

[Bug]: unexpected SPU error #766

Closed
@linzzzzzz

Description

@linzzzzzz

Issue Type

Usability

Modules Involved

SPU runtime

Have you reproduced the bug with SPU HEAD?

Yes

Have you searched existing issues?

Yes

SPU Version

spu 0.9.0b1

OS Platform and Distribution

Linux

Python Version

3.10

Compiler Version

No response

Current Behavior?

Why plaintext calculation is ok while SPU simulation failed? I don't think I'm performing any sophisticated calculation.

Standalone code to reproduce the issue

A = np.array([[-4.16757847e-01, -5.62668272e-02, -2.13619610e+00,
         1.64027081e+00, -1.79343559e+00, -8.41747366e-01,
         5.02881417e-01, -1.24528809e+00, -1.05795222e+00,
        -9.09007615e-01],
       [ 5.51454045e-01,  2.29220801e+00,  4.15393930e-02,
        -1.11792545e+00,  5.39058321e-01, -5.96159700e-01,
        -1.91304965e-02,  1.17500122e+00, -7.47870949e-01,
         9.02525097e-03],
       [-8.78107893e-01, -1.56434170e-01,  2.56570452e-01,
        -9.88779049e-01, -3.38821966e-01, -2.36184031e-01,
        -6.37655012e-01, -1.18761229e+00, -1.42121723e+00,
        -1.53495196e-01],
       [-2.69056960e-01,  2.23136679e+00, -2.43476758e+00,
         1.12726505e-01,  3.70444537e-01,  1.35963386e+00,
         5.01857207e-01, -8.44213704e-01,  9.76147160e-06,
         5.42352572e-01]])

B = np.array([[-0.3135082 ,  0.77101174, -1.86809065,  1.73118467,  1.46767801,
        -0.33567734,  0.61134078,  0.04797059, -0.82913529,  0.08771022],
       [ 1.00036589, -0.38109252, -0.37566942, -0.07447076,  0.43349633,
         1.27837923, -0.63467931,  0.50839624,  0.21611601, -1.85861239],
       [-0.41931648, -0.1323289 , -0.03957024,  0.32600343, -2.04032305,
         0.04625552, -0.67767558, -1.43943903,  0.52429643,  0.73527958],
       [-0.65325027,  0.84245628, -0.38151648,  0.06648901, -1.09873895,
         1.58448706, -2.65944946, -0.09145262,  0.69511961, -2.03346655]])

C = np.array([[-0.18946926, -0.07721867,  0.82470301,  1.24821292, -0.40389227,
        -1.38451867,  1.36723542,  1.21788563, -0.46200535,  0.35088849],
       [ 0.38186623,  0.56627544,  0.20420798,  1.40669624, -1.7379595 ,
         1.04082395,  0.38047197, -0.21713527,  1.1735315 , -2.34360319],
       [ 1.16152149,  0.38607805, -1.13313327,  0.43309255, -0.30408644,
         2.58529487,  1.83533272,  0.44068987, -0.71925384, -0.58341459],
       [-0.32504963, -0.56023451, -0.90224607, -0.59097228, -0.27617949,
        -0.51688389, -0.69858995, -0.92889192,  2.55043824, -1.47317325]])

D = np.array([[-1.02141473,  0.4323957 , -0.32358007,  0.42382471,  0.79918   ,
         1.26261366,  0.75196485, -0.99376098,  1.10914328, -1.76491773],
       [-0.1144213 , -0.49817419, -1.06079904,  0.59166652, -0.18325657,
         1.01985473, -1.48246548,  0.84631189,  0.49794015,  0.12650418],
       [-1.41881055, -0.25177412, -1.54667461, -2.08265194,  3.2797454 ,
         0.97086132,  1.79259285, -0.42901332,  0.69619798,  0.69741627],
       [ 0.60151581,  0.00365949, -0.22824756, -2.06961226,  0.61014409,
         0.4234969 ,  1.11788673, -0.27424209,  1.74181219, -0.44750088]])

A_1 = np.square(A)
B_1 = B + 0.1
C_1 = np.square(C)
D_1 = D + 0.1

s_0, s_1 = A_1.shape
I = jnp.tile(jnp.arange(s_1), (s_0,1))
Z = jnp.zeros(A_1.shape, dtype=int)


def my_compare(x1, x2):

    A = x1[3]*x2[3]*(x1[0]*x2[1]-x2[0]*x1[1])+x1[1]*x2[1]*(x1[2]*x2[3]-x2[2]*x1[3])
    B = x1[1]*x1[3]*x2[1]*x2[3]
    A_sign = A > 0
    B_sign = B > 0
    z = jnp.logical_xor(A_sign, B_sign)

    zz_0 = jax.lax.select(z, x2[0], x1[0])
    zz_1 = jax.lax.select(z, x2[1], x1[1])
    zz_2 = jax.lax.select(z, x2[2], x1[2])
    zz_3 = jax.lax.select(z, x2[3], x1[3])
    zz_4 = jax.lax.select(z, x2[4], x1[4])
    zz_5 = jax.lax.select(z, x2[5], x1[5])

    return [zz_0,zz_1,zz_2,zz_3,zz_4,zz_5]


fn = lambda a,b,c,d,e,f: jax.lax.reduce([a,b,c,d,e,f], [0.1,10000.0,0.1,10000.0,0,0], my_compare, [0])


### plaintext calculation
res = fn(A_1[0], B_1[0], C_1[0], D_1[0], I[0], Z[0])

res


### SPU simulation

config = spu.RuntimeConfig(
    protocol=spu.spu_pb2.ProtocolKind.CHEETAH,
    field=spu.spu_pb2.FieldType.FM128, 
    fxp_fraction_bits=40,
)

simulator = pps.Simulator(2, config)
spu_argmax = pps.sim_jax(simulator, fn)


res = spu_argmax(A_1[0], B_1[0], C_1[0], D_1[0], I[0], Z[0])

res

Relevant log output

### plaintext calculation
[Array(1.5507424, dtype=float32),
 Array(0.14797059, dtype=float32),
 Array(1.4832454, dtype=float32),
 Array(-0.893761, dtype=float32),
 Array(7, dtype=int32),
 Array(0, dtype=int32)]



### SPU simulation
RuntimeError                              Traceback (most recent call last)
Cell In[49], line 13
      9 simulator = pps.Simulator(2, config)
     10 spu_argmax = pps.sim_jax(simulator, fn)
---> 13 z = spu_argmax(A_1[0], B_1[0], C_1[0], D_1[0], I[0], Z[0])
     15 z

File ~/.conda/envs/sf310/lib/python3.10/site-packages/spu/utils/simulation.py:168, in sim_jax.<locals>.wrapper(*args, **kwargs)
    154 executable, output = spu_fe.compile(
    155     spu_fe.Kind.JAX,
    156     fun,
   (...)
    163     copts=copts,
    164 )
    166 wrapper.pphlo = executable.code.decode("utf-8")
--> 168 out_flat = sim(executable, *args_flat)
    170 _, output_tree = jax.tree_util.tree_flatten(output)
    172 return jax.tree_util.tree_unflatten(output_tree, out_flat)

File ~/.conda/envs/sf310/lib/python3.10/site-packages/spu/utils/simulation.py:116, in Simulator.__call__(self, executable, *flat_args)
    110 jobs = [
    111     PropagatingThread(target=wrapper, args=(rank,))
    112     for rank in range(self.wsize)
    113 ]
    115 [job.start() for job in jobs]
--> 116 parties = [job.join() for job in jobs]
    118 outputs = zip(*parties)
    119 return [self.io.reconstruct(out) for out in outputs]

File ~/.conda/envs/sf310/lib/python3.10/site-packages/spu/utils/simulation.py:116, in <listcomp>(.0)
    110 jobs = [
    111     PropagatingThread(target=wrapper, args=(rank,))
    112     for rank in range(self.wsize)
    113 ]
    115 [job.start() for job in jobs]
--> 116 parties = [job.join() for job in jobs]
    118 outputs = zip(*parties)
    119 return [self.io.reconstruct(out) for out in outputs]

File ~/.conda/envs/sf310/lib/python3.10/site-packages/spu/utils/simulation.py:43, in PropagatingThread.join(self)
     41 super(PropagatingThread, self).join()
     42 if self.exc:
---> 43     raise self.exc
     44 return self.ret

File ~/.conda/envs/sf310/lib/python3.10/site-packages/spu/utils/simulation.py:36, in PropagatingThread.run(self)
     34 self.exc = None
     35 try:
---> 36     self.ret = self._target(*self._args, **self._kwargs)
     37 except BaseException as e:
     38     self.exc = e

File ~/.conda/envs/sf310/lib/python3.10/site-packages/spu/utils/simulation.py:105, in Simulator.__call__.<locals>.wrapper(rank)
    102     rt.set_var(executable.input_names[idx], param[rank])
    104 # run
--> 105 rt.run(executable)
    107 # do outfeed
    108 return [rt.get_var(name) for name in executable.output_names]

File ~/.conda/envs/sf310/lib/python3.10/site-packages/spu/api.py:44, in Runtime.run(self, executable)
     37 def run(self, executable: spu_pb2.ExecutableProto) -> None:
     38     """Run an SPU executable.
     39 
     40     Args:
     41         executable (spu_pb2.ExecutableProto): executable.
     42 
     43     """
---> 44     return self._vm.Run(executable.SerializeToString())

RuntimeError: what: 
	[Enforce fail at libspu/kernel/hal/polymorphic.cc:195] (x.shape() == y.shape()). 
Stacktrace:
#0 spu::kernel::hlo::Greater()+0x7fe3b6c1c37b
#1 spu::device::pphlo::dispatchOp<>()+0x7fe3b6532660
#2 spu::device::pphlo::dispatchOp<>()+0x7fe3b653379a
#3 spu::device::pphlo::dispatchOp<>()+0x7fe3b6536b45
#4 spu::device::pphlo::dispatchOp<>()+0x7fe3b6537496
#5 spu::device::pphlo::dispatchOp<>()+0x7fe3b65392ed
#6 spu::device::pphlo::dispatchOp<>()+0x7fe3b653b8f4
#7 spu::device::runBlock()+0x7fe3b6693c25
#8 spu::device::runRegion()+0x7fe3b6695cb3
#9 std::_Function_handler<>::_M_invoke()+0x7fe3b651d57d
#10 spu::kernel::hlo::TreeReduce()+0x7fe3b6c400a3
#11 spu::kernel::hlo::Reduce()+0x7fe3b6c422dd

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions