164 lines
4.9 KiB
Python
164 lines
4.9 KiB
Python
#! /usr/bin/env python3
|
|
|
|
import os
|
|
import mercantile
|
|
import rasterio
|
|
import numpy as np
|
|
import random
|
|
import warnings
|
|
from rasterio.errors import NotGeoreferencedWarning
|
|
from rasterio.io import MemoryFile
|
|
from rasterio.transform import from_bounds
|
|
from rasterio.merge import merge
|
|
from PIL import Image
|
|
import gc
|
|
import random
|
|
from osgeo import gdal, osr
|
|
from affine import Affine
|
|
|
|
|
|
def get_5x5_neighbors(tile: mercantile.Tile) -> list[mercantile.Tile]:
|
|
neighbors = []
|
|
for main_neighbour in mercantile.neighbors(tile):
|
|
for sub_neighbour in mercantile.neighbors(main_neighbour):
|
|
if sub_neighbour not in neighbors:
|
|
neighbors.append(sub_neighbour)
|
|
return neighbors
|
|
|
|
def get_tiff_map(tile: mercantile.Tile, sat_year: str, satellite_dataset_dir: str) -> (np.ndarray, dict):
|
|
"""
|
|
Returns a TIFF map of the given tile using GDAL.
|
|
"""
|
|
tile_data = []
|
|
neighbors = get_5x5_neighbors(tile)
|
|
|
|
with warnings.catch_warnings():
|
|
warnings.filterwarnings("ignore")
|
|
for neighbor in neighbors:
|
|
west, south, east, north = mercantile.bounds(neighbor)
|
|
tile_path = f"{satellite_dataset_dir}/{sat_year}/{neighbor.z}_{neighbor.x}_{neighbor.y}.jpg"
|
|
if not os.path.exists(tile_path):
|
|
raise FileNotFoundError(f"Tile {neighbor.z}_{neighbor.x}_{neighbor.y} not found.")
|
|
|
|
img = Image.open(tile_path)
|
|
img_array = np.array(img)
|
|
|
|
# Create an in-memory GDAL dataset
|
|
mem_driver = gdal.GetDriverByName('MEM')
|
|
dataset = mem_driver.Create('', img_array.shape[1], img_array.shape[0], 3, gdal.GDT_Byte)
|
|
for i in range(3):
|
|
dataset.GetRasterBand(i + 1).WriteArray(img_array[:, :, i])
|
|
|
|
# Set GeoTransform and Projection
|
|
geotransform = (west, (east - west) / img_array.shape[1], 0, north, 0, -(north - south) / img_array.shape[0])
|
|
dataset.SetGeoTransform(geotransform)
|
|
srs = osr.SpatialReference()
|
|
srs.ImportFromEPSG(3857)
|
|
dataset.SetProjection(srs.ExportToWkt())
|
|
|
|
tile_data.append(dataset)
|
|
|
|
# Merge tiles using GDAL
|
|
vrt_options = gdal.BuildVRTOptions()
|
|
vrt = gdal.BuildVRT('', [td for td in tile_data], options=vrt_options)
|
|
mosaic = vrt.ReadAsArray()
|
|
|
|
# Get metadata
|
|
out_trans = vrt.GetGeoTransform()
|
|
out_crs = vrt.GetProjection()
|
|
out_trans = Affine.from_gdal(*out_trans)
|
|
out_meta = {
|
|
"driver": "GTiff",
|
|
"height": mosaic.shape[1],
|
|
"width": mosaic.shape[2],
|
|
"transform": out_trans,
|
|
"crs": out_crs,
|
|
}
|
|
|
|
# Clean up
|
|
for td in tile_data:
|
|
td.FlushCache()
|
|
vrt = None
|
|
gc.collect()
|
|
|
|
return mosaic, out_meta
|
|
|
|
def get_random_tiff_patch(
|
|
lat: float,
|
|
lon: float,
|
|
patch_width: int,
|
|
patch_height: int,
|
|
sat_year: str,
|
|
satellite_dataset_dir: str = "/mnt/drive/satellite_dataset",
|
|
) -> (np.ndarray, int, int, int, int, rasterio.transform.Affine):
|
|
"""
|
|
Returns a random patch from the satellite image.
|
|
"""
|
|
|
|
tile = get_tile_from_coord(lat, lon, 17)
|
|
|
|
mosaic, out_meta = get_tiff_map(tile, sat_year, satellite_dataset_dir)
|
|
|
|
|
|
transform = out_meta["transform"]
|
|
del out_meta
|
|
|
|
x_pixel, y_pixel = geo_to_pixel_coordinates(lat, lon, transform)
|
|
|
|
# TODO
|
|
# Temporal constant, replace with a better solution
|
|
KS = 120
|
|
|
|
x_offset_range = [
|
|
x_pixel - patch_width + KS + 1,
|
|
x_pixel - KS - 1,
|
|
]
|
|
y_offset_range = [
|
|
y_pixel - patch_height + KS + 1,
|
|
y_pixel - KS - 1,
|
|
]
|
|
|
|
# Randomly select an offset within the valid range
|
|
x_offset = random.randint(*x_offset_range)
|
|
y_offset = random.randint(*y_offset_range)
|
|
|
|
x_offset = np.clip(x_offset, 0, mosaic.shape[-1] - patch_width)
|
|
y_offset = np.clip(y_offset, 0, mosaic.shape[-2] - patch_height)
|
|
|
|
# Update x, y to reflect the clamping of x_offset and y_offset
|
|
x, y = x_pixel - x_offset, y_pixel - y_offset
|
|
patch = mosaic[
|
|
:, y_offset : y_offset + patch_height, x_offset : x_offset + patch_width
|
|
]
|
|
|
|
patch_transform = rasterio.transform.Affine(
|
|
transform.a,
|
|
transform.b,
|
|
transform.c + x_offset * transform.a + y_offset * transform.b,
|
|
transform.d,
|
|
transform.e,
|
|
transform.f + x_offset * transform.d + y_offset * transform.e,
|
|
)
|
|
|
|
gc.collect()
|
|
|
|
return patch, x, y, x_offset, y_offset, patch_transform
|
|
|
|
def get_tile_from_coord(
|
|
lat: float, lng: float, zoom_level: int
|
|
) -> mercantile.Tile:
|
|
"""
|
|
Returns the tile containing the given coordinates.
|
|
"""
|
|
tile = mercantile.tile(lng, lat, zoom_level)
|
|
return tile
|
|
|
|
def geo_to_pixel_coordinates(
|
|
lat: float, lon: float, transform: rasterio.transform.Affine
|
|
) -> (int, int):
|
|
"""
|
|
Converts a pair of (lat, lon) coordinates to pixel coordinates.
|
|
"""
|
|
x_pixel, y_pixel = ~transform * (lon, lat)
|
|
return round(x_pixel), round(y_pixel)
|