Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
swolchok committed Sep 24, 2024
1 parent 43ad34e commit bac70f7
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 9 deletions.
1 change: 1 addition & 0 deletions torchchat/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,7 @@ def chat(
if text_transformer_args is not None
else 2048
),
max_seq_length
)

max_seq_length = (
Expand Down
26 changes: 17 additions & 9 deletions torchchat/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,10 @@ def forward(
)

return self.decoder(decoder_input, input_pos=input_pos)

def setup_caches(self, batch_size, max_seq_len) -> None:
self.decoder.setup_caches(batch_size, max_seq_len)

def _encoder_feature_select(self, encoder_output) -> Tensor:
selected_image_feature = encoder_output[1][0].view(
*encoder_output[1][0].shape[2:]
Expand All @@ -154,7 +154,7 @@ def _get_decoder_input(
image_embeds = self.mm_projector(encoder_output)
if post_tokens is None:
return torch.cat((pre_img_embed, image_embeds), dim=1)

post_img_embed = self.tok_embeddings(post_tokens)
return torch.cat((pre_img_embed, image_embeds, post_img_embed), dim=1)

Expand Down Expand Up @@ -227,7 +227,7 @@ def _llava(cls):
},
fusion_class=ConcateFusion,
)

@classmethod
def get_recipe(cls, model_type):
match model_type:
Expand Down Expand Up @@ -338,7 +338,7 @@ def _sanity_check(
def from_params(cls, params_path):
with open(params_path, "r") as f:
loaded_params = json.loads(f.read())

if (model_type_name := loaded_params.get("model_type", None)) is None:
# The model params is in the transformer_args format
# set the model_type to TextOnly and reformat the params
Expand Down Expand Up @@ -460,14 +460,14 @@ def build_model(self) -> nn.Module:
modules[name] = module_class(**config_args)

return recipe.fusion_class(**modules)

def _replace_known_params(self, params):
patterns = {"QuickGELUActivation()": QuickGELUActivation()}
for key, value in params.items():
if isinstance(value, Hashable) and value in patterns:
params[key] = patterns[value]
return params

@abstractmethod
def forward(self, *args, **kwargs):
raise NotImplementedError("forward method is not implemented")
Expand Down Expand Up @@ -939,7 +939,15 @@ def __init__(self, config, path) -> None:
self.model_ = exec_lib._load_for_executorch(str(path))

self.text_transformer_args = TransformerArgs.from_params(self.config.transformer_args["text"])

# TODO: attempt to use "get_max_seq_len" method on the model after
# ExecuTorch bug is fixed.
max_seq_len = 128
# try:
# max_seq_len = self.model_.run_method("get_max_seq_len", [])
# except Exception as e:
# pass
self.text_transformer_args.max_seq_length = max_seq_len

def forward(self, x, input_pos):
# model_.forward expects inputs to be wrapped in a tuple
forward_inputs = (x.to(torch.long), input_pos.to(torch.long))
Expand All @@ -958,6 +966,6 @@ def forward(self, x, input_pos):

def setup_caches(self, max_batch_size, max_seq_length):
pass

except:
pass

0 comments on commit bac70f7

Please sign in to comment.