Skip to content

Commit a6a2c25

Browse files
Add dota8.yaml and O tests (ultralytics#7394)
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
1 parent d0562d7 commit a6a2c25

File tree

13 files changed

+176
-16
lines changed

13 files changed

+176
-16
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ See [OBB Docs](https://docs.ultralytics.com/tasks/obb/) for usage examples with
199199
| [YOLOv8l-obb](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8l-obb.pt) | 1024 | 80.7 | 1278.42 | 11.83 | 44.5 | 433.8 |
200200
| [YOLOv8x-obb](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8x-obb.pt) | 1024 | 81.36 | 1759.10 | 13.23 | 69.5 | 676.7 |
201201

202-
- **mAP<sup>test</sup>** values are for single-model multi-scale on [DOTAv1](https://captain-whu.github.io/DOTA/index.html) dataset. <br>Reproduce by `yolo val obb data=DOTAv1.yaml device=0 split=test`
202+
- **mAP<sup>test</sup>** values are for single-model multi-scale on [DOTAv1](https://captain-whu.github.io/DOTA/index.html) dataset. <br>Reproduce by `yolo val obb data=DOTAv1.yaml device=0 split=test` and submit merged results to [DOTA evaluation](https://captain-whu.github.io/DOTA/evaluation.html).
203203
- **Speed** averaged over DOTAv1 val images using an [Amazon EC2 P4d](https://aws.amazon.com/ec2/instance-types/p4/) instance. <br>Reproduce by `yolo val obb data=DOTAv1.yaml batch=1 device=0|cpu`
204204

205205
</details>

docs/en/datasets/obb/dota-v2.md

+28-1
Original file line numberDiff line numberDiff line change
@@ -66,13 +66,40 @@ Typically, datasets incorporate a YAML (Yet Another Markup Language) file detail
6666
--8<-- "ultralytics/cfg/datasets/DOTAv1.yaml"
6767
```
6868

69+
## Split DOTA images
70+
71+
To train DOTA dataset, We split original DOTA images with high-resolution into images with 1024x1024 resolution in multi-scale way.
72+
73+
!!! Example "Split images"
74+
75+
=== "Python"
76+
77+
```python
78+
from ultralytics.data.split_dota import split_trainval, split_test
79+
80+
# split train and val set, with labels.
81+
split_trainval(
82+
data_root='path/to/DOTAv1.0/',
83+
save_dir='path/to/DOTAv1.0-split/',
84+
rates=[0.5, 1.0, 1.5], # multi-scale
85+
gap=500
86+
)
87+
# split test set, without labels.
88+
split_test(
89+
data_root='path/to/DOTAv1.0/',
90+
save_dir='path/to/DOTAv1.0-split/',
91+
rates=[0.5, 1.0, 1.5], # multi-scale
92+
gap=500
93+
)
94+
```
95+
6996
## Usage
7097

7198
To train a model on the DOTA v1 dataset, you can utilize the following code snippets. Always refer to your model's documentation for a thorough list of available arguments.
7299

73100
!!! Warning
74101

75-
Please note that all images and associated annotations in the DOTAv2 dataset can be used for academic purposes, but commercial use is prohibited. Your understanding and respect for the dataset creators' wishes are greatly appreciated!
102+
Please note that all images and associated annotations in the DOTAv1 dataset can be used for academic purposes, but commercial use is prohibited. Your understanding and respect for the dataset creators' wishes are greatly appreciated!
76103

77104
!!! Example "Train Example"
78105

docs/en/datasets/obb/dota8.md

+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
---
2+
comments: true
3+
description: Discover the versatile DOTA8 dataset, perfect for testing and debugging oriented detection models. Learn how to get started with YOLOv8-obb model training.
4+
keywords: Ultralytics, YOLOv8, oriented detection, DOTA8 dataset, dataset, model training, YAML
5+
---
6+
7+
# DOTA8 Dataset
8+
9+
## Introduction
10+
11+
[Ultralytics](https://ultralytics.com) DOTA8 is a small, but versatile oriented object detection dataset composed of the first 8 images of 8 images of the split DOTAv1 set, 4 for training and 4 for validation. This dataset is ideal for testing and debugging object detection models, or for experimenting with new detection approaches. With 8 images, it is small enough to be easily manageable, yet diverse enough to test training pipelines for errors and act as a sanity check before training larger datasets.
12+
13+
This dataset is intended for use with Ultralytics [HUB](https://hub.ultralytics.com) and [YOLOv8](https://github.com/ultralytics/ultralytics).
14+
15+
## Dataset YAML
16+
17+
A YAML (Yet Another Markup Language) file is used to define the dataset configuration. It contains information about the dataset's paths, classes, and other relevant information. In the case of the DOTA8 dataset, the `dota8.yaml` file is maintained at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/datasets/dota8.yaml](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/datasets/dota8.yaml).
18+
19+
!!! Example "ultralytics/cfg/datasets/dota8.yaml"
20+
21+
```yaml
22+
--8<-- "ultralytics/cfg/datasets/dota8.yaml"
23+
```
24+
25+
## Usage
26+
27+
To train a YOLOv8n-obb model on the DOTA8 dataset for 100 epochs with an image size of 640, you can use the following code snippets. For a comprehensive list of available arguments, refer to the model [Training](../../modes/train.md) page.
28+
29+
!!! Example "Train Example"
30+
31+
=== "Python"
32+
33+
```python
34+
from ultralytics import YOLO
35+
36+
# Load a model
37+
model = YOLO('yolov8n-obb.pt') # load a pretrained model (recommended for training)
38+
39+
# Train the model
40+
results = model.train(data='dota8.yaml', epochs=100, imgsz=640)
41+
```
42+
43+
=== "CLI"
44+
45+
```bash
46+
# Start training from a pretrained *.pt model
47+
yolo obb train data=dota8.yaml model=yolov8n-obb.pt epochs=100 imgsz=640
48+
```
49+
50+
## Sample Images and Annotations
51+
52+
Here are some examples of images from the DOTA8 dataset, along with their corresponding annotations:
53+
54+
<img src="https://github.com/Laughing-q/assets/assets/61612323/965d3ff7-5b9b-4add-b62e-9795921b60de" alt="Dataset sample image" width="800">
55+
56+
- **Mosaiced Image**: This image demonstrates a training batch composed of mosaiced dataset images. Mosaicing is a technique used during training that combines multiple images into a single image to increase the variety of objects and scenes within each training batch. This helps improve the model's ability to generalize to different object sizes, aspect ratios, and contexts.
57+
58+
The example showcases the variety and complexity of the images in the DOTA8 dataset and the benefits of using mosaicing during the training process.
59+
60+
## Citations and Acknowledgments
61+
62+
If you use the DOTA dataset in your research or development work, please cite the following paper:
63+
64+
!!! Quote ""
65+
66+
=== "BibTeX"
67+
68+
```bibtex
69+
@article{9560031,
70+
author={Ding, Jian and Xue, Nan and Xia, Gui-Song and Bai, Xiang and Yang, Wen and Yang, Michael and Belongie, Serge and Luo, Jiebo and Datcu, Mihai and Pelillo, Marcello and Zhang, Liangpei},
71+
journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
72+
title={Object Detection in Aerial Images: A Large-Scale Benchmark and Challenges},
73+
year={2021},
74+
volume={},
75+
number={},
76+
pages={1-1},
77+
doi={10.1109/TPAMI.2021.3117983}
78+
}
79+
```
80+
81+
A special note of gratitude to the team behind the DOTA datasets for their commendable effort in curating this dataset. For an exhaustive understanding of the dataset and its nuances, please visit the [official DOTA website](https://captain-whu.github.io/DOTA/index.html).

docs/en/tasks/obb.md

+7-9
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,12 @@ YOLOv8 pretrained OBB models are shown here, which are pretrained on the [DOTAv1
3232
| [YOLOv8l-obb](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8l-obb.pt) | 1024 | 80.7 | 1278.42 | 11.83 | 44.5 | 433.8 |
3333
| [YOLOv8x-obb](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8x-obb.pt) | 1024 | 81.36 | 1759.10 | 13.23 | 69.5 | 676.7 |
3434

35-
- **mAP<sup>test</sup>** values are for single-model multi-scale on [DOTAv1 test](http://cocodataset.org) dataset. <br>Reproduce by `yolo val obb data=DOTAv1.yaml device=0 split=test`
35+
- **mAP<sup>test</sup>** values are for single-model multi-scale on [DOTAv1 test](https://captain-whu.github.io/DOTA/index.html) dataset. <br>Reproduce by `yolo val obb data=DOTAv1.yaml device=0 split=test` and submit merged results to [DOTA evaluation](https://captain-whu.github.io/DOTA/evaluation.html).
3636
- **Speed** averaged over DOTAv1 val images using an [Amazon EC2 P4d](https://aws.amazon.com/ec2/instance-types/p4/) instance. <br>Reproduce by `yolo val obb data=DOTAv1.yaml batch=1 device=0|cpu`
3737

3838
## Train
3939

40-
<!-- TODO: probably we should create a sample dataset like coco128.yaml, named dota128.yaml? -->
41-
42-
Train YOLOv8n-obb on the dota128.yaml dataset for 100 epochs at image size 640. For a full list of available arguments see the [Configuration](../usage/cfg.md) page.
40+
Train YOLOv8n-obb on the dota8.yaml dataset for 100 epochs at image size 640. For a full list of available arguments see the [Configuration](../usage/cfg.md) page.
4341

4442
!!! Example
4543

@@ -54,19 +52,19 @@ Train YOLOv8n-obb on the dota128.yaml dataset for 100 epochs at image size 640.
5452
model = YOLO('yolov8n-obb.yaml').load('yolov8n.pt') # build from YAML and transfer weights
5553

5654
# Train the model
57-
results = model.train(data='dota128-obb.yaml', epochs=100, imgsz=640)
55+
results = model.train(data='dota8-obb.yaml', epochs=100, imgsz=640)
5856
```
5957
=== "CLI"
6058

6159
```bash
6260
# Build a new model from YAML and start training from scratch
63-
yolo obb train data=dota128-obb.yaml model=yolov8n-obb.yaml epochs=100 imgsz=640
61+
yolo obb train data=dota8-obb.yaml model=yolov8n-obb.yaml epochs=100 imgsz=640
6462

6563
# Start training from a pretrained *.pt model
66-
yolo obb train data=dota128-obb.yaml model=yolov8n-obb.pt epochs=100 imgsz=640
64+
yolo obb train data=dota8-obb.yaml model=yolov8n-obb.pt epochs=100 imgsz=640
6765

6866
# Build a new model from YAML, transfer pretrained weights to it and start training
69-
yolo obb train data=dota128-obb.yaml model=yolov8n-obb.yaml pretrained=yolov8n-obb.pt epochs=100 imgsz=640
67+
yolo obb train data=dota8-obb.yaml model=yolov8n-obb.yaml pretrained=yolov8n-obb.pt epochs=100 imgsz=640
7068
```
7169

7270
### Dataset format
@@ -75,7 +73,7 @@ OBB dataset format can be found in detail in the [Dataset Guide](../datasets/obb
7573

7674
## Val
7775

78-
Validate trained YOLOv8n-obb model accuracy on the dota128-obb dataset. No argument need to passed as the `model`
76+
Validate trained YOLOv8n-obb model accuracy on the dota8-obb dataset. No argument need to passed as the `model`
7977
retains it's training `data` and arguments as model attributes.
8078

8179
!!! Example

docs/mkdocs.yml

+1
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,7 @@ nav:
260260
- Oriented Bounding Boxes (OBB):
261261
- datasets/obb/index.md
262262
- DOTAv2: datasets/obb/dota-v2.md
263+
- DOTA8: datasets/obb/dota8.md
263264
- Multi-Object Tracking:
264265
- datasets/track/index.md
265266
- Guides:

tests/test_cli.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@
1313
('detect', 'yolov8n', 'coco8.yaml'),
1414
('segment', 'yolov8n-seg', 'coco8-seg.yaml'),
1515
('classify', 'yolov8n-cls', 'imagenet10'),
16-
('pose', 'yolov8n-pose', 'coco8-pose.yaml'), ] # (task, model, data)
16+
('pose', 'yolov8n-pose', 'coco8-pose.yaml'),
17+
('obb', 'yolov8n-obb', 'dota8.yaml'), ] # (task, model, data)
1718
EXPORT_ARGS = [
1819
('yolov8n', 'torchscript'),
1920
('yolov8n-seg', 'torchscript'),
2021
('yolov8n-cls', 'torchscript'),
21-
('yolov8n-pose', 'torchscript'), ] # (model, format)
22+
('yolov8n-pose', 'torchscript'),
23+
('yolov8n-obb', 'torchscript'), ] # (model, format)
2224

2325

2426
def run(cmd):

tests/test_python.py

+3
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def test_predict_img():
7777
seg_model = YOLO(WEIGHTS_DIR / 'yolov8n-seg.pt')
7878
cls_model = YOLO(WEIGHTS_DIR / 'yolov8n-cls.pt')
7979
pose_model = YOLO(WEIGHTS_DIR / 'yolov8n-pose.pt')
80+
obb_model = YOLO(WEIGHTS_DIR / 'yolov8n-obb.pt')
8081
im = cv2.imread(str(SOURCE))
8182
assert len(model(source=Image.open(SOURCE), save=True, verbose=True, imgsz=32)) == 1 # PIL
8283
assert len(model(source=im, save=True, save_txt=True, imgsz=32)) == 1 # ndarray
@@ -105,6 +106,8 @@ def test_predict_img():
105106
assert len(results) == t.shape[0]
106107
results = pose_model(t, imgsz=32)
107108
assert len(results) == t.shape[0]
109+
results = obb_model(t, imgsz=32)
110+
assert len(results) == t.shape[0]
108111

109112

110113
def test_predict_grey_and_4ch():

ultralytics/cfg/datasets/DOTAv1.5.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# parent
66
# ├── ultralytics
77
# └── datasets
8-
# └── dota2 ← downloads here (2GB)
8+
# └── dota1.5 ← downloads here (2GB)
99

1010
# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
1111
path: ../datasets/DOTAv1.5 # dataset root dir

ultralytics/cfg/datasets/DOTAv1.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# parent
66
# ├── ultralytics
77
# └── datasets
8-
# └── dota2 ← downloads here (2GB)
8+
# └── dota1 ← downloads here (2GB)
99

1010
# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
1111
path: ../datasets/DOTAv1 # dataset root dir

ultralytics/cfg/datasets/dota8.yaml

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Ultralytics YOLO 🚀, AGPL-3.0 license
2+
# DOTA8 dataset 8 images from split DOTAv1 dataset by Ultralytics
3+
# Documentation: https://docs.ultralytics.com/datasets/obb/dota8/
4+
# Example usage: yolo train model=yolov8n-obb.pt data=dota8.yaml
5+
# parent
6+
# ├── ultralytics
7+
# └── datasets
8+
# └── dota8 ← downloads here (1MB)
9+
10+
# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
11+
path: ../datasets/dota8 # dataset root dir
12+
train: images/train # train images (relative to 'path') 4 images
13+
val: images/val # val images (relative to 'path') 4 images
14+
15+
# Classes for DOTA 1.0
16+
names:
17+
0: plane
18+
1: ship
19+
2: storage tank
20+
3: baseball diamond
21+
4: tennis court
22+
5: basketball court
23+
6: ground track field
24+
7: harbor
25+
8: bridge
26+
9: large vehicle
27+
10: small vehicle
28+
11: helicopter
29+
12: roundabout
30+
13: soccer ball field
31+
14: swimming pool
32+
33+
# Download script/URL (optional)
34+
download: https://github.com/ultralytics/yolov5/releases/download/v1.0/dota8.zip

ultralytics/engine/results.py

+3
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,9 @@ def save_crop(self, save_dir, file_name=Path('im.jpg')):
323323
if self.probs is not None:
324324
LOGGER.warning('WARNING ⚠️ Classify task do not support `save_crop`.')
325325
return
326+
if self.obb is not None:
327+
LOGGER.warning('WARNING ⚠️ OBB task do not support `save_crop`.')
328+
return
326329
for d in self.boxes:
327330
save_one_box(d.xyxy,
328331
self.orig_img.copy(),

ultralytics/models/yolo/obb/val.py

+11
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,17 @@ def pred_to_json(self, predn, filename):
106106
'rbox': [round(x, 3) for x in r],
107107
'poly': [round(x, 3) for x in b]})
108108

109+
def save_one_txt(self, predn, save_conf, shape, file):
110+
"""Save YOLO detections to a txt file in normalized coordinates in a specific format."""
111+
gn = torch.tensor(shape)[[1, 0, 1, 0]] # normalization gain whwh
112+
for *xyxy, conf, cls, angle in predn.tolist():
113+
xywha = torch.tensor([*xyxy, angle]).view(1, 5)
114+
xywha[:, :4] /= gn
115+
xyxyxyxy = ops.xywhr2xyxyxyxy(xywha).view(-1).tolist() # normalized xywh
116+
line = (cls, *xyxyxyxy, conf) if save_conf else (cls, *xyxyxyxy) # label format
117+
with open(file, 'a') as f:
118+
f.write(('%g ' * len(line)).rstrip() % line + '\n')
119+
109120
def eval_json(self, stats):
110121
"""Evaluates YOLO output in JSON format and returns performance statistics."""
111122
if self.args.save_json and self.is_dota and len(self.jdict):

ultralytics/models/yolo/segment/train.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def plot_training_samples(self, batch, ni):
5151
batch['batch_idx'],
5252
batch['cls'].squeeze(-1),
5353
batch['bboxes'],
54-
batch['masks'],
54+
masks=batch['masks'],
5555
paths=batch['im_file'],
5656
fname=self.save_dir / f'train_batch{ni}.jpg',
5757
on_plot=self.on_plot)

0 commit comments

Comments
 (0)