-
Notifications
You must be signed in to change notification settings - Fork 24
Copy-in embeddings in reduced precision and handle precision conversion during inference #73
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -2,7 +2,9 @@ | |||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| import uk.ac.manchester.tornado.api.KernelContext; | ||||||||||||||||||||||||||||||||||||||||||||||
| import uk.ac.manchester.tornado.api.math.TornadoMath; | ||||||||||||||||||||||||||||||||||||||||||||||
| import uk.ac.manchester.tornado.api.types.HalfFloat; | ||||||||||||||||||||||||||||||||||||||||||||||
| import uk.ac.manchester.tornado.api.types.arrays.FloatArray; | ||||||||||||||||||||||||||||||||||||||||||||||
| import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| public class TransformerComputeKernels { | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -19,6 +21,18 @@ public static void emptyTaskToForceCopyIn(FloatArray buffer) { | |||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| public static void convertFP16toFP32(KernelContext context, HalfFloatArray x, FloatArray wrapX) { | ||||||||||||||||||||||||||||||||||||||||||||||
| int i = context.globalIdx; | ||||||||||||||||||||||||||||||||||||||||||||||
| wrapX.set(i, x.get(i).getFloat32()); | ||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||
| wrapX.set(i, x.get(i).getFloat32()); | |
| if (i < x.getSize() && i < wrapX.getSize()) { | |
| wrapX.set(i, x.get(i).getFloat32()); | |
| } |
Copilot
AI
Nov 26, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Extra whitespace after parameter. There are two spaces between the comma and FloatArray - should be one space.
| public static void convertFP32toFP16(KernelContext context, FloatArray wrapX, HalfFloatArray x) { | |
| public static void convertFP32toFP16(KernelContext context, FloatArray wrapX, HalfFloatArray x) { |
Copilot
AI
Nov 26, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing bounds check in kernel. The kernel should validate that context.globalIdx is within the valid range of both arrays before accessing them to prevent out-of-bounds access. Add a check like if (i < wrapX.getSize() && i < x.getSize()).
| wrapX.set(i, x.get(i).getFloat32()); | |
| } | |
| public static void convertFP32toFP16(KernelContext context, FloatArray wrapX, HalfFloatArray x) { | |
| int i = context.globalIdx; | |
| float valInput = wrapX.get(i); | |
| HalfFloat val = new HalfFloat(valInput); | |
| x.set(i,val); | |
| if (i < wrapX.getSize() && i < x.getSize()) { | |
| wrapX.set(i, x.get(i).getFloat32()); | |
| } | |
| } | |
| public static void convertFP32toFP16(KernelContext context, FloatArray wrapX, HalfFloatArray x) { | |
| int i = context.globalIdx; | |
| if (i < wrapX.getSize() && i < x.getSize()) { | |
| float valInput = wrapX.get(i); | |
| HalfFloat val = new HalfFloat(valInput); | |
| x.set(i, val); | |
| } |
Copilot
AI
Nov 26, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing space after comma. Add a space after the comma for consistency with code style.
| x.set(i,val); | |
| x.set(i, val); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Duplicate field initialization. The field
fields.embeddingXis initialized twice - once at line 70 and again at line 80 with the same value. The second initialization should be removed.