-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathgdal_classes.py
71 lines (56 loc) · 2.57 KB
/
gdal_classes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
import gdal
import numpy as np
class GDALRaster(object):
def __init__(self, path):
self.obj = gdal.Open(path)
self.left, self.cell_size, _, self.top, *_ = self.obj.GetGeoTransform()
self.shape = np.array([self.obj.RasterXSize, self.obj.RasterYSize])
self.x_size, self.y_size = self.shape * self.cell_size
class RasterSelection(object):
def __init__(self, raster, x_offset=0, y_offset=0, x_max=0, y_max=0):
# Set raster and pixel coordinates
if not any((x_offset, y_offset, x_max, y_max)): # whole raster
x_max, y_max = raster.x_size, raster.y_size
self.left = raster.left + x_offset
self.top = raster.top - y_offset
self.cell_size = raster.cell_size
self.x_pixels = int(x_max / raster.cell_size)
self.y_pixels = int(y_max / raster.cell_size)
self.bounds = list(map(lambda x: int(x / self.cell_size), (x_offset, y_offset, x_max, y_max)))
# Get driver from template raster
self.driver = raster.obj.GetDriver()
# Fetch the array
band = raster.obj.GetRasterBand(1) # 1 band raster
self.array = band.ReadAsArray(*self.bounds) # meters -> pixels
def write(self, out_path):
out_raster = self.driver.Create(out_path, self.x_pixels, self.y_pixels, 1, gdal.GDT_Int32)
out_raster.SetGeoTransform((self.left, self.cell_size, 0, self.top, 0, -self.cell_size))
out_band = out_raster.GetRasterBand(1)
out_band.SetNoDataValue(0)
out_band.WriteArray(self.array, 0, 0)
out_band.FlushCache()
def make_tiles(x_size, y_size, tile_size):
x = list(range(0, int(x_size), tile_size)) + [int(x_size)]
y = list(range(0, int(y_size), tile_size)) + [int(y_size)]
for i in range(len(x) - 1):
for j in range(len(y) - 1):
yield (x[i], y[j], x[i + 1] - x[i], y[j + 1] - y[j])
# Path to the raster file
cdl_path = os.path.join("GIS", "nass_de", "cdl_30m_r_de_2016_utm18.tif")
# Initialize the CDL raster
cdl_raster = GDALRaster(cdl_path)
# Initialize a set of tiles to break up the raster
tile_size = 250000 # 25 km
tiles = make_tiles(cdl_raster.x_size, cdl_raster.y_size, tile_size)
# Save tiles
out_tile = os.path.join("GIS", "tiles", "tile_{}.tif")
for counter, tile in enumerate(tiles):
print(counter)
sample = RasterSelection(cdl_raster, *tile)
sample.write(out_tile.format(counter))
# Make a raster of corn pixels
corn_raster = os.path.join("GIS", "tiles", "delaware_corn.tif")
sample = RasterSelection(cdl_raster)
sample.array[sample.array != 1] = 0
sample.write(corn_raster)