Skip to content

Commit

Permalink
fix inputs & multimodal model bug (#809)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang authored Apr 26, 2024
1 parent 4818e3d commit 52ee111
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 0 deletions.
1 change: 1 addition & 0 deletions swift/llm/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def _get_dataset(*args, **kwargs):

# only use train_dataset
dataset = get_dataset(data)[0]
logger.info(f'quant_dataset: {dataset}')
dataset = dataset.shuffle()

samples = []
Expand Down
16 changes: 16 additions & 0 deletions swift/llm/utils/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,8 @@ def encode(
self, example: Dict[str,
Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
inputs, tokenizer_kwargs = super().encode(example)
if len(inputs) == 0:
return inputs, tokenizer_kwargs
inputs.pop('loss_scale', None)
inputs.update(tokenizer_kwargs)
return inputs, tokenizer_kwargs
Expand Down Expand Up @@ -711,6 +713,8 @@ def encode(
self, example: Dict[str,
Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
inputs, _ = super().encode(example)
if len(inputs) == 0:
return inputs, {}
inputs.pop('loss_scale', None)
from llava.mm_utils import expand2square
model = self.model.model
Expand Down Expand Up @@ -909,6 +913,8 @@ def encode(
image = self.model.vis_processor(image)
images.append(image.to(dtype))
inputs, _ = super().encode(example)
if len(inputs) == 0:
return inputs, {}
inputs.pop('loss_scale', None)
input_ids = inputs['input_ids']
labels = inputs['labels']
Expand Down Expand Up @@ -1055,6 +1061,8 @@ def encode(
self, example: Dict[str,
Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
inputs, _ = super().encode(example)
if len(inputs) == 0:
return inputs, {}
images_path = example['images']
images = []
for image_path in images_path:
Expand Down Expand Up @@ -1152,6 +1160,8 @@ def encode(
example['query'], history, '<image_placeholder>')

inputs, _ = super().encode(example)
if len(inputs) == 0:
return inputs, {}
images = []
for image_path in images_path:
image = _read_from_path(image_path)
Expand Down Expand Up @@ -1256,6 +1266,8 @@ def encode(
assert len(images_path) == 1
image = _read_from_path(images_path[0])
inputs, _ = super().encode(example)
if len(inputs) == 0:
return inputs, {}
inputs.pop('loss_scale', None)
model = self.model
inputs2 = model.build_conversation_input_ids(
Expand Down Expand Up @@ -1338,6 +1350,8 @@ def encode(
assert len(images_path) == 1
image = _read_from_path(images_path[0])
inputs, _ = super().encode(example)
if len(inputs) == 0:
return inputs, {}
input_ids = inputs['input_ids']
labels = inputs['labels']

Expand Down Expand Up @@ -1510,6 +1524,8 @@ def encode(
image = image.resize((max_edge, max_edge))
images.append(image)
inputs, _ = super().encode(example)
if len(inputs) == 0:
return inputs, {}
input_ids = inputs['input_ids']
labels = inputs['labels']
images = process_images(images, image_processor)
Expand Down

0 comments on commit 52ee111

Please sign in to comment.