Skip to content

Recipe stuck when predicting #99

Closed
@Kai-Piontek

Description

Thank you Francois for fixing #48, I can now train the model successfully.

But when I want to run inference with
eole predict --src wmt17_en_de/test.src.bpe --model_path wmt17_en_de/bigwmt17/step_50000 --beam_size 5 --batch_size 4096 --batch_type tokens --output wmt17_en_de/pred.trg.bpe --gpu 0

I get the error below. What might be wrong?

Best, Kai

[2024-09-08 18:23:26,170 INFO] Loading checkpoint from wmt17_en_de/bigwmt17/step_50000
[2024-09-08 18:23:26,893 INFO] Building model...
[2024-09-08 18:23:26,894 WARNING] You have a CUDA device, should run with -gpu_ranks
[2024-09-08 18:23:26,894 WARNING] You have a CUDA device, should run with -gpu_ranks
[2024-09-08 18:23:26,894 WARNING] You have a CUDA device, should run with -gpu_ranks
[2024-09-08 18:23:26,894 WARNING] You have a CUDA device, should run with -gpu_ranks
[2024-09-08 18:23:26,894 WARNING] You have a CUDA device, should run with -gpu_ranks
[2024-09-08 18:23:26,894 WARNING] You have a CUDA device, should run with -gpu_ranks
[2024-09-08 18:23:26,901 WARNING] You have a CUDA device, should run with -gpu_ranks
[2024-09-08 18:23:26,925 INFO] Loading data into the model
[2024-09-08 18:23:27,136 INFO] Transforms applied: []
Traceback (most recent call last):
File "/usr/local/bin/eole", line 33, in
sys.exit(load_entry_point('EOLE', 'console_scripts', 'eole')())
File "/eole/eole/bin/main.py", line 39, in main
bin_cls.run(args)
File "/eole/eole/bin/run/predict.py", line 42, in run
predict(config)
File "/eole/eole/bin/run/predict.py", line 18, in predict
_, _, _ = engine.infer_file()
File "/eole/eole/inference_engine.py", line 37, in infer_file
scores, estims, preds = self._predict(infer_iter)
File "/eole/eole/inference_engine.py", line 163, in _predict
scores, estims, preds = self.predictor._predict(
File "/eole/eole/predict/inference.py", line 454, in _predict
batch_data = self.predict_batch(batch, attn_debug)
File "/eole/eole/predict/translator.py", line 121, in predict_batch
return self._translate_batch_with_strategy(batch, decode_strategy)
File "/eole/eole/predict/translator.py", line 194, in _translate_batch_with_strategy
log_probs, attn = self._decode_and_generate(
File "/eole/eole/predict/inference.py", line 664, in _decode_and_generate
dec_out, dec_attn = self.model.decoder(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/eole/eole/decoders/transformer_decoder.py", line 200, in forward
emb, attn, attn_align = layer(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/eole/eole/decoders/transformer_base.py", line 76, in forward
layer_out, attns = self._forward(*args, **kwargs)
File "/eole/eole/decoders/transformer_decoder.py", line 95, in _forward
self_attn, _ = self.self_attn(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/eole/eole/modules/multi_headed_attn.py", line 690, in forward
return super()._forward2(
File "/eole/eole/modules/multi_headed_attn.py", line 461, in _forward2
attn_output = self.flash_attn_func(
File "/usr/local/lib/python3.10/dist-packages/flash_attn/flash_attn_interface.py", line 831, in flash_attn_func
return FlashAttnFunc.apply(
File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 598, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/usr/local/lib/python3.10/dist-packages/flash_attn/flash_attn_interface.py", line 511, in forward
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
File "/usr/local/lib/python3.10/dist-packages/flash_attn/flash_attn_interface.py", line 51, in _flash_attn_forward
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
RuntimeError: FlashAttention only support fp16 and bf16 data type

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions