110 lines
3.1 KiB
Python
110 lines
3.1 KiB
Python
|
#! /usr/bin/env python3
|
||
|
|
||
|
import os
|
||
|
import mercantile
|
||
|
import rasterio
|
||
|
import numpy as np
|
||
|
import random
|
||
|
import warnings
|
||
|
from rasterio.errors import NotGeoreferencedWarning
|
||
|
from rasterio.io import MemoryFile
|
||
|
from rasterio.transform import from_bounds
|
||
|
from rasterio.merge import merge
|
||
|
from PIL import Image
|
||
|
import gc
|
||
|
|
||
|
|
||
|
def get_3x3_neighbors(tile: mercantile.Tile) -> list[mercantile.Tile]:
|
||
|
neighbors = []
|
||
|
for neighbour in mercantile.neighbors(tile):
|
||
|
if neighbour not in neighbors:
|
||
|
neighbors.append(neighbour)
|
||
|
|
||
|
neighbors.append(tile)
|
||
|
return neighbors
|
||
|
|
||
|
def get_tiff_map(tile: mercantile.Tile, sat_year: str, satellite_dataset_dir:str) -> (np.ndarray, dict):
|
||
|
"""
|
||
|
Returns a TIFF map of the given tile.
|
||
|
"""
|
||
|
tile_data = []
|
||
|
neighbors = get_3x3_neighbors(tile)
|
||
|
|
||
|
with warnings.catch_warnings():
|
||
|
warnings.filterwarnings("ignore", category=NotGeoreferencedWarning)
|
||
|
for neighbor in neighbors:
|
||
|
west, south, east, north = mercantile.bounds(neighbor)
|
||
|
tile_path = f"{satellite_dataset_dir}/{sat_year}/{neighbor.z}_{neighbor.x}_{neighbor.y}.jpg"
|
||
|
if not os.path.exists(tile_path):
|
||
|
raise FileNotFoundError(
|
||
|
f"Tile {neighbor.z}_{neighbor.x}_{neighbor.y} not found."
|
||
|
)
|
||
|
|
||
|
with Image.open(tile_path) as img:
|
||
|
width, height = img.size
|
||
|
memfile = MemoryFile()
|
||
|
with memfile.open(
|
||
|
driver="GTiff",
|
||
|
height=height,
|
||
|
width=width,
|
||
|
count=3,
|
||
|
dtype="uint8",
|
||
|
crs="EPSG:3857",
|
||
|
transform=from_bounds(west, south, east, north, width, height),
|
||
|
) as dataset:
|
||
|
data = rasterio.open(tile_path).read()
|
||
|
dataset.write(data)
|
||
|
tile_data.append(memfile.open())
|
||
|
memfile.close()
|
||
|
|
||
|
mosaic, out_trans = merge(tile_data)
|
||
|
|
||
|
out_meta = tile_data[0].meta.copy()
|
||
|
out_meta.update(
|
||
|
{
|
||
|
"driver": "GTiff",
|
||
|
"height": mosaic.shape[1],
|
||
|
"width": mosaic.shape[2],
|
||
|
"transform": out_trans,
|
||
|
"crs": "EPSG:3857",
|
||
|
}
|
||
|
)
|
||
|
|
||
|
# Clean up MemoryFile instances to free up memory
|
||
|
for td in tile_data:
|
||
|
td.close()
|
||
|
|
||
|
del neighbors
|
||
|
del tile_data
|
||
|
gc.collect()
|
||
|
|
||
|
return mosaic, out_meta
|
||
|
|
||
|
def get_random_tiff_patch(
|
||
|
lat: float,
|
||
|
lon: float,
|
||
|
satellite_dataset_dir: str,
|
||
|
) -> (np.ndarray):
|
||
|
"""
|
||
|
Returns a random patch from the satellite image.
|
||
|
"""
|
||
|
|
||
|
tile = get_tile_from_coord(lat, lon, 17)
|
||
|
sat_years = ["2023", "2021", "2019", "2016"]
|
||
|
|
||
|
# Randomly select a satellite year
|
||
|
sat_year = random.choice(sat_years)
|
||
|
|
||
|
mosaic, _ = get_tiff_map(tile, sat_year, satellite_dataset_dir)
|
||
|
|
||
|
return mosaic
|
||
|
|
||
|
def get_tile_from_coord(
|
||
|
lat: float, lng: float, zoom_level: int
|
||
|
) -> mercantile.Tile:
|
||
|
"""
|
||
|
Returns the tile containing the given coordinates.
|
||
|
"""
|
||
|
tile = mercantile.tile(lng, lat, zoom_level)
|
||
|
return tile
|