-
Notifications
You must be signed in to change notification settings - Fork 97
feat: reduce cg memory footprint #225
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
test pass
|
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True, normalize=True)[0] | ||
assert ( | ||
transcription == "mister quilter is the apostle of the middle classes and we are glad to welcome his gospel" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
about this, is it a good idea to use "in" ? it happened that the beggining was correct but not the end
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it doesn't fail with beam > 1 or larger model, but in test we need to make it fast, so I guess this is the only way to make it ok. Do you see another way which doesn't require more computation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean using "in" instead of "=="
src/kernl/optimizer/attention.py
Outdated
@@ -31,8 +42,11 @@ def attention_wrapper(q, k, v, output, sm_scale, is_causal, attention_mask): | |||
|
|||
# When there is a large difference between those dimensions, our kernel become inefficient | |||
# (almost no parallelization), so we use pytorch instead | |||
if q.size(-2) == 1 and k.size(-2) > 50: | |||
attention_reference(q, k, v, output, sm_scale, is_causal=is_causal, attention_mask=attention_mask) | |||
if q.size(-2) == 1 and k.size(-2) > 50 and (attention_mask is None) and not is_causal: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't understand "(attention_mask is None) and not is_causal" the condition is only for attention_vec_mat_forward, for attention_reference it's ok
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you are right, I ll modify it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
test pass after all the commits above
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor changes only
src/kernl/optimizer/attention.py
Outdated
k = k.view(k.size(0), 1, k.size(-2), k.size(-1)) | ||
v = v.view(v.size(0), 1, v.size(-2), v.size(-1)) | ||
output = output.view(output.size(0), 1, output.size(-2), output.size(-1)) | ||
q.unsqueeze_(dim=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you are mutating the input ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes on PyTorch the underscore at the end of a method always mean the op is done in place, aka the original object is mutated
src/kernl/optimizer/attention.py
Outdated
else: | ||
attention_forward(q, k, v, output, sm_scale, is_causal=is_causal, attention_mask=attention_mask) | ||
|
||
if extend_head: | ||
return output.view(q.size(0), q.size(-2), q.size(-1)) | ||
output.squeeze_(dim=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
idem, you are mutating the input ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above
@@ -1,5 +1,5 @@ | |||
triton==2.0.0.dev20221202 | |||
torch== 2.0.0.dev20221214+cu117 | |||
torch==2.0.0.dev20230104+cu117 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dockerfile update missing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated
v_col_major = v.permute(0, 1, 3, 2).contiguous().permute(0, 1, 3, 2) | ||
# mutate v, so its storage is col major | ||
v.set_(source=v_col_major) | ||
# print("q", q.size(), q.stride(), len(q.untyped_storage()), q.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed
@@ -29,7 +29,7 @@ def _compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): | |||
return cuda_graphs_wrapper(gm, example_inputs) | |||
|
|||
|
|||
def optimize_model(original_model: PreTrainedModel) -> None: | |||
def optimize_model(model: PreTrainedModel) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is a breaking change, but I think we will do breaking changes anyway, could be nice to document it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where?
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True, normalize=True)[0] | ||
assert ( | ||
transcription == "mister quilter is the apostle of the middle classes and we are glad to welcome his gospel" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean using "in" instead of "=="
@@ -12,31 +12,77 @@ | |||
# See the License for the specific language governing permissions and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know why this file is in this folder, it should be in optimizer, maybe mistake in previous commits
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
fix #206
fix #205