-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RFC: unsafe_bitcast #43065
RFC: unsafe_bitcast #43065
Conversation
This is required for using these macros before `esc` is defined. An alternative fix may be to define `esc` earlier. However, avoid using `esc` here seems to be more compatible with other macros defined at this stage.
datatype_pointerfree(T) || | ||
throw(ArgumentError("output type $T may contain a boxed object")) | ||
datatype_pointerfree(S) || | ||
throw(ArgumentError("input type $S may contain a boxed object")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
datatype_pointerfree(T) || | |
throw(ArgumentError("output type $T may contain a boxed object")) | |
datatype_pointerfree(S) || | |
throw(ArgumentError("input type $S may contain a boxed object")) | |
isbitstype(T) || throw(ArgumentError("output type $T has undefined layout")) | |
isbitstype(S) || throw(ArgumentError("input type $S has undefined layout")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm... Can we support Union fields? That's one of the main motivations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIRC, you are prohibited from accessing the union-bits fields in most cases (TBAA violation)
How do you operate on a a tuple of bytes with an unknown format? I don't really understand how this would be used. |
The point is to use opaque bytes only for "semantic-agnostic" low-level APIs like transport and storage. I think a good example is function CUDA.shfl_recurse(op, x)
xbytes = Base.unsafe_bitcast(NTuple{sizeof(x),UInt8}, x)
ybytes = map(op, xbytes)
y = Base.unsafe_bitcast(typeof(x), ybytes)
return y
end Here, we serialize the input From We can then use, e.g., val = something(shfl_down_sync(mask, Some{T}(val), delta)) to move The use case for concurrent data structures would be similar; i.e., pointer-free values are internally serialized into and deserialized from opaque bytes which are stored into (say) |
What does |
Sorry, I should've chosen more concrete example. function shfl_down_sync(mask, x, delta)
@assert sizeof(x) == 16 # the following code is the specialization for 16-byte typeof(x)
xb1, xb2, xb3, xb4 = Base.unsafe_bitcast(NTuple{4,Int32}, x)
yb1 = shfl_down_sync(mask, xb1, delta) # calls llvm.nvvm.shfl.sync.down.i32 intrinsic
yb2 = shfl_down_sync(mask, xb2, delta)
yb3 = shfl_down_sync(mask, xb3, delta)
yb4 = shfl_down_sync(mask, xb4, delta)
ybytes = (yb1, yb2, yb3, yb4)
y = Base.unsafe_bitcast(typeof(x), ybytes)
return y
end (I changed |
This comment was marked as off-topic.
This comment was marked as off-topic.
This comment was marked as off-topic.
This comment was marked as off-topic.
This comment was marked as off-topic.
This comment was marked as off-topic.
This comment was marked as off-topic.
This comment was marked as off-topic.
This comment was marked as off-topic.
This comment was marked as off-topic.
This comment was marked as off-topic.
This comment was marked as off-topic.
I hide the above comments as off-topic. This PR is all about casting immutable pointerfree values. I think it would be more fruitful to discuss arrays and buffers separately elsewhere. |
Some comments by @vtjnash (as I understood them, it had been a while since I looked at this PR):
|
Just cross-referencing this here: #41071 |
Another note to myself from the chat with Jameson and Tim: All we can do with Footnotes
|
|
It may still be valuable to have this as an unsafe operation, because function vecdot_q4_kernel_MVP(scales)
kmask1, kmask2, kmask3 = 0x3f3f3f3f, 0x0f0f0f0f, 0x03030303
scales_uint32 = reinterpret(NTuple{3, UInt32}, scales)
utmp0, utmp1, utmp2 = scales[1], scales[2], scales[3]
return
end
function main()
scales = UInt8.((1,2,3,4, 2,3,4,5, 3,4,5,6)) # NTuple{12, UInt8}
@cuda threads=1 blocks=1 vecdot_q4_kernel_MVP(scales)
end
Worth reopening over? Or can we could improve |
That seems more like a CUDA.jl issue (refusing to constant fold packedsize and padding)? |
Ah, if that's expected I will look into that. Thanks. |
This PR tries to add (hopefully) well-defined low-level type punning facility usable for arbitrary pointer-free immutable objects.
The main use case is to support converting an arbitrary pointer-free immutable object to an opaque chunk of bytes (e.g.,
NTuple{N,UInt8}
) and back. For example, this will be useful for using a rich set of immutable objects on GPU API likeCUDA.shfl_sync
etc. which are currently defined byreinterpret
ing some small subset of types to integers. If we can cast nested structs withUnion
fields, we can correctly execute various transducers with complex state transitions on GPU. Another important use case is an emulation of tearable atomics which is useful for efficient concurrent algorithms such as work-stealing dequeue and seqlock.However, casting one type to another (aka type punning) is known to be hard when the compiler wants to infer something from the type system. If you can find various one-hour technical talks (e.g., CppCon 2019: Timur Doumler “Type punning in modern C++” - YouTube) on how to do it correctly, it's a good indication that we need to have an API in
Base
with a clear definition of when and how it can be used. For example, C++20 now hasstd::bit_cast
as a similar API.I haven't had time to dig deep into this to convince myself that the API I came up with was OK. But given the recent discussion (#32660, #42968, #43035) in expanding what
reinterpret
does, I think it's worth opening it as an alternative take at it; i.e., unsafe (narrow contract) API with wider use cases but weaker guarantee (no cross-process roundtrip). So, I'd appreciate it if people who know Julia and LLVM compiler can look at it.One aspect of the API that I'm still worried about is what we can say about the returned object when the input type contains some padding. I wonder if we should rather create an "asymmetric" API
unsafe_bitembed(T, x::S) -> y::T
andunsafe_bitextract(S, y::T) -> x::S
whereS
can contain padding butT
must not. We can then clearly document thatT
is a chunk of opaque bytes that can only be usable in a meaningful way afterunsafe_bitextract
. I don't know if it helps the compiler, though.ping @vtjnash @JeffBezanson @Keno @vchuravy @maleadt