Skip to content

Commit

Permalink
Consistently take prefix in model constructors (#2191)
Browse files Browse the repository at this point in the history
* Consistently take `prefix` in model constructors

* Release test check fix

* Misc refactor-related fixes
  • Loading branch information
danieldk authored Jul 5, 2024
1 parent 67ef064 commit 05c094f
Show file tree
Hide file tree
Showing 23 changed files with 210 additions and 131 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ jobs:
runs-on: ["self-hosted", "${{ needs.build-and-push.outputs.runs_on }}", "multi-gpu"]
if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest'
env:
PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main' || inputs.release-tests == 'true') && '--release' || '' }}
PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main' || inputs.release-tests == true) && '--release' || '' }}
steps:
- name: Checkout repository
uses: actions/checkout@v4
Expand Down
3 changes: 2 additions & 1 deletion server/text_generation_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from text_generation_server.models.custom_modeling.mpt_modeling import (
MPTForCausalLM,
)
from text_generation_server.models.bloom import BloomCausalLMBatch
from text_generation_server.models.custom_modeling.bloom_modeling import (
BloomForCausalLM,
)
Expand Down Expand Up @@ -522,7 +523,7 @@ def get_model(
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
batch_class=CausalLMBatchKeysLast,
batch_class=BloomCausalLMBatch,
)
elif model_type == MPT:
return CausalLM(
Expand Down
3 changes: 2 additions & 1 deletion server/text_generation_server/models/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,8 @@ def __init__(
if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)

model = model_class(config, weights)
prefix = ""
model = model_class(prefix, config, weights)

torch.distributed.barrier(group=self.process_group)
super().__init__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -816,7 +816,7 @@ def forward(


class BloomForCausalLM(BloomPreTrainedModel):
def __init__(self, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__(config)
self.transformer = BloomModel(config, weights)

Expand Down
6 changes: 3 additions & 3 deletions server/text_generation_server/models/custom_modeling/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ def forward(


class CLIPTextTransformer(nn.Module):
def __init__(self, config: CLIPTextConfig):
def __init__(self, prefix: str, config: CLIPTextConfig):
super().__init__()
self.config = config
embed_dim = config.hidden_size
Expand Down Expand Up @@ -536,9 +536,9 @@ class CLIPTextModel(CLIPPreTrainedModel):

_no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]

def __init__(self, config: CLIPTextConfig):
def __init__(self, prefix, config: CLIPTextConfig):
super().__init__(config)
self.text_model = CLIPTextTransformer(config)
self.text_model = CLIPTextTransformer(prefix, config)
# Initialize weights and apply final processing
self.post_init()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -363,9 +363,9 @@ def forward(self, hidden_states):


class FlashCohereLayer(nn.Module):
def __init__(self, layer_id, config, weights):
def __init__(self, prefix: str, layer_id, config, weights):
super().__init__()
prefix = f"model.layers.{layer_id}"
prefix = f"{prefix}.layers.{layer_id}"
self.self_attn = FlashCohereAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
)
Expand Down Expand Up @@ -416,18 +416,19 @@ def forward(


class FlashCohereModel(torch.nn.Module):
def __init__(self, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__()

process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.embed_tokens = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights
prefix=f"{prefix}.embed_tokens", weights=weights
)
self.layers = nn.ModuleList(
[
FlashCohereLayer(
prefix,
layer_id,
config,
weights,
Expand All @@ -436,7 +437,7 @@ def __init__(self, config, weights):
]
)
self.norm = FastLayerNorm.load_no_bias(
prefix="model.norm", weights=weights, eps=config.layer_norm_eps
prefix=f"{prefix}.norm", weights=weights, eps=config.layer_norm_eps
)

self.gradient_checkpointing = False
Expand Down Expand Up @@ -486,10 +487,15 @@ def forward(


class FlashCohereForCausalLM(torch.nn.Module):
def __init__(self, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__()

self.model = FlashCohereModel(config, weights)
if not prefix:
prefix = "model"
else:
prefix = f"{prefix}.model"

self.model = FlashCohereModel(prefix, config, weights)
try:
self.lm_head = SpeculativeHead.load(
config,
Expand All @@ -499,7 +505,7 @@ def __init__(self, config, weights):
except RuntimeError:
self.lm_head = SpeculativeHead.load(
config,
prefix="model.embed_tokens",
prefix=f"{prefix}.embed_tokens",
weights=weights,
)
self.logit_scale = config.logit_scale
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -593,9 +593,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:


class DbrxLayer(nn.Module):
def __init__(self, layer_id, config, weights):
def __init__(self, prefix: str, layer_id, config, weights):
super().__init__()
prefix = f"transformer.blocks.{layer_id}"
prefix = f"{prefix}.blocks.{layer_id}"

self.attn = DbrxNormAttentionNorm(
prefix=f"{prefix}.norm_attn_norm", config=config, weights=weights
Expand Down Expand Up @@ -637,16 +637,17 @@ def forward(


class DbrxModel(torch.nn.Module):
def __init__(self, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__()

self.embed_tokens = TensorParallelEmbedding(
prefix="transformer.wte", weights=weights
prefix=f"{prefix}.wte", weights=weights
)

self.layers = nn.ModuleList(
[
DbrxLayer(
prefix,
layer_id,
config,
weights,
Expand All @@ -655,7 +656,7 @@ def __init__(self, config, weights):
]
)
self.norm = FastLayerNorm.load_no_bias(
prefix="transformer.norm_f", weights=weights, eps=1e-5
prefix=f"{prefix}.norm_f", weights=weights, eps=1e-5
)

self.head_size = self.layers[0].attn.self_attn.head_size
Expand Down Expand Up @@ -702,9 +703,14 @@ def forward(


class FlashDbrxForCausalLM(torch.nn.Module):
def __init__(self, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__()

if not prefix:
prefix = "transformer"
else:
prefix = f"{prefix}.transformer"

self.model = DbrxModel(config, weights)
self.lm_head = SpeculativeHead.load(
config,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def __init__(

class Gemma2FastRMSNorm(FastRMSNorm):
@classmethod
def load(cls, prefix, weights, eps=1e-6):
def load(cls, prefix: str, weights, eps=1e-6):
dtype = weights.dtype
weights.dtype = torch.float32
weight = weights.get_tensor(f"{prefix}.weight") + 1
Expand All @@ -123,7 +123,7 @@ def forward(self, hidden_states, residual=None):
return hidden_states.to(self.dtype), residual


def load_attention(config, prefix, weights):
def load_attention(config, prefix: str, weights):
if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights)
else:
Expand Down Expand Up @@ -305,7 +305,7 @@ def forward(self, hidden_states):


class FlashGemma2Layer(nn.Module):
def __init__(self, prefix, config, weights, causal: bool, is_sliding: bool):
def __init__(self, prefix: str, config, weights, causal: bool, is_sliding: bool):
super().__init__()
self.self_attn = FlashGemma2Attention(
prefix=f"{prefix}.self_attn",
Expand Down Expand Up @@ -376,7 +376,7 @@ def forward(


class FlashGemma2Model(torch.nn.Module):
def __init__(self, prefix, config, weights, causal: bool):
def __init__(self, prefix: str, config, weights, causal: bool):
super().__init__()

process_group = weights.process_group
Expand Down Expand Up @@ -442,7 +442,7 @@ def forward(


class FlashGemma2ForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights, *, causal: bool = True):
def __init__(self, prefix: str, config, weights, *, causal: bool = True):
super().__init__()

embed_norm = config.hidden_size**0.5
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def __init__(

class GemmaFastRMSNorm(FastRMSNorm):
@classmethod
def load(cls, prefix, weights, eps=1e-6):
def load(cls, prefix: str, weights, eps=1e-6):
dtype = weights.dtype
weights.dtype = torch.float32
weight = weights.get_tensor(f"{prefix}.weight") + 1
Expand All @@ -123,7 +123,7 @@ def forward(self, hidden_states, residual=None):
return hidden_states.to(self.dtype), residual


def load_attention(config, prefix, weights):
def load_attention(config, prefix: str, weights):
if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights)
else:
Expand Down Expand Up @@ -261,7 +261,7 @@ def forward(


class GemmaMLP(nn.Module):
def __init__(self, prefix, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__()
act = config.hidden_act
self.act = (
Expand Down Expand Up @@ -299,7 +299,7 @@ def forward(self, hidden_states):


class FlashGemmaLayer(nn.Module):
def __init__(self, prefix, config, weights, causal: bool):
def __init__(self, prefix: str, config, weights, causal: bool):
super().__init__()
self.self_attn = FlashGemmaAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights, causal=causal
Expand Down Expand Up @@ -354,7 +354,7 @@ def forward(


class FlashGemmaModel(torch.nn.Module):
def __init__(self, prefix, config, weights, causal: bool):
def __init__(self, prefix: str, config, weights, causal: bool):
super().__init__()

process_group = weights.process_group
Expand Down Expand Up @@ -419,7 +419,7 @@ def forward(


class FlashGemmaForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights, *, causal: bool = True):
def __init__(self, prefix: str, config, weights, *, causal: bool = True):
super().__init__()

embed_norm = config.hidden_size**0.5
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def forward(


class GPT2MLP(nn.Module):
def __init__(self, prefix, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__()
act = config.activation_function
self.act = (
Expand Down Expand Up @@ -298,7 +298,7 @@ def forward(self, hidden_states):


class FlashGPT2Layer(nn.Module):
def __init__(self, prefix, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__()
self.self_attn = FlashGPT2Attention(
prefix=f"{prefix}.attn", config=config, weights=weights
Expand Down Expand Up @@ -350,7 +350,7 @@ def forward(


class FlashGPT2Model(torch.nn.Module):
def __init__(self, prefix, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__()

process_group = weights.process_group
Expand Down Expand Up @@ -414,7 +414,7 @@ def forward(


class FlashGPT2ForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__()

self.embed_tokens = TensorParallelEmbedding(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")


def load_attention(config, prefix, weights, layer_id):
def load_attention(config, prefix: str, weights, layer_id):
# Only defined in granite.
bias = getattr(config, "attention_bias", False)
head_size = config.hidden_size // config.num_attention_heads
Expand Down Expand Up @@ -467,7 +467,7 @@ def forward(


class FlashLlamaForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__()

self.embed_tokens = TensorParallelEmbedding(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def forward(


class MistralMLP(nn.Module):
def __init__(self, prefix, config, weights, layer_id):
def __init__(self, prefix: str, config, weights, layer_id):
super().__init__()
self.hidden_act = config.hidden_act
self.act = (
Expand Down Expand Up @@ -328,7 +328,7 @@ def forward(self, hidden_states, adapter_data):


class MistralLayer(nn.Module):
def __init__(self, prefix, config, weights, layer_id):
def __init__(self, prefix: str, config, weights, layer_id):
super().__init__()
self.self_attn = MistralAttention(
prefix=f"{prefix}.self_attn",
Expand Down Expand Up @@ -392,7 +392,7 @@ def forward(


class MistralModel(torch.nn.Module):
def __init__(self, prefix, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__()

process_group = weights.process_group
Expand Down Expand Up @@ -462,7 +462,7 @@ def forward(


class FlashMistralForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights, name=None):
def __init__(self, prefix: str, config, weights, name=None):
if name is None:
name = "model"
super().__init__()
Expand Down
Loading

0 comments on commit 05c094f

Please sign in to comment.