Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc]Add customized information for models #4132

Merged
merged 6 commits into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,10 @@ def forward(
) -> torch.Tensor:
return self.impl.forward(query, key, value, kv_cache, attn_metadata,
kv_scale)

def extra_repr(self) -> str:
s = f"head_size={self.impl.head_size}"
s += f", num_heads={self.impl.num_heads}"
s += f", num_kv_heads={self.impl.num_kv_heads}"
s += f", scale={self.impl.scale}"
return s
3 changes: 3 additions & 0 deletions vllm/model_executor/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
ops.gelu_tanh_and_mul(out, x)
return out

def extra_repr(self) -> str:
return f'approximate={repr(self.approximate)}'


class NewGELU(nn.Module):

Expand Down
5 changes: 5 additions & 0 deletions vllm/model_executor/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,8 @@ def forward(
self.variance_epsilon,
)
return out

def extra_repr(self) -> str:
s = f"hidden_size={self.weight.data.size(0)}"
s += f", eps={self.variance_epsilon}"
return s
16 changes: 16 additions & 0 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,14 @@ def forward(self, input_):
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias

def extra_repr(self) -> str:
s = f"in_features={self.input_size}"
s += f", output_features={self.output_size_per_partition}"
s += f", bias={self.bias is not None}"
s += f", tp_size={get_tensor_model_parallel_world_size()}"
s += f", gather_output={self.gather_output}"
return s


class MergedColumnParallelLinear(ColumnParallelLinear):
"""Packed linear layers with column parallelism.
Expand Down Expand Up @@ -585,3 +593,11 @@ def forward(self, input_):
output = output_
output_bias = self.bias
return output, output_bias

def extra_repr(self) -> str:
s = f"input_features={self.input_size_per_partition}"
s += f", output_features={self.output_size}"
s += f", bias={self.bias is not None}"
s += f", tp_size={self.tp_size}"
s += f", reduce_results={self.reduce_results}"
return s
5 changes: 5 additions & 0 deletions vllm/model_executor/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
logits = logits[:, :self.org_vocab_size]
return logits

def extra_repr(self) -> str:
s = "vocab_size={vocab_size}, org_vocab_size={org_vocab_size}"
s += ", scale={scale}, logits_as_input={logits_as_input}"
return s.format(**self.__dict__)
jeejeelee marked this conversation as resolved.
Show resolved Hide resolved


def _prune_hidden_states(
hidden_states: torch.Tensor,
Expand Down
6 changes: 6 additions & 0 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,12 @@ def forward(
self.cos_sin_cache, self.is_neox_style)
return query, key

def extra_repr(self) -> str:
s = "head_size={head_size}, rotary_dim={rotary_dim}"
s += ", max_position_embeddings={max_position_embeddings}"
s += ", base={base}, is_neox_style={is_neox_style}"
return s.format(**self.__dict__)
jeejeelee marked this conversation as resolved.
Show resolved Hide resolved


class LinearScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with linear scaling.
Expand Down
8 changes: 8 additions & 0 deletions vllm/model_executor/layers/vocab_parallel_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,14 @@ def forward(self, input_):
output = tensor_model_parallel_all_reduce(output_parallel)
return output

def extra_repr(self) -> str:
s = f"num_embeddings={self.num_embeddings_per_partition}"
s += f", embedding_dim={self.embedding_dim}"
s += f", org_vocab_size={self.org_vocab_size}"
s += f', num_embeddings_padded={self.num_embeddings_padded}'
s += f', tp_size={self.tp_size}'
return s


class ParallelLMHead(VocabParallelEmbedding):
"""Parallelized LM head.
Expand Down
Loading