2. Mound Classification in DTMs from Mars Arabia Terra

2.1. Installation guide

Clone the repository: https://github.com/epn-ml/Workshop-GMAP
Create the conda environment with: conda env create -f env.yml
Download the data: https://figshare.com/articles/dataset/Mound_Segmentation_Data/21180661
Download the saved models: https://figshare.com/articles/software/saved_models_zip/21180445
%pylab inline
#Geo-libraries
from osgeo import gdal
import rasterio
import rasterio as rio
from rasterio.plot import show
import rasterio.plot as rplot
import earthpy.spatial as es
from rasterio.plot import show
from rasterio import mask
import osgeo
from osgeo import osr
import richdem as rd
#import gdal
#import elevation 
import geopandas as gpd
from rasterio.windows import Window


#visualisation
import pandas as pd
from scipy import ndimage
import matplotlib.pyplot as plt
from PIL import Image

#matrix computation/ML
import numpy as np
import segmentation_models_pytorch as smp
import torch
import subprocess
import albumentations as albu
from torch.optim.lr_scheduler import ExponentialLR, MultiStepLR
from sklearn.metrics import confusion_matrix

#for normalising
from sklearn.preprocessing import MinMaxScaler

#utils
import os
from pathlib import Path
import shutil
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as BaseDataset
from torch import FloatTensor
import warnings
import copy
warnings.simplefilter(action = 'ignore', category = FutureWarning)
Populating the interactive namespace from numpy and matplotlib
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
/var/folders/x3/2bzh843n0tv469w6l6sd8sq00000gn/T/ipykernel_23922/3907882804.py in <module>
      1 get_ipython().run_line_magic('pylab', 'inline')
      2 #Geo-libraries
----> 3 from osgeo import gdal
      4 import rasterio
      5 import rasterio as rio

ModuleNotFoundError: No module named 'osgeo'

2.2. Load Data

dem_path = 'training_site/DEM/Low-res_DEM/'
file_path = dem_path + "Firsoff_HiRISE_5_low_res.tif-adj.tif"
gdal_data = gdal.Open(file_path)

#load raster
gdal_band = gdal_data.GetRasterBand(1)
nodataval = gdal_band.GetNoDataValue()

# convert to a numpy array
data_array=[]
data_array = gdal_band.ReadAsArray().astype(np.float64)
# data_array

# replace missing values if necessary
if np.any(data_array == nodataval):
    data_array[data_array == nodataval] = np.nan
plt.figure(figsize = (7, 12))
plt.imshow(data_array)
plt.colorbar()
plt.show()
../../../_images/Workshop-Mound-Segmentation_5_0.png

2.2.1. Plot out data with Matplotlib’s ‘contour’

fig = plt.figure(figsize = (5, 5))
ax = fig.add_subplot(111)
plt.contour(data_array)
plt.title("Elevation Contours of Firsoff HiRISE 5 (low res)")
cbar = plt.colorbar()
plt.gca().set_aspect('equal')
plt.show()
../../../_images/Workshop-Mound-Segmentation_7_0.png
#Plot image in grayscale
plt.figure(figsize=(5,5))
plt.imshow(data_array, cmap = 'gray')
plt.colorbar()
plt.title("Firsoff HiRISE 5 (low res) gray scale")
plt.show()


#Plot a contourf graph
plt.figure(figsize=(5,5))
plt.title("Elevation Contourf of Firsoff HiRISE 5 (low res)")
plt.contourf(data_array)
plt.colorbar()
plt.gca().set_aspect('equal')
plt.gca().invert_yaxis()
plt.show()
../../../_images/Workshop-Mound-Segmentation_8_0.png ../../../_images/Workshop-Mound-Segmentation_8_1.png
dem = gdal.Open(file_path)
slp = gdal.DEMProcessing(dem_path + "slope.tif", dem, "slope", computeEdges=True)
slparray = slp.GetRasterBand(1).ReadAsArray()

# Don't use matplotlib to visualize the slope
# Use richdem as shown below
plt.figure()
plt.imshow(slparray)
plt.colorbar()
plt.show()
../../../_images/Workshop-Mound-Segmentation_10_0.png

2.2.2. Creating a Hillshade

This is a technique where a lighting effect is added to a map based on elevation variations within the landscape.

hillshade = gdal.DEMProcessing(dem_path + "hillshade1.tif", dem, "hillshade")
hillshade_array = hillshade.GetRasterBand(1).ReadAsArray()

plt.figure()
plt.imshow(hillshade_array)
plt.colorbar()
plt.show()
../../../_images/Workshop-Mound-Segmentation_12_0.png

2.2.3. Visualise slope

This identifies the steepness at each cell of a raster surface.

rd_dem = rd.LoadGDAL(file_path)
rd_slope = rd.TerrainAttribute(rd_dem, attrib="slope_degrees")
rd.rdShow(rd_slope, axes=False, cmap='magma', figsize=(8, 5.5))
plt.show()

# To save the result
# rd.SaveGDAL("rd_slope.tif", rd_slope)
A Slope calculation (degrees)
C Horn, B.K.P., 1981. Hill shading and the reflectance map. Proceedings of the IEEE 69, 14–47. doi:10.1109/PROC.1981.11918



[                                                  ] (1% - 0.0s - 1 threads)
[=                                                 ] (2% - 0.0s - 1 threads)
[=                                                 ] (3% - 0.0s - 1 threads)
[==                                                ] (4% - 0.0s - 1 threads)
[==                                                ] (5% - 0.0s - 1 threads)
[===                                               ] (6% - 0.0s - 1 threads)
[===                                               ] (7% - 0.0s - 1 threads)
[====                                              ] (8% - 0.0s - 1 threads)
[====                                              ] (9% - 0.0s - 1 threads)
[=====                                             ] (10% - 0.0s - 1 threads)
[=====                                             ] (11% - 0.0s - 1 threads)
[======                                            ] (12% - 0.0s - 1 threads)
[======                                            ] (13% - 0.0s - 1 threads)
[=======                                           ] (14% - 0.0s - 1 threads)
[=======                                           ] (15% - 0.0s - 1 threads)
[========                                          ] (16% - 0.0s - 1 threads)
[========                                          ] (17% - 0.0s - 1 threads)
[=========                                         ] (18% - 0.0s - 1 threads)
[=========                                         ] (19% - 0.0s - 1 threads)
[==========                                        ] (20% - 0.0s - 1 threads)
[==========                                        ] (21% - 0.0s - 1 threads)
[===========                                       ] (22% - 0.0s - 1 threads)
[===========                                       ] (23% - 0.0s - 1 threads)
[============                                      ] (24% - 0.0s - 1 threads)
[============                                      ] (25% - 0.0s - 1 threads)
[=============                                     ] (26% - 0.0s - 1 threads)
[=============                                     ] (27% - 0.0s - 1 threads)
[==============                                    ] (28% - 0.0s - 1 threads)
[==============                                    ] (29% - 0.0s - 1 threads)
[===============                                   ] (30% - 0.0s - 1 threads)
[===============                                   ] (31% - 0.0s - 1 threads)
[================                                  ] (32% - 0.0s - 1 threads)
[================                                  ] (33% - 0.0s - 1 threads)
[=================                                 ] (34% - 0.0s - 1 threads)
[=================                                 ] (35% - 0.0s - 1 threads)
[==================                                ] (36% - 0.0s - 1 threads)
[==================                                ] (37% - 0.0s - 1 threads)
[===================                               ] (38% - 0.0s - 1 threads)
[===================                               ] (39% - 0.0s - 1 threads)
[====================                              ] (40% - 0.0s - 1 threads)
[====================                              ] (41% - 0.0s - 1 threads)
[=====================                             ] (42% - 0.0s - 1 threads)
[=====================                             ] (43% - 0.0s - 1 threads)
[======================                            ] (44% - 0.0s - 1 threads)
[======================                            ] (45% - 0.0s - 1 threads)
[=======================                           ] (46% - 0.0s - 1 threads)
[=======================                           ] (47% - 0.0s - 1 threads)
[========================                          ] (48% - 0.0s - 1 threads)
[========================                          ] (49% - 0.0s - 1 threads)
[=========================                         ] (50% - 0.0s - 1 threads)
[=========================                         ] (51% - 0.0s - 1 threads)
[==========================                        ] (52% - 0.0s - 1 threads)
[==========================                        ] (53% - 0.0s - 1 threads)
[===========================                       ] (54% - 0.0s - 1 threads)
[===========================                       ] (55% - 0.0s - 1 threads)
[============================                      ] (56% - 0.0s - 1 threads)
[============================                      ] (57% - 0.0s - 1 threads)
[=============================                     ] (58% - 0.0s - 1 threads)
[=============================                     ] (59% - 0.0s - 1 threads)
[==============================                    ] (60% - 0.0s - 1 threads)
[==============================                    ] (61% - 0.0s - 1 threads)
[===============================                   ] (62% - 0.0s - 1 threads)
[===============================                   ] (63% - 0.0s - 1 threads)
[================================                  ] (64% - 0.0s - 1 threads)
[================================                  ] (65% - 0.0s - 1 threads)
[=================================                 ] (66% - 0.0s - 1 threads)
[=================================                 ] (67% - 0.0s - 1 threads)
[==================================                ] (68% - 0.0s - 1 threads)
[==================================                ] (69% - 0.0s - 1 threads)
[===================================               ] (70% - 0.0s - 1 threads)
[===================================               ] (71% - 0.0s - 1 threads)
[====================================              ] (72% - 0.0s - 1 threads)
[====================================              ] (73% - 0.0s - 1 threads)
[=====================================             ] (74% - 0.0s - 1 threads)
[=====================================             ] (75% - 0.0s - 1 threads)
[======================================            ] (76% - 0.0s - 1 threads)
[======================================            ] (77% - 0.0s - 1 threads)
[=======================================           ] (78% - 0.0s - 1 threads)
[=======================================           ] (79% - 0.0s - 1 threads)
[========================================          ] (80% - 0.0s - 1 threads)
[========================================          ] (81% - 0.0s - 1 threads)
[=========================================         ] (82% - 0.0s - 1 threads)
[=========================================         ] (83% - 0.0s - 1 threads)
[==========================================        ] (84% - 0.0s - 1 threads)
[==========================================        ] (85% - 0.0s - 1 threads)
[===========================================       ] (86% - 0.0s - 1 threads)
[===========================================       ] (87% - 0.0s - 1 threads)
[============================================      ] (88% - 0.0s - 1 threads)
[============================================      ] (89% - 0.0s - 1 threads)
[=============================================     ] (90% - 0.0s - 1 threads)
[=============================================     ] (91% - 0.0s - 1 threads)
[==============================================    ] (92% - 0.0s - 1 threads)
[==============================================    ] (93% - 0.0s - 1 threads)
[===============================================   ] (94% - 0.0s - 1 threads)
[===============================================   ] (95% - 0.0s - 1 threads)
[================================================  ] (96% - 0.0s - 1 threads)
[================================================  ] (97% - 0.0s - 1 threads)
[================================================= ] (98% - 0.0s - 1 threads)
[================================================= ] (99% - 0.0s - 1 threads)
t Wall-time = 0.00775209
../../../_images/Workshop-Mound-Segmentation_14_1.png
aspect = rd.TerrainAttribute(rd_dem, attrib='aspect')
rd.rdShow(aspect, axes=False, cmap='jet', figsize=(8, 5.5))
plt.show()
A Aspect attribute calculation
C Horn, B.K.P., 1981. Hill shading and the reflectance map. Proceedings of the IEEE 69, 14–47. doi:10.1109/PROC.1981.11918



[                                                  ] (1% - 0.0s - 1 threads)
[=                                                 ] (2% - 0.0s - 1 threads)
[=                                                 ] (3% - 0.0s - 1 threads)
[==                                                ] (4% - 0.0s - 1 threads)
[==                                                ] (5% - 0.0s - 1 threads)
[===                                               ] (6% - 0.0s - 1 threads)
[===                                               ] (7% - 0.0s - 1 threads)
[====                                              ] (8% - 0.0s - 1 threads)
[====                                              ] (9% - 0.0s - 1 threads)
[=====                                             ] (10% - 0.0s - 1 threads)
[=====                                             ] (11% - 0.0s - 1 threads)
[======                                            ] (12% - 0.0s - 1 threads)
[======                                            ] (13% - 0.0s - 1 threads)
[=======                                           ] (14% - 0.0s - 1 threads)
[=======                                           ] (15% - 0.0s - 1 threads)
[========                                          ] (16% - 0.0s - 1 threads)
[========                                          ] (17% - 0.0s - 1 threads)
[=========                                         ] (18% - 0.0s - 1 threads)
[=========                                         ] (19% - 0.0s - 1 threads)
[==========                                        ] (20% - 0.0s - 1 threads)
[==========                                        ] (21% - 0.0s - 1 threads)
[===========                                       ] (22% - 0.0s - 1 threads)
[===========                                       ] (23% - 0.0s - 1 threads)
[============                                      ] (24% - 0.0s - 1 threads)
[============                                      ] (25% - 0.0s - 1 threads)
[=============                                     ] (26% - 0.0s - 1 threads)
[=============                                     ] (27% - 0.0s - 1 threads)
[==============                                    ] (28% - 0.0s - 1 threads)
[==============                                    ] (29% - 0.0s - 1 threads)
[===============                                   ] (30% - 0.0s - 1 threads)
[===============                                   ] (31% - 0.0s - 1 threads)
[================                                  ] (32% - 0.0s - 1 threads)
[================                                  ] (33% - 0.0s - 1 threads)
[=================                                 ] (34% - 0.0s - 1 threads)
[=================                                 ] (35% - 0.0s - 1 threads)
[==================                                ] (36% - 0.0s - 1 threads)
[==================                                ] (37% - 0.0s - 1 threads)
[===================                               ] (38% - 0.0s - 1 threads)
[===================                               ] (39% - 0.0s - 1 threads)
[====================                              ] (40% - 0.0s - 1 threads)
[====================                              ] (41% - 0.0s - 1 threads)
[=====================                             ] (42% - 0.0s - 1 threads)
[=====================                             ] (43% - 0.0s - 1 threads)
[======================                            ] (44% - 0.0s - 1 threads)
[======================                            ] (45% - 0.0s - 1 threads)
[=======================                           ] (46% - 0.0s - 1 threads)
[=======================                           ] (47% - 0.0s - 1 threads)
[========================                          ] (48% - 0.0s - 1 threads)
[========================                          ] (49% - 0.0s - 1 threads)
[=========================                         ] (50% - 0.0s - 1 threads)
[=========================                         ] (51% - 0.0s - 1 threads)
[==========================                        ] (52% - 0.0s - 1 threads)
[==========================                        ] (53% - 0.0s - 1 threads)
[===========================                       ] (54% - 0.0s - 1 threads)
[===========================                       ] (55% - 0.0s - 1 threads)
[============================                      ] (56% - 0.0s - 1 threads)
[============================                      ] (57% - 0.0s - 1 threads)
[=============================                     ] (58% - 0.0s - 1 threads)
[=============================                     ] (59% - 0.0s - 1 threads)
[==============================                    ] (60% - 0.0s - 1 threads)
[==============================                    ] (61% - 0.0s - 1 threads)
[===============================                   ] (62% - 0.0s - 1 threads)
[===============================                   ] (63% - 0.0s - 1 threads)
[================================                  ] (64% - 0.0s - 1 threads)
[================================                  ] (65% - 0.0s - 1 threads)
[=================================                 ] (66% - 0.0s - 1 threads)
[=================================                 ] (67% - 0.0s - 1 threads)
[==================================                ] (68% - 0.0s - 1 threads)
[==================================                ] (69% - 0.0s - 1 threads)
[===================================               ] (70% - 0.0s - 1 threads)
[===================================               ] (71% - 0.0s - 1 threads)
[====================================              ] (72% - 0.0s - 1 threads)
[====================================              ] (73% - 0.0s - 1 threads)
[=====================================             ] (74% - 0.0s - 1 threads)
[=====================================             ] (75% - 0.0s - 1 threads)
[======================================            ] (76% - 0.0s - 1 threads)
[======================================            ] (77% - 0.0s - 1 threads)
[=======================================           ] (78% - 0.0s - 1 threads)
[=======================================           ] (79% - 0.0s - 1 threads)
[========================================          ] (80% - 0.0s - 1 threads)
[========================================          ] (81% - 0.0s - 1 threads)
[=========================================         ] (82% - 0.0s - 1 threads)
[=========================================         ] (83% - 0.0s - 1 threads)
[==========================================        ] (84% - 0.0s - 1 threads)
[==========================================        ] (85% - 0.0s - 1 threads)
[===========================================       ] (86% - 0.0s - 1 threads)
[===========================================       ] (87% - 0.0s - 1 threads)
[============================================      ] (88% - 0.0s - 1 threads)
[============================================      ] (89% - 0.0s - 1 threads)
[=============================================     ] (90% - 0.0s - 1 threads)
[=============================================     ] (91% - 0.0s - 1 threads)
[==============================================    ] (92% - 0.0s - 1 threads)
[==============================================    ] (93% - 0.0s - 1 threads)
[===============================================   ] (94% - 0.0s - 1 threads)
[===============================================   ] (95% - 0.0s - 1 threads)
[================================================  ] (96% - 0.0s - 1 threads)
[================================================  ] (97% - 0.0s - 1 threads)
[================================================= ] (98% - 0.0s - 1 threads)
[================================================= ] (99% - 0.0s - 1 threads)
t Wall-time = 0.0103459
../../../_images/Workshop-Mound-Segmentation_15_1.png
dataset = gdal.Open(file_path)

# Getting som info about our DEM
print("Driver: {}/{}".format(dataset.GetDriver().ShortName,
                            dataset.GetDriver().LongName))
print("Size is {} x {} x {}".format(dataset.RasterXSize,
                                    dataset.RasterYSize,
                                    dataset.RasterCount))
print("Projection is {}".format(dataset.GetProjection()))
geotransform = dataset.GetGeoTransform()
if geotransform:
    print("Origin = ({}, {})".format(geotransform[0], geotransform[3]))
    print("Pixel Size = ({}, {})".format(geotransform[1], geotransform[5]))
Driver: GTiff/GeoTIFF
Size is 358 x 632 x 1
Projection is PROJCS["unnamed",GEOGCS["Geographic Coordinate System",DATUM["D_MARS",SPHEROID["MARS",3396190,0]],PRIMEM["Reference Meridian",0],UNIT["degree",0.0174532925199433,AUTHORITY["EPSG","9122"]]],PROJECTION["Equirectangular"],PARAMETER["standard_parallel_1",0],PARAMETER["central_meridian",0],PARAMETER["false_easting",0],PARAMETER["false_northing",0],UNIT["metre",1,AUTHORITY["EPSG","9001"]],AXIS["Easting",EAST],AXIS["Northing",NORTH]]
Origin = (-534141.0, 135873.0)
Pixel Size = (18.0, -18.0)
# data_array_interp = np.interp(data_array, (-3200.0, -1400), (0, 255)).astype(np.uint8)
# plt.figure()
# plt.imshow(data_array_interp)
# plt.colorbar()
# plt.show()
### Data and label sources

mounds = '../training_site/mapping/polygons/mounds.shp'
dem = '../training_site/DEM/High-res_DEM/Firsoff_HiRISE_5.tif-adj.tif'
dem_lowres = '../training_site/DEM/Low-res_DEM/Firsoff_HiRISE_5_low_res.tif-adj.tif'
vents = "../training_site/mapping/points/vent_points.shp" 


### remove invalid shape geometries

mound_data = gpd.read_file(mounds)

toremove = []
for id, gg in enumerate(mound_data.geometry): # seems like we have some invalid geometries, dropping them
    if not gg.is_valid:
        print(f"not valid {id}")
        toremove.append(id)
        
mound_data = mound_data.drop(toremove)
mound_data.drop_duplicates(inplace=True)
not valid 56
not valid 92
not valid 100
not valid 163
not valid 200
not valid 329
not valid 330
### Loading DEM source.
dsource= rio.open(dem)
dtm = dsource.read()
meta = dsource.meta
### Reprojecting the labels coordinate system to that of the original DEM.
mounds_corrected  = mound_data.to_crs(meta['crs'])
### Masking the labels. 0 for non mounds, 1 for mounds. Can be used for segmentation.
out_img, out_transform = mask.mask(dataset=dsource, shapes = mounds_corrected.geometry, nodata=np.nan) # do the crop
out_img = out_img[0]
nodata = meta["nodata"]
invalid_elevation = dtm[0] == nodata
isnan = np.isnan(out_img)
out_img[isnan] = 0
out_img[~isnan] = 1
out_img[invalid_elevation] = np.nan
# overall mask 0 non mounds, 1 mounds, np.nan unknown
figure()
imshow(out_img)
<matplotlib.image.AxesImage at 0x7f65b3111370>
../../../_images/Workshop-Mound-Segmentation_19_1.png

2.3. Split the image into tiles

2.3.1. GDAL implementation

def get_extent(dataset):

    cols = dataset.RasterXSize
    rows = dataset.RasterYSize
    transform = dataset.GetGeoTransform()
    minx = transform[0]
    maxx = transform[0] + cols * transform[1] + rows * transform[2]

    miny = transform[3] + cols * transform[4] + rows * transform[5]
    maxy = transform[3]

    return {
            "minX": str(minx), "maxX": str(maxx),
            "minY": str(miny), "maxY": str(maxy),
            "cols": str(cols), "rows": str(rows)
            }

def create_tiles(minx, miny, maxx, maxy, n):
    width = maxx - minx
    height = maxy - miny

    matrix = []

    for j in range(n, 0, -1):
        for i in range(0, n):

            ulx = minx + (width/n) * i # 10/5 * 1
            uly = miny + (height/n) * j # 10/5 * 1

            lrx = minx + (width/n) * (i + 1)
            lry = miny + (height/n) * (j - 1)
            matrix.append([[ulx, uly], [lrx, lry]])

    return matrix

def split(file_name, n):
    raw_file_name = os.path.splitext(os.path.basename(file_name))[0].replace("_downsample", "")
    driver = gdal.GetDriverByName('GTiff')
    dataset = gdal.Open(file_name)
    band = dataset.GetRasterBand(1)
    transform = dataset.GetGeoTransform()

    extent = get_extent(dataset)
    cols = int(extent["cols"])
    rows = int(extent["rows"])

    minx = float(extent["minX"])
    maxx = float(extent["maxX"])
    miny = float(extent["minY"])
    maxy = float(extent["maxY"])

    width = maxx - minx
    height = maxy - miny

    output_path = os.path.join("data", raw_file_name)
    if not os.path.exists(output_path):
        os.makedirs(output_path)

    #print("GCD", gcd(round(width, 0), round(height, 0)))
    print("Width", width)
    print("height", height)


    tiles = create_tiles(minx, miny, maxx, maxy, n)
    transform = dataset.GetGeoTransform()
    xOrigin = transform[0]
    yOrigin = transform[3]
    pixelWidth = transform[1]
    pixelHeight = -transform[5]

    #print(xOrigin, yOrigin)

    tile_num = 0
    ds = []
    for tile in tiles:

        minx = tile[0][0]
        maxx = tile[1][0]
        miny = tile[1][1]
        maxy = tile[0][1]

        p1 = (minx, maxy)
        p2 = (maxx, miny)

        i1 = int((p1[0] - xOrigin) / pixelWidth)
        j1 = int((yOrigin - p1[1])  / pixelHeight)
        i2 = int((p2[0] - xOrigin) / pixelWidth)
        j2 = int((yOrigin - p2[1]) / pixelHeight)

        #print(i1, j1)
        #print(i2, j2)

        new_cols = i2-i1
        new_rows = j2-j1

        data = band.ReadAsArray(i1, j1, new_cols, new_rows)

        #print data

        new_x = xOrigin + i1*pixelWidth
        new_y = yOrigin - j1*pixelHeight

        #print(new_x, new_y)

        new_transform = (new_x, transform[1], transform[2], new_y, transform[4], transform[5])

        output_file_base = raw_file_name + "_" + str(tile_num) + ".tif"
        output_file = os.path.join("data", raw_file_name, output_file_base)

        dst_ds = driver.Create(output_file,
                               new_cols,
                               new_rows,
                               1,
                               gdal.GDT_Float32)

        #writting output raster
        dst_ds.GetRasterBand(1).WriteArray(data)

        tif_metadata = {
            "minX": str(minx), "maxX": str(maxx),
            "minY": str(miny), "maxY": str(maxy)
        }
        dst_ds.SetMetadata(tif_metadata)

        #setting extension of output raster
        # top left x, w-e pixel resolution, rotation, top left y, rotation, n-s pixel resolution
        dst_ds.SetGeoTransform(new_transform)

        wkt = dataset.GetProjection()

        # setting spatial reference of output raster
        srs = osr.SpatialReference()
        srs.ImportFromWkt(wkt)
        dst_ds.SetProjection( srs.ExportToWkt() )
     
        ds.append(dst_ds)
    return ds
        #Close output raster dataset
#         dst_ds = None

#         tile_num += 1
    
    dataset = None
def tilize(dem_source, n):
    out = split(dem_source, n)
    for i in range(len(out)):
        ds_richdem = rd.rdarray(out[i].GetRasterBand(1).ReadAsArray(), no_data = meta['nodata'])
        driver = gdal.GetDriverByName('GTiff')
        driver.CreateCopy("tiles/{}.tif".format(i), out[i], 0)

2.3.2. Rasterio implementation

with rio.open(dem) as src:
    i = 0
    for ji, window in src.block_windows(1):
        r = src.read(window = window, masked = False)
        print(r.shape)
        rio.plot.show(r)
        kwargs = src.meta.copy()
        kwargs.update({
        "driver": "GTIff",
        'height': window.height,
        'width': window.width,
        'transform': rio.windows.transform(window, src.transform)})
        if not os.path.exists('tiles/'):
            os.makedirs('tiles/')
        
        if r.shape == (1,256,256):
            with open('tiles/tile_{}.npy'.format(i), 'wb') as f:
                np.save(f, src.read(window=window))
            with rio.open('tiles/tile_{}.tif'.format(i), 'w', **kwargs) as dst:
                 dst.write(src.read(window=window))
            i += 1

2.4. Annotation script

def numbers(s):
    return int(re.search(r"\d+", s).group(0))


for file in os.listdir('tiles'):
    if file.endswith('.tif'):
        fn = numbers(str(file))
        dsource= rio.open('tiles/{}'.format(file))
        dtm = dsource.read()
        meta = dsource.meta
        out_img, out_transform = mask.mask(dataset=dsource, shapes = mounds_corrected.geometry, crop = False) 
        out_img = out_img[0]
        nodata = meta["nodata"]
        invalid_elevation = dtm[0] == nodata
        isnan = np.isnan(out_img)
        out_img[isnan] = 0
        out_img[~isnan] = 1
        out_img[invalid_elevation] = np.nan
        # overall mask 0 non mounds, 1 mounds, np.nan unknown
        figure()
        imshow(out_img)
        
            
        if not os.path.exists('labels/'):
            os.makedirs('labels/')

        out = np.expand_dims(out_img, axis = 0)
        with open('labels/tile_{}.npy'.format(fn), 'wb') as f:
                np.save(f, out)
        kwargs = meta.copy()
        with rio.open('labels/tile_{}.tif'.format(fn), 'w', **kwargs) as dst:
             dst.write(out)
  

Data Cache contains 3 channel feature sets already created

# Root Directory
Image_Path = 'data_cache/'
# Inputs/Sources
dataframes_path = Image_Path + "DataFrames/"
dataset_dirs_path = dataframes_path + 'three_channel_dataset.csv'
import albumentations
from sklearn.preprocessing import MinMaxScaler

scaler = MinMaxScaler()

from torch.utils.data import DataLoader
from torch.utils.data import Dataset as BaseDataset
from torch import FloatTensor


class Dataset(BaseDataset):    
    CLASSES = ['non-mound', 'mound']
    
    def __init__(
          self, 
          dirs,
          classes=None,
          augmentation=None):
        self.df = pd.read_csv(dirs)
        self.ids = self.df.index.to_list()

        ### convert str names to class values on masks
        self.class_values = CLASSES = [0, 1]
        self.augmentation = augmentation
    
    def __getitem__(self, i):
        ### read data
        if self.df.at[i, 'origin'][-3:] == 'tif':
            ### load DEM
            with rasterio.open(self.df.at[i, 'origin']) as dem:
                dem_array = dem.read(1)
            ### load hillshade
            with rasterio.open(self.df.at[i, 'hillshade']) as hillshade:
                hillshade_array = hillshade.read(1)
            ### load slope
            with rasterio.open(self.df.at[i, 'slope']) as slope:
                slope_array = slope.read(1)
            ### load mask label
            with rasterio.open(self.df.at[i, 'mask']) as mask:
                mask_array = mask.read(1)
            mask_array = np.where(mask_array == 0.5, 0, mask_array)
            # apply augmentations
            if self.augmentation:
                ### only reshaping, so applying each function separately is ok.
                sample1 = self.augmentation(image=dem_array)
                dem_array= sample1['image']

                sample2 = self.augmentation(image=slope_array)
                slope_array = sample2['image']           

                sample3 = self.augmentation(image=hillshade_array)
                hillshade_array = sample3['image']

            input_array = [dem_array, slope_array, hillshade_array]
            input_array = np.stack(input_array, axis=0).astype('float')
            input_array = FloatTensor(input_array)
            mask_array = FloatTensor(np.expand_dims(mask_array,axis=(0)))
        
        elif self.df.at[i, 'origin'][-3:] == 'jpg':
            ###load the image and corresponding mask label
            image = Image.open(self.df.at[i, 'origin'])
            mask = Image.open(self.df.at[i, 'mask']).convert('L')
            
            ###convert image to numpy array
            input_array = np.asarray(image)
            input_array.transpose(2, 0, 1).astype('float32')
            mask_array = np.asarray(mask)
            mask_array = scaler.fit_transform(mask_array)
            mask_array = mask_array[ : ,0]            
            input_array = np.resize(input_array, (3, 224, 192))
            mask_array = np.resize(input_array, (224, 192))
            mask_array [mask_array >= 0.5] = 1 
            mask_array [mask_array < 0.5] = 0
            mask_array = FloatTensor(np.expand_dims(mask_array,axis=(0)))    
            input_array = FloatTensor(input_array)

        return input_array, mask_array
      
    def __len__(self):
        return len(self.ids)
### Sample reshaping (X,Y) --> (224,192)
def get_training_augmentation():
    train_transform = [albumentations.PadIfNeeded(min_height=224, min_width=192, always_apply=True, border_mode=0)]

    return albumentations.Compose(train_transform)

def get_validation_augmentation():
    """Add paddings to make image shape divisible by 32"""
    test_transform = [
        albumentations.PadIfNeeded(224, 192)
    ]
    return albumentations.Compose(test_transform)
!ls data_cache/Annotated\ Tiles\ Sets/
'16x32 Tiles'  '32x64 Tiles'  '4x8 Tiles'
'2x4 Tiles'    '3x6 Tiles'    '8x16 Tiles'
ENCODER = 'se_resnext50_32x4d'
# ENCODER = 'resnet18'
### use imagenet weights
ENCODER_WEIGHTS = 'imagenet' 
# ENCODER_WEIGHTS = 'ssl'
### Sigmoid activation to obtain logits in final layer
ACTIVATION = 'sigmoid' 
CLASSES = [1, 0]
### altenative UNET instead of FPN
model = smp.Unet(
    encoder_name=ENCODER, 
    encoder_weights=None, #ENCODER_WEIGHTS, 
    classes=1, 
    in_channels=3,
    activation=ACTIVATION)
#train_dirs_path = 'data_cache/DataFrames/three_channel_train.csv'
### Uncomment to exclude GAN images for training
CLASSES = [1, 0]

train_dirs_path = 'data_cache/DataFrames/three_channel_train(NO GAN).csv'
validate_dirs_path = 'data_cache/DataFrames/three_channel_valid.csv'
test_dirs_path = 'data_cache/DataFrames/three_channel_test.csv'

train_dataset = Dataset(train_dirs_path,
                        augmentation=get_training_augmentation(),                         
                        classes=CLASSES)


valid_dataset = Dataset(validate_dirs_path,
                        augmentation=get_validation_augmentation(),                         
                        classes=CLASSES)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers = 8)
valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=True, num_workers = 8)
train_dataset[2][0].shape
torch.Size([3, 224, 192])
# Dice/F1 score - https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient
# IoU/Jaccard score - https://en.wikipedia.org/wiki/Jaccard_index

loss = smp.utils.losses.DiceLoss()
metrics = [
    smp.utils.metrics.Fscore(threshold=0.5),
]

optimizer = torch.optim.Adam([ 
    dict(params=model.parameters(), lr=1e-3),
])

optimizer.param_groups[0]['lr']
0.001
## utility fn to for updating learnt weights as per the scheduler.
def update_learning_rate(scheduler, optimizer):
    scheduler.step()
    lr = optimizer.param_groups[0]['lr']
    print('learning rate = %.7f' % lr)

2.5. Define device for computation

device  = torch.device("cuda:0")
# create epoch runners 
# it is a simple loop of iterating over dataloader`s samples
train_epoch = smp.utils.train.TrainEpoch(
    model.to(device), 
    loss=loss, 
    metrics=metrics, 
    optimizer=optimizer,
    verbose=True,
)

valid_epoch = smp.utils.train.ValidEpoch(
    model.to(device), 
    loss=loss, 
    metrics=metrics, 
    verbose=True,
)
# train model for 100 epochs
train_loss=[]
valid_loss=[]
train_fscore=[]
valid_fscore=[]
min_loss = float("inf")
max_score = 0
state_not_changed = 0
scheduler1 = ExponentialLR(optimizer, gamma=0.9)
#scheduler2 = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1,)
for i in range(0, 100):
    print('\nEpoch: {}'.format(i))
    train_logs = train_epoch.run(train_loader)
    valid_logs = valid_epoch.run(valid_loader)
    train_loss.append(train_logs['dice_loss'])
    valid_loss.append(valid_logs['dice_loss'])
    train_fscore.append(train_logs['fscore'])
    valid_fscore.append(valid_logs['fscore'])
    state_not_changed+=1
    scheduler1.step()
#     scheduler2.step()
    # do something (save model, change lr, etc.)
    if min_loss > valid_logs['dice_loss']:
        min_loss = valid_logs['dice_loss']
        torch.save(model, './best_loss_model.pth')
        print('best_loss_model saved!')
        state_not_changed = 0
        
    if max_score < valid_logs['fscore']:
        max_score = valid_logs['fscore']
        torch.save(model, './best_score_model.pth')
        print('best_score_model saved!')
        state_not_changed = 0
    
    
    if i % 10 == 0 and i != 0:
        print('Current learning rate = '+str(optimizer.param_groups[0]['lr'])+'!')
#         optimizer.param_groups[0]['lr'] /= 10
#         print('Decrease decoder learning rate to '+str(optimizer.param_groups[0]['lr'])+'!')
#         state_not_changed = 0
        
    if state_not_changed == 15:
        print('The model did not improve during the last 15 epochs')
        break
plt.plot(train_loss, label='Training loss')
plt.plot(val_loss, label = 'Validation loss')
plt.legend(frameon=False)
plt.show()
with torch.no_grad():
            confusion_mat = []
            for sample in valid_loader:
            # get the inputs
                image = sample['image'].float()
                mask = sample['target_mask'].float()
                image = image.to(device)
                
                output = model(image)
                #probability_class = torch.exp(output) 
                prediction = torch.argmax(output,1).to("cpu")
                y_true = torch.argmax(mask,1).to("cpu")
              
                cm = confusion_matrix(prediction.view(-1), y_true.view(-1))
                confusion_mat.append( cm)
                #print(' cm:{}'.format(cm))
                
                # performance metrics for the model
                score_accuracy = accuracy_score(prediction.view(-1), y_true.view(-1))
                report = classification_report(prediction.view(-1), y_true.view(-1)) 
                
                fpr, tpr, _ = roc_curve(prediction.view(-1), y_true.view(-1))
                roc_auc = auc(fpr,tpr)
                print('score_accuracy:{}'.format(score_accuracy))
                print('Receiver operating characteristic (ROC):{}'.format(roc_auc))
            
confusion_matrix = np.sum(confusion_mat,axis = 0)
df_cm = pd.DataFrame(confusion_matrix, index = [i for i in ['non-mound', 'mound']],
                  columns = [i for i in ['non-mound','mound']])

import seaborn as sn
sn.heatmap(df_cm, annot = True,)
<AxesSubplot:>
../../../_images/Workshop-Mound-Segmentation_43_1.png
def predict_image(image):
    image_tensor = np.expand_dims(image,(2,3))
    #print(image_tensor.shape)
    image_tensor = np.transpose(image_tensor,(3,2,0,1))
    input = torch.from_numpy(image_tensor).to(device, dtype = torch.float)
    output = model(input)
    output = torch.argmax(output,1)
    index = output.float().to("cpu")
    return np.array(index)

def numbers(s):
    return int(re.search(r"\d+", s).group(0))

def reconstruct_map(tile_folder):
    fn_list = []
    for file in os.listdir(tile_folder):
        if file.endswith('.tif'):
            fn = numbers(str(file))
            fn_list.append(file)
    fn_list = np.sort(fn_list)
    sources = [rio.open(os.path.join(tile_folder,fn)) for fn in fn_list]
    merged_image, out_transform = merge(sources)
    return merged_image, out_transform
# helper function for data visualization
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(10, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()
# test dataset without transformations for image visualization
test_dataset = Dataset(test_dirs_path, 
                           classes=CLASSES)
best_model = torch.load('FPN_3c_Without.pth')
idx = [25, 30, 36, 38, 42, 53, 54]
#for i in range(len(test_dataset)):
for i in idx:
#     n = np.random.choice(len(test_dataset))
    n=i
    image_vis = test_dataset[n][0]
    image, gt_mask = test_dataset[n]
    
    x_tensor = image.unsqueeze(0)
    pr_mask = best_model.predict(x_tensor)
    pr_mask = (pr_mask.squeeze().cpu().numpy().round())

    visualize(
        image=image_vis[2], 
        ground_truth_mask=gt_mask[0], 
        predicted_mask=pr_mask
    )
../../../_images/Workshop-Mound-Segmentation_48_0.png ../../../_images/Workshop-Mound-Segmentation_48_1.png ../../../_images/Workshop-Mound-Segmentation_48_2.png ../../../_images/Workshop-Mound-Segmentation_48_3.png ../../../_images/Workshop-Mound-Segmentation_48_4.png ../../../_images/Workshop-Mound-Segmentation_48_5.png ../../../_images/Workshop-Mound-Segmentation_48_6.png
def recall(tp, fn):
    if tp+fn == 0:
        return 'invalid'
    return float(tp)/float(tp+fn)

def precision(tp, fp):
    if tp+fp == 0:
        return 'invalid'
    return float(tp)/float(tp+fp)

def f1_score(tp, fp, fn):
    if tp+fp+fn == 0:
        return 'invalid'
    return float(tp)/(tp + (float(fp + fn)/2))

def for_metric(fn, tn):
    if fn+tn == 0:
        return 'invalid'
    return float(fn)/(fn+tn)
threshold = 0.5 ## Feel free to play around with the threshold value to find what works best
results = pd.DataFrame(columns=['id', 'tp', 'fp', 'tn', 'fn', 'precision', 'recall'])
# for i, sample in enumerate(test_dataset):
# idx = [75, 256, 269, 321, 406, 415, 504]
for i in range(len(test_dataset)):
    correct = 0
    sample = test_dataset[i]
    im = sample[0].unsqueeze(0).float()#.to(device)
    mask = sample[1].float()#.to(device)
    out = best_model(im).detach()
    #compute output class
    pred = out
    pred[out >= threshold] = 1
    pred[out < threshold] = 0
    correct += (pred == mask).sum().item()
    total = len(mask.view(-1))
    accuracy = float(correct/total)
#     tp, fp, fn, tn = smp.metrics.get_stats(out, mask, mode='binary', threshold=0.5)
    # Sklearn doc
    tn, fp, fn, tp = confusion_matrix(mask.view(-1), pred.view(-1), labels=[0., 1.]).ravel()
    
    tmp = pd.DataFrame(columns=['id', 'tp', 'fp', 'tn', 'fn', 'precision', 'recall'])
    tmp['id'] = [i]
    tmp['tp'] = [tp]
    tmp['fp'] = [fp]
    tmp['tn'] = [tn]
    tmp['fn'] = [fn]
    tmp['recall'] = [recall(tp=tp, fn=fn)]
    tmp['precision'] = [precision(tp=tp, fp=fp)]
    results = results.append(tmp)
ttp = np.sum(results['tp'].values)
tfp = np.sum(results['fp'].values)
ttn = np.sum(results['tn'].values)
tfn = np.sum(results['fn'].values)
print(f'tp = {ttp}, fp = {tfp}, tn={ttn}, fn ={tfn}, precision = {precision(tp=ttp, fp=tfp)}, recall = {recall(tp=ttp, fn=tfn)} ')
print(f'F-Score = {f1_score(tp= ttp, fp= tfp, fn= tfn)}')
tp = 1291642, fp = 384376, tn=16219260, fn =297106, precision = 0.7706611742833311, recall = 0.8129936276867068 
F-Score = 0.7912616095609915