#! /usr/bin/env python3 import os import mercantile import rasterio import numpy as np import random import warnings from rasterio.errors import NotGeoreferencedWarning from rasterio.io import MemoryFile from rasterio.transform import from_bounds from rasterio.merge import merge from PIL import Image import gc import random from osgeo import gdal, osr from affine import Affine def get_5x5_neighbors(tile: mercantile.Tile) -> list[mercantile.Tile]: neighbors = [] for main_neighbour in mercantile.neighbors(tile): for sub_neighbour in mercantile.neighbors(main_neighbour): if sub_neighbour not in neighbors: neighbors.append(sub_neighbour) return neighbors def get_tiff_map(tile: mercantile.Tile, sat_year: str, satellite_dataset_dir: str) -> (np.ndarray, dict): """ Returns a TIFF map of the given tile using GDAL. """ tile_data = [] neighbors = get_5x5_neighbors(tile) with warnings.catch_warnings(): warnings.filterwarnings("ignore") for neighbor in neighbors: west, south, east, north = mercantile.bounds(neighbor) tile_path = f"{satellite_dataset_dir}/{sat_year}/{neighbor.z}_{neighbor.x}_{neighbor.y}.jpg" if not os.path.exists(tile_path): raise FileNotFoundError(f"Tile {neighbor.z}_{neighbor.x}_{neighbor.y} not found.") img = Image.open(tile_path) img_array = np.array(img) # Create an in-memory GDAL dataset mem_driver = gdal.GetDriverByName('MEM') dataset = mem_driver.Create('', img_array.shape[1], img_array.shape[0], 3, gdal.GDT_Byte) for i in range(3): dataset.GetRasterBand(i + 1).WriteArray(img_array[:, :, i]) # Set GeoTransform and Projection geotransform = (west, (east - west) / img_array.shape[1], 0, north, 0, -(north - south) / img_array.shape[0]) dataset.SetGeoTransform(geotransform) srs = osr.SpatialReference() srs.ImportFromEPSG(3857) dataset.SetProjection(srs.ExportToWkt()) tile_data.append(dataset) # Merge tiles using GDAL vrt_options = gdal.BuildVRTOptions() vrt = gdal.BuildVRT('', [td for td in tile_data], options=vrt_options) mosaic = vrt.ReadAsArray() # Get metadata out_trans = vrt.GetGeoTransform() out_crs = vrt.GetProjection() out_trans = Affine.from_gdal(*out_trans) out_meta = { "driver": "GTiff", "height": mosaic.shape[1], "width": mosaic.shape[2], "transform": out_trans, "crs": out_crs, } # Clean up for td in tile_data: td.FlushCache() vrt = None gc.collect() return mosaic, out_meta def get_random_tiff_patch( lat: float, lon: float, patch_width: int, patch_height: int, sat_year: str, satellite_dataset_dir: str = "/mnt/drive/satellite_dataset", ) -> (np.ndarray, int, int, int, int, rasterio.transform.Affine): """ Returns a random patch from the satellite image. """ tile = get_tile_from_coord(lat, lon, 17) mosaic, out_meta = get_tiff_map(tile, sat_year, satellite_dataset_dir) transform = out_meta["transform"] del out_meta x_pixel, y_pixel = geo_to_pixel_coordinates(lat, lon, transform) # TODO # Temporal constant, replace with a better solution KS = 120 x_offset_range = [ x_pixel - patch_width + KS + 1, x_pixel - KS - 1, ] y_offset_range = [ y_pixel - patch_height + KS + 1, y_pixel - KS - 1, ] # Randomly select an offset within the valid range x_offset = random.randint(*x_offset_range) y_offset = random.randint(*y_offset_range) x_offset = np.clip(x_offset, 0, mosaic.shape[-1] - patch_width) y_offset = np.clip(y_offset, 0, mosaic.shape[-2] - patch_height) # Update x, y to reflect the clamping of x_offset and y_offset x, y = x_pixel - x_offset, y_pixel - y_offset patch = mosaic[ :, y_offset : y_offset + patch_height, x_offset : x_offset + patch_width ] patch_transform = rasterio.transform.Affine( transform.a, transform.b, transform.c + x_offset * transform.a + y_offset * transform.b, transform.d, transform.e, transform.f + x_offset * transform.d + y_offset * transform.e, ) gc.collect() return patch, x, y, x_offset, y_offset, patch_transform def get_tile_from_coord( lat: float, lng: float, zoom_level: int ) -> mercantile.Tile: """ Returns the tile containing the given coordinates. """ tile = mercantile.tile(lng, lat, zoom_level) return tile def geo_to_pixel_coordinates( lat: float, lon: float, transform: rasterio.transform.Affine ) -> (int, int): """ Converts a pair of (lat, lon) coordinates to pixel coordinates. """ x_pixel, y_pixel = ~transform * (lon, lat) return round(x_pixel), round(y_pixel)