-
Notifications
You must be signed in to change notification settings - Fork 57
/
Copy pathclassify.blog_post.py
100 lines (80 loc) · 3.81 KB
/
classify.blog_post.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import numpy as np
import os
from osgeo import gdal
from sklearn import metrics
from sklearn.ensemble import RandomForestClassifier
# A list of "random" colors (for a nicer output)
COLORS = ["#000000", "#FFFF00", "#1CE6FF", "#FF34FF", "#FF4A46", "#008941"]
def create_mask_from_vector(vector_data_path, cols, rows, geo_transform,
projection, target_value=1):
"""Rasterize the given vector (wrapper for gdal.RasterizeLayer)."""
data_source = gdal.OpenEx(vector_data_path, gdal.OF_VECTOR)
layer = data_source.GetLayer(0)
driver = gdal.GetDriverByName('MEM') # In memory dataset
target_ds = driver.Create('', cols, rows, 1, gdal.GDT_UInt16)
target_ds.SetGeoTransform(geo_transform)
target_ds.SetProjection(projection)
gdal.RasterizeLayer(target_ds, [1], layer, burn_values=[target_value])
return target_ds
def vectors_to_raster(file_paths, rows, cols, geo_transform, projection):
"""Rasterize all the vectors in the given directory into a single image."""
labeled_pixels = np.zeros((rows, cols))
for i, path in enumerate(file_paths):
label = i+1
ds = create_mask_from_vector(path, cols, rows, geo_transform,
projection, target_value=label)
band = ds.GetRasterBand(1)
labeled_pixels += band.ReadAsArray()
ds = None
return labeled_pixels
def write_geotiff(fname, data, geo_transform, projection):
"""Create a GeoTIFF file with the given data."""
driver = gdal.GetDriverByName('GTiff')
rows, cols = data.shape
dataset = driver.Create(fname, cols, rows, 1, gdal.GDT_Byte)
dataset.SetGeoTransform(geo_transform)
dataset.SetProjection(projection)
band = dataset.GetRasterBand(1)
band.WriteArray(data)
dataset = None # Close the file
raster_data_path = "data/image/2298119ene2016recorteTT.tif"
output_fname = "classification.tiff"
train_data_path = "data/train"
validation_data_path = "data/test"
raster_dataset = gdal.Open(raster_data_path, gdal.GA_ReadOnly)
geo_transform = raster_dataset.GetGeoTransform()
proj = raster_dataset.GetProjectionRef()
bands_data = []
for b in range(1, raster_dataset.RasterCount+1):
band = raster_dataset.GetRasterBand(b)
bands_data.append(band.ReadAsArray())
bands_data = np.dstack(bands_data)
rows, cols, n_bands = bands_data.shape
files = [f for f in os.listdir(train_data_path) if f.endswith('.shp')]
classes = [f.split('.')[0] for f in files]
shapefiles = [os.path.join(train_data_path, f)
for f in files if f.endswith('.shp')]
labeled_pixels = vectors_to_raster(shapefiles, rows, cols, geo_transform, proj)
is_train = np.nonzero(labeled_pixels)
training_labels = labeled_pixels[is_train]
training_samples = bands_data[is_train]
classifier = RandomForestClassifier(n_jobs=4, n_estimators=10)
classifier.fit(training_samples, training_labels)
n_samples = rows*cols
flat_pixels = bands_data.reshape((n_samples, n_bands))
result = classifier.predict(flat_pixels)
classification = result.reshape((rows, cols))
write_geotiff(output_fname, classification, geo_transform, proj)
shapefiles = [os.path.join(validation_data_path, "%s.shp"%c) for c in classes]
verification_pixels = vectors_to_raster(shapefiles, rows, cols, geo_transform, proj)
for_verification = np.nonzero(verification_pixels)
verification_labels = verification_pixels[for_verification]
predicted_labels = classification[for_verification]
print("Confussion matrix:\n%s" %
metrics.confusion_matrix(verification_labels, predicted_labels))
target_names = ['Class %s' % s for s in classes]
print("Classification report:\n%s" %
metrics.classification_report(verification_labels, predicted_labels,
target_names=target_names))
print("Classification accuracy: %f" %
metrics.accuracy_score(verification_labels, predicted_labels))