Skip to content

Commit

Permalink
Allow rotation matrix to be batched in PointCloud, TriMesh
Browse files Browse the repository at this point in the history
Fixes #32
  • Loading branch information
nirmal-suthar committed Sep 12, 2020
1 parent fed0da4 commit 7b07e68
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 10 deletions.
16 changes: 13 additions & 3 deletions src/transforms/mesh_func.jl
Original file line number Diff line number Diff line change
Expand Up @@ -199,11 +199,13 @@ scale(m::TriMesh, factor) = scale(m, Float32.(factor))

"""
rotate!(m::TriMesh, rotmat::AbstractArray{<:Number,2})
rotate!(m::TriMesh, rotmat::AbstractArray{<:Number,3})
Rotate the TriMesh `m` by rotation matrix `rotmat`
and overwrite `m` with rotated TriMesh.
Rotation matrix `rotmat` should be of size `(3,3)`
Rotation matrix `rotmat` should be of size `(3,3)` or `(3,3,B)`
where B is the batch size of TriMesh.
See also: [`rotate`](@ref)
Expand All @@ -222,7 +224,15 @@ function rotate!(m::TriMesh, rotmat::AbstractArray{Float32,2})
return m
end

rotate!(m::TriMesh, rotmat::AbstractArray{<:Number,2}) = rotate!(m, Float32.(rotmat))
function rotate!(m::TriMesh, rotmat::AbstractArray{Float32,3})
size(rotmat) == (3, 3, m.N) ||
error("rotmat must be (3, 3, $(m.N)) array, but instead got $(size(rotmat)) array")
verts_padded = Flux.batched_mul(Flux.batched_transpose(rotmat),get_verts_padded(m))
m._verts_padded = verts_padded
return m
end

rotate!(m::TriMesh, rotmat::AbstractArray{<:Number}) = rotate!(m, Float32.(rotmat))

"""
rotate(m::TriMesh, rotmat::AbstractArray{<:Number,2})
Expand All @@ -240,7 +250,7 @@ julia> rotmat = rand(3,3)
julia> m = rotate(m, rotmat)
```
"""
function rotate(m::TriMesh, rotmat::AbstractArray{<:Number,2})
function rotate(m::TriMesh, rotmat::AbstractArray{<:Number})
m = deepcopy(m)
rotate!(m, rotmat)
return m
Expand Down
21 changes: 14 additions & 7 deletions src/transforms/pcloud_func.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ scale(pcloud::PointCloud, factor::Number) = scale(pcloud, Float32(factor))
Rotate the PointCloud `pcloud` by rotation matrix `rotmat`
and overwrite `pcloud` with rotated PointCloud.
Rotation matrix `rotmat` should be of size `(3,3)`
Rotation matrix `rotmat` should be of size `(3,3)` or `(3,3,B)`
where B is the batch size of PointCloud.
See also: [`rotate`](@ref)
Expand All @@ -114,11 +115,20 @@ function rotate!(pcloud::PointCloud, rotmat::AbstractArray{Float32,2})
return pcloud
end

rotate!(pcloud::PointCloud, rotmat::AbstractArray{<:Number,2}) =
function rotate!(pcloud::PointCloud, rotmat::AbstractArray{Float32,3})
_B = size(pcloud.points, 3)
size(rotmat) == (3, 3, _B) ||
error("rotmat must be (3, 3, $(_B)) array, but instead got $(size(rotmat)) array")
size(pcloud.points, 1) == 3 || error("dimension of points in PointCloud must be 3")
pcloud.points = Flux.batched_mul(Flux.batched_transpose(rotmat),pcloud.points)
return pcloud
end

rotate!(pcloud::PointCloud, rotmat::AbstractArray{<:Number}) =
rotate!(pcloud, Float32.(rotmat))

"""
rotate(pcloud::PointCloud, rotmat::Array{Number,2})
rotate(pcloud::PointCloud, rotmat::Array{<:Number})
Rotate the PointCloud `pcloud` by rotation matrix `rotmat`.
Expand All @@ -133,15 +143,12 @@ julia> rotmat = rand(3,3)
julia> p = rotate(p, rotmat)
```
"""
function rotate(pcloud::PointCloud, rotmat::AbstractArray{Float32,2})
function rotate(pcloud::PointCloud, rotmat::AbstractArray{<:Number})
p = deepcopy(pcloud)
rotate!(p, rotmat)
return p
end

rotate(pcloud::PointCloud, rotmat::AbstractArray{Number,2}) =
rotate(pcloud, Float32.(rotmat))

"""
realign!(src::PointCloud, tgt::PointCloud)
realign!(src::PointCloud, tgt_min::AbstractArray{<:Number,2}, tgt_max::AbstractArray{<:Number,2})
Expand Down
11 changes: 11 additions & 0 deletions test/transforms/mesh_func.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,18 +77,29 @@
m = deepcopy(_mesh)
rotmat = 2 .* one(rand(Float32, 3, 3))
rotmat_inv = inv(rotmat)
rotmat_b = cat(rotmat, rotmat, dims=3)
rotmat_inv_b = cat(rotmat_inv, rotmat_inv, dims=3)
m2 = FUNC(FUNC(m, rotmat), rotmat_inv)
m3 = FUNC(FUNC(m, rotmat_b), rotmat_inv_b)
if inplace
@test m2 === m
@test m3 === m
else
@test m2 !== m
@test m3 !== m
end
@test all(isapprox.(
get_verts_packed(_mesh),
get_verts_packed(m2),
rtol = 1e-5,
atol = 1e-5,
))
@test all(isapprox.(
get_verts_packed(_mesh),
get_verts_packed(m3),
rtol = 1e-5,
atol = 1e-5,
))
end
end

Expand Down
6 changes: 6 additions & 0 deletions test/transforms/pcloud_func.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,20 @@
p = rand(Float32, 3, 8, 2)
rotmat = 2 .* one(rand(Float32, 3, 3))
rotmat_inv = inv(rotmat)
rotmat_b = cat(rotmat, rotmat, dims=3)
rotmat_inv_b = cat(rotmat_inv, rotmat_inv, dims=3)
pc1 = PointCloud(p)
pc2 = FUNC(FUNC(pc1, rotmat), rotmat_inv)
pc3 = FUNC(FUNC(pc1, rotmat_b), rotmat_inv_b)
if inplace
@test pc1 == pc2
@test pc1 == pc3
else
@test pc1 != pc2
@test pc1 != pc3
end
@test all(isapprox.(p, pc2.points, rtol = 1e-5, atol = 1e-5))
@test all(isapprox.(p, pc3.points, rtol = 1e-5, atol = 1e-5))
end
end

Expand Down

0 comments on commit 7b07e68

Please sign in to comment.