glue-factory-custom/gluefactory/datasets/sat_utils.py

110 lines
3.1 KiB
Python
Raw Normal View History

2024-02-14 00:02:09 +01:00
#! /usr/bin/env python3
import os
import mercantile
import rasterio
import numpy as np
import random
import warnings
from rasterio.errors import NotGeoreferencedWarning
from rasterio.io import MemoryFile
from rasterio.transform import from_bounds
from rasterio.merge import merge
from PIL import Image
import gc
def get_3x3_neighbors(tile: mercantile.Tile) -> list[mercantile.Tile]:
neighbors = []
for neighbour in mercantile.neighbors(tile):
if neighbour not in neighbors:
neighbors.append(neighbour)
neighbors.append(tile)
return neighbors
def get_tiff_map(tile: mercantile.Tile, sat_year: str, satellite_dataset_dir:str) -> (np.ndarray, dict):
"""
Returns a TIFF map of the given tile.
"""
tile_data = []
neighbors = get_3x3_neighbors(tile)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=NotGeoreferencedWarning)
for neighbor in neighbors:
west, south, east, north = mercantile.bounds(neighbor)
tile_path = f"{satellite_dataset_dir}/{sat_year}/{neighbor.z}_{neighbor.x}_{neighbor.y}.jpg"
if not os.path.exists(tile_path):
raise FileNotFoundError(
f"Tile {neighbor.z}_{neighbor.x}_{neighbor.y} not found."
)
with Image.open(tile_path) as img:
width, height = img.size
memfile = MemoryFile()
with memfile.open(
driver="GTiff",
height=height,
width=width,
count=3,
dtype="uint8",
crs="EPSG:3857",
transform=from_bounds(west, south, east, north, width, height),
) as dataset:
data = rasterio.open(tile_path).read()
dataset.write(data)
tile_data.append(memfile.open())
memfile.close()
mosaic, out_trans = merge(tile_data)
out_meta = tile_data[0].meta.copy()
out_meta.update(
{
"driver": "GTiff",
"height": mosaic.shape[1],
"width": mosaic.shape[2],
"transform": out_trans,
"crs": "EPSG:3857",
}
)
# Clean up MemoryFile instances to free up memory
for td in tile_data:
td.close()
del neighbors
del tile_data
gc.collect()
return mosaic, out_meta
def get_random_tiff_patch(
lat: float,
lon: float,
satellite_dataset_dir: str,
) -> (np.ndarray):
"""
Returns a random patch from the satellite image.
"""
tile = get_tile_from_coord(lat, lon, 17)
sat_years = ["2023", "2021", "2019", "2016"]
# Randomly select a satellite year
sat_year = random.choice(sat_years)
mosaic, _ = get_tiff_map(tile, sat_year, satellite_dataset_dir)
return mosaic
def get_tile_from_coord(
lat: float, lng: float, zoom_level: int
) -> mercantile.Tile:
"""
Returns the tile containing the given coordinates.
"""
tile = mercantile.tile(lng, lat, zoom_level)
return tile