@@ -129,6 +129,42 @@ def benchmark_qebc(args: argparse.Namespace, output_dir: str) -> List[BenchmarkR
129129 )
130130
131131
132+ def benchmark_qec_unsharded (
133+ args : argparse .Namespace , output_dir : str
134+ ) -> List [BenchmarkResult ]:
135+ tables = get_tables (TABLE_SIZES , is_pooled = False )
136+ sharder = TestQuantECSharder (
137+ sharding_type = "" ,
138+ kernel_type = EmbeddingComputeKernel .QUANT .value ,
139+ shardable_params = [table .name for table in tables ],
140+ )
141+
142+ module = QuantEmbeddingCollection (
143+ # pyre-ignore [6]
144+ tables = tables ,
145+ device = torch .device ("cpu" ),
146+ quant_state_dict_split_scale_bias = True ,
147+ )
148+
149+ args_kwargs = {
150+ argname : getattr (args , argname )
151+ for argname in dir (args )
152+ # Don't include output_dir since output_dir was modified
153+ if not argname .startswith ("_" ) and argname not in IGNORE_ARGNAME
154+ }
155+
156+ return benchmark_module (
157+ module = module ,
158+ sharder = sharder ,
159+ sharding_types = [],
160+ compile_modes = BENCH_COMPILE_MODES ,
161+ tables = tables ,
162+ output_dir = output_dir ,
163+ benchmark_unsharded = True , # benchmark unsharded module
164+ ** args_kwargs ,
165+ )
166+
167+
132168def benchmark_qebc_unsharded (
133169 args : argparse .Namespace , output_dir : str
134170) -> List [BenchmarkResult ]:
@@ -185,9 +221,10 @@ def main() -> None:
185221 "QuantEmbeddingCollection" ,
186222 ]
187223
188- # Only do unsharded QEBC benchmark when using CPU device
224+ # Only do unsharded QEBC/QEC benchmark when using CPU device
189225 if args .device_type == "cpu" :
190226 module_names .append ("unshardedQuantEmbeddingBagCollection" )
227+ module_names .append ("unshardedQuantEmbeddingCollection" )
191228
192229 for module_name in module_names :
193230 output_dir = args .output_dir + f"/run_{ datetime_sfx } "
@@ -197,9 +234,12 @@ def main() -> None:
197234 elif module_name == "QuantEmbeddingCollection" :
198235 output_dir += "_qec"
199236 benchmark_func = benchmark_qec
200- else :
237+ elif module_name == "unshardedQuantEmbeddingBagCollection" :
201238 output_dir += "_uqebc"
202239 benchmark_func = benchmark_qebc_unsharded
240+ else :
241+ output_dir += "_uqec"
242+ benchmark_func = benchmark_qec_unsharded
203243
204244 if not os .path .exists (output_dir ):
205245 # Place all outputs under the datetime folder
0 commit comments