From a5404ea284a022d8d8cb1a7595e123e825be5bf9 Mon Sep 17 00:00:00 2001 From: Wei Ji <23487320+weiji14@users.noreply.github.com> Date: Fri, 28 Jul 2023 16:20:24 +1200 Subject: [PATCH 1/5] Config for irrigation_scenes and custom SpatioTemporalDataset loader Initial mmsegmentation configuration file for the irrigation_scenes dataset on https://huggingface.co/datasets/ibm-nasa-geospatial/hls_irrigation_scenes. As this is a time-series dataset with data from four months stored in four different folders, a custom SpatioTemporalDataset class (subclassed from GeospatialDataset) and LoadSpatioTemporalImagesFromFile class (subclassed from LoadGeospatialImageFromFile) was created to perform the data loading. Training with only the first 3 months (June, July, August) for now. Also updated the fine-tuning-examples/README.md to mention how to run the irrigation_scenes setup. --- fine-tuning-examples/README.md | 35 +- .../configs/irrigation_scenes_config.py | 334 ++++++++++++++++++ .../geospatial_fm/datasets.py | 54 ++- .../geospatial_fm/geospatial_pipelines.py | 82 ++++- 4 files changed, 484 insertions(+), 21 deletions(-) create mode 100644 fine-tuning-examples/configs/irrigation_scenes_config.py diff --git a/fine-tuning-examples/README.md b/fine-tuning-examples/README.md index 3381f4f..dd21fe5 100644 --- a/fine-tuning-examples/README.md +++ b/fine-tuning-examples/README.md @@ -24,33 +24,30 @@ We provide a simple architecture in [the configuration file](./configs/config.py 3. `conda activate ` 4. `cd fine-tuning-examples` 5. Install torch and torchvision: `pip install torch==1.7.1 torchvision==0.8.2` (May vary with your system. Please check at https://pytorch.org/get-started/locally/) -6. `pip install .` +6. `pip install --editable .` 7. `pip install -U openmim` 8. `mim install mmcv-full==1.5.0` (This may take a while for torch > 1.7.1, as wheel must be built) ### Data -Download the flood detection dataset from [Sen1Floods11](https://github.com/cloudtostreet/Sen1Floods11). +- Download the flood detection dataset from [Sen1Floods11](https://github.com/cloudtostreet/Sen1Floods11). +- Download the fire scars detection dataset from [Hugging Face](https://huggingface.co/datasets/ibm-nasa-geospatial/hls_burn_scars). +- Download the irrigation_scenes dataset from [HuggingFace](https://huggingface.co/datasets/ibm-nasa-geospatial/hls_irrigation_scenes) +## Running the code +1. Complete the configs with your setup specifications. Parts that must be completed are marked with `#TO BE DEFINED BY USER`. They relate to where you downloaded the dataset, pretrained model weights, test set (e.g. regular one or Bolivia out of bag data) and where you are going to save the experiment outputs. -Download the fire scars detection dataset from [Hugging Face](https://huggingface.co/datasets/nasa-impact/hls_burn_scars). +2. + a. With the conda env created above activated and from the `fine-tuning-examples` folder, run either of the commands below: + mim train mmsegmentation --launcher pytorch configs/sen1floods11_config.py + mim train mmsegmentation --launcher pytorch configs/firescars_config.py + mim train mmsegmentation --launcher pytorch configs/irrigation_scenes_config.py -## Running the code -1. Complete the configs with your setup specifications. Parts that must be completed are marked with `#TO BE DEFINED BY USER`. They relate to where you downloaded the dataset, pretrained model weights, test set (e.g. regular one or Bolivia out of bag data) and where you are going to save the experiment outputs. + b. To run testing: + + mim test mmsegmentation configs/sen1floods11_config.py --checkpoint /path/to/best/checkpoint/model.pth --eval "mIoU" + mim test mmsegmentation configs/firescars_config.py --checkpoint /path/to/best/checkpoint/model.pth --eval "mIoU" -2. - a. With the conda env created above activated and from the `fine-tuning-examples` folder, run: - - `mim train mmsegmentation --launcher pytorch configs/sen1floods11_config.py` or - - `mim train mmsegmentation --launcher pytorch configs/firescars_config.py` - - b. To run testing: - - `mim test mmsegmentation configs/sen1floods11_config.py --checkpoint /path/to/best/checkpoint/model.pth --eval "mIoU"` or - - `mim test mmsegmentation configs/firescars_config.py --checkpoint /path/to/best/checkpoint/model.pth --eval "mIoU"` - ## Additional documentation -This project builds on [MMSegmentation](https://mmsegmentation.readthedocs.io/en/0.x/) and [MMCV](https://mmcv.readthedocs.io/en/v1.5.0/). For additional documentation, consult their docs (please note this is currently version 0.30.0 of MMSegmentation and version 1.5.0 of MMCV, not latest). \ No newline at end of file +This project builds on [MMSegmentation](https://mmsegmentation.readthedocs.io/en/0.x/) and [MMCV](https://mmcv.readthedocs.io/en/v1.5.0/). For additional documentation, consult their docs (please note this is currently version 0.30.0 of MMSegmentation and version 1.5.0 of MMCV, not latest). diff --git a/fine-tuning-examples/configs/irrigation_scenes_config.py b/fine-tuning-examples/configs/irrigation_scenes_config.py new file mode 100644 index 0000000..2ef433d --- /dev/null +++ b/fine-tuning-examples/configs/irrigation_scenes_config.py @@ -0,0 +1,334 @@ +import os + +# base options +dist_params = dict(backend="nccl") +log_level = "INFO" +load_from = None +resume_from = None +cudnn_benchmark = True + +custom_imports = dict(imports=["geospatial_fm"]) + + +### Configs +# Data +# TO BE DEFINED BY USER: Data root to firescar downloaded dataset +data_root = "../data/irrigation_scenes/" + +dataset_type = "SpatioTemporalDataset" +num_classes = 1 +num_frames = int(os.getenv("NUM_FRAMES", 3)) +img_size = int(os.getenv("IMG_SIZE", 224)) +num_workers = int(os.getenv("DATA_LOADER_NUM_WORKERS", 2)) +samples_per_gpu = 1 +CLASSES = (0, 1) + +img_norm_cfg = dict( + means=[0.166, 0.166, 0.166, 0.166, 0.166, 0.166], + stds=[0.114, 0.114, 0.114, 0.114, 0.114, 0.114], +) +# Sentinel-2 Bands 2,3,4,8A,11,12 (Blue, Green, Red, NIR_Narrow, SWIR1, SWIR2) +bands = [0, 1, 2, 3, 4, 5] + +tile_size = img_size +orig_nsize = 512 +crop_size = (tile_size, tile_size) + +img_suffix = ".tif" +seg_map_suffix = ".tif" + + +# ignore_index = -1 +# image_nodata = -9999 +# image_nodata_replace = 0 +image_to_float32 = True + +# Model +# TO BE DEFINED BY USER: path to pretrained backbone weights +pretrained_weights_path = "../pretrain_ckpts/Prithvi_100M.pt" +num_layers = 12 +patch_size = 16 +embed_dim = 768 +num_heads = 12 +tubelet_size = 1 + +# TRAINING +# epochs=50 +# eval_epoch_interval = 5 + +# TO BE DEFINED BY USER: Save directory +experiment = "test_1" +project_dir = "../finetune_weights/irrigation_scenes" +work_dir = os.path.join(project_dir, experiment) +save_path = work_dir + +gpu_ids = [0] + +splits = { + "train": data_root + "training_chips/training_data.txt", + "val": data_root + "validation_chips/validation_data.txt", + "test": data_root + "validation_chips/validation_data.txt", +} + +# Pipelines +train_pipeline = [ + dict(type="LoadSpatioTemporalImagesFromFile", to_float32=image_to_float32), + dict( + type="LoadGeospatialAnnotations", + reduce_zero_label=False, + nodata=255, + nodata_replace=2, + ), + dict(type="RandomFlip", prob=0.5), + dict(type="ToTensor", keys=["img", "gt_semantic_seg"]), + dict(type="TorchNormalize", **img_norm_cfg), + dict(type="TorchRandomCrop", crop_size=crop_size), + dict( + type="Reshape", + keys=["img"], + new_shape=(len(bands), num_frames, tile_size, tile_size), + ), + dict(type="Reshape", keys=["gt_semantic_seg"], new_shape=(1, tile_size, tile_size)), + dict(type="CastTensor", keys=["gt_semantic_seg"], new_type="torch.LongTensor"), + dict(type="Collect", keys=["img", "gt_semantic_seg"]), +] + +val_pipeline = [ + dict(type="LoadSpatioTemporalImagesFromFile", to_float32=image_to_float32), + dict( + type="LoadGeospatialAnnotations", + reduce_zero_label=False, + nodata=255, + nodata_replace=2, + ), + dict(type="ToTensor", keys=["img", "gt_semantic_seg"]), + dict(type="TorchNormalize", **img_norm_cfg), + dict(type="TorchRandomCrop", crop_size=crop_size), + dict( + type="Reshape", + keys=["img"], + new_shape=(len(bands), num_frames, tile_size, tile_size), + ), + dict(type="Reshape", keys=["gt_semantic_seg"], new_shape=(1, tile_size, tile_size)), + dict(type="CastTensor", keys=["gt_semantic_seg"], new_type="torch.LongTensor"), + dict( + type="Collect", + keys=["img", "gt_semantic_seg"], + meta_keys=[ + "img_info", + "ann_info", + "seg_fields", + "img_prefix", + "seg_prefix", + "filename", + "ori_filename", + "img", + "img_shape", + "ori_shape", + "pad_shape", + "scale_factor", + "img_norm_cfg", + "gt_semantic_seg", + ], + ), +] + +test_pipeline = [ + dict(type="LoadSpatioTemporalImagesFromFile", to_float32=image_to_float32), + dict(type="ToTensor", keys=["img"]), + dict(type="TorchNormalize", **img_norm_cfg), + dict( + type="Reshape", + keys=["img"], + new_shape=(len(bands), num_frames, -1, -1), + look_up={"2": 1, "3": 2}, + ), + dict(type="CastTensor", keys=["img"], new_type="torch.FloatTensor"), + dict( + type="CollectTestList", + keys=["img"], + meta_keys=[ + "img_info", + "seg_fields", + "img_prefix", + "seg_prefix", + "filename", + "ori_filename", + "img", + "img_shape", + "ori_shape", + "pad_shape", + "scale_factor", + "img_norm_cfg", + ], + ), +] + +CLASSES = (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16) + +data = dict( + samples_per_gpu=samples_per_gpu, + workers_per_gpu=4, + train=dict( + type=dataset_type, + # CLASSES=CLASSES, + reduce_zero_label=True, + data_root=data_root, + img_dir="month1", + ann_dir="masks", + pipeline=train_pipeline, + img_suffix=img_suffix, + seg_map_suffix=seg_map_suffix, + # split=splits["train"], + ), + val=dict( + type=dataset_type, + # CLASSES=CLASSES, + reduce_zero_label=True, + data_root=data_root, + img_dir="month1", + ann_dir="masks", + pipeline=val_pipeline, + img_suffix=img_suffix, + seg_map_suffix=seg_map_suffix, + # split=splits["val"], + ), + test=dict( + type=dataset_type, + # CLASSES=CLASSES, + reduce_zero_label=True, + data_root=data_root, + img_dir="month1", + ann_dir="masks", + pipeline=test_pipeline, + img_suffix=img_suffix, + seg_map_suffix=seg_map_suffix, + # split=splits["test"], + ), +) +# gt_seg_map_loader_cfg=dict(nodata=-1, nodata_replace=2))) + +# AdamW optimizer, no weight decay for position embedding & layer norm in backbone +optimizer = dict(type="Adam", lr=1.5e-5, betas=(0.9, 0.999), weight_decay=0.05) +optimizer_config = dict(grad_clip=None) +lr_config = dict( + policy="poly", + warmup="linear", + warmup_iters=1500, + warmup_ratio=1e-6, + power=1.0, + min_lr=0.0, + by_epoch=False, +) + +log_config = dict( + interval=20, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type="TensorboardLoggerHook", by_epoch=False), + ], +) + +checkpoint_config = dict(by_epoch=True, interval=10, out_dir=save_path) + +evaluation = dict( + interval=1180, metric="mIoU", pre_eval=True, save_best="mIoU", by_epoch=False +) +reduce_train_set = dict(reduce_train_set=False) +reduce_factor = dict(reduce_factor=1) + +optimizer_config = dict(grad_clip=None) + +runner = dict(type="IterBasedRunner", max_iters=10000) +workflow = [("train", 1)] + +norm_cfg = dict(type="BN", requires_grad=True) + +loss_weights_multi = [ + 1.5652886, + 0.46067129, + 0.59387921, + 0.48431193, + 0.65555127, + 0.73865282, + 0.77616475, + 3.46336277, + 1.01650963, + 1.87640752, + 1.52960976, + 1.49788817, + 57.55048277, + 1.97697006, + 2.34793961, + 0.83456613, +] + +# loss_func = dict(type='DiceLoss', use_sigmoid=False, loss_weight=1, class_weight=loss_weights_multi) +loss_func = dict( + type="CrossEntropyLoss", + use_sigmoid=False, + class_weight=loss_weights_multi, + avg_non_ignore=True, +) + + +output_embed_dim = embed_dim * num_frames + +model = dict( + type="TemporalEncoderDecoder", + frozen_backbone=False, + backbone=dict( + type="TemporalViTEncoder", + pretrained=pretrained_weights_path, + img_size=img_size, + patch_size=patch_size, + num_frames=num_frames, + tubelet_size=1, + in_chans=len(bands), + embed_dim=embed_dim, + depth=num_layers, + num_heads=num_heads, + mlp_ratio=4.0, + norm_pix_loss=False, + ), + neck=dict( + type="ConvTransformerTokensToEmbeddingNeck", + embed_dim=embed_dim * num_frames, + output_embed_dim=output_embed_dim, + drop_cls_token=True, + Hp=img_size // patch_size, + Wp=img_size // patch_size, + ), + decode_head=dict( + num_classes=len(loss_weights_multi), + in_channels=output_embed_dim, + type="FCNHead", + in_index=-1, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=loss_func, + ), + auxiliary_head=dict( + num_classes=len(loss_weights_multi), + in_channels=output_embed_dim, + type="FCNHead", + in_index=-1, + channels=256, + num_convs=2, + concat_input=False, + dropout_ratio=0.1, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=loss_func, + ), + train_cfg=dict(), + test_cfg=dict( + mode="slide", + stride=(int(tile_size / 2), int(tile_size / 2)), + crop_size=(tile_size, tile_size), + ), +) diff --git a/fine-tuning-examples/geospatial_fm/datasets.py b/fine-tuning-examples/geospatial_fm/datasets.py index 76a63eb..c3b1173 100644 --- a/fine-tuning-examples/geospatial_fm/datasets.py +++ b/fine-tuning-examples/geospatial_fm/datasets.py @@ -22,4 +22,56 @@ def __init__(self, CLASSES=(0, 1), PALETTE=None, **kwargs): # ignore_index=2, **kwargs) - self.gt_seg_map_loader = LoadGeospatialAnnotations(reduce_zero_label=reduce_zero_label, **gt_seg_map_loader_cfg) \ No newline at end of file + self.gt_seg_map_loader = LoadGeospatialAnnotations( + reduce_zero_label=reduce_zero_label, **gt_seg_map_loader_cfg + ) + + +@DATASETS.register_module() +class SpatioTemporalDataset(GeospatialDataset): + """ + Time-series dataset for irrigation data at + https://huggingface.co/datasets/ibm-nasa-geospatial/hls_irrigation_scenes + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def load_annotations(self, img_dir, img_suffix, ann_dir, seg_map_suffix, split): + """Load annotation from directory. + + Args: + img_dir (str): Path to image directory + img_suffix (str): Suffix of images. + ann_dir (str|None): Path to annotation directory. + seg_map_suffix (str|None): Suffix of segmentation maps. + split (str|None): Split txt file. If split is specified, only file + with suffix in the splits will be loaded. Otherwise, all images + in img_dir/ann_dir will be loaded. Default: None + + Returns: + list[dict]: All image info of dataset. + """ + + img_infos = [] + if split is not None: + raise NotImplementedError + else: + for img in self.file_client.list_dir_or_file( + dir_path=img_dir, list_dir=False, suffix=img_suffix, recursive=True + ): + # Get 'T10SFG_chip22.tif' basename from 'scene_m01_T10SFG_chip22.tif' + basename = "_".join(img.split(sep="_")[2:]) + img_info = dict( + filename_t1=f"scene_m01_{basename}", + filename_t2=f"scene_m02_{basename}", + filename_t3=f"scene_m03_{basename}", + filename_t4=f"scene_m04_{basename}", + ) + if ann_dir is not None: + seg_map = f"mask_{basename.replace(img_suffix, seg_map_suffix)}" + img_info["ann"] = dict(seg_map=seg_map) + img_infos.append(img_info) + img_infos = sorted(img_infos, key=lambda x: x["filename_t1"]) + + return img_infos diff --git a/fine-tuning-examples/geospatial_fm/geospatial_pipelines.py b/fine-tuning-examples/geospatial_fm/geospatial_pipelines.py index 25a97e1..c87ba27 100644 --- a/fine-tuning-examples/geospatial_fm/geospatial_pipelines.py +++ b/fine-tuning-examples/geospatial_fm/geospatial_pipelines.py @@ -287,12 +287,92 @@ def __repr__(self): return repr_str +@PIPELINES.register_module() +class LoadSpatioTemporalImagesFromFile(LoadGeospatialImageFromFile): + """ + Load a time-series dataset from multiple files. + + Currently hardcoded to assume that GeoTIFF files are structured in four + different 'monthX' folders like so: + + - month1/ + - scene_m01_XXXXXX_chip01.tif + - scene_m01_XXXXXX_chip02.tif + - month2/ + - scene_m02_XXXXXX_chip01.tif + - scene_m02_XXXXXX_chip02.tif + - month3/ + - scene_m03_XXXXXX_chip01.tif + - scene_m03_XXXXXX_chip02.tif + - month4/ + - scene_m04_XXXXXX_chip01.tif + - scene_m04_XXXXXX_chip02.tif + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def __call__(self, results): + """ + Call functions to load image and get image meta information. + + Args: + results (dict): Result dict from :obj:`mmseg.CustomDataset`. + + Returns: + dict: The dict contains loaded image and meta information. + """ + if results.get("img_prefix") is not None: + img_prefix = results["img_prefix"] + assert img_prefix.endswith("month1") + filenames = [ + osp.join(img_prefix, results["img_info"]["filename_t1"]), # June + osp.join( + img_prefix.replace("month1", "month2"), # July + results["img_info"]["filename_t2"], + ), + osp.join( + img_prefix.replace("month1", "month3"), # August + results["img_info"]["filename_t3"], + ), + # osp.join( + # img_prefix.replace("month1", "month4"), # September + # results["img_info"]["filename_t4"], + # ), + ] + else: + raise NotImplementedError + + img = np.stack(arrays=list(map(open_tiff, filenames)), axis=0) + # assert img.shape == (1, 6, 512, 512) # Time, Channels, Height, Width + if self.to_float32: + img = img.astype(dtype=np.float32) + if self.nodata is not None: + img = np.where(img == self.nodata, self.nodata_replace, img) + + results["filename"] = filenames[0] + results["ori_filename"] = results["img_info"]["filename_t1"] + results["img"] = img + results["img_shape"] = img.shape + results["ori_shape"] = img.shape + # Set initial values for default meta_keys + results["pad_shape"] = img.shape + results["scale_factor"] = 1.0 + results["flip"] = False + num_channels = 1 if len(img.shape) < 3 else img.shape[0] + results["img_norm_cfg"] = dict( + mean=np.zeros(num_channels, dtype=np.float32), + std=np.ones(num_channels, dtype=np.float32), + to_rgb=False, + ) + return results + + @PIPELINES.register_module() class LoadGeospatialAnnotations(object): """Load annotations for semantic segmentation. Args: - to_uint8 (bool): Whether to convert the loaded label to a uint8 reduce_zero_label (bool): Whether reduce all label value by 1. Usually used for datasets where 0 is background label. Default: False. From bdc65dab1f028216e7c1c055fd152ab0fd16955c Mon Sep 17 00:00:00 2001 From: Wei Ji <23487320+weiji14@users.noreply.github.com> Date: Wed, 2 Aug 2023 15:34:09 +1200 Subject: [PATCH 2/5] Move filepaths up a directory Config folder has moved from the fine-tuning-examples folder up to the root directory in 464e9f2f10a7563a5f12ca5d761e4d22f9d2c87c/#8, so no need to do `../` anymore. --- configs/irrigation_scenes_config.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/configs/irrigation_scenes_config.py b/configs/irrigation_scenes_config.py index 2ef433d..c2ce7e8 100644 --- a/configs/irrigation_scenes_config.py +++ b/configs/irrigation_scenes_config.py @@ -13,7 +13,7 @@ ### Configs # Data # TO BE DEFINED BY USER: Data root to firescar downloaded dataset -data_root = "../data/irrigation_scenes/" +data_root = "data/irrigation_scenes/" dataset_type = "SpatioTemporalDataset" num_classes = 1 @@ -45,7 +45,7 @@ # Model # TO BE DEFINED BY USER: path to pretrained backbone weights -pretrained_weights_path = "../pretrain_ckpts/Prithvi_100M.pt" +pretrained_weights_path = "pretrain_ckpts/Prithvi_100M.pt" num_layers = 12 patch_size = 16 embed_dim = 768 @@ -58,7 +58,7 @@ # TO BE DEFINED BY USER: Save directory experiment = "test_1" -project_dir = "../finetune_weights/irrigation_scenes" +project_dir = "finetune_weights/irrigation_scenes" work_dir = os.path.join(project_dir, experiment) save_path = work_dir From 8339e35ae8eedd7201465e9a28e7f456d03da7f7 Mon Sep 17 00:00:00 2001 From: Wei Ji <23487320+weiji14@users.noreply.github.com> Date: Wed, 2 Aug 2023 15:41:18 +1200 Subject: [PATCH 3/5] Handle channel first and channel last logic The old open_tiff function used rasterio.open which stacked the bands/channels in the first position (CHW), but moving to tiffile.imread in 86e9ba97f23fc4b6aa3797f6d3670bd20651f4c1 changed the stacking to the last position (HWC). Need to use channel last (NHWC) for the RandomFlip function since it is somewhat hardcoded to flip on axis 1, and then use TorchPermute to change to channel first (NCHW) so that TorchNormalize (using torchvision which expects BCHW) works. --- configs/irrigation_scenes_config.py | 24 +++++++++++++++++++++--- geospatial_fm/geospatial_pipelines.py | 6 ++++-- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/configs/irrigation_scenes_config.py b/configs/irrigation_scenes_config.py index c2ce7e8..004e109 100644 --- a/configs/irrigation_scenes_config.py +++ b/configs/irrigation_scenes_config.py @@ -72,15 +72,24 @@ # Pipelines train_pipeline = [ - dict(type="LoadSpatioTemporalImagesFromFile", to_float32=image_to_float32), + dict( + type="LoadSpatioTemporalImagesFromFile", + to_float32=image_to_float32, + channels_last=True, + ), dict( type="LoadGeospatialAnnotations", reduce_zero_label=False, nodata=255, nodata_replace=2, ), - dict(type="RandomFlip", prob=0.5), + dict(type="RandomFlip", prob=0.5), # flip on axis 1, assume channel last NHWC dict(type="ToTensor", keys=["img", "gt_semantic_seg"]), + dict( + type="TorchPermute", + keys=["img"], + order=(0, 3, 1, 2), # channel last to channels first NCHW + ), dict(type="TorchNormalize", **img_norm_cfg), dict(type="TorchRandomCrop", crop_size=crop_size), dict( @@ -94,7 +103,11 @@ ] val_pipeline = [ - dict(type="LoadSpatioTemporalImagesFromFile", to_float32=image_to_float32), + dict( + type="LoadSpatioTemporalImagesFromFile", + to_float32=image_to_float32, + channels_last=True, + ), dict( type="LoadGeospatialAnnotations", reduce_zero_label=False, @@ -102,6 +115,11 @@ nodata_replace=2, ), dict(type="ToTensor", keys=["img", "gt_semantic_seg"]), + dict( + type="TorchPermute", + keys=["img"], + order=(0, 3, 1, 2), # channel last to channels first NCHW + ), dict(type="TorchNormalize", **img_norm_cfg), dict(type="TorchRandomCrop", crop_size=crop_size), dict( diff --git a/geospatial_fm/geospatial_pipelines.py b/geospatial_fm/geospatial_pipelines.py index 3f652ca..df2dd6e 100644 --- a/geospatial_fm/geospatial_pipelines.py +++ b/geospatial_fm/geospatial_pipelines.py @@ -3,7 +3,6 @@ """ import numpy as np import os.path as osp -import rasterio import torch import torchvision.transforms.functional as F @@ -373,7 +372,10 @@ def __call__(self, results): raise NotImplementedError img = np.stack(arrays=list(map(open_tiff, filenames)), axis=0) - # assert img.shape == (1, 6, 512, 512) # Time, Channels, Height, Width + assert img.shape == (3, 512, 512, 6) # Time, Height, Width, Channels + if not self.channels_last: + img = np.transpose(a=img, axes=(0, 2, 3, 1)) + assert img.shape == (3, 6, 512, 512) # Time, Channels, Height, Width if self.to_float32: img = img.astype(dtype=np.float32) if self.nodata is not None: From 4d9c8972ce98dfac63a372e7979f68ec1f7b7e2e Mon Sep 17 00:00:00 2001 From: Wei Ji <23487320+weiji14@users.noreply.github.com> Date: Wed, 2 Aug 2023 15:49:43 +1200 Subject: [PATCH 4/5] Patch LoadGeospatialAnnotations to insert ann_info value from img_info Hacky way to avoid `KeyError: 'ann_info'` by setting `results["ann_info"]["seg_map"]` to `results["img_info"]["ann"]["seg_map"]`. Also edited docstring of the LoadGeospatialAnnotations class slightly. Cherry-picked from https://github.com/NASA-IMPACT/hls-foundation/commit/e5fb7abbf7c009149665b1224423a26a892467c9. --- geospatial_fm/geospatial_pipelines.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/geospatial_fm/geospatial_pipelines.py b/geospatial_fm/geospatial_pipelines.py index df2dd6e..ca55e6b 100644 --- a/geospatial_fm/geospatial_pipelines.py +++ b/geospatial_fm/geospatial_pipelines.py @@ -408,9 +408,8 @@ class LoadGeospatialAnnotations(object): Usually used for datasets where 0 is background label. Default: False. nodata (float/int): no data value to substitute to nodata_replace - nodata_replace (float/int): value to use to replace no data - - + nodata_replace (float/int): The value used to replace nodata values + with. Default: -1. """ def __init__( @@ -424,7 +423,10 @@ def __init__( self.nodata_replace = nodata_replace def __call__(self, results): - if results.get("seg_prefix", None) is not None: + if results.get("ann_info", {}).get("seg_map") is None: + results["ann_info"] = {"seg_map": results["img_info"]["ann"]["seg_map"]} + + if results.get("seg_prefix") is not None: filename = osp.join(results["seg_prefix"], results["ann_info"]["seg_map"]) else: filename = results["ann_info"]["seg_map"] From bf1cbfa9110f899dab0b74c8f0cb7c13e41f1a96 Mon Sep 17 00:00:00 2001 From: Wei Ji <23487320+weiji14@users.noreply.github.com> Date: Thu, 3 Aug 2023 16:44:03 +1200 Subject: [PATCH 5/5] Update test_pipeline config also Making sure that the test_pipeline is consistent with the training and validation pipeline. --- configs/irrigation_scenes_config.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/configs/irrigation_scenes_config.py b/configs/irrigation_scenes_config.py index 004e109..ae16488 100644 --- a/configs/irrigation_scenes_config.py +++ b/configs/irrigation_scenes_config.py @@ -152,14 +152,23 @@ ] test_pipeline = [ - dict(type="LoadSpatioTemporalImagesFromFile", to_float32=image_to_float32), + dict( + type="LoadSpatioTemporalImagesFromFile", + to_float32=image_to_float32, + channels_last=True, + ), dict(type="ToTensor", keys=["img"]), + dict( + type="TorchPermute", + keys=["img"], + order=(0, 3, 1, 2), # channel last to channels first NCHW + ), dict(type="TorchNormalize", **img_norm_cfg), + dict(type="TorchRandomCrop", crop_size=crop_size), # TODO remove hardcoded 224 size dict( type="Reshape", keys=["img"], - new_shape=(len(bands), num_frames, -1, -1), - look_up={"2": 1, "3": 2}, + new_shape=(len(bands), num_frames, tile_size, tile_size), ), dict(type="CastTensor", keys=["img"], new_type="torch.FloatTensor"), dict(