Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Snowflake arctic model implementation #4652

Merged
merged 52 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
30b0a1a
yak tps 72 at bs=1 (#55)
sfc-gh-hazhang Apr 17, 2024
37d9262
renmaing
sfc-gh-hazhang Apr 23, 2024
bf3f61f
run linting
sfc-gh-hazhang Apr 24, 2024
bfe5011
Merge pull request #3 from Snowflake-Labs/arctic-linting
iamontheinet Apr 24, 2024
140af95
arctic load quantized checkpoint
aurickq Apr 29, 2024
1dbacc8
revert tensorize
aurickq Apr 29, 2024
13a8f6b
update
aurickq Apr 29, 2024
07ca1d1
wip
sfc-gh-aqiao Apr 29, 2024
dc4a13a
wip
sfc-gh-aqiao Apr 29, 2024
d4b2e26
update
aurickq Apr 30, 2024
6353101
update
aurickq Apr 30, 2024
9c82b9d
update example
sfc-gh-aqiao Apr 30, 2024
9f4a66d
update
sfc-gh-aqiao Apr 30, 2024
24f0fd9
update
aurickq Apr 30, 2024
d028829
update example
aurickq Apr 30, 2024
e35668b
update
aurickq Apr 30, 2024
ad00b54
Merge pull request #7 from Snowflake-Labs/arctic-quant
sfc-gh-aqiao Apr 30, 2024
6aa6de3
remove dummy path in arctic
sfc-gh-aqiao May 1, 2024
15de0c2
Merge pull request #8 from Snowflake-Labs/remove-dummy
sfc-gh-aqiao May 1, 2024
e68f9a1
merge main
sfc-gh-hazhang May 4, 2024
da5e148
adapt to latest changes
sfc-gh-hazhang May 6, 2024
0111d02
run linting
sfc-gh-hazhang May 6, 2024
332890b
sharded prequantized checkpoints (#12)
aurickq May 7, 2024
67c2d60
run linting
sfc-gh-hazhang May 7, 2024
8585450
pr cleanup
sfc-gh-aqiao May 8, 2024
1e5aabf
address comments
sfc-gh-hazhang May 8, 2024
c18a290
add docstring
sfc-gh-hazhang May 8, 2024
4d9ddaf
support remote code and HF repo name
sfc-gh-hazhang May 8, 2024
e832146
update
sfc-gh-aqiao May 8, 2024
4ef9a53
Merge branch 'sfc-gh-hazhang:arctic-prr' into arctic-prr
sfc-gh-aqiao May 8, 2024
904acc9
update
sfc-gh-aqiao May 8, 2024
f2490a6
Merge branch 'arctic-prr' of github.com:sfc-gh-aqiao/vllm into arctic…
sfc-gh-aqiao May 8, 2024
de08a4a
update
sfc-gh-aqiao May 8, 2024
3235680
Merge pull request #1 from sfc-gh-aqiao/arctic-prr
aurickq May 8, 2024
9dd2653
fix mypy
sfc-gh-aqiao May 8, 2024
98d60ce
Merge branch 'sfc-gh-hazhang:arctic-prr' into arctic-prr
sfc-gh-aqiao May 8, 2024
ed34596
Merge pull request #2 from sfc-gh-aqiao/arctic-prr
aurickq May 8, 2024
1584db8
expose public save_sharded_state interface
sfc-gh-aqiao May 8, 2024
0a1a10f
update
sfc-gh-aqiao May 8, 2024
6c50193
Merge pull request #3 from sfc-gh-aqiao/arctic-prr
aurickq May 8, 2024
46b9024
fix ruff
sfc-gh-aqiao May 8, 2024
6de25cf
Merge branch 'sfc-gh-hazhang:arctic-prr' into arctic-prr
sfc-gh-aqiao May 8, 2024
58570be
Merge pull request #4 from sfc-gh-aqiao/arctic-prr
aurickq May 8, 2024
9593598
isort
sfc-gh-aqiao May 8, 2024
13a57ef
Merge branch 'arctic-prr' of github.com:sfc-gh-aqiao/vllm into arctic…
sfc-gh-aqiao May 8, 2024
e4e4299
Merge branch 'sfc-gh-hazhang:arctic-prr' into arctic-prr
sfc-gh-aqiao May 8, 2024
056273c
Merge pull request #5 from sfc-gh-aqiao/arctic-prr
aurickq May 8, 2024
462df46
Update vllm/model_executor/layers/quantization/deepspeedfp.py
aurickq May 8, 2024
319104a
separate out sharded state loader
sfc-gh-aqiao May 8, 2024
7d84bb0
Merge pull request #6 from sfc-gh-aqiao/arctic-prr
aurickq May 8, 2024
1bef4d6
remove
sfc-gh-hazhang May 9, 2024
8ec57a8
remove empty cache
sfc-gh-hazhang May 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions examples/offline_inference_arctic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from vllm import LLM, SamplingParams

# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

# Create an LLM.
llm = LLM(model="snowflake/snowflake-arctic-instruct",
quantization="deepspeedfp",
tensor_parallel_size=8,
trust_remote_code=True)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.

outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
4 changes: 3 additions & 1 deletion vllm/model_executor/layers/fused_moe/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_moe, get_config_file_name)
fused_experts, fused_moe, fused_topk, get_config_file_name)

__all__ = [
"fused_moe",
"fused_topk",
"fused_experts",
"get_config_file_name",
]
137 changes: 90 additions & 47 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,60 +308,16 @@ def get_moe_configs(E: int, N: int,
return None


def fused_moe(
def fused_topk(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
inplace: bool = False,
override_config: Optional[Dict[str, Any]] = None,
use_fp8: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.

Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.

Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
):
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16
]

M, _ = hidden_states.shape
E, N, _ = w1.shape

if is_hip():
# The MoE kernels are not yet supported on ROCm.
Expand Down Expand Up @@ -393,6 +349,33 @@ def fused_moe(
del token_expert_indicies # Not used. Will be used in the future.
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids


def fused_experts(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
override_config: Optional[Dict[str, Any]] = None,
use_fp8: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None):
# Check constraints.
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16
]

M, _ = hidden_states.shape
E, N, _ = w1.shape

if override_config:
config = override_config
Expand Down Expand Up @@ -477,3 +460,63 @@ def fused_moe(
out=hidden_states)
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
dim=1)


def fused_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
inplace: bool = False,
override_config: Optional[Dict[str, Any]] = None,
use_fp8: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.

Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.

Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"

topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
renormalize)
return fused_experts(hidden_states,
w1,
w2,
topk_weights,
topk_ids,
inplace=inplace,
override_config=override_config,
use_fp8=use_fp8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale)
sfc-gh-hazhang marked this conversation as resolved.
Show resolved Hide resolved
3 changes: 3 additions & 0 deletions vllm/model_executor/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.deepspeedfp import (
DeepSpeedFPConfig)
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import (
Expand All @@ -19,6 +21,7 @@
"squeezellm": SqueezeLLMConfig,
"gptq_marlin": GPTQMarlinConfig,
"marlin": MarlinConfig,
"deepspeedfp": DeepSpeedFPConfig
}


Expand Down
Loading
Loading