Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions detectree2/models/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def __init__(self, eval_period, model, data_loader, img_per_dataset=6):
eval_period (int): The number of iterations between evaluations.
model (torch.nn.Module): The model to evaluate.
data_loader (torch.utils.data.DataLoader): The data loader for evaluation.
patience (int): The number of evaluation periods to wait for improvement before early stopping.
img_per_dataset (int): The number of images per dataset to visualize.
"""
self._model = model
self._period = eval_period
Expand All @@ -339,11 +339,11 @@ def __init__(self, eval_period, model, data_loader, img_per_dataset=6):

def after_step(self):
"""
Hook to be called after each training iteration to evaluate the model and manage checkpoints.
Hook to be called after each training iteration to visualize model predictions.

- Evaluates the model at regular intervals.
- Saves the best model checkpoint based on the AP50 metric.
- Implements early stopping if the AP50 does not improve after a set number of evaluations.
This hook runs at regular intervals to perform inference on a sample of the validation dataset.
It then visualizes the predictions and logs the resulting images to the event storage,
making them accessible through tools like TensorBoard.
"""
next_iter = self.trainer.iter + 1
is_final = next_iter == self.trainer.max_iter
Expand Down Expand Up @@ -703,9 +703,9 @@ def get_tree_dicts(directory: str, class_mapping: Optional[Dict[str, int]] = Non
"""Get the tree dictionaries.

Args:
directory: Path to directory
classes: List of classes to include
classes_at: Signifies which column (if any) corresponds to the class labels
directory: Path to directory containing geojson annotation files.
class_mapping: A dictionary mapping class labels from geojson properties to category indices.
If None, all annotations are assigned to category 0 (tree).

Returns:
List of dictionaries corresponding to segmentations of trees. Each dictionary includes
Expand Down Expand Up @@ -919,12 +919,13 @@ def remove_registered_data(name="tree"):
MetadataCatalog.remove(name + "_" + d)


def register_test_data(test_location, name="tree"):
def register_test_data(test_location, name="tree", class_mapping_file=None):
"""Register data for testing.

Args:
test_location: directory containing test data
name: string to name data
class_mapping_file: Path to the class mapping file (json or pickle).
"""
d = "test"

Expand Down Expand Up @@ -993,7 +994,6 @@ def setup_cfg(
base_lr: base learning rate
weight_decay: weight decay for optimizer
max_iter: maximum number of iterations
num_classes: number of classes
eval_period: number of iterations between evaluations
out_dir: directory to save outputs
resize: resize strategy for images
Expand Down
174 changes: 110 additions & 64 deletions detectree2/preprocessing/tiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,19 +138,27 @@ def process_tile(img_path: str,
"""Process a single tile for making predictions.

Args:
img_path: Path to the orthomosaic
out_dir: Output directory
buffer: Overlapping buffer of tiles in meters (UTM)
tile_width: Tile width in meters
tile_height: Tile height in meters
dtype_bool: Flag to edit dtype to prevent black tiles
minx: Minimum x coordinate of tile
miny: Minimum y coordinate of tile
crs: Coordinate reference system
tilename: Name of the tile
img_path: Path to the orthomosaic.
out_dir: Output directory.
buffer: Overlapping buffer of tiles in meters (UTM).
tile_width: Tile width in meters.
tile_height: Tile height in meters.
dtype_bool: Flag to edit dtype to prevent black tiles.
minx: Minimum x coordinate of the tile.
miny: Minimum y coordinate of the tile.
crs: Coordinate reference system.
tilename: Name of the tile.
crowns: Crown polygons as a GeoDataFrame used to skip tiles if coverage is below `threshold`.
threshold: Minimum fraction [0,1] of tile coverage by `crowns` required to avoid skipping the tile.
nan_threshold: Maximum proportion [0,1] of the tile that can be nodata or NaN values before skipping.
mask_gdf: A GeoDataFrame containing polygons that act as masks for the tile. Only the interior is kept, the rest of the image will become nodata.
additional_nodata: List of additional pixel values to treat as nodata.
image_statistics: A list of dictionaries where each dictionary contains information about the pixel distribution of that band. One list element per band.
ignore_bands_indices: List of integer indices of bands to ignore during processing.
use_convex_mask: If True, creates a convex mask around crown polygons to exclude areas outside of annotated training crowns.

Returns:
None
A tuple containing the rasterio dataset, output path root, overlapping crowns, and tile parameters (minx, miny, buffer), or None if the tile is skipped.
"""
try:
with rasterio.open(img_path) as data:
Expand Down Expand Up @@ -290,22 +298,30 @@ def process_tile_ms(img_path: str,
image_statistics: List[Dict[str, float]] = None,
ignore_bands_indices: List[int] = [],
use_convex_mask: bool = True):
"""Process a single tile for making predictions.
"""Process a single multispectral tile for making predictions.

Args:
img_path: Path to the orthomosaic
out_dir: Output directory
buffer: Overlapping buffer of tiles in meters (UTM)
tile_width: Tile width in meters
tile_height: Tile height in meters
dtype_bool: Flag to edit dtype to prevent black tiles
minx: Minimum x coordinate of tile
miny: Minimum y coordinate of tile
crs: Coordinate reference system
tilename: Name of the tile
img_path: Path to the orthomosaic.
out_dir: Output directory.
buffer: Overlapping buffer of tiles in meters (UTM).
tile_width: Tile width in meters.
tile_height: Tile height in meters.
dtype_bool: Flag to edit dtype to prevent black tiles.
minx: Minimum x coordinate of the tile.
miny: Minimum y coordinate of the tile.
crs: Coordinate reference system.
tilename: Name of the tile.
crowns: Crown polygons as a GeoDataFrame used to skip tiles if coverage is below `threshold`.
threshold: Minimum fraction [0,1] of tile coverage by `crowns` required to avoid skipping the tile.
nan_threshold: Maximum proportion [0,1] of the tile that can be nodata or NaN values before skipping.
mask_gdf: A GeoDataFrame containing polygons that act as masks for the tile. Only the interior is kept, the rest of the image will become nodata.
additional_nodata: List of additional pixel values to treat as nodata.
image_statistics: A list of dictionaries where each dictionary contains information about the pixel distribution of that band. One list element per band.
ignore_bands_indices: List of integer indices of bands to ignore during processing.
use_convex_mask: If True, creates a convex mask around crown polygons to exclude areas outside of annotated crowns.

Returns:
None
A tuple containing the rasterio dataset, output path root, overlapping crowns, and tile parameters (minx, miny, buffer), or None if the tile is skipped.
"""
try:
with rasterio.open(img_path) as data:
Expand Down Expand Up @@ -457,19 +473,26 @@ def process_tile_train(
"""Process a single tile for training data.

Args:
img_path: Path to the orthomosaic
out_dir: Output directory
buffer: Overlapping buffer of tiles in meters (UTM)
tile_width: Tile width in meters
tile_height: Tile height in meters
dtype_bool: Flag to edit dtype to prevent black tiles
minx: Minimum x coordinate of tile
miny: Minimum y coordinate of tile
crs: Coordinate reference system
tilename: Name of the tile
crowns: Crown polygons as a geopandas dataframe
threshold: Min proportion of the tile covered by crowns to be accepted {0,1}
nan_theshold: Max proportion of tile covered by nans
img_path: Path to the orthomosaic.
out_dir: Output directory.
buffer: Overlapping buffer of tiles in meters (UTM).
tile_width: Tile width in meters.
tile_height: Tile height in meters.
dtype_bool: Flag to edit dtype to prevent black tiles.
minx: Minimum x coordinate of tile.
miny: Minimum y coordinate of tile.
crs: Coordinate reference system.
tilename: Name of the tile.
crowns: Crown polygons as a geopandas DataFrame.
threshold: Min proportion of the tile covered by crowns to be accepted {0,1}.
nan_threshold: Max proportion of tile covered by NaNs.
mode: Type of the raster data ("rgb" or "ms").
class_column: Name of the column in `crowns` DataFrame for class-based tiling.
mask_gdf: A GeoDataFrame containing polygons that act as masks for the tile. Only the interior is kept, the rest of the image will become nodata.
additional_nodata: List of additional pixel values to treat as nodata.
image_statistics: A list of dictionaries where each dictionary contains information about the pixel distribution of that band. One list element per band.
ignore_bands_indices: List of integer indices of bands to ignore during processing.
use_convex_mask: If True, creates a convex mask around crown polygons to exclude areas outside of annotated crowns.

Returns:
None
Expand Down Expand Up @@ -544,7 +567,20 @@ def _calculate_tile_placements(
tile_placement: str = "grid",
overlapping_tiles: bool = False,
) -> List[Tuple[int, int]]:
"""Internal method for calculating the placement of tiles"""
"""Internal method for calculating the placement of tiles.

Args:
img_path: Path to the orthomosaic.
buffer: Overlapping buffer of tiles in meters (UTM).
tile_width: Tile width in meters.
tile_height: Tile height in meters.
crowns: Crown polygons as a GeoDataFrame. Required for 'adaptive' placement.
tile_placement: Strategy for placing tiles ('grid' or 'adaptive').
overlapping_tiles: If True, generates additional tiles offset by half a tile size.

Returns:
A list of (minx, miny) coordinates for the lower-coordinates corner of each tile.
"""

if tile_placement == "grid":
with rasterio.open(img_path) as data:
Expand Down Expand Up @@ -622,17 +658,18 @@ def calculate_image_statistics(file_path,
min_windows=100,
mode="rgb",
ignore_bands_indices: List[int] = []):
"""
Calculate statistics for a raster using either whole image or sampled windows.
"""Calculate statistics for a raster using either whole image or sampled windows.

Parameters:
- file_path: str, path to the raster file.
- values_to_ignore: list, values to ignore in statistics (e.g., NaN, custom values).
- window_size: int, size of square window for sampling.
- min_windows: int, minimum number of valid windows to include in statistics.
Args:
file_path: Path to the raster file.
values_to_ignore: Values to ignore in statistics (e.g., NaN, custom values).
window_size: Size of square window for sampling.
min_windows: Minimum number of valid windows to include in statistics.
mode: Type of the raster data ("rgb" or "ms").
ignore_bands_indices: List of integer indices of bands to ignore during statistics calculation.

Returns:
- List of dictionaries containing statistics for each band.
List of dictionaries containing statistics for each band.
"""
if values_to_ignore is None:
values_to_ignore = []
Expand Down Expand Up @@ -769,26 +806,34 @@ def tile_data(
) -> None:
"""Tiles up orthomosaic and corresponding crowns (if supplied) into training/prediction tiles.

Tiles up large rasters into manageable tiles for training and prediction. If crowns are not supplied, the function
will tile up the entire landscape for prediction. If crowns are supplied, the function will tile these with the image
and skip tiles without a minimum coverage of crowns. The 'threshold' can be varied to ensure good coverage of
crowns across a training tile. Tiles that do not have sufficient coverage are skipped.
Tiles up large rasters into manageable tiles for training and prediction. If crowns are not
supplied, the function will tile up the entire landscape for prediction. If crowns are supplied,
the function will tile these with the image and skip tiles without a minimum coverage of crowns.
The 'threshold' can be varied to ensure good coverage of crowns across a training tile. Tiles
that do not have sufficient coverage are skipped.

Args:
img_path: Path to the orthomosaic
out_dir: Output directory
buffer: Overlapping buffer of tiles in meters (UTM)
tile_width: Tile width in meters
tile_height: Tile height in meters
crowns: Crown polygons as a GeoPandas DataFrame
threshold: Minimum proportion of the tile covered by crowns to be accepted [0,1]
nan_threshold: Maximum proportion of tile covered by NaNs [0,1]
dtype_bool: Flag to edit dtype to prevent black tiles
mode: Type of the raster data ("rgb" or "ms")
class_column: Name of the column in `crowns` DataFrame for class-based tiling
img_path: Path to the orthomosaic.
out_dir: Output directory.
buffer: Overlapping buffer of tiles in meters (UTM).
tile_width: Tile width in meters.
tile_height: Tile height in meters.
crowns: Crown polygons as a GeoDataFrame.
threshold: Minimum proportion of the tile covered by crowns to be accepted [0,1].
nan_threshold: Maximum proportion of the tile covered by NaNs [0,1].
dtype_bool: Flag to edit dtype to prevent black tiles.
mode: Type of the raster data ("rgb" or "ms").
class_column: Name of the column in `crowns` DataFrame for class-based tiling.
tile_placement: Strategy for placing tiles.
"grid" for fixed grid placement based on the bounds of the input image, optimized for speed.
"adaptive" for dynamic placement of tiles based on crowns, adjusts based on data features for better coverage.
mask_path: Path to a mask file to use for tiling.
multithreaded: Flag to enable multithreaded processing.
random_subset: Number of random tiles it will try to process per image. If -1, all tiles are processed.
additional_nodata: List of additional pixel values to treat as nodata.
overlapping_tiles: Flag to enable overlapping tiles for more training data generation. More useful for training the detection part of the Mask R-CNN model.
ignore_bands_indices: List of integer indices of bands to ignore during processing.
use_convex_mask: If True, creates a convex mask around crown polygons to exclude areas outside of the crowns.

Returns:
None
Expand Down Expand Up @@ -1154,7 +1199,7 @@ def image_details(fileroot):
fileroot: image filename without file extension

Returns:
Box structure
A list of two tuples representing the bounding box with buffer: [(xmin, xmax), (ymin, ymax)].
"""
image_info = fileroot.split("_")
minx = int(image_info[-5])
Expand All @@ -1171,11 +1216,11 @@ def is_overlapping_box(test_boxes_array, train_box):
"""Check if the train box overlaps with any of the test boxes.

Args:
test_boxes_array:
train_box:
test_boxes_array: A list of bounding boxes to check against.
train_box: The bounding box to test for overlap.

Returns:
Boolean
True if `train_box` overlaps with any box in `test_boxes_array`, False otherwise.
"""
for test_box in test_boxes_array:
test_box_x = test_box[0]
Expand Down Expand Up @@ -1245,6 +1290,7 @@ def to_traintest_folders( # noqa: C901
test_frac: fraction of tiles to be used for testing
folds: number of folds to split the data into
strict: if True, training/validation files will be removed if there is any overlap with test files (inc buffer)
seed: Random seed for shuffling to ensure reproducibility.

Returns:
None
Expand Down
Loading