Skip to content

Commit

Permalink
kdlite for mpas
Browse files Browse the repository at this point in the history
  • Loading branch information
guo.2154 committed Jan 28, 2024
1 parent 74f8d56 commit 6650eeb
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 14 deletions.
74 changes: 68 additions & 6 deletions include/ftk/basic/kd_lite.hh
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace ftk {
// implementation of https://www.nvidia.com/content/gtc-2010/pdfs/2140_gtc2010.pdf
template <int nd, typename I=int, typename F=double>
__host__
void kd_build_recursive(
void kdlite_build_recursive(
const I n,
const I current,
const F *X, // coordinates
Expand Down Expand Up @@ -47,15 +47,15 @@ void kd_build_recursive(
// fprintf(stderr, "current=%d, offset=%d, length=%d, lbm=%d, median=%d\n", current, offset, length, lbm, heap[current]);

if (lbm - 1 >= 1)
kd_build_recursive<nd, I, F>(n, current*2+1, X, level+1, offset, lbm-1, heap, ids); // left
kdlite_build_recursive<nd, I, F>(n, current*2+1, X, level+1, offset, lbm-1, heap, ids); // left
if (length - lbm >= 1)
kd_build_recursive<nd, I, F>(n, current*2+2, X, level+1, offset+lbm, length-lbm, heap, ids); // right
kdlite_build_recursive<nd, I, F>(n, current*2+2, X, level+1, offset+lbm, length-lbm, heap, ids); // right
}
}

template <int nd, typename I=int, typename F=double>
__host__
void kd_build(
void kdlite_build(
const I n, // number of points
const F *X, // coordinates
I *heap) // out: pre-allocated heap
Expand All @@ -65,15 +65,77 @@ void kd_build(
for (int i = 0; i < n; i ++)
ids[i] = i;

kd_build_recursive<nd, I, F>(n, 0, X, 0, 0, n, heap, ids.data());
kdlite_build_recursive<nd, I, F>(n, 0, X, 0, 0, n, heap, ids.data());

// for (int i = 0; i < n; i ++)
// fprintf(stderr, "i=%d, heap=%d\n", i, heap[i]);
}

template <int nd, typename I=int, typename F=double>
__device__ __host__
I kd_nearest(I n, const F *X, const I *heap, const F *x)
I kdlite_nearest(I n, const F *X, const I *heap, const F *x)
{
static size_t max_stack_size = 32; // TODO

I S[max_stack_size];
I top = 0;

S[top++] = 0; // push root // S[top].depth = 0; // root // depth = log2(i+1);

I best = -1; // no best yet
F best_d2 = 1e32; // no best distance yet

while (top != 0) { // stack is not empty
const I i = S[--top]; // pop stack

const I xid = heap[i];
const I depth = std::log2(i+1);
const I axis = depth % nd;
I next, other;

if (x[axis] < X[nd*xid+axis]) {
next = i * 2 + 1; // left child
other = i * 2 + 2; // right child
} else {
next = i * 2 + 2; // right child
other = i * 2 + 1; // left child
}

const F d2 = vector_dist_2norm2<F>(nd, x, X + nd*xid); // distance to the current node
if (d2 < best_d2) {
best = xid;
best_d2 = d2;

// fprintf(stderr, "current_best=%d, d2=%f, X=%f, %f, %f\n",
// best, best_d2,
// X[nd*xid], X[nd*xid+1], X[nd*xid+2]);
}

// const F dp = x[axis] - X[nd*xid+axis]; // distance to the median
// const F dp2 = dp * dp;

if (next < n) { // the next node exists
assert(top < max_stack_size);
S[top++] = next; // push stack
}

if (other < n) {
const F dp = x[axis] - X[nd*xid+axis];
const F dp2 = dp * dp;

if (dp2 <= best_d2) {
assert(top < max_stack_size);
S[top++] = other; // push stack
}
}
}

return best;
}

template <int nd, typename I=int, typename F=double>
__device__ __host__
I kdlite_nearest_bfs(I n, const F *X, const I *heap, const F *x)
{
static size_t max_queue_size = 32768; // TODO
typedef struct {
Expand Down
24 changes: 16 additions & 8 deletions include/ftk/mesh/mpas_mesh.hh
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ public:

public:
std::shared_ptr<kd_t<F, 3>> kd_cells, kd_vertices;
std::vector<int> kdl_cells;
std::vector<int> kdlite_heap;

ndarray<I> cellsOnVertex, cellsOnEdge, cellsOnCell,
edgesOnCell,
Expand Down Expand Up @@ -126,12 +126,12 @@ public:
template <typename I, typename F>
void mpas_mesh<I, F>::initialize()
{
kd_cells.reset(new kd_t<F, 3>);
kd_cells->set_inputs(this->xyzCells);
kd_cells->build();
// kd_cells.reset(new kd_t<F, 3>);
// kd_cells->set_inputs(this->xyzCells);
// kd_cells->build();

kdl_cells.resize( n_cells()*2, -1 );
kd_build<3>((int)n_cells(), xyzCells.data(), kdl_cells.data());
kdlite_heap.resize( n_cells()*3, -1 );
kdlite_build<3>((int)n_cells(), xyzCells.data(), kdlite_heap.data());
}

template <typename I, typename F>
Expand Down Expand Up @@ -738,14 +738,22 @@ bool mpas_mesh<I, F>::point_in_cell_i(const I cell_i, const F x[]) const
template <typename I, typename F>
I mpas_mesh<I, F>::locate_cell_i(const F x[]) const
{
I cell_i = kd_cells->find_nearest(x);
const I cell_i = kdlite_nearest<3, I, F>(
(int)n_cells(),
xyzCells.data(),
kdlite_heap.data(),
x);
#if 0
I cell_i = kd_cells->find_nearest(x);
I cell_i1 = kd_nearest<3, I, F>((int)n_cells(),
xyzCells.data(),
kdl_cells.data(),
x);

fprintf(stderr, "cell_i=%d, %d\n", cell_i, cell_i1);
F d1 = vector_dist_2norm2<3>(x, xyzCells.data() + cell_i*3);
F d2 = vector_dist_2norm2<3>(x, xyzCells.data() + cell_i1*3);

fprintf(stderr, "cell_i=%d, %d, dist=%f, %f\n", cell_i, cell_i1, d1, d2);
assert(cell_i == cell_i1);
#endif

Expand Down

0 comments on commit 6650eeb

Please sign in to comment.