Description
Thank you Francois for fixing #48, I can now train the model successfully.
But when I want to run inference with
eole predict --src wmt17_en_de/test.src.bpe --model_path wmt17_en_de/bigwmt17/step_50000 --beam_size 5 --batch_size 4096 --batch_type tokens --output wmt17_en_de/pred.trg.bpe --gpu 0
I get the error below. What might be wrong?
Best, Kai
[2024-09-08 18:23:26,170 INFO] Loading checkpoint from wmt17_en_de/bigwmt17/step_50000
[2024-09-08 18:23:26,893 INFO] Building model...
[2024-09-08 18:23:26,894 WARNING] You have a CUDA device, should run with -gpu_ranks
[2024-09-08 18:23:26,894 WARNING] You have a CUDA device, should run with -gpu_ranks
[2024-09-08 18:23:26,894 WARNING] You have a CUDA device, should run with -gpu_ranks
[2024-09-08 18:23:26,894 WARNING] You have a CUDA device, should run with -gpu_ranks
[2024-09-08 18:23:26,894 WARNING] You have a CUDA device, should run with -gpu_ranks
[2024-09-08 18:23:26,894 WARNING] You have a CUDA device, should run with -gpu_ranks
[2024-09-08 18:23:26,901 WARNING] You have a CUDA device, should run with -gpu_ranks
[2024-09-08 18:23:26,925 INFO] Loading data into the model
[2024-09-08 18:23:27,136 INFO] Transforms applied: []
Traceback (most recent call last):
File "/usr/local/bin/eole", line 33, in
sys.exit(load_entry_point('EOLE', 'console_scripts', 'eole')())
File "/eole/eole/bin/main.py", line 39, in main
bin_cls.run(args)
File "/eole/eole/bin/run/predict.py", line 42, in run
predict(config)
File "/eole/eole/bin/run/predict.py", line 18, in predict
_, _, _ = engine.infer_file()
File "/eole/eole/inference_engine.py", line 37, in infer_file
scores, estims, preds = self._predict(infer_iter)
File "/eole/eole/inference_engine.py", line 163, in _predict
scores, estims, preds = self.predictor._predict(
File "/eole/eole/predict/inference.py", line 454, in _predict
batch_data = self.predict_batch(batch, attn_debug)
File "/eole/eole/predict/translator.py", line 121, in predict_batch
return self._translate_batch_with_strategy(batch, decode_strategy)
File "/eole/eole/predict/translator.py", line 194, in _translate_batch_with_strategy
log_probs, attn = self._decode_and_generate(
File "/eole/eole/predict/inference.py", line 664, in _decode_and_generate
dec_out, dec_attn = self.model.decoder(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/eole/eole/decoders/transformer_decoder.py", line 200, in forward
emb, attn, attn_align = layer(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/eole/eole/decoders/transformer_base.py", line 76, in forward
layer_out, attns = self._forward(*args, **kwargs)
File "/eole/eole/decoders/transformer_decoder.py", line 95, in _forward
self_attn, _ = self.self_attn(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/eole/eole/modules/multi_headed_attn.py", line 690, in forward
return super()._forward2(
File "/eole/eole/modules/multi_headed_attn.py", line 461, in _forward2
attn_output = self.flash_attn_func(
File "/usr/local/lib/python3.10/dist-packages/flash_attn/flash_attn_interface.py", line 831, in flash_attn_func
return FlashAttnFunc.apply(
File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 598, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/usr/local/lib/python3.10/dist-packages/flash_attn/flash_attn_interface.py", line 511, in forward
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
File "/usr/local/lib/python3.10/dist-packages/flash_attn/flash_attn_interface.py", line 51, in _flash_attn_forward
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
RuntimeError: FlashAttention only support fp16 and bf16 data type
Activity