Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Running FP8 quantized model fails on NVIDIA L4 (repack_fp8_for_marlin) #2388

Open
2 of 4 tasks
DrNochi opened this issue Aug 9, 2024 · 5 comments
Open
2 of 4 tasks

Comments

@DrNochi
Copy link

DrNochi commented Aug 9, 2024

System Info

  • Hardware: AWS g6.12xlarge (us-east-2) / 4x NVIDIA L4 GPU
  • OS: Ubuntu 24.04 LTS (Noble Numbat)
  • NVIDIA Driver: nvidia-open 560.28.03
  • CUDA: 12.6
  • Docker: Docker version 27.1.1, build 6312585
  • NVIDIA Container Toolkit: 1.16.1
  • TGI: ghcr.io/huggingface/text-generation-inference:latest (4b44be4c038f)
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.28.03              Driver Version: 560.28.03      CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA L4                      Off |   00000000:38:00.0 Off |                    0 |
| N/A   41C    P8             16W /   72W |       1MiB /  23034MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA L4                      Off |   00000000:3A:00.0 Off |                    0 |
| N/A   42C    P8             17W /   72W |       1MiB /  23034MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA L4                      Off |   00000000:3C:00.0 Off |                    0 |
| N/A   41C    P8             17W /   72W |       1MiB /  23034MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   3  NVIDIA L4                      Off |   00000000:3E:00.0 Off |                    0 |
| N/A   38C    P8             16W /   72W |       1MiB /  23034MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

Information

  • Docker
  • The CLI directly

Tasks

  • An officially supported command
  • My own modifications

Reproduction

To reproduce please run the following shell script:

model=neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8
volume=$PWD/weights
token=<REDACTED>

docker run --rm --runtime nvidia --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
    ghcr.io/huggingface/text-generation-inference:latest --model-id $model

The follow exception appears during startup:

2024-08-09T12:44:30.578630Z  INFO text_generation_launcher: Args {
    model_id: "neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8",
    revision: None,
    validation_workers: 2,
    sharded: None,
    num_shard: None,
    quantize: None,
    speculate: None,
    dtype: None,
    trust_remote_code: false,
    max_concurrent_requests: 128,
    max_best_of: 2,
    max_stop_sequences: 4,
    max_top_n_tokens: 5,
    max_input_tokens: None,
    max_input_length: None,
    max_total_tokens: None,
    waiting_served_ratio: 0.3,
    max_batch_prefill_tokens: None,
    max_batch_total_tokens: None,
    max_waiting_tokens: 20,
    max_batch_size: None,
    cuda_graphs: None,
    port: 80,
    shard_uds_path: "/tmp/text-generation-server",
    master_addr: "localhost",
    master_port: 29500,
    huggingface_hub_cache: Some(
        "/data",
    ),
    weights_cache_override: None,
    disable_custom_kernels: false,
    cuda_memory_fraction: 1.0,
    rope_scaling: None,
    rope_factor: None,
    json_output: false,
    otlp_endpoint: None,
    otlp_service_name: "text-generation-inference.router",
    cors_allow_origin: [],
    api_key: None,
    watermark_gamma: None,
    watermark_delta: None,
    ngrok: false,
    ngrok_authtoken: None,
    ngrok_edge: None,
    tokenizer_config_path: None,
    disable_grammar_support: false,
    env: false,
    max_client_batch_size: 4,
    lora_adapters: None,
    usage_stats: On,
}
2024-08-09T12:44:30.578721Z  INFO hf_hub: Token file not found "/root/.cache/huggingface/token"
2024-08-09T12:44:30.632826Z  INFO text_generation_launcher: Model supports up to 131072 but tgi will now set its default to 4096 instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens=131122 --max-total-tokens=131072 --max-input-tokens=131071`.
2024-08-09T12:44:30.632839Z  INFO text_generation_launcher: Default `max_input_tokens` to 4095
2024-08-09T12:44:30.632841Z  INFO text_generation_launcher: Default `max_total_tokens` to 4096
2024-08-09T12:44:30.632843Z  INFO text_generation_launcher: Default `max_batch_prefill_tokens` to 4145
2024-08-09T12:44:30.632844Z  INFO text_generation_launcher: Using default cuda graphs [1, 2, 4, 8, 16, 32]
2024-08-09T12:44:30.632939Z  INFO download: text_generation_launcher: Starting check and download process for neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8
2024-08-09T12:44:33.840860Z  INFO text_generation_launcher: Files are already present on the host. Skipping download.
2024-08-09T12:44:34.638228Z  INFO download: text_generation_launcher: Successfully downloaded weights for neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8
2024-08-09T12:44:34.638608Z  INFO shard-manager: text_generation_launcher: Starting shard rank=0
2024-08-09T12:44:39.179449Z  INFO text_generation_launcher: GPU does not support FP8, using Marlin FP8 kernel
2024-08-09T12:44:39.213012Z ERROR text_generation_launcher: Error when initializing model
Traceback (most recent call last):
  File "/opt/conda/bin/text-generation-server", line 8, in <module>
    sys.exit(app())
  File "/opt/conda/lib/python3.10/site-packages/typer/main.py", line 311, in __call__
    return get_command(self)(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1157, in __call__
:Error: ShardCannotStart 
    return self.main(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/typer/core.py", line 778, in main
    return _main(
  File "/opt/conda/lib/python3.10/site-packages/typer/core.py", line 216, in _main
    rv = self.invoke(ctx)
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1688, in invoke
    return _process_result(sub_ctx.command.invoke(sub_ctx))
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1434, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 783, in invoke
    return __callback(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/typer/main.py", line 683, in wrapper
    return callback(**use_params)  # type: ignore
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/cli.py", line 109, in serve
    server.serve(
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/server.py", line 274, in serve
    asyncio.run(
  File "/opt/conda/lib/python3.10/asyncio/runners.py", line 44, in run
    return loop.run_until_complete(main)
  File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 636, in run_until_complete
    self.run_forever()
  File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
    self._run_once()
  File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once
    handle._run()
  File "/opt/conda/lib/python3.10/asyncio/events.py", line 80, in _run
    self._context.run(self._callback, *self._args)
> File "/opt/conda/lib/python3.10/site-packages/text_generation_server/server.py", line 229, in serve_inner
    model = get_model_with_lora_adapters(
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/__init__.py", line 1195, in get_model_with_lora_adapters
    model = get_model(
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/__init__.py", line 766, in get_model
    return FlashCausalLM(
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/flash_causal_lm.py", line 896, in __init__
    model = model_class(prefix, config, weights)
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 528, in __init__
    self.model = FlashLlamaModel(prefix, config, weights)
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 418, in __init__
    FlashLlamaLayer(
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 346, in __init__
    self.self_attn = FlashLlamaAttention(
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 166, in __init__
    self.query_key_value = load_attention(config, prefix, weights, index)
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 94, in load_attention
    base_layer = TensorParallelColumnLinear.load_multi(
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/layers/tensor_parallel.py", line 179, in load_multi
    linear = get_linear(weight, bias)
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/layers/linear.py", line 102, in get_linear
    return weight.get_linear(bias)
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/layers/fp8.py", line 185, in get_linear
    return get_fp8_linear().from_fp8(
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/layers/marlin/fp8.py", line 66, in from_fp8
    return cls(qweight=weight, scales=scale.to(dtype), bias=bias)
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/layers/marlin/fp8.py", line 45, in __init__
    qweight, scales = repack_fp8_for_marlin(qweight, scales)
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/layers/marlin/fp8.py", line 138, in repack_fp8_for_marlin
    scales = permute_scales(scales)
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/layers/marlin/util.py", line 48, in permute_scales
    scales = scales.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
RuntimeError: shape '[-1, 32]' is invalid for input of size 3
2024-08-09T12:44:40.147176Z ERROR shard-manager: text_generation_launcher: Shard complete standard error output:

2024-08-09 12:44:36.570 | INFO     | text_generation_server.utils.import_utils:<module>:73 - Detected system cuda
/opt/conda/lib/python3.10/site-packages/text_generation_server/utils/sgmv.py:18: UserWarning: Could not import SGMV kernel from Punica, falling back to loop.
  warnings.warn("Could not import SGMV kernel from Punica, falling back to loop.")
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/selective_scan_interface.py:159: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
  def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/selective_scan_interface.py:232: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
  def backward(ctx, dout):
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/layernorm.py:508: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
  def forward(
/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/layernorm.py:567: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
  def backward(ctx, dout, *args):
/opt/conda/lib/python3.10/site-packages/torch/distributed/c10d_logger.py:79: FutureWarning: You are using a Backend <class 'text_generation_server.utils.dist.FakeGroup'> as a ProcessGroup. This usage is deprecated since PyTorch 2.0. Please use a public API of PyTorch Distributed instead.
  return func(*args, **kwargs)
Traceback (most recent call last):

  File "/opt/conda/bin/text-generation-server", line 8, in <module>
    sys.exit(app())

  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/cli.py", line 109, in serve
    server.serve(

  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/server.py", line 274, in serve
    asyncio.run(

  File "/opt/conda/lib/python3.10/asyncio/runners.py", line 44, in run
    return loop.run_until_complete(main)

  File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 649, in run_until_complete
    return future.result()

  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/server.py", line 229, in serve_inner
    model = get_model_with_lora_adapters(

  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/__init__.py", line 1195, in get_model_with_lora_adapters
    model = get_model(

  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/__init__.py", line 766, in get_model
    return FlashCausalLM(

  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/flash_causal_lm.py", line 896, in __init__
    model = model_class(prefix, config, weights)

  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 528, in __init__
    self.model = FlashLlamaModel(prefix, config, weights)

  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 418, in __init__
    FlashLlamaLayer(

  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 346, in __init__
    self.self_attn = FlashLlamaAttention(

  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 166, in __init__
    self.query_key_value = load_attention(config, prefix, weights, index)

  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 94, in load_attention
    base_layer = TensorParallelColumnLinear.load_multi(

  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/layers/tensor_parallel.py", line 179, in load_multi
    linear = get_linear(weight, bias)

  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/layers/linear.py", line 102, in get_linear
    return weight.get_linear(bias)

  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/layers/fp8.py", line 185, in get_linear
    return get_fp8_linear().from_fp8(

  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/layers/marlin/fp8.py", line 66, in from_fp8
    return cls(qweight=weight, scales=scale.to(dtype), bias=bias)

  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/layers/marlin/fp8.py", line 45, in __init__
    qweight, scales = repack_fp8_for_marlin(qweight, scales)

  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/layers/marlin/fp8.py", line 138, in repack_fp8_for_marlin
    scales = permute_scales(scales)

  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/layers/marlin/util.py", line 48, in permute_scales
    scales = scales.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]

RuntimeError: shape '[-1, 32]' is invalid for input of size 3
 rank=0
2024-08-09T12:44:40.244304Z ERROR text_generation_launcher: Shard 0 failed to start

Expected behavior

I expect no RuntimeError during shard initialization. TGI should start up and server the model without problems.

@DrNochi
Copy link
Author

DrNochi commented Aug 9, 2024

While I did not find the issue already being reported here, I did find a similar issue being reported with vLLM. They even just fixed it as well, so maybe their changes could be ported over to TGI?

@ErikKaum
Copy link
Member

ErikKaum commented Aug 9, 2024

Hi @DrNochi 👋

Thanks for reporting this, I'll tag @danieldk, he is a marlin expert!

@DrNochi
Copy link
Author

DrNochi commented Aug 9, 2024

Side note: Why is TGI even falling back on using Marlin kernels? As far as I know NVIDIA L4 is using the Ada Lovelace Architecture with CUDA compute capability 8.9, which should have hardware support for FP: NVIDIA CUDA Docs. Am I missing something?

Having a quick look through the code, I found the following PR #2277 which was part of the latest release, basically "blocking" TGI from utilizing the native FP8 support by forcing the Marlin kernels for CC 8.9. I did not find any issues or further explaination pertaining to these changes. Maybe @OlivierDehaene could shed some light on the reasoning behind this change?

@danieldk
Copy link
Member

Side note: Why is TGI even falling back on using Marlin kernels? As far as I know NVIDIA L4 is using the Ada Lovelace Architecture with CUDA compute capability 8.9, which should have hardware support for FP: NVIDIA CUDA Docs. Am I missing something?

We switched to fbgemm-gpu for FP8 matmul. However, it uses TMA (Tensor Memory Accelerator), which is not supported in CC 8.9.

@nbroad1881
Copy link
Contributor

If you want to use llama 3.1 8b instruct in fp8, you can use the original repo:

docker run --gpus all --shm-size 1g -e HF_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.3.0 --model-id meta-llama/Meta-Llama-3.1-8B-Instruct --quantize fp8

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants