@@ -323,3 +323,149 @@ def test_selective_state_update(dim, dstate, has_z, itype):
323
323
324
324
assert torch .allclose (state , state_ref , rtol = rtol , atol = atol )
325
325
assert torch .allclose (out , out_ref , rtol = rtol , atol = atol )
326
+
327
+
328
+ @pytest .mark .parametrize ("itype" ,
329
+ [torch .float32 , torch .float16 , torch .bfloat16 ])
330
+ @pytest .mark .parametrize ("has_z" , [False , True ])
331
+ @pytest .mark .parametrize ("dstate" , [16 , 32 , 64 ])
332
+ @pytest .mark .parametrize ("dim" , [2048 , 2048 + 16 , 4096 ])
333
+ def test_selective_state_update_with_batch_indices (dim , dstate , has_z , itype ):
334
+ device = "cuda"
335
+ rtol , atol = (3e-4 , 1e-3 ) if itype == torch .float32 else (5e-3 , 1e-2 )
336
+ if itype == torch .bfloat16 :
337
+ rtol , atol = 7e-2 , 7e-2
338
+ if torch .version .hip :
339
+ atol *= 2
340
+ # set seed
341
+ torch .random .manual_seed (0 )
342
+ batch_size = 16
343
+
344
+ total_entries = 10 * batch_size
345
+ state = torch .randn (total_entries , dim , dstate , dtype = itype , device = device )
346
+ state_indices = torch .randperm (total_entries )[:batch_size ].to (
347
+ dtype = torch .int32 , device = device )
348
+
349
+ x = torch .randn (batch_size , dim , device = device , dtype = itype )
350
+ dt = torch .randn (batch_size , dim , device = device , dtype = itype )
351
+ dt_bias = torch .rand (dim , device = device ) - 4.0
352
+ A = - torch .rand (dim , dstate , device = device ) - 1.0
353
+ B = torch .randn (batch_size , dstate , device = device )
354
+ C = torch .randn (batch_size , dstate , device = device )
355
+ D = torch .randn (dim , device = device )
356
+ z = torch .randn_like (x ) if has_z else None
357
+ state_ref = state [state_indices , :].detach ().clone ()
358
+ out = selective_state_update (state ,
359
+ x ,
360
+ dt ,
361
+ A ,
362
+ B ,
363
+ C ,
364
+ D = D ,
365
+ z = z ,
366
+ dt_bias = dt_bias ,
367
+ dt_softplus = True ,
368
+ state_batch_indices = state_indices )
369
+ out_ref = selective_state_update_ref (state_ref ,
370
+ x ,
371
+ dt ,
372
+ A ,
373
+ B ,
374
+ C ,
375
+ D = D ,
376
+ z = z ,
377
+ dt_bias = dt_bias ,
378
+ dt_softplus = True )
379
+
380
+ assert torch .allclose (state [state_indices , :],
381
+ state_ref ,
382
+ rtol = rtol ,
383
+ atol = atol )
384
+ assert torch .allclose (out , out_ref , rtol = rtol , atol = atol )
385
+
386
+
387
+ @pytest .mark .parametrize ("itype" ,
388
+ [torch .float32 , torch .float16 , torch .bfloat16 ])
389
+ @pytest .mark .parametrize ("has_z" , [False , True ])
390
+ @pytest .mark .parametrize ("tie_hdim" , [False , True ])
391
+ @pytest .mark .parametrize ("ngroups" , [1 , 2 , 4 ])
392
+ @pytest .mark .parametrize ("dstate" , [16 , 32 , 64 ])
393
+ @pytest .mark .parametrize ("dim" , [2048 , 4096 ])
394
+ def test_selective_state_update_with_heads_with_batch_indices (
395
+ dim , dstate , ngroups , has_z , tie_hdim , itype ):
396
+ device = "cuda"
397
+ rtol , atol = (3e-4 , 1e-3 ) if itype == torch .float32 else (5e-3 , 3e-2 )
398
+ if itype == torch .bfloat16 :
399
+ rtol , atol = 1e-1 , 1e-1
400
+ # set seed
401
+ torch .random .manual_seed (0 )
402
+ batch_size = 16
403
+ headdim = 64
404
+ nheads = dim // headdim
405
+
406
+ total_entries = 10 * batch_size
407
+ state = torch .randn (total_entries ,
408
+ nheads ,
409
+ headdim ,
410
+ dstate ,
411
+ dtype = itype ,
412
+ device = device )
413
+ state_indices = torch .randperm (total_entries )[:batch_size ].to (
414
+ dtype = torch .int32 , device = device )
415
+
416
+ x = torch .randn (batch_size , nheads , headdim , device = device , dtype = itype )
417
+ if not tie_hdim :
418
+ dt = torch .randn (batch_size ,
419
+ nheads ,
420
+ headdim ,
421
+ device = device ,
422
+ dtype = itype )
423
+ dt_bias = torch .rand (nheads , headdim , device = device ) - 4.0
424
+ A = - torch .rand (nheads , headdim , dstate , device = device ) - 1.0
425
+ D = torch .randn (nheads , headdim , device = device )
426
+ else :
427
+ dt = repeat (torch .randn (batch_size , nheads , device = device ,
428
+ dtype = itype ),
429
+ "b h -> b h p" ,
430
+ p = headdim )
431
+ dt_bias = repeat (torch .rand (nheads , device = device ) - 4.0 ,
432
+ "h -> h p" ,
433
+ p = headdim )
434
+ A = repeat (- torch .rand (nheads , device = device ) - 1.0 ,
435
+ "h -> h p n" ,
436
+ p = headdim ,
437
+ n = dstate )
438
+ D = repeat (torch .randn (nheads , device = device ), "h -> h p" , p = headdim )
439
+ B = torch .randn (batch_size , ngroups , dstate , device = device )
440
+ C = torch .randn (batch_size , ngroups , dstate , device = device )
441
+ z = torch .randn_like (x ) if has_z else None
442
+ state_ref = state [state_indices , :].detach ().clone ()
443
+ out = selective_state_update (state ,
444
+ x ,
445
+ dt ,
446
+ A ,
447
+ B ,
448
+ C ,
449
+ D = D ,
450
+ z = z ,
451
+ dt_bias = dt_bias ,
452
+ dt_softplus = True ,
453
+ state_batch_indices = state_indices )
454
+ out_ref = selective_state_update_ref (state_ref ,
455
+ x ,
456
+ dt ,
457
+ A ,
458
+ B ,
459
+ C ,
460
+ D = D ,
461
+ z = z ,
462
+ dt_bias = dt_bias ,
463
+ dt_softplus = True )
464
+
465
+ print (f"Output max diff: { (out - out_ref ).abs ().max ().item ()} " )
466
+ print (f"Output mean diff: { (out - out_ref ).abs ().mean ().item ()} " )
467
+ assert torch .allclose (state [state_indices , :],
468
+ state_ref ,
469
+ rtol = rtol ,
470
+ atol = atol )
471
+ assert torch .allclose (out , out_ref , rtol = rtol , atol = atol )
0 commit comments