Skip to content

Commit

Permalink
[Feature] Add slider in cd (PaddlePaddle#92)
Browse files Browse the repository at this point in the history
* [Feature] Add cd slider

* [Fix] Tuple instead of list

* [Fix] Spell repair

* [Fix] Spell repair
  • Loading branch information
geoyee authored Jul 13, 2022
1 parent ec9d58b commit 05e1eee
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 3 deletions.
83 changes: 83 additions & 0 deletions paddlers/tasks/change_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import math
import os
import os.path as osp
from collections import OrderedDict
from operator import attrgetter
Expand Down Expand Up @@ -545,6 +546,88 @@ def predict(self, img_file, transforms=None):
}
return prediction

def slider_predict(self, img_file, save_dir, block_size, overlap=36, transforms=None):
"""
Do inference.
Args:
Args:
img_file(List[str]):
List of image paths.
save_dir(str):
Directory that contains saved geotiff file.
block_size(List[int] or Tuple[int], int):
The size of block.
overlap(List[int] or Tuple[int], int):
The overlap between two blocks. Defaults to 36.
transforms(paddlers.transforms.Compose or None, optional):
Transforms for inputs. If None, the transforms for evaluation process will be used. Defaults to None.
"""
try:
from osgeo import gdal
except:
import gdal

if len(img_file) != 2:
raise ValueError("`img_file` must be a list of length 2.")
if isinstance(block_size, int):
block_size = (block_size, block_size)
elif isinstance(block_size, (tuple, list)) and len(block_size) == 2:
block_size = tuple(block_size)
else:
raise ValueError("`block_size` must be a tuple/list of length 2 or an integer.")
if isinstance(overlap, int):
overlap = (overlap, overlap)
elif isinstance(overlap, (tuple, list)) and len(overlap) == 2:
overlap = tuple(overlap)
else:
raise ValueError("`overlap` must be a tuple/list of length 2 or an integer.")

src1_data = gdal.Open(img_file[0])
src2_data = gdal.Open(img_file[1])
width = src1_data.RasterXSize
height = src1_data.RasterYSize
bands = src1_data.RasterCount

driver = gdal.GetDriverByName("GTiff")
file_name = osp.splitext(osp.normpath(img_file[0]).split(os.sep)[-1])[0] + ".tif"
if not osp.exists(save_dir):
os.makedirs(save_dir)
save_file = osp.join(save_dir, file_name)
dst_data = driver.Create(save_file, width, height, 1, gdal.GDT_Byte)
dst_data.SetGeoTransform(src1_data.GetGeoTransform())
dst_data.SetProjection(src1_data.GetProjection())
band = dst_data.GetRasterBand(1)
band.WriteArray(255 * np.ones((height, width), dtype="uint8"))

step = np.array(block_size) - np.array(overlap)
for yoff in range(0, height, step[1]):
for xoff in range(0, width, step[0]):
xsize, ysize = block_size
if xoff + xsize > width:
xsize = int(width - xoff)
if yoff + ysize > height:
ysize = int(height - yoff)
im1 = src1_data.ReadAsArray(int(xoff), int(yoff), xsize, ysize).transpose((1, 2, 0))
im2 = src2_data.ReadAsArray(int(xoff), int(yoff), xsize, ysize).transpose((1, 2, 0))
# fill
h, w = im1.shape[:2]
im1_fill = np.zeros((block_size[1], block_size[0], bands), dtype=im1.dtype)
im2_fill = im1_fill.copy()
im1_fill[:h, :w, :] = im1
im2_fill[:h, :w, :] = im2
im_fill = (im1_fill, im2_fill)
# predict
pred = self.predict(im_fill, transforms)["label_map"].astype("uint8")
# overlap
rd_block = band.ReadAsArray(int(xoff), int(yoff), xsize, ysize)
mask = (rd_block == pred[:h, :w]) | (rd_block == 255)
temp = pred[:h, :w].copy()
temp[mask == False] = 0
band.WriteArray(temp, int(xoff), int(yoff))
dst_data.FlushCache()
dst_data = None
print("GeoTiff saved in {}.".format(save_file))

def _preprocess(self, images, transforms, to_tensor=True):
arrange_transforms(
model_type=self.model_type, transforms=transforms, mode='test')
Expand Down
6 changes: 3 additions & 3 deletions paddlers/tasks/segmenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ def slider_predict(self, img_file, save_dir, block_size, overlap=36, transforms=
img_file(str):
Image path.
save_dir(str):
Folder of geotiff saved.
Directory that contains saved geotiff file.
block_size(List[int] or Tuple[int], int):
The size of block.
overlap(List[int] or Tuple[int], int):
Expand All @@ -545,13 +545,13 @@ def slider_predict(self, img_file, save_dir, block_size, overlap=36, transforms=
elif isinstance(block_size, (tuple, list)) and len(block_size) == 2:
block_size = tuple(block_size)
else:
raise ValueError("`block_size` must be a tuple/list of length 2 or a integer.")
raise ValueError("`block_size` must be a tuple/list of length 2 or an integer.")
if isinstance(overlap, int):
overlap = (overlap, overlap)
elif isinstance(overlap, (tuple, list)) and len(overlap) == 2:
overlap = tuple(overlap)
else:
raise ValueError("`overlap` must be a tuple/list of length 2 or a integer.")
raise ValueError("`overlap` must be a tuple/list of length 2 or an integer.")

src_data = gdal.Open(img_file)
width = src_data.RasterXSize
Expand Down

0 comments on commit 05e1eee

Please sign in to comment.