Skip to content

Commit

Permalink
Jb/july24 (#148)
Browse files Browse the repository at this point in the history
* parallelise tiling

* multispectral compatibility

* updated multi-class functionality
  • Loading branch information
PatBall1 authored Sep 17, 2024
1 parent 2f29886 commit 1c30ea4
Show file tree
Hide file tree
Showing 13 changed files with 1,562 additions and 523 deletions.
1 change: 0 additions & 1 deletion .gitattributes
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@
*.pth filter=lfs diff=lfs merge=lfs -text
8 changes: 2 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,14 @@
<!-- <a href="https://github.com/hhatto/autopep8"><img alt="Code style: autopep8" src="https://img.shields.io/badge/code%20style-autopep8-000000.svg"></a> -->


Python package for automatic tree crown delineation based on Mask R-CNN. Pre-trained models can be picked in the [`model_garden`](https://github.com/PatBall1/detectree2/tree/master/model_garden).
A tutorial on how to prepare data, train models and make predictions is available [here](https://patball1.github.io/detectree2/tutorial.html). For questions, collaboration proposals and requests for data email [James Ball](mailto:ball.jgc@gmail.com). Some example data is available for download [here](https://doi.org/10.5281/zenodo.8136161).
Python package for automatic tree crown delineation in aerial RGB and multispectral imagery based on Mask R-CNN. Pre-trained models can be picked in the [`model_garden`](https://github.com/PatBall1/detectree2/tree/master/model_garden).
A tutorial on how to prepare data, train models and make predictions is available [here](https://patball1.github.io/detectree2/tutorial.html). For questions, collaboration proposals and requests for data email [James Ball](mailto:ball.jgc@gmail.com). Some example data is available to download [here](https://doi.org/10.5281/zenodo.8136161).

Detectree2是一个基于Mask R-CNN的自动树冠检测与分割的Python包。您可以在[`model_garden`](https://github.com/PatBall1/detectree2/tree/master/model_garden)中选择预训练模型。[这里](https://patball1.github.io/detectree2/tutorial.html)提供了如何准备数据、训练模型和进行预测的教程。如果有任何问题,合作提案或者需要样例数据,可以邮件联系[James Ball](mailto:ball.jgc@gmail.com)。一些示例数据可以在[这里](https://doi.org/10.5281/zenodo.8136161)下载。

| <a href="https://www.conservation.cam.ac.uk/"><img src="./report/cam_logo.png" width="140"></a> | <sup> Code developed by James Ball, Seb Hickman, Thomas Koay, Oscar Jiang, Luran Wang, Panagiotis Ioannou, James Hinton and Matthew Archer in the [Forest Ecology and Conservation Group](https://coomeslab.org/) at the University of Cambridge. The Forest Ecology and Conservation Group is led by Professor David Coomes and is part of the University of Cambridge [Conservation Research Institute](https://www.conservation.cam.ac.uk/). </sup>|
| :---: | :--- |

<br/><br/>
> [!NOTE]
> To save bandwidth, trained models have been moved to [Zenodo](https://zenodo.org/records/10522461). Download models directly with `wget` or equivalent.

## Citation

Expand Down
74 changes: 74 additions & 0 deletions detectree2/data_loading/custom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import cv2
import detectron2.data.transforms as T
import numpy as np
import rasterio
import torch
from detectron2.structures import BitMasks, BoxMode, Instances
from torch.utils.data import Dataset


class CustomTIFFDataset(Dataset):
def __init__(self, annotations, transforms=None):
"""
Args:
annotations (list): List of dictionaries containing image file paths and annotations.
transforms (callable, optional): Optional transform to be applied on a sample.
"""
self.annotations = annotations
self.transforms = transforms

def __len__(self):
return len(self.annotations)

def __getitem__(self, idx):
# Load the TIFF image
img_info = self.annotations[idx]
with rasterio.open(img_info['file_name']) as src:
# Read all bands (assuming they are all needed)
image = src.read()
# Normalize or rescale if necessary
image = image.astype(np.float32) / 255.0 # Example normalization
# If the number of bands is not 3, reduce to 3 or handle accordingly
#if image.shape[0] > 3:
# image = image[:3, :, :] # Taking the first 3 bands (e.g., RGB)
# Convert to HWC format expected by Detectron2
#image = np.transpose(image, (1, 2, 0))

# Prepare annotations (this part needs to be adapted to your specific annotations)
target = {
"image_id": idx,
"annotations": img_info["annotations"],
"width": img_info["width"],
"height": img_info["height"],
}

if self.transforms is not None:
augmentations = T.AugmentationList(self.transforms)
image, target = augmentations(image, target)

# Convert to Detectron2-compatible format
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
instances = self.get_detectron_instances(target)

return image, instances

def get_detectron_instances(self, target):
"""
Converts annotations into Detectron2's format.
This example assumes annotations are in COCO format, and you'll need to adapt it for your needs.
"""
boxes = [obj["bbox"] for obj in target["annotations"]]
boxes = torch.as_tensor(boxes, dtype=torch.float32)
boxes = BoxMode.convert(boxes, BoxMode.XYWH_ABS, BoxMode.XYXY_ABS)

# Create BitMasks from the binary mask data (assuming the mask is a binary numpy array)
masks = [obj["segmentation"] for obj in target["annotations"]] # Replace with actual mask loading
masks = BitMasks(torch.stack([torch.from_numpy(mask) for mask in masks]))

instances = Instances(
image_size=(target["height"], target["width"]),
gt_boxes=boxes,
gt_classes=torch.tensor([obj["category_id"] for obj in target["annotations"]], dtype=torch.int64),
gt_masks=masks
)
return instances
Loading

0 comments on commit 1c30ea4

Please sign in to comment.