diff --git a/llava/train/train.py b/llava/train/train.py index cbfcc1bb4..f418a42ad 100644 --- a/llava/train/train.py +++ b/llava/train/train.py @@ -654,7 +654,7 @@ def modality_lengths(self): length_list = [] for sample in self.list_data_dict: cur_len = sum(len(conv['value'].split()) for conv in sample['conversations']) - cur_len = cur_len if 'image' in sample else -cur_len + cur_len = cur_len if 'images' in sample else -cur_len length_list.append(cur_len) return length_list @@ -700,11 +700,11 @@ def expand2square(pil_img, background_color): # image exist in the data if 'image' in self.list_data_dict[i]: - data_dict['image'] = image + data_dict['images'] = image elif self.data_args.is_multimodal: # image does not exist in the data, but the model is multimodal crop_size = self.data_args.image_processor.crop_size - data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width']) + data_dict['images'] = torch.zeros(3, crop_size['height'], crop_size['width']) return data_dict @@ -732,8 +732,8 @@ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: attention_mask=input_ids.ne(self.tokenizer.pad_token_id), ) - if 'image' in instances[0]: - images = [instance['image'] for instance in instances] + if 'images' in instances[0]: + images = [instance['images'] for instance in instances] if all(x is not None and x.shape == images[0].shape for x in images): batch['images'] = torch.stack(images) else: