glue-factory-custom/gluefactory/datasets/s_utils.py

164 lines
4.9 KiB
Python
Raw Normal View History

2024-02-14 00:02:09 +01:00
#! /usr/bin/env python3
import os
import mercantile
import rasterio
import numpy as np
import random
import warnings
from rasterio.errors import NotGeoreferencedWarning
from rasterio.io import MemoryFile
from rasterio.transform import from_bounds
from rasterio.merge import merge
from PIL import Image
import gc
import random
from osgeo import gdal, osr
from affine import Affine
def get_5x5_neighbors(tile: mercantile.Tile) -> list[mercantile.Tile]:
neighbors = []
for main_neighbour in mercantile.neighbors(tile):
for sub_neighbour in mercantile.neighbors(main_neighbour):
if sub_neighbour not in neighbors:
neighbors.append(sub_neighbour)
return neighbors
def get_tiff_map(tile: mercantile.Tile, sat_year: str, satellite_dataset_dir: str) -> (np.ndarray, dict):
"""
Returns a TIFF map of the given tile using GDAL.
"""
tile_data = []
neighbors = get_5x5_neighbors(tile)
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
for neighbor in neighbors:
west, south, east, north = mercantile.bounds(neighbor)
tile_path = f"{satellite_dataset_dir}/{sat_year}/{neighbor.z}_{neighbor.x}_{neighbor.y}.jpg"
if not os.path.exists(tile_path):
raise FileNotFoundError(f"Tile {neighbor.z}_{neighbor.x}_{neighbor.y} not found.")
img = Image.open(tile_path)
img_array = np.array(img)
# Create an in-memory GDAL dataset
mem_driver = gdal.GetDriverByName('MEM')
dataset = mem_driver.Create('', img_array.shape[1], img_array.shape[0], 3, gdal.GDT_Byte)
for i in range(3):
dataset.GetRasterBand(i + 1).WriteArray(img_array[:, :, i])
# Set GeoTransform and Projection
geotransform = (west, (east - west) / img_array.shape[1], 0, north, 0, -(north - south) / img_array.shape[0])
dataset.SetGeoTransform(geotransform)
srs = osr.SpatialReference()
srs.ImportFromEPSG(3857)
dataset.SetProjection(srs.ExportToWkt())
tile_data.append(dataset)
# Merge tiles using GDAL
vrt_options = gdal.BuildVRTOptions()
vrt = gdal.BuildVRT('', [td for td in tile_data], options=vrt_options)
mosaic = vrt.ReadAsArray()
# Get metadata
out_trans = vrt.GetGeoTransform()
out_crs = vrt.GetProjection()
out_trans = Affine.from_gdal(*out_trans)
out_meta = {
"driver": "GTiff",
"height": mosaic.shape[1],
"width": mosaic.shape[2],
"transform": out_trans,
"crs": out_crs,
}
# Clean up
for td in tile_data:
td.FlushCache()
vrt = None
gc.collect()
return mosaic, out_meta
def get_random_tiff_patch(
lat: float,
lon: float,
patch_width: int,
patch_height: int,
sat_year: str,
satellite_dataset_dir: str = "/mnt/drive/satellite_dataset",
) -> (np.ndarray, int, int, int, int, rasterio.transform.Affine):
"""
Returns a random patch from the satellite image.
"""
tile = get_tile_from_coord(lat, lon, 17)
mosaic, out_meta = get_tiff_map(tile, sat_year, satellite_dataset_dir)
transform = out_meta["transform"]
del out_meta
x_pixel, y_pixel = geo_to_pixel_coordinates(lat, lon, transform)
# TODO
# Temporal constant, replace with a better solution
KS = 120
x_offset_range = [
x_pixel - patch_width + KS + 1,
x_pixel - KS - 1,
]
y_offset_range = [
y_pixel - patch_height + KS + 1,
y_pixel - KS - 1,
]
# Randomly select an offset within the valid range
x_offset = random.randint(*x_offset_range)
y_offset = random.randint(*y_offset_range)
x_offset = np.clip(x_offset, 0, mosaic.shape[-1] - patch_width)
y_offset = np.clip(y_offset, 0, mosaic.shape[-2] - patch_height)
# Update x, y to reflect the clamping of x_offset and y_offset
x, y = x_pixel - x_offset, y_pixel - y_offset
patch = mosaic[
:, y_offset : y_offset + patch_height, x_offset : x_offset + patch_width
]
patch_transform = rasterio.transform.Affine(
transform.a,
transform.b,
transform.c + x_offset * transform.a + y_offset * transform.b,
transform.d,
transform.e,
transform.f + x_offset * transform.d + y_offset * transform.e,
)
gc.collect()
return patch, x, y, x_offset, y_offset, patch_transform
def get_tile_from_coord(
lat: float, lng: float, zoom_level: int
) -> mercantile.Tile:
"""
Returns the tile containing the given coordinates.
"""
tile = mercantile.tile(lng, lat, zoom_level)
return tile
def geo_to_pixel_coordinates(
lat: float, lon: float, transform: rasterio.transform.Affine
) -> (int, int):
"""
Converts a pair of (lat, lon) coordinates to pixel coordinates.
"""
x_pixel, y_pixel = ~transform * (lon, lat)
return round(x_pixel), round(y_pixel)