@@ -57,12 +57,72 @@ def tearDown(self):
5757 shutil .rmtree (self .temp_dir )
5858
5959 def test_get_shapes_for_config (self ):
60+ # Test custom shapes
6061 shapes = get_shapes_for_config (
6162 self .test_config ["model_params" ][0 ]["matrix_shapes" ]
6263 )
6364 self .assertEqual (len (shapes ), 1 )
6465 self .assertEqual (shapes [0 ], ("custom" , [1024 , 1024 , 1024 ]))
6566
67+ # Test llama shapes
68+ llama_shapes = get_shapes_for_config ([{"name" : "llama" }])
69+ self .assertEqual (len (llama_shapes ), 4 ) # 4 LLaMa shapes
70+ self .assertTrue (
71+ any (name .startswith ("llama_attn.wqkv" ) for name , _ in llama_shapes )
72+ )
73+ self .assertTrue (
74+ any (name .startswith ("llama_attn.w0" ) for name , _ in llama_shapes )
75+ )
76+ self .assertTrue (
77+ any (name .startswith ("llama_ffn.w13" ) for name , _ in llama_shapes )
78+ )
79+ self .assertTrue (
80+ any (name .startswith ("llama_ffn.w2" ) for name , _ in llama_shapes )
81+ )
82+
83+ # Test pow2 shapes
84+ pow2_shapes = get_shapes_for_config (
85+ [{"name" : "pow2" , "min_power" : 10 , "max_power" : 12 }]
86+ )
87+ self .assertEqual (len (pow2_shapes ), 3 ) # 3 powers of 2 (10, 11, 12)
88+ self .assertEqual (pow2_shapes [0 ], ("pow2_0" , [1024 , 1024 , 1024 ])) # 2^10
89+ self .assertEqual (pow2_shapes [1 ], ("pow2_1" , [2048 , 2048 , 2048 ])) # 2^11
90+ self .assertEqual (pow2_shapes [2 ], ("pow2_2" , [4096 , 4096 , 4096 ])) # 2^12
91+
92+ # Test pow2_extended shapes
93+ pow2_extended_shapes = get_shapes_for_config (
94+ [{"name" : "pow2_extended" , "min_power" : 10 , "max_power" : 11 }]
95+ )
96+ self .assertEqual (
97+ len (pow2_extended_shapes ), 4
98+ ) # 2 powers of 2, each with 2 variants
99+ self .assertEqual (
100+ pow2_extended_shapes [0 ], ("pow2_extended_0" , [1024 , 1024 , 1024 ])
101+ ) # 2^10
102+ self .assertEqual (
103+ pow2_extended_shapes [1 ], ("pow2_extended_1" , [1536 , 1536 , 1536 ])
104+ ) # 2^10 + 2^9
105+ self .assertEqual (
106+ pow2_extended_shapes [2 ], ("pow2_extended_2" , [2048 , 2048 , 2048 ])
107+ ) # 2^11
108+ self .assertEqual (
109+ pow2_extended_shapes [3 ], ("pow2_extended_3" , [3072 , 3072 , 3072 ])
110+ ) # 2^11 + 2^10
111+
112+ # Test sweep shapes (limited to a small range for testing)
113+ sweep_shapes = get_shapes_for_config (
114+ [{"name" : "sweep" , "min_power" : 8 , "max_power" : 9 }]
115+ )
116+ # For min_power=8, max_power=9, we should have 8 shapes (2^3 = 8 combinations)
117+ self .assertEqual (len (sweep_shapes ), 8 )
118+ # Check that all shapes have the expected format
119+ for name , shape in sweep_shapes :
120+ self .assertTrue (name .startswith ("sweep_" ))
121+ self .assertEqual (len (shape ), 3 ) # [M, K, N]
122+ # Check that all dimensions are powers of 2 between 2^8 and 2^9
123+ for dim in shape :
124+ self .assertTrue (dim in [256 , 512 ]) # 2^8, 2^9
125+
66126 def test_get_param_combinations (self ):
67127 model_param = self .test_config ["model_params" ][0 ]
68128 shapes , params = get_param_combinations (model_param )
0 commit comments