Skip to content

Commit

Permalink
NFC: changes to llama
Browse files Browse the repository at this point in the history
  • Loading branch information
pavanimajety committed Aug 8, 2024
1 parent 4b5cbd8 commit 826a724
Showing 1 changed file with 1 addition and 22 deletions.
23 changes: 1 addition & 22 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@
from .interfaces import SupportsLoRA
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers

from vllm.logger import init_logger
logger = init_logger(__name__)

class LlamaMLP(nn.Module):

def __init__(
Expand Down Expand Up @@ -307,9 +306,6 @@ def forward(
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if not torch.cuda.is_current_stream_capturing():
logger.info(f" input ids: {input_ids}")

if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
Expand All @@ -321,11 +317,6 @@ def forward(
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]

if not torch.cuda.is_current_stream_capturing():
logger.info(f"hidden states: 0 {hidden_states}")
if residual:
logger.info(f"residual: {residual}")

for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(
Expand All @@ -335,9 +326,6 @@ def forward(
attn_metadata,
residual,
)
if not torch.cuda.is_current_stream_capturing():
logger.info(f"hidden states at {i} : {hidden_states}")
logger.info(f"residual {i} : {residual}")

if not get_pp_group().is_last_rank:
return IntermediateTensors({
Expand Down Expand Up @@ -431,9 +419,6 @@ def forward(
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[torch.Tensor, IntermediateTensors]:

if not torch.cuda.is_current_stream_capturing():
logger.info(f"starting with input_ids: {input_ids} ")
model_output = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
return model_output
Expand Down Expand Up @@ -476,10 +461,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
logger.info("Printing weights from weight loader: ")
for name, loaded_weight in weights:
logger.info(f"name: {name}")
logger.info(f"loaded_weight: {loaded_weight}")
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
Expand All @@ -500,9 +482,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loaded_weight = loaded_weight[0]
weight_loader(param, loaded_weight)
continue
#To Do: Remove when ModelOpt fixes the quantized model.
if ("output_quantizer._amax") in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
Expand Down

0 comments on commit 826a724

Please sign in to comment.