Skip to content

Commit 01310a9

Browse files
MoelfN5N3CameronBieganek
authored
map of unequal length bitarray (#47013)
* fix `bit_map!` with unequal length. * relax three-arg `bit_map!`. * add bit array test for unequal length. Co-authored-by: N5N3 <2642243996@qq.com> Co-authored-by: Cameron Bieganek <8310743+CameronBieganek@users.noreply.github.com>
1 parent 186c0be commit 01310a9

File tree

2 files changed

+67
-6
lines changed

2 files changed

+67
-6
lines changed

base/bitarray.jl

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1779,26 +1779,42 @@ end
17791779
# map across the chunks. Otherwise, fall-back to the AbstractArray method that
17801780
# iterates bit-by-bit.
17811781
function bit_map!(f::F, dest::BitArray, A::BitArray) where F
1782-
size(A) == size(dest) || throw(DimensionMismatch("sizes of dest and A must match"))
1782+
length(A) <= length(dest) || throw(DimensionMismatch("length of destination must be >= length of collection"))
17831783
isempty(A) && return dest
17841784
destc = dest.chunks
17851785
Ac = A.chunks
1786-
for i = 1:(length(Ac)-1)
1786+
len_Ac = length(Ac)
1787+
for i = 1:(len_Ac-1)
17871788
destc[i] = f(Ac[i])
17881789
end
1789-
destc[end] = f(Ac[end]) & _msk_end(A)
1790+
# the last effected UInt64's original content
1791+
dest_last = destc[len_Ac]
1792+
_msk = _msk_end(A)
1793+
# first zero out the bits mask is going to change
1794+
destc[len_Ac] = (dest_last & (~_msk))
1795+
# then update bits by `or`ing with a masked RHS
1796+
destc[len_Ac] |= f(Ac[len_Ac]) & _msk
17901797
dest
17911798
end
17921799
function bit_map!(f::F, dest::BitArray, A::BitArray, B::BitArray) where F
1793-
size(A) == size(B) == size(dest) || throw(DimensionMismatch("sizes of dest, A, and B must all match"))
1800+
min_bitlen = min(length(A), length(B))
1801+
min_bitlen <= length(dest) || throw(DimensionMismatch("length of destination must be >= length of smallest input collection"))
17941802
isempty(A) && return dest
1803+
isempty(B) && return dest
17951804
destc = dest.chunks
17961805
Ac = A.chunks
17971806
Bc = B.chunks
1798-
for i = 1:(length(Ac)-1)
1807+
len_Ac = min(length(Ac), length(Bc))
1808+
for i = 1:len_Ac-1
17991809
destc[i] = f(Ac[i], Bc[i])
18001810
end
1801-
destc[end] = f(Ac[end], Bc[end]) & _msk_end(A)
1811+
# the last effected UInt64's original content
1812+
dest_last = destc[len_Ac]
1813+
_msk = _msk_end(min_bitlen)
1814+
# first zero out the bits mask is going to change
1815+
destc[len_Ac] = (dest_last & ~(_msk))
1816+
# then update bits by `or`ing with a masked RHS
1817+
destc[len_Ac] |= f(Ac[end], Bc[end]) & _msk
18021818
dest
18031819
end
18041820

test/bitarray.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1494,6 +1494,51 @@ timesofar("reductions")
14941494
C17970 = map(x -> x ? false : true, A17970)
14951495
@test C17970::BitArray{1} == map(~, A17970)
14961496
end
1497+
1498+
#=
1499+
|<----------------dest----------(original_tail)->|
1500+
|<------------------b2(l)------>| extra_l |
1501+
|<------------------b3(l)------>|
1502+
|<------------------b4(l+extra_l)--------------->|
1503+
|<--------------desk_inbetween-------->| extra÷2 |
1504+
=#
1505+
@testset "Issue #47011, map! over unequal length bitarray" begin
1506+
for l = [0, 1, 63, 64, 65, 127, 128, 129, 255, 256, 257, 6399, 6400, 6401]
1507+
for extra_l = [10, 63, 64, 65, 127, 128, 129, 255, 256, 257, 6399, 6400, 6401]
1508+
1509+
dest = bitrand(l+extra_l)
1510+
b2 = bitrand(l)
1511+
original_tail = last(dest, extra_l)
1512+
for op in (!, ~)
1513+
map!(op, dest, b2)
1514+
@test first(dest, l) == map(op, b2)
1515+
# check we didn't change bits we're not suppose to
1516+
@test last(dest, extra_l) == original_tail
1517+
end
1518+
1519+
b3 = bitrand(l)
1520+
b4 = bitrand(l+extra_l)
1521+
# when dest is longer than one source but shorter than the other
1522+
dest_inbetween = bitrand(l + extra_l÷2)
1523+
original_tail_inbetween = last(dest_inbetween, extra_l÷2)
1524+
for op in (|, )
1525+
map!(op, dest, b2, b3)
1526+
@test first(dest, l) == map(op, b2, b3)
1527+
# check we didn't change bits we're not suppose to
1528+
@test last(dest, extra_l) == original_tail
1529+
1530+
map!(op, dest, b2, b4)
1531+
@test first(dest, l) == map(op, b2, b4)
1532+
# check we didn't change bits we're not suppose to
1533+
@test last(dest, extra_l) == original_tail
1534+
1535+
map!(op, dest_inbetween, b2, b4)
1536+
@test first(dest_inbetween, l) == map(op, b2, b4)
1537+
@test last(dest_inbetween, extra_l÷2) == original_tail_inbetween
1538+
end
1539+
end
1540+
end
1541+
end
14971542
end
14981543

14991544
## Filter ##

0 commit comments

Comments
 (0)