File tree Expand file tree Collapse file tree 1 file changed +6
-2
lines changed Expand file tree Collapse file tree 1 file changed +6
-2
lines changed Original file line number Diff line number Diff line change @@ -345,13 +345,17 @@ def forward(
345345
346346 if self .use_output :
347347 output_shape = output_shape if output_shape is not None else query .shape
348- output = torch .zeros (output_shape , dtype = output_dtype , device = query .device )
349348 hidden_size = output_shape [- 1 ]
349+
350+ # Use torch.empty to avoid initializing tensor with zero.
351+ output_numel = output_shape .numel ()
352+ output_shape = (output_numel // (self .num_heads * self .head_size ), self .num_heads , self .head_size )
353+ output = torch .empty (output_shape , dtype = output_dtype , device = query .device )
354+
350355 # Reshape the query, key, and value tensors.
351356 # NOTE(woosuk): We do this outside the custom op to minimize the
352357 # CPU overheads from the non-CUDA-graph regions.
353358 query = query .view (- 1 , self .num_heads , self .head_size )
354- output = output .view (- 1 , self .num_heads , self .head_size )
355359 if key is not None :
356360 key = key .view (- 1 , self .num_kv_heads , self .head_size )
357361 if value is not None :
You can’t perform that action at this time.
0 commit comments