Skip to content

Commit

Permalink
Merge pull request #52 from FluxML/gpumesh
Browse files Browse the repository at this point in the history
Adapt to  CUDA.jl v3
  • Loading branch information
nirmal-suthar authored Sep 2, 2021
2 parents 9e4d734 + 618677f commit dfe43bb
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 13 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
<a href="https://github.com/FluxML/Flux3D.jl/actions" alt="Build Status">
<img src="https://github.com/FluxML/Flux3D.jl/workflows/CI/badge.svg"/>
</a>
<a href="https://buildkite.com/julialang/flux3d-dot-jl" alt="BuildKite Build Status">
<img src="https://badge.buildkite.com/40bca770b8b1183fa75cb172d706bc71d5cb5ed960cdcb6d2a.svg"/>
</a>
<a href="https://codecov.io/gh/FluxML/Flux3D.jl" alt="Codecov">
<img src="https://codecov.io/gh/FluxML/Flux3D.jl/branch/master/graph/badge.svg?token=8kpPqDfChf"/>
</a>
Expand Down
14 changes: 7 additions & 7 deletions src/rep/mesh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ export TriMesh,

import GeometryBasics, Printf, MeshIO
import GeometryBasics:
Point3f0, GLTriangleFace, NgonFace, convert_simplex, meta, triangle_mesh, value
Point3f, GLTriangleFace, NgonFace, convert_simplex, meta, triangle_mesh, value

import Zygote: @ignore

Expand Down Expand Up @@ -74,8 +74,8 @@ mutable struct TriMesh{T<:AbstractFloat,R<:Integer,S} <: AbstractObject
equalised::Bool
valid::BitArray{1}
offset::Int8
_verts_len::S
_faces_len::S
_verts_len::Vector{Int64}
_faces_len::Vector{Int64}

_verts_packed::S
_verts_padded::S
Expand Down Expand Up @@ -131,8 +131,8 @@ function TriMesh(
verts = [T.(v) for v in verts]
faces = [R.(f) for f in faces]

_verts_len = S(size.(verts, 2))
_faces_len = S(size.(faces, 2))
_verts_len = size.(verts, 2)
_faces_len = size.(faces, 2)

N = length(verts)
V = maximum(_verts_len)
Expand All @@ -143,7 +143,7 @@ function TriMesh(

_verts_list = verts::Vector{<:S{T,2}}
_faces_list = faces::Vector{Array{R,2}}

return TriMesh{T,R,S}(
N,
V,
Expand Down Expand Up @@ -242,7 +242,7 @@ Initialize GeometryBasics.Mesh from triangle mesh in TriMesh `m` at `index`.
See also: [`gbmeshes`](@ref)
"""
function GBMesh(verts::AbstractArray{T,2}, faces::AbstractArray{R,2}) where {T,R}
points = Point3f0[GeometryBasics.Point{3,Float32}(verts[:, i]) for i = 1:size(verts, 2)]
points = Point3f[GeometryBasics.Point{3,Float32}(verts[:, i]) for i = 1:size(verts, 2)]
verts_dim = size(faces, 1)
poly_face = NgonFace{verts_dim,UInt32}[
NgonFace{verts_dim,UInt32}(faces[:, i]) for i = 1:size(faces, 2)
Expand Down
7 changes: 4 additions & 3 deletions src/transforms/mesh_func.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ function sample_points(
samples = Zygote.bufferfrom(samples)

for (i, _len) in enumerate(m._faces_len)
probvec = faces_areas_prob[1, 1:_len, i]
# Distributions package expects Array!
probvec = cpu(faces_areas_prob[1, 1:_len, i])
dist = Distributions.Categorical(probvec)
sample_faces_idx = @ignore rand(dist, num_samples)
sample_faces = faces_padded[:, sample_faces_idx, i]
Expand Down Expand Up @@ -95,9 +96,9 @@ julia> m = load_trimesh("teapot.obj")
julia> normalize!(m)
```
"""
function normalize!(m::TriMesh)
function normalize!(m::TriMesh{T,R,S}) where {T,R,S}
verts_padded = get_verts_padded(m)
_len = reshape(m._verts_len, 1, 1, :)
_len = S(reshape(m._verts_len, 1, 1, :)) # move to `S` storage type
_centroid = sum(verts_padded; dims = 2) ./ _len
_correction = ((_centroid .^ 2) .* (m.V .- _len))
_std =
Expand Down
2 changes: 1 addition & 1 deletion test/cuda/metrics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
end
end

verts = get_verts_packed(m)
verts = get_verts_packed(m) |> cpu
L = L * transpose(verts)
L = Flux3D._norm(L; dims = 2)
@test isapprox(mean(L), laplacian_loss(m))
Expand Down
4 changes: 2 additions & 2 deletions test/cuda/rep.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ end
@info "Testing VoxelGrid..."
@testset "VoxelGrid" begin
res = 4
voxels = rand(Float32, res, res, res, 2)
voxels = rand(Float32, res, res, res, 2) |> gpu
v = VoxelGrid(voxels) |> gpu
@test Flux3D._assert_voxel(v)
@test v isa VoxelGrid{Float32}
Expand All @@ -35,7 +35,7 @@ end
@test v.voxels isa CuArray{Float32,4}
@test size(v.voxels) == (res, res, res, 2)

voxels = rand(Float32, res, res, res)
voxels = rand(Float32, res, res, res) |> gpu
v2 = VoxelGrid(voxels) |> gpu
@test v2 isa VoxelGrid{Float32}
@test v2.voxels isa CuArray{Float32,4}
Expand Down

0 comments on commit dfe43bb

Please sign in to comment.