Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 12 additions & 14 deletions colossalai/inference/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from colossalai.inference.sampler import search_tokens
from colossalai.inference.spec import Drafter, GlideInput
from colossalai.inference.struct import Sequence
from colossalai.inference.utils import get_model_size, has_index_file
from colossalai.inference.utils import get_model_size
from colossalai.interface import ModelWrapper
from colossalai.logging import get_dist_logger
from colossalai.pipeline.stage_manager import PipelineStageManager
Expand Down Expand Up @@ -113,18 +113,15 @@ def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy
model_policy (Policy): the policy to replace the model
"""

casuallm = None
if isinstance(model_or_path, str):
try:
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True)
arch = getattr(hf_config, "architectures")[0]
if arch in _supported_models.keys():
casuallm = _supported_models[arch](hf_config)
if isinstance(casuallm, AutoModelForCausalLM):
# NOTE(caidi) It's necessary to add half() here, otherwise baichuan13B will overflow the memory.
model = AutoModelForCausalLM.from_pretrained(model_or_path, trust_remote_code=True).half()
else:
model = _supported_models[arch](hf_config)
# NOTE(lry89757) Currently we load the model using transformers-api,
# but we will use lazy tensor and checkpoint io to accelerate
# the model load process in the future.
model = _supported_models[arch].from_pretrained(model_or_path, trust_remote_code=True)
else:
raise ValueError(f"Model {arch} is not supported.")

Expand Down Expand Up @@ -175,13 +172,14 @@ def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
)

if isinstance(model_or_path, str) and not isinstance(casuallm, AutoModelForCausalLM):
from colossalai.inference.core.plugin import InferCheckpoint_io
# NOTE(lry89757) Deprecated currently, will reused when introduce lazy tensor
# if isinstance(model_or_path, str) and not isinstance(casuallm, AutoModelForCausalLM):
# from colossalai.inference.core.plugin import InferCheckpoint_io

cpt_io = InferCheckpoint_io()
if_has_index_file, model_index_file = has_index_file(model_or_path)
assert if_has_index_file, "the model path is invalid"
cpt_io.load_model(self.model, model_index_file)
# cpt_io = InferCheckpoint_io()
# if_has_index_file, model_index_file = has_index_file(model_or_path)
# assert if_has_index_file, "the model path is invalid"
# cpt_io.load_model(self.model, model_index_file)

free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
peak_memory = init_gpu_memory - free_gpu_memory
Expand Down
32 changes: 18 additions & 14 deletions colossalai/inference/executor/rpc_worker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
from typing import List, Tuple, Union

import rpyc
Expand All @@ -19,7 +18,7 @@
model_policy_map,
)
from colossalai.inference.sampler import search_tokens
from colossalai.inference.utils import get_model_size, has_index_file
from colossalai.inference.utils import get_model_size
from colossalai.interface import ModelWrapper
from colossalai.logging import get_dist_logger
from colossalai.pipeline.stage_manager import PipelineStageManager
Expand Down Expand Up @@ -178,15 +177,19 @@ def _init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy
"""

if isinstance(model_or_path, str):
is_local = os.path.isdir(model_or_path)
# is_local = os.path.isdir(model_or_path)
try:
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True)
arch = getattr(hf_config, "architectures")[0]
if is_local:
model = _SUPPORTED_MODELS[arch](hf_config)
else:
# load the real checkpoint
model = _SUPPORTED_MODELS[arch].from_pretrained(model_or_path, trust_remote_code=True)
# NOTE(lry89757) Currently we load the model using transformers-api,
# but we will use lazy tensor and checkpoint io to accelerate
# the model load process in the future.
model = _SUPPORTED_MODELS[arch].from_pretrained(model_or_path, trust_remote_code=True)
# if is_local:
# model = _SUPPORTED_MODELS[arch](hf_config)
# else:
# # load the real checkpoint
# model = _SUPPORTED_MODELS[arch].from_pretrained(model_or_path, trust_remote_code=True)
except Exception as e:
logger.error(
f"An exception occurred during loading model: {e}, model should be loaded by transformers\n"
Expand Down Expand Up @@ -235,13 +238,14 @@ def _init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
)

if isinstance(model_or_path, str) and is_local:
from colossalai.inference.core.plugin import InferCheckpoint_io
# NOTE(lry89757) Deprecated currently, will reused when introduce lazy tensor
# if isinstance(model_or_path, str) and is_local:
# from colossalai.inference.core.plugin import InferCheckpoint_io

cpt_io = InferCheckpoint_io()
if_has_index_file, model_index_file = has_index_file(model_or_path)
assert if_has_index_file, "the model path is invalid"
cpt_io.load_model(self.model, model_index_file)
# cpt_io = InferCheckpoint_io()
# if_has_index_file, model_index_file = has_index_file(model_or_path)
# assert if_has_index_file, "the model path is invalid"
# cpt_io.load_model(self.model, model_index_file)

free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
peak_memory = init_gpu_memory - free_gpu_memory
Expand Down
83 changes: 42 additions & 41 deletions colossalai/inference/modeling/models/nopadding_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,48 +645,49 @@ def forward(
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
# NOTE This is a hack to ensure we could load the right weight from LlamaAttention checkpoint due to the use of torch.stack(q_weight, k_weight, v_weight)
for hook in self._load_state_dict_pre_hooks.values():
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)

persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
local_state = {k: v for k, v in local_name_params if v is not None}

key = "qkv_weight"
k1 = "q_proj.weight"
k2 = "k_proj.weight"
k3 = "v_proj.weight"
q_w = state_dict[prefix + k1]
k_w = state_dict[prefix + k2]
v_w = state_dict[prefix + k3]

device_mesh = self.helper_layout.device_mesh
sharding_spec = self.helper_layout.sharding_spec
q_w = distribute_tensor(q_w, device_mesh, sharding_spec)
k_w = distribute_tensor(k_w, device_mesh, sharding_spec)
v_w = distribute_tensor(v_w, device_mesh, sharding_spec)

qkv_w = torch.stack([q_w.T, k_w.T, v_w.T], dim=0)

input_param = nn.Parameter(
qkv_w
) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)

param = local_state[key]

try:
with torch.no_grad():
param.copy_(input_param)
except Exception as ex:
error_msgs.append(
'While copying the parameter named "{}", '
"whose dimensions in the model are {} and "
"whose dimensions in the checkpoint are {}, "
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
)
if self.num_heads == self.num_key_value_heads:
# NOTE This is a hack to ensure we could load the right weight from LlamaAttention checkpoint due to the use of torch.stack(q_weight, k_weight, v_weight)
for hook in self._load_state_dict_pre_hooks.values():
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)

persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
local_state = {k: v for k, v in local_name_params if v is not None}

key = "qkv_weight"
k1 = "q_proj.weight"
k2 = "k_proj.weight"
k3 = "v_proj.weight"
q_w = state_dict[prefix + k1]
k_w = state_dict[prefix + k2]
v_w = state_dict[prefix + k3]

device_mesh = self.helper_layout.device_mesh
sharding_spec = self.helper_layout.sharding_spec
q_w = distribute_tensor(q_w, device_mesh, sharding_spec)
k_w = distribute_tensor(k_w, device_mesh, sharding_spec)
v_w = distribute_tensor(v_w, device_mesh, sharding_spec)

qkv_w = torch.stack([q_w.T, k_w.T, v_w.T], dim=0)

input_param = nn.Parameter(
qkv_w
) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)

param = local_state[key]

try:
with torch.no_grad():
param.copy_(input_param)
except Exception as ex:
error_msgs.append(
'While copying the parameter named "{}", '
"whose dimensions in the model are {} and "
"whose dimensions in the checkpoint are {}, "
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
)

strict = False # to avoid unexpected_keys
strict = False # to avoid unexpected_keys
super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)
Expand Down