Skip to content

Conversation

pommedeterresautee
Copy link
Member

fix #206
fix #205

@pommedeterresautee pommedeterresautee self-assigned this Dec 21, 2022
@pommedeterresautee pommedeterresautee added performance make things faster, always and removed feature labels Dec 21, 2022
@pommedeterresautee
Copy link
Member Author

test pass

================================================================================================== 2849 passed in 8941.45s (2:29:01) ===================================================================================================

transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True, normalize=True)[0]
assert (
transcription == "mister quilter is the apostle of the middle classes and we are glad to welcome his gospel"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

about this, is it a good idea to use "in" ? it happened that the beggining was correct but not the end

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it doesn't fail with beam > 1 or larger model, but in test we need to make it fast, so I guess this is the only way to make it ok. Do you see another way which doesn't require more computation?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean using "in" instead of "=="

@@ -31,8 +42,11 @@ def attention_wrapper(q, k, v, output, sm_scale, is_causal, attention_mask):

# When there is a large difference between those dimensions, our kernel become inefficient
# (almost no parallelization), so we use pytorch instead
if q.size(-2) == 1 and k.size(-2) > 50:
attention_reference(q, k, v, output, sm_scale, is_causal=is_causal, attention_mask=attention_mask)
if q.size(-2) == 1 and k.size(-2) > 50 and (attention_mask is None) and not is_causal:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't understand "(attention_mask is None) and not is_causal" the condition is only for attention_vec_mat_forward, for attention_reference it's ok

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right, I ll modify it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@pommedeterresautee
Copy link
Member Author

test pass after all the commits above


================================================================================================== 2855 passed in 9623.61s (2:40:23) ===================================================================================================

Copy link
Contributor

@gaetansnl gaetansnl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor changes only

k = k.view(k.size(0), 1, k.size(-2), k.size(-1))
v = v.view(v.size(0), 1, v.size(-2), v.size(-1))
output = output.view(output.size(0), 1, output.size(-2), output.size(-1))
q.unsqueeze_(dim=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are mutating the input ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes on PyTorch the underscore at the end of a method always mean the op is done in place, aka the original object is mutated

else:
attention_forward(q, k, v, output, sm_scale, is_causal=is_causal, attention_mask=attention_mask)

if extend_head:
return output.view(q.size(0), q.size(-2), q.size(-1))
output.squeeze_(dim=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

idem, you are mutating the input ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

@@ -1,5 +1,5 @@
triton==2.0.0.dev20221202
torch== 2.0.0.dev20221214+cu117
torch==2.0.0.dev20230104+cu117
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dockerfile update missing

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

v_col_major = v.permute(0, 1, 3, 2).contiguous().permute(0, 1, 3, 2)
# mutate v, so its storage is col major
v.set_(source=v_col_major)
# print("q", q.size(), q.stride(), len(q.untyped_storage()), q.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

print

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

@@ -29,7 +29,7 @@ def _compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
return cuda_graphs_wrapper(gm, example_inputs)


def optimize_model(original_model: PreTrainedModel) -> None:
def optimize_model(model: PreTrainedModel) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a breaking change, but I think we will do breaking changes anyway, could be nice to document it

Copy link
Member Author

@pommedeterresautee pommedeterresautee Jan 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where?

transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True, normalize=True)[0]
assert (
transcription == "mister quilter is the apostle of the middle classes and we are glad to welcome his gospel"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean using "in" instead of "=="

@@ -12,31 +12,77 @@
# See the License for the specific language governing permissions and
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know why this file is in this folder, it should be in optimizer, maybe mistake in previous commits

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@pommedeterresautee pommedeterresautee merged commit 8b0ec72 into main Jan 6, 2023
@pommedeterresautee pommedeterresautee deleted the feat/recycle_tensor_cg branch January 6, 2023 10:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance make things faster, always
Development

Successfully merging this pull request may close these issues.

no copy of recycled K/V cache in CUDA graphs Reduce memory overhead of CUDA graphs
2 participants