Skip to content

minicpm-v 全参数finetune报错 #677

@SeanLiaoy

Description

@SeanLiaoy

Describe the bug
训练几步后报错,相同的数据集用来训练qwen-vl-chat正常,启动训练命令如下

NPROC_PER_NODE=8 \
MASTER_PORT=29500 \
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
swift sft \
    --sft_type full \
    --model_type minicpm-v-3b-chat\
    --model_id_or_path ${BASE_MODEL_PATH} \
    --check_model_is_latest false \
    --dtype fp16 \
    --num_train_epochs 2 \
    --custom_train_dataset_path ${TRAIN_DATASET} \
    --custom_val_dataset_path ${VAL_DATASET} \
    --train_dataset_sample -1 \
    --max_length 400 \
    --learning_rate 1e-5 \
    --batch_size 8 \
    --gradient_accumulation_steps 8 \
    --output_dir ${SAVE_PATH} \
    --deepspeed 'default-zero2' \
    --eval_steps 100 \
    --save_steps 200 \
    --save_total_limit 2

Your hardware and system info

Additional context

Time to load fused_adam op: 0.10169339179992676 seconds
{'loss': 4.241786, 'acc': 0.43225819, 'learning_rate': 0.0, 'epoch': 0.0, 'global_step': 1}                                                                                                                                                                              
{'loss': 3.26653814, 'acc': 0.47446066, 'learning_rate': 4.42e-06, 'epoch': 0.01, 'global_step': 5}                                                                                                                                                                      
{'loss': 2.87632065, 'acc': 0.48180504, 'learning_rate': 6.33e-06, 'epoch': 0.03, 'global_step': 10}                                                                                                                                                                     
{'loss': 2.60134201, 'acc': 0.53103466, 'learning_rate': 7.44e-06, 'epoch': 0.04, 'global_step': 15}                                                                                                                                                                     
Train:   2%|█████▌                                                                                                                                                                                                                      | 19/760 [01:08<44:15,  3.58s/it]Traceback (most recent call last):
  File "/data/miniconda3/envs/env-3.10.6/lib/python3.10/site-packages/swift/cli/sft.py", line 5, in <module>
    sft_main()
  File "/data/miniconda3/envs/env-3.10.6/lib/python3.10/site-packages/swift/utils/run_utils.py", line 31, in x_main
    result = llm_x(args, **kwargs)
  File "/data/miniconda3/envs/env-3.10.6/lib/python3.10/site-packages/swift/llm/sft.py", line 236, in llm_sft
    trainer.train(training_args.resume_from_checkpoint)
  File "/data/miniconda3/envs/env-3.10.6/lib/python3.10/site-packages/swift/trainers/trainers.py", line 50, in train
    res = super().train(*args, **kwargs)
  File "/data/miniconda3/envs/env-3.10.6/lib/python3.10/site-packages/transformers/trainer.py", line 1537, in train
    return inner_training_loop(
  File "/data/miniconda3/envs/env-3.10.6/lib/python3.10/site-packages/transformers/trainer.py", line 1821, in _inner_training_loop
    for step, inputs in enumerate(epoch_iterator):
  File "/data/miniconda3/envs/env-3.10.6/lib/python3.10/site-packages/accelerate/data_loader.py", line 462, in __iter__
    next_batch = next(dataloader_iter)
  File "/data/miniconda3/envs/env-3.10.6/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 630, in __next__
    data = self._next_data()
  File "/data/miniconda3/envs/env-3.10.6/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 674, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/data/miniconda3/envs/env-3.10.6/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/data/miniconda3/envs/env-3.10.6/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/data/miniconda3/envs/env-3.10.6/lib/python3.10/site-packages/swift/llm/utils/utils.py", line 200, in __getitem__
    res = self._try_fetch(idx)
  File "/data/miniconda3/envs/env-3.10.6/lib/python3.10/site-packages/swift/llm/utils/utils.py", line 210, in _try_fetch
    res = self.template.encode(data)
  File "/data/miniconda3/envs/env-3.10.6/lib/python3.10/site-packages/swift/llm/utils/template.py", line 1115, in encode
    pixel_values = self.model.transform(image)[None].to(
  File "/data/miniconda3/envs/env-3.10.6/lib/python3.10/site-packages/torchvision/transforms/transforms.py", line 95, in __call__
    img = t(img)
  File "/data/miniconda3/envs/env-3.10.6/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/miniconda3/envs/env-3.10.6/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/miniconda3/envs/env-3.10.6/lib/python3.10/site-packages/torchvision/transforms/transforms.py", line 277, in forward
    return F.normalize(tensor, self.mean, self.std, self.inplace)
  File "/data/miniconda3/envs/env-3.10.6/lib/python3.10/site-packages/torchvision/transforms/functional.py", line 363, in normalize
    return F_t.normalize(tensor, mean=mean, std=std, inplace=inplace)
  File "/data/miniconda3/envs/env-3.10.6/lib/python3.10/site-packages/torchvision/transforms/_functional_tensor.py", line 928, in normalize
    return tensor.sub_(mean).div_(std)
RuntimeError: The size of tensor a (4) must match the size of tensor b (3) at non-singleton dimension 0

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingsolved

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions