Skip to content

Add offline dataset generation #39

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

vkkhare
Copy link
Contributor

@vkkhare vkkhare commented Jun 22, 2025

Description

Adding offline dataset generation for faster training cycles for predictors.

  • Creates a new activation capture class for all the training and accuracy benchmarking
  • generate_dataset.py generates huggingface compatible sparsity datasets for the predictors

Signed-off-by: Varun Khare <varun.khare@nimbleedgehq.ai>
@vkkhare vkkhare marked this pull request as draft June 22, 2025 17:59
Signed-off-by: Varun Khare <varun.khare@nimbleedgehq.ai>
vkkhare and others added 4 commits June 22, 2025 18:41
Signed-off-by: Varun Khare <varun.khare@nimbleedgehq.ai>
Signed-off-by: Varun Khare <varun.khare@nimbleedgehq.ai>
Signed-off-by: Varun Khare <varun.khare@nimbleedgehq.ai>
Signed-off-by: Varun Khare <varun.khare@niimbleedgehq.ai>
@vkkhare vkkhare marked this pull request as ready for review June 26, 2025 06:25
@vkkhare vkkhare changed the title [WIIP] Add offline dataset generation Add offline dataset generation Jun 26, 2025
Signed-off-by: Varun Khare <varun.khare@niimbleedgehq.ai>
for i, layer in enumerate(model.model.layers):

# Capture hidden states before MLP
handle = layer.register_forward_hook(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't this capture the states pre-attention?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup that is what Deja VU recommends right for better parallelisation. I can move it post layer normalization too


# Capture MLP gate activations (after activation function)
if hasattr(layer.mlp, 'gate_proj'):
handle = layer.mlp.gate_proj.register_forward_hook(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't this capture the states before the activation function rather than after?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will return the output of gate_proj and up_proj linear layers. We compute the activations separately before storing it

up_act = self.mlp_activations[up_key]

# Apply SwiGLU activation: silu(gate) * up
gated_act = F.silu(gate_act) * up_act
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should have support for non-silu activations here - OPT uses ReLU at the very least

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

++ will take it as input

self.remove_hooks()

# Hook into each transformer layer
for i, layer in enumerate(model.model.layers):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model.model.layers will cause problems with OPT due to it having an inconvenient wrapper around the decoder, but there are a few ways of fixing this - I can either refactor the OPT code or we can have some kind of model.get_layers() method that we call here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model.get_layers looks more idiomatic approach

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants