@@ -323,3 +323,149 @@ def test_selective_state_update(dim, dstate, has_z, itype):
323323
324324 assert torch .allclose (state , state_ref , rtol = rtol , atol = atol )
325325 assert torch .allclose (out , out_ref , rtol = rtol , atol = atol )
326+
327+
328+ @pytest .mark .parametrize ("itype" ,
329+ [torch .float32 , torch .float16 , torch .bfloat16 ])
330+ @pytest .mark .parametrize ("has_z" , [False , True ])
331+ @pytest .mark .parametrize ("dstate" , [16 , 32 , 64 ])
332+ @pytest .mark .parametrize ("dim" , [2048 , 2048 + 16 , 4096 ])
333+ def test_selective_state_update_with_batch_indices (dim , dstate , has_z , itype ):
334+ device = "cuda"
335+ rtol , atol = (3e-4 , 1e-3 ) if itype == torch .float32 else (5e-3 , 1e-2 )
336+ if itype == torch .bfloat16 :
337+ rtol , atol = 7e-2 , 7e-2
338+ if torch .version .hip :
339+ atol *= 2
340+ # set seed
341+ torch .random .manual_seed (0 )
342+ batch_size = 16
343+
344+ total_entries = 10 * batch_size
345+ state = torch .randn (total_entries , dim , dstate , dtype = itype , device = device )
346+ state_indices = torch .randperm (total_entries )[:batch_size ].to (
347+ dtype = torch .int32 , device = device )
348+
349+ x = torch .randn (batch_size , dim , device = device , dtype = itype )
350+ dt = torch .randn (batch_size , dim , device = device , dtype = itype )
351+ dt_bias = torch .rand (dim , device = device ) - 4.0
352+ A = - torch .rand (dim , dstate , device = device ) - 1.0
353+ B = torch .randn (batch_size , dstate , device = device )
354+ C = torch .randn (batch_size , dstate , device = device )
355+ D = torch .randn (dim , device = device )
356+ z = torch .randn_like (x ) if has_z else None
357+ state_ref = state [state_indices , :].detach ().clone ()
358+ out = selective_state_update (state ,
359+ x ,
360+ dt ,
361+ A ,
362+ B ,
363+ C ,
364+ D = D ,
365+ z = z ,
366+ dt_bias = dt_bias ,
367+ dt_softplus = True ,
368+ state_batch_indices = state_indices )
369+ out_ref = selective_state_update_ref (state_ref ,
370+ x ,
371+ dt ,
372+ A ,
373+ B ,
374+ C ,
375+ D = D ,
376+ z = z ,
377+ dt_bias = dt_bias ,
378+ dt_softplus = True )
379+
380+ assert torch .allclose (state [state_indices , :],
381+ state_ref ,
382+ rtol = rtol ,
383+ atol = atol )
384+ assert torch .allclose (out , out_ref , rtol = rtol , atol = atol )
385+
386+
387+ @pytest .mark .parametrize ("itype" ,
388+ [torch .float32 , torch .float16 , torch .bfloat16 ])
389+ @pytest .mark .parametrize ("has_z" , [False , True ])
390+ @pytest .mark .parametrize ("tie_hdim" , [False , True ])
391+ @pytest .mark .parametrize ("ngroups" , [1 , 2 , 4 ])
392+ @pytest .mark .parametrize ("dstate" , [16 , 32 , 64 ])
393+ @pytest .mark .parametrize ("dim" , [2048 , 4096 ])
394+ def test_selective_state_update_with_heads_with_batch_indices (
395+ dim , dstate , ngroups , has_z , tie_hdim , itype ):
396+ device = "cuda"
397+ rtol , atol = (3e-4 , 1e-3 ) if itype == torch .float32 else (5e-3 , 3e-2 )
398+ if itype == torch .bfloat16 :
399+ rtol , atol = 1e-1 , 1e-1
400+ # set seed
401+ torch .random .manual_seed (0 )
402+ batch_size = 16
403+ headdim = 64
404+ nheads = dim // headdim
405+
406+ total_entries = 10 * batch_size
407+ state = torch .randn (total_entries ,
408+ nheads ,
409+ headdim ,
410+ dstate ,
411+ dtype = itype ,
412+ device = device )
413+ state_indices = torch .randperm (total_entries )[:batch_size ].to (
414+ dtype = torch .int32 , device = device )
415+
416+ x = torch .randn (batch_size , nheads , headdim , device = device , dtype = itype )
417+ if not tie_hdim :
418+ dt = torch .randn (batch_size ,
419+ nheads ,
420+ headdim ,
421+ device = device ,
422+ dtype = itype )
423+ dt_bias = torch .rand (nheads , headdim , device = device ) - 4.0
424+ A = - torch .rand (nheads , headdim , dstate , device = device ) - 1.0
425+ D = torch .randn (nheads , headdim , device = device )
426+ else :
427+ dt = repeat (torch .randn (batch_size , nheads , device = device ,
428+ dtype = itype ),
429+ "b h -> b h p" ,
430+ p = headdim )
431+ dt_bias = repeat (torch .rand (nheads , device = device ) - 4.0 ,
432+ "h -> h p" ,
433+ p = headdim )
434+ A = repeat (- torch .rand (nheads , device = device ) - 1.0 ,
435+ "h -> h p n" ,
436+ p = headdim ,
437+ n = dstate )
438+ D = repeat (torch .randn (nheads , device = device ), "h -> h p" , p = headdim )
439+ B = torch .randn (batch_size , ngroups , dstate , device = device )
440+ C = torch .randn (batch_size , ngroups , dstate , device = device )
441+ z = torch .randn_like (x ) if has_z else None
442+ state_ref = state [state_indices , :].detach ().clone ()
443+ out = selective_state_update (state ,
444+ x ,
445+ dt ,
446+ A ,
447+ B ,
448+ C ,
449+ D = D ,
450+ z = z ,
451+ dt_bias = dt_bias ,
452+ dt_softplus = True ,
453+ state_batch_indices = state_indices )
454+ out_ref = selective_state_update_ref (state_ref ,
455+ x ,
456+ dt ,
457+ A ,
458+ B ,
459+ C ,
460+ D = D ,
461+ z = z ,
462+ dt_bias = dt_bias ,
463+ dt_softplus = True )
464+
465+ print (f"Output max diff: { (out - out_ref ).abs ().max ().item ()} " )
466+ print (f"Output mean diff: { (out - out_ref ).abs ().mean ().item ()} " )
467+ assert torch .allclose (state [state_indices , :],
468+ state_ref ,
469+ rtol = rtol ,
470+ atol = atol )
471+ assert torch .allclose (out , out_ref , rtol = rtol , atol = atol )
0 commit comments