Skip to content

Commit ee40b24

Browse files
author
Martha Morrissey
authored
Merge pull request #149 from developmentseed/train_test_val
Add option to split data into train/test/validate sets
2 parents 9337612 + 28e2c29 commit ee40b24

File tree

5 files changed

+101
-16
lines changed

5 files changed

+101
-16
lines changed

docs/parameters.rst

+11
Original file line numberDiff line numberDiff line change
@@ -49,5 +49,16 @@ Here is the full list of configuration parameters you can specify in a ``config.
4949
``'segmentation'``
5050
Output is an array of shape ``(256, 256)`` with values matching the class index label at that position. The classes are applied sequentially according to ``config.json`` so latter classes will be written over earlier class labels if there is overlap.
5151

52+
**seed**: int
53+
Random generator seed. Optional, use to make results reproducible.
54+
55+
**split_vals**: list
56+
Default: `[0.8, 0.2]`
57+
Percentage of data to put in each category listed in split_names. Must be a list of floats that sum to one and match the length of `split-names`. For train, validate, and test data, a list like `[0.7, 0.2, 0.1]` is suggested.
58+
59+
**split_names**: list
60+
Default: `['train', 'test']`
61+
List of names for each subset of the data. Length of list must match length of `split_vals`.
62+
5263
**imagery_offset**: list of ints
5364
An optional list of integers representing the number of pixels to offset imagery. For example ``[15, -5]`` will move the images 15 pixels right and 5 pixels up relative to the requested tile bounds.

label_maker/package.py

+38-15
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
from label_maker.utils import is_tif
1010

1111

12-
def package_directory(dest_folder, classes, imagery, ml_type, seed=False, train_size=0.8, **kwargs):
12+
def package_directory(dest_folder, classes, imagery, ml_type, seed=False, split_names=['train', 'test'],
13+
split_vals=[0.8, .2], **kwargs):
1314
"""Generate an .npz file containing arrays for training machine learning algorithms
1415
1516
Parameters
@@ -28,16 +29,26 @@ def package_directory(dest_folder, classes, imagery, ml_type, seed=False, train_
2829
ml_type: str
2930
Defines the type of machine learning. One of "classification", "object-detection", or "segmentation"
3031
seed: int
31-
Random generator seed. Optional, use to make results reproducable.
32-
train_size: float
33-
Portion of the data to use in training, the remainder is used as test data (default 0.8)
32+
Random generator seed. Optional, use to make results reproducible.
33+
split_vals: list
34+
Default: [0.8, 0.2]
35+
Percentage of data to put in each catagory listed in split_names.
36+
Must be floats and must sum to one.
37+
split_names: list
38+
Default: ['train', 'test']
39+
List of names for each subset of the data.
3440
**kwargs: dict
3541
Other properties from CLI config passed as keywords to other utility functions
3642
"""
3743
# if a seed is given, use it
3844
if seed:
3945
np.random.seed(seed)
4046

47+
if len(split_names) != len(split_vals):
48+
raise ValueError('`split_names` and `split_vals` must be the same length. Please update your config.')
49+
if not np.isclose(sum(split_vals), 1):
50+
raise ValueError('`split_vals` must sum to one. Please update your config.')
51+
4152
# open labels file, create tile array
4253
labels_file = op.join(dest_folder, 'labels.npz')
4354
labels = np.load(labels_file)
@@ -60,7 +71,7 @@ def package_directory(dest_folder, classes, imagery, ml_type, seed=False, train_
6071
# open the images and load those plus the labels into the final arrays
6172
o = urlparse(imagery)
6273
_, image_format = op.splitext(o.path)
63-
if is_tif(imagery): # if a TIF is provided, use jpg as tile format
74+
if is_tif(imagery): # if a TIF is provided, use jpg as tile format
6475
image_format = '.jpg'
6576
for tile in tiles:
6677
image_file = op.join(dest_folder, 'tiles', '{}{}'.format(tile, image_format))
@@ -86,16 +97,28 @@ def package_directory(dest_folder, classes, imagery, ml_type, seed=False, train_
8697
elif ml_type == 'segmentation':
8798
y_vals.append(labels[tile][..., np.newaxis]) # Add grayscale channel
8899

89-
# split into train and test
90-
split_index = int(len(x_vals) * train_size)
91-
92-
# convert lists to numpy arrays
100+
# Convert lists to numpy arrays
93101
x_vals = np.array(x_vals, dtype=np.uint8)
94102
y_vals = np.array(y_vals, dtype=np.uint8)
95103

96-
print('Saving packaged file to {}'.format(op.join(dest_folder, 'data.npz')))
97-
np.savez(op.join(dest_folder, 'data.npz'),
98-
x_train=x_vals[:split_index, ...],
99-
y_train=y_vals[:split_index, ...],
100-
x_test=x_vals[split_index:, ...],
101-
y_test=y_vals[split_index:, ...])
104+
# Get number of data samples per split from the float proportions
105+
split_n_samps = [len(x_vals) * val for val in split_vals]
106+
107+
if np.any(split_n_samps == 0):
108+
raise ValueError('split must not generate zero samples per partition, change ratio of values in config file.')
109+
110+
# Convert into a cumulative sum to get indices
111+
split_inds = np.cumsum(split_n_samps).astype(np.integer)
112+
113+
# Exclude last index as `np.split` handles splitting without that value
114+
split_arrs_x = np.split(x_vals, split_inds[:-1])
115+
split_arrs_y = np.split(y_vals, split_inds[:-1])
116+
117+
save_dict = {}
118+
119+
for si, split_name in enumerate(split_names):
120+
save_dict[f'x_{split_name}'] = split_arrs_x[si]
121+
save_dict[f'y_{split_name}'] = split_arrs_y[si]
122+
123+
np.savez(op.join(dest_folder, 'data.npz'), **save_dict)
124+
print('Saving packaged file to {}'.format(op.join(dest_folder, 'data.npz')))

label_maker/validate.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,7 @@
3030
'background_ratio': {'type': 'float'},
3131
'ml_type': {'allowed': ['classification', 'object-detection', 'segmentation'], 'required': True},
3232
'seed': {'type': 'integer'},
33-
'imagery_offset': {'type': 'list', 'schema': {'type': 'integer'}, 'minlength': 2, 'maxlength': 2}
33+
'imagery_offset': {'type': 'list', 'schema': {'type': 'integer'}, 'minlength': 2, 'maxlength': 2},
34+
'split_vals': {'type': 'list', 'schema': {'type': 'float'}},
35+
'split_names': {'type': 'list', 'schema': {'type': 'string'}}
3436
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
{"country": "portugal",
2+
"bounding_box": [
3+
-9.4575,
4+
38.8467,
5+
-9.4510,
6+
38.8513
7+
],
8+
"zoom": 17,
9+
"classes": [
10+
{ "name": "Water Tower", "filter": ["==", "man_made", "water_tower"] },
11+
{ "name": "Building", "filter": ["has", "building"] },
12+
{ "name": "Farmland", "filter": ["==", "landuse", "farmland"] },
13+
{ "name": "Ruins", "filter": ["==", "historic", "ruins"] },
14+
{ "name": "Parking", "filter": ["==", "amenity", "parking"] },
15+
{ "name": "Roads", "filter": ["has", "highway"] }
16+
],
17+
"imagery": "https://api.mapbox.com/v4/mapbox.satellite/{z}/{x}/{y}.jpg?access_token=ACCESS_TOKEN",
18+
"background_ratio": 1,
19+
"ml_type": "classification",
20+
"seed": 19,
21+
"split_names": ["train", "test", "val"],
22+
"split_vals": [0.7, 0.2, 0.1]
23+
}

test/integration/test_classification_package.py

+26
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,24 @@
77

88
import numpy as np
99

10+
1011
class TestClassificationPackage(unittest.TestCase):
1112
"""Tests for classification package creation"""
13+
1214
@classmethod
1315
def setUpClass(cls):
1416
makedirs('integration-cl')
1517
copyfile('test/fixtures/integration/labels-cl.npz', 'integration-cl/labels.npz')
1618
copytree('test/fixtures/integration/tiles', 'integration-cl/tiles')
1719

20+
makedirs('integration-cl-split')
21+
copyfile('test/fixtures/integration/labels-cl.npz', 'integration-cl-split/labels.npz')
22+
copytree('test/fixtures/integration/tiles', 'integration-cl-split/tiles')
23+
1824
@classmethod
1925
def tearDownClass(cls):
2026
rmtree('integration-cl')
27+
rmtree('integration-cl-split')
2128

2229
def test_cli(self):
2330
"""Verify data.npz produced by CLI"""
@@ -48,3 +55,22 @@ def test_cli(self):
4855
[0, 0, 0, 0, 0, 0, 1]]
4956
)
5057
self.assertTrue(np.array_equal(data['y_test'], expected_y_test))
58+
59+
def test_cli_3way_split(self):
60+
"""Verify data.npz produced by CLI when split into train/test/val"""
61+
62+
cmd = 'label-maker package --dest integration-cl-split --config test/fixtures/integration/config_3way.integration.json'
63+
cmd = cmd.split(' ')
64+
subprocess.run(cmd, universal_newlines=True)
65+
66+
data = np.load('integration-cl-split/data.npz')
67+
68+
# validate our image data with shapes
69+
self.assertEqual(data['x_train'].shape, (5, 256, 256, 3))
70+
self.assertEqual(data['x_test'].shape, (2, 256, 256, 3))
71+
self.assertEqual(data['x_val'].shape, (1, 256, 256, 3))
72+
73+
# validate label data with shapes
74+
self.assertEqual(data['y_train'].shape, (5, 7))
75+
self.assertEqual(data['y_test'].shape, (2, 7))
76+
self.assertEqual(data['y_val'].shape, (1, 7))

0 commit comments

Comments
 (0)