-
Notifications
You must be signed in to change notification settings - Fork 0
feat(causal message passing) #2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
vahanhov
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please check the comments before merging
| if isinstance(succ_node, SequenceElement): | ||
| sequence_end_idx = succ_node.end_idx | ||
| elif isinstance(edge, SequenceElement): | ||
| sequence_end_idx = edge.end_idx | ||
| else: | ||
| sequence_end_idx = pred_node.end_idx | ||
| pred_node, edge, succ_node = edge_sequence[0] | ||
| if isinstance(succ_node, SequenceElement): | ||
| sequence_start_idx = succ_node.end_idx | ||
| elif isinstance(edge, SequenceElement): | ||
| sequence_start_idx = edge.end_idx | ||
| else: | ||
| sequence_start_idx = pred_node.end_idx |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if isinstance(succ_node, SequenceElement): | |
| sequence_end_idx = succ_node.end_idx | |
| elif isinstance(edge, SequenceElement): | |
| sequence_end_idx = edge.end_idx | |
| else: | |
| sequence_end_idx = pred_node.end_idx | |
| pred_node, edge, succ_node = edge_sequence[0] | |
| if isinstance(succ_node, SequenceElement): | |
| sequence_start_idx = succ_node.end_idx | |
| elif isinstance(edge, SequenceElement): | |
| sequence_start_idx = edge.end_idx | |
| else: | |
| sequence_start_idx = pred_node.end_idx | |
| if succ_node is not None: | |
| sequence_end_idx = succ_node.end_idx | |
| elif edge is not None: | |
| sequence_end_idx = edge.end_idx | |
| else: | |
| sequence_end_idx = pred_node.end_idx | |
| pred_node, edge, succ_node = edge_sequence[0] | |
| if succ_node is not None: | |
| sequence_start_idx = succ_node.end_idx | |
| elif edge is not None: | |
| sequence_start_idx = edge.end_idx | |
| else: | |
| sequence_start_idx = pred_node.end_idx |
If this is correct, I find it more intuitive. Otherwise I had to go look into the code to answer the question of why are these checks mutually exclusive.
| assert new_t_embeddings.shape == t_embeddings.shape | ||
| new_token_embeddings.append(new_t_embeddings.unsqueeze(0)) | ||
| return torch.cat(new_token_embeddings, dim=0) + token_embeddings | ||
| class CausalMessagePassingLayer(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Eventually we should add a link to the paper here, because it's not at all trivial what's happening here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good
| **kwargs, | ||
| ) -> dict: | ||
| # only last token for input_ids if past is not None | ||
| truncated_input_ids = input_ids |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| truncated_input_ids = input_ids |
| # past_key_values = None | ||
| if None: | ||
| truncated_input_ids = input_ids[:, -1].unsqueeze(-1) | ||
|
|
||
| # the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed | ||
| if past_key_values[0][0].shape[0] == input_ids.shape[0]: | ||
| past_key_values = self._convert_to_bloom_cache(past_key_values) | ||
|
|
||
| # if `inputs_embeds` are passed, we only want to use them in the 1st generation step |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| # past_key_values = None | |
| if None: | |
| truncated_input_ids = input_ids[:, -1].unsqueeze(-1) | |
| # the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed | |
| if past_key_values[0][0].shape[0] == input_ids.shape[0]: | |
| past_key_values = self._convert_to_bloom_cache(past_key_values) | |
| # if `inputs_embeds` are passed, we only want to use them in the 1st generation step |
| # if inputs_embeds is not None and past_key_values is None: | ||
| # model_inputs = {"inputs_embeds": inputs_embeds} | ||
| # else: | ||
| model_inputs = {"input_ids": truncated_input_ids} | ||
|
|
||
| model_inputs.update( | ||
| { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| # if inputs_embeds is not None and past_key_values is None: | |
| # model_inputs = {"inputs_embeds": inputs_embeds} | |
| # else: | |
| model_inputs = {"input_ids": truncated_input_ids} | |
| model_inputs.update( | |
| { | |
| return { |
| # "full_input_ids": input_ids, | ||
| } | ||
| ) | ||
| return model_inputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| # "full_input_ids": input_ids, | |
| } | |
| ) | |
| return model_inputs | |
| } |
| return new_positions.unsqueeze(0) | ||
|
|
||
|
|
||
| def _get_all_edge_previous_positions( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't understand what's happening with these positions. Are you actually using this one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes I just debugged it yesterday, it sets all previous edges in a sequence to be directly before the current edge in terms of position.
* novelty debugging * running solution * message passing slightly better * simplified serialize * current code * flamingo inspired * message passing correctly implemented * positions update * removing commented code * causal message passing * edge case in case using another model besides serialize * update message passing and position embedding * Update src/transformers/models/bloom/modeling_bloom.py * removed unnecessary code
* init commit * config updated also some modeling * Processor and Model config combined * extraction pipeline(upto before spectogram & mel_conditioner) added but not properly tested * model loading successful! * feature extractor done! * FE can now be called from HF * postprocessing added in fe file * same as prev commit * Pop2PianoConfig doc done * cfg docs slightly changed * fe docs done * batched * batched working! * temp * v1 * checking * trying to go with generate * with generate and model tests passed * before rebasing * . * tests done docs done remaining others & nits * nits * LogMelSpectogram shifted to FeatureExtractor * is_tf rmeoved from pop2piano/init * import solved * tokenization tests added * minor fixed regarding modeling_pop2piano * tokenizer changed to only return midi_object and other changes * Updated paper abstract(Camera-ready version) (#2) * more comments and nits * ruff changes * code quality fix * sg comments * t5 change added and rebased * comments except batching * batching done * comments * small doc fix * example removed from modeling * ckpt * forward it compatible with fe and generation done * comments * comments * code-quality fix(maybe) * ckpts changed * doc file changed from mdx to md * test fixes * tokenizer test fix * changes * nits done main changes remaining * code modified * Pop2PianoProcessor added with tests * other comments * added Pop2PianoProcessor to dummy_objects * added require_onnx to modeling file * changes * update .md file * remove extra line in index.md * back to the main index * added pop2piano to index * Added tokenizer.__call__ with valid args and batch_decode and aligned the processor part too * changes * added return types to 2 tokenizer methods * the PR build test might work now * added backends * PR build fix * vocab added * comments * refactored vocab into 1 file * added conversion script * comments * essentia version changed in .md * comments * more tokenizer tests added * minor fix * tests extended for outputs acc check * small fix --------- Co-authored-by: Jongho Choi <sweetcocoa@snu.ac.kr>
…es (attempt #2) (huggingface#26784) * Update logits_process.py docstrings + match arg fields to __init__'s * Ran `make style`
causal message passing