Skip to content

Commit

Permalink
much faster with numba
Browse files Browse the repository at this point in the history
  • Loading branch information
carderne committed Feb 16, 2024
1 parent b0c1d5c commit c0578a7
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 66 deletions.
9 changes: 2 additions & 7 deletions example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,7 @@
"guess_out = folder_out / 'guess.tif'\n",
"guess_skeletonized_out = folder_out / 'guess_skel.tif'\n",
"guess_nulled = folder_out / 'guess_nulled.tif'\n",
"guess_vec_out = folder_out / 'guess.gpkg'\n",
"animate_out = folder_out / 'animated'"
"guess_vec_out = folder_out / 'guess.gpkg'"
]
},
{
Expand Down Expand Up @@ -220,11 +219,7 @@
"metadata": {},
"outputs": [],
"source": [
"dist = gf.optimise(targets, costs, start,\n",
" jupyter=True,\n",
" animate=True,\n",
" affine=affine,\n",
" animate_path=animate_out)\n",
"dist = gf.optimise(targets, costs, start)\n",
"save_raster(dist_out, dist, affine)\n",
"plt.imshow(dist)"
]
Expand Down
97 changes: 38 additions & 59 deletions gridfinder/gridfinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,24 @@
Implements Dijkstra's algorithm on a cost-array to create an MST.
"""

import os
import pickle
import sys
from heapq import heapify, heappop, heappush
from math import sqrt
from pathlib import Path
from typing import Optional, Tuple, Union
from typing import List, Optional, Tuple, Union

import numba as nb
import numpy as np
import rasterio
from affine import Affine
from IPython.display import Markdown, display

from gridfinder.util import save_raster

sys.setrecursionlimit(100000)


def get_targets_costs(
targets_in: Union[str, Path], costs_in: Union[str, Path],
targets_in: Union[str, Path],
costs_in: Union[str, Path],
) -> Tuple[np.ndarray, np.ndarray, Tuple[int, int], Affine]:
"""Load the targets and costs arrays from the given file paths.
Expand Down Expand Up @@ -49,7 +47,7 @@ def get_targets_costs(
start = tuple(target_list[0].tolist())

targets = targets.astype(np.int8)
costs = costs.astype(np.float16)
costs = costs.astype(np.float32)

return targets, costs, start, affine

Expand Down Expand Up @@ -78,15 +76,16 @@ def estimate_mem_use(targets: np.ndarray, costs: np.ndarray) -> float:
return est_mem / 1e9


@nb.njit
def optimise(
targets: np.ndarray,
costs: np.ndarray,
start: Tuple[int, int],
silent: bool = False,
jupyter: bool = False,
animate: bool = False,
affine: Optional[Affine] = None,
animate_path: str = "",
silent: bool = False,
) -> np.ndarray:
"""Run the Dijkstra algorithm for the supplied arrays.
Expand All @@ -95,7 +94,7 @@ def optimise(
targets : 2D array of targets.
costs : 2D array of costs.
start : tuple with row, col of starting point.
jupyter : Whether the code is being run from a Jupyter Notebook.
silent : whether to print progress
Returns
-------
Expand All @@ -104,48 +103,28 @@ def optimise(
on-grid point. Values of 0 imply that cell is part of an MV grid line.
"""

if jupyter or animate or affine or animate_path:
print(
"Warning: the following parameters are ignored: jupyter, animate, affine, animate_path" # NoQA
)

max_i = costs.shape[0]
max_j = costs.shape[1]
shape = costs.shape
max_i = shape[0]
max_j = shape[1]

visited = np.zeros_like(targets, dtype=np.int8)
dist = np.full_like(costs, np.nan, dtype=np.float32)

# want to set this to dtype='int32, int32'
# but then the if type(prev_loc) == tuple check will break
# becuas it gets instantiated with tuples
prev = np.full_like(costs, np.nan, dtype=object)
visited = np.zeros(shape, dtype=np.int8)
dist = np.full(shape, np.nan, dtype=np.float32)
prev = np.full((shape[0], shape[1], 2), -1, dtype=np.int32)

dist[start] = 0

# dist, loc
queue = [[0, start]]
# dist, loc
queue: List[Tuple[float, Tuple[int, int]]] = [(0.0, start)]
heapify(queue)

def zero_and_heap_path(loc: Tuple[int, int]) -> None:
"""Zero the location's distance value and follow upstream doing same.
Parameters
----------
loc : row, col of current point.
"""

if not dist[loc] == 0:
dist[loc] = 0
visited[loc] = 1

heappush(queue, [0, loc])
prev_loc = prev[loc]

if type(prev_loc) == tuple:
zero_and_heap_path(prev_loc)

counter = 0
progress = 0
max_cells = targets.shape[0] * targets.shape[1]
handle = None
if jupyter:
handle = display(Markdown(""), display_id=True)

while len(queue):
current = heappop(queue)
Expand Down Expand Up @@ -175,7 +154,16 @@ def zero_and_heap_path(loc: Tuple[int, int]) -> None:
# if the location is connected
if targets[next_loc]:
prev[next_loc] = current_loc
zero_and_heap_path(next_loc)
zero_loc = next_loc
while not dist[zero_loc] == 0.0:
dist[zero_loc] = 0.0
visited[zero_loc] = 1

heappush(queue, (0.0, zero_loc))
new_zero_loc = prev[zero_loc]
zero_loc = (new_zero_loc[0], new_zero_loc[1])
if zero_loc[0] == -1:
break

# otherwise it's a normal queue cell
else:
Expand All @@ -191,28 +179,19 @@ def zero_and_heap_path(loc: Tuple[int, int]) -> None:
if next_dist < dist[next_loc]:
dist[next_loc] = next_dist
prev[next_loc] = current_loc
heappush(queue, [next_dist, next_loc])
heappush(queue, (next_dist, next_loc))

else:
heappush(queue, [next_dist, next_loc])
heappush(queue, (next_dist, next_loc))
visited[next_loc] = 1
dist[next_loc] = next_dist
prev[next_loc] = current_loc

counter += 1
progress_new = 100 * counter / max_cells
if int(progress_new) > int(progress):
progress = progress_new
message = f"{progress:.2f} %"
if jupyter and handle:
handle.update(message)
elif not silent:
print(message)
if animate:
i = int(progress)
path = os.path.join(animate_path, f"arr{i:03d}.tif")
if not affine:
raise ValueError("Must provide an affine when animate=True") # NoQA
save_raster(path, dist, affine)
if not silent:
counter += 1
progress_new = int(100 * counter / max_cells)
if progress_new > progress + 4:
progress = progress_new
print(progress, "%")

return dist
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ dependencies = [
"click ~= 8.0",
"fiona ~= 1.9",
"geopandas ~= 0.14",
"numba ~= 0.58",
"numpy ~= 1.26",
"pandas ~= 2.2",
"pillow ~= 10.2",
Expand Down

0 comments on commit c0578a7

Please sign in to comment.