You can retrain the base decoder of ChartVLM and reproduce our results via the following steps:
cd base_decoder/train
The following datasets are used in our paper:
- ChartQA [Dataset Page]
- PlotQA [Dataset Page]
- Chart2Text [Dataset Page]
- SimChart9K [Download]
-
In order to speed up the data i/o during the training process for base_decoder, we choose to preprocess the downloaded chart data, saving as the .npy format.
-
You have to first preprocess the data before starting the training process (This should be the absolute path of the downloaded datasets)
cd tools/data_preprocess/
# Change the root path for the downloaded ChartQA or Chart2Text dataset
python data_preprocess_ChartQA_Chart2Text.py
cd tools/data_preprocess/
# Change the root path for the downloaded PlotQA dataset
python data_preprocess_PlotQA.py
cd tools/data_preprocess/
# Change the root path for the downloaded SimChart9K dataset
python data_preprocess_SimChart9K.py
# return to the 'tools' directory
cd ..
- Train the Base Model using multi-GPU
sh scripts/dist_train.sh 8 \
--config ./cfgs/image_to_csv_base_merge_all_trained.yaml \
--VAL_PER_EPOCH 0
- Train the Base Model using multi-machines
sh scripts/slurm_train.sh ${PARTITION} ${JOB_NAME} ${NUM_NODES} \
--cfg_file ./cfgs/image_to_csv_base_merge_all_trained.yaml \
--VAL_PER_EPOCH 0
- Train the Large Model using multi-GPU
sh scripts/dist_train.sh 8 \
--config ./cfgs/image_to_csv_large_merge_all_trained.yaml \
--VAL_PER_EPOCH 0
- Train the Large Model using multi-machines
sh scripts/slurm_train.sh ${PARTITION} ${JOB_NAME} ${NUM_NODES} \
--cfg_file ./cfgs/image_to_csv_large_merge_all_trained.yaml \
--VAL_PER_EPOCH 0
- Evaluate the Model using multi-GPU, SCRM metric, and 1280 output tokens
sh scripts/dist_test.sh 4 \
--config ./cfgs/image_to_csv_large_merge_all_trained.yaml \
--criterion csv_metric \
--num_token 1280