445 lines
13 KiB
Python
445 lines
13 KiB
Python
# article dependencies
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torchvision
|
|
import torchvision.transforms as transforms
|
|
from torch.utils.data import Dataset, DataLoader
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
from tqdm import tqdm
|
|
from torchvision.utils import make_grid
|
|
import random
|
|
import json
|
|
import os
|
|
from PIL import Image
|
|
import resource
|
|
|
|
memory_limit_gb = 24
|
|
soft, hard = resource.getrlimit(resource.RLIMIT_AS)
|
|
resource.setrlimit(resource.RLIMIT_AS, (memory_limit_gb * 1024 * 1024 * 1024, hard))
|
|
|
|
# --------
|
|
# CONSTANTS
|
|
# --------
|
|
IMG_H = 160
|
|
IMG_W = 320
|
|
DATASET_PATH = "../../datasets/sat_data/woodbridge/images/"
|
|
|
|
# configuring device
|
|
if torch.cuda.is_available():
|
|
device = torch.device("cuda:0")
|
|
print("Running on the GPU")
|
|
else:
|
|
device = torch.device("cpu")
|
|
print("Running on the CPU")
|
|
|
|
|
|
def print_memory_usage_gpu():
|
|
print(
|
|
"GPU memory allocated:",
|
|
round(torch.cuda.memory_allocated(0) / 1024**3, 1),
|
|
"GB",
|
|
)
|
|
print(
|
|
"GPU memory cached: ", round(torch.cuda.memory_cached(0) / 1024**3, 1), "GB"
|
|
)
|
|
|
|
|
|
class GEImagePreprocess:
|
|
def __init__(
|
|
self,
|
|
path=DATASET_PATH,
|
|
patch_w=IMG_W,
|
|
patch_h=IMG_H,
|
|
):
|
|
super().__init__()
|
|
self.path = path
|
|
self.training_set = []
|
|
self.validation_set = []
|
|
self.patch_w = patch_w
|
|
self.patch_h = patch_h
|
|
|
|
def load_images(self):
|
|
images = os.listdir(self.path)
|
|
for image in tqdm(range(len(images)), desc="Loading images"):
|
|
img = Image.open(self.path + images[image])
|
|
img = self.preprocess_image(img)
|
|
|
|
return self.training_set, self.validation_set
|
|
|
|
def preprocess_image(self, image):
|
|
width, height = image.size
|
|
num_patches_w = width // self.patch_w
|
|
num_patches_h = height // self.patch_h
|
|
|
|
for i in range(num_patches_w):
|
|
for j in range(num_patches_h):
|
|
patch = image.crop(
|
|
(
|
|
i * self.patch_w,
|
|
j * self.patch_h,
|
|
(i + 1) * self.patch_w,
|
|
(j + 1) * self.patch_h,
|
|
)
|
|
)
|
|
patch = patch.convert("L")
|
|
patch = np.array(patch).astype(np.float32)
|
|
patch = patch / 255
|
|
if (i + j) % 10 == 0:
|
|
self.validation_set.append(patch)
|
|
else:
|
|
self.training_set.append(patch)
|
|
|
|
|
|
training_images, validation_images = GEImagePreprocess().load_images()
|
|
tr, val = GEImagePreprocess(path='../../datasets/sat_data/fountainhead/images/').load_images()
|
|
|
|
training_images.extend(tr)
|
|
validation_images.extend(val)
|
|
|
|
# defining dataset class
|
|
class GEDataset(Dataset):
|
|
def __init__(self, data, transforms=None):
|
|
self.data = data
|
|
self.transforms = transforms
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, idx):
|
|
image = self.data[idx]
|
|
|
|
if self.transforms != None:
|
|
image = self.transforms(image)
|
|
return image
|
|
|
|
|
|
# creating pytorch datasets
|
|
training_data = GEDataset(
|
|
training_images,
|
|
transforms=transforms.Compose(
|
|
[transforms.ToTensor(), transforms.Normalize((0.5), (0.5))]
|
|
),
|
|
)
|
|
|
|
validation_data = GEDataset(
|
|
validation_images,
|
|
transforms=transforms.Compose(
|
|
[transforms.ToTensor(), transforms.Normalize((0.5), (0.5))]
|
|
),
|
|
)
|
|
|
|
test_data = GEDataset(
|
|
validation_images,
|
|
transforms=transforms.Compose(
|
|
[
|
|
transforms.ToTensor(),
|
|
transforms.Normalize((0.5), (0.5)),
|
|
]
|
|
),
|
|
)
|
|
|
|
|
|
class Encoder(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels=1,
|
|
out_channels=128,
|
|
latent_dim=1000,
|
|
stride=2,
|
|
act_fn=nn.LeakyReLU(),
|
|
debug=False,
|
|
):
|
|
super().__init__()
|
|
self.debug = debug
|
|
|
|
self.net = nn.Sequential(
|
|
nn.Conv2d(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
kernel_size=2,
|
|
stride=stride,
|
|
),
|
|
nn.BatchNorm2d(out_channels),
|
|
act_fn,
|
|
nn.Conv2d(
|
|
in_channels=out_channels,
|
|
out_channels=out_channels * 2,
|
|
kernel_size=2,
|
|
stride=stride,
|
|
),
|
|
nn.BatchNorm2d(out_channels * 2),
|
|
act_fn,
|
|
nn.Conv2d(
|
|
in_channels=out_channels * 2,
|
|
out_channels=out_channels * 4,
|
|
kernel_size=2,
|
|
stride=stride,
|
|
),
|
|
nn.BatchNorm2d(out_channels * 4),
|
|
act_fn,
|
|
nn.Conv2d(
|
|
in_channels=out_channels * 4,
|
|
out_channels=out_channels * 8,
|
|
kernel_size=2,
|
|
stride=stride,
|
|
),
|
|
nn.BatchNorm2d(out_channels * 8),
|
|
act_fn,
|
|
nn.Conv2d(
|
|
in_channels=out_channels * 8,
|
|
out_channels=out_channels * 8,
|
|
kernel_size=2,
|
|
stride=stride,
|
|
),
|
|
act_fn,
|
|
nn.BatchNorm2d(out_channels * 8),
|
|
nn.Flatten(),
|
|
nn.Linear(51200, latent_dim),
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = x.view(-1, 1, 160, 320)
|
|
# Print also the function name
|
|
for layer in self.net:
|
|
x = layer(x)
|
|
if self.debug:
|
|
print(layer.__class__.__name__, "output shape:\t", x.shape)
|
|
encoded_latent_image = x
|
|
return encoded_latent_image
|
|
|
|
|
|
# defining decoder
|
|
class Decoder(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels=1,
|
|
out_channels=128,
|
|
latent_dim=1000,
|
|
stride=2,
|
|
kernel_size=2,
|
|
act_fn=nn.LeakyReLU(),
|
|
debug=False,
|
|
):
|
|
super().__init__()
|
|
self.debug = debug
|
|
|
|
self.out_channels = out_channels
|
|
|
|
self.linear = nn.Sequential(
|
|
nn.Linear(latent_dim, 51200),
|
|
)
|
|
|
|
self.conv = nn.Sequential(
|
|
nn.ConvTranspose2d(
|
|
in_channels=out_channels * 8,
|
|
out_channels=out_channels * 8,
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
),
|
|
act_fn,
|
|
nn.BatchNorm2d(out_channels * 8),
|
|
nn.ConvTranspose2d(
|
|
in_channels=out_channels * 8,
|
|
out_channels=out_channels * 4,
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
),
|
|
act_fn,
|
|
nn.BatchNorm2d(out_channels * 4),
|
|
nn.ConvTranspose2d(
|
|
in_channels=out_channels * 4,
|
|
out_channels=out_channels * 2,
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
),
|
|
act_fn,
|
|
nn.BatchNorm2d(out_channels * 2),
|
|
nn.ConvTranspose2d(
|
|
in_channels=out_channels * 2,
|
|
out_channels=out_channels,
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
),
|
|
act_fn,
|
|
nn.BatchNorm2d(out_channels),
|
|
nn.ConvTranspose2d(
|
|
in_channels=out_channels,
|
|
out_channels=in_channels,
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
),
|
|
act_fn,
|
|
nn.BatchNorm2d(in_channels),
|
|
)
|
|
|
|
def forward(self, x):
|
|
output = self.linear(x)
|
|
output = output.view(len(output), self.out_channels * 8, 5, 10)
|
|
for layer in self.conv:
|
|
output = layer(output)
|
|
if self.debug:
|
|
print(layer.__class__.__name__, "output shape:\t", output.shape)
|
|
return output
|
|
|
|
|
|
# defining autoencoder
|
|
class Autoencoder(nn.Module):
|
|
def __init__(self, encoder, decoder):
|
|
super().__init__()
|
|
self.encoder = encoder
|
|
self.encoder.to(device)
|
|
|
|
self.decoder = decoder
|
|
self.decoder.to(device)
|
|
|
|
def forward(self, x):
|
|
encoded = self.encoder(x)
|
|
decoded = self.decoder(encoded)
|
|
return decoded
|
|
|
|
|
|
class ConvolutionalAutoencoder:
|
|
def __init__(self, autoencoder):
|
|
self.network = autoencoder
|
|
self.optimizer = torch.optim.RMSprop(
|
|
self.network.parameters(), lr=0.01, alpha=0.99, eps=1e-08,
|
|
weight_decay=0, momentum=0, centered=False
|
|
)
|
|
|
|
def train(
|
|
self, loss_function, epochs, batch_size, training_set, validation_set, test_set
|
|
):
|
|
# creating log
|
|
log_dict = {
|
|
"training_loss_per_batch": [],
|
|
"validation_loss_per_batch": [],
|
|
"visualizations": [],
|
|
}
|
|
|
|
# defining weight initialization function
|
|
def init_weights(module):
|
|
if isinstance(module, nn.Conv2d):
|
|
torch.nn.init.xavier_uniform_(module.weight)
|
|
module.bias.data.fill_(0.01)
|
|
elif isinstance(module, nn.Linear):
|
|
torch.nn.init.xavier_uniform_(module.weight)
|
|
module.bias.data.fill_(0.01)
|
|
|
|
# initializing network weights
|
|
self.network.apply(init_weights)
|
|
|
|
# creating dataloaders
|
|
train_loader = DataLoader(training_set, batch_size)
|
|
val_loader = DataLoader(validation_set, batch_size)
|
|
test_loader = DataLoader(test_set, 10)
|
|
|
|
# setting convnet to training mode
|
|
self.network.train()
|
|
self.network.to(device)
|
|
|
|
for epoch in range(epochs):
|
|
print(f"Epoch {epoch+1}/{epochs}")
|
|
train_losses = []
|
|
|
|
# ------------
|
|
# TRAINING
|
|
# ------------
|
|
print("training...")
|
|
for images in tqdm(train_loader):
|
|
# sending images to device
|
|
images = images.to(device)
|
|
# reconstructing images
|
|
output = self.network(images)
|
|
# computing loss
|
|
loss = loss_function(output, images.view(-1, 1, 160, 320))
|
|
# zeroing gradients
|
|
self.optimizer.zero_grad()
|
|
# calculating gradients
|
|
loss.backward()
|
|
# optimizing weights
|
|
self.optimizer.step()
|
|
|
|
# --------------
|
|
# LOGGING
|
|
# --------------
|
|
log_dict["training_loss_per_batch"].append(loss.item())
|
|
|
|
# --------------
|
|
# VALIDATION
|
|
# --------------
|
|
print("validating...")
|
|
for val_images in tqdm(val_loader):
|
|
with torch.no_grad():
|
|
# sending validation images to device
|
|
val_images = val_images.to(device)
|
|
# reconstructing images
|
|
output = self.network(val_images)
|
|
# computing validation loss
|
|
val_loss = loss_function(output, val_images.view(-1, 1, 160, 320))
|
|
|
|
# --------------
|
|
# LOGGING
|
|
# --------------
|
|
log_dict["validation_loss_per_batch"].append(val_loss.item())
|
|
|
|
# --------------
|
|
# VISUALISATION
|
|
# --------------
|
|
print(
|
|
f"training_loss: {round(loss.item(), 4)} validation_loss: {round(val_loss.item(), 4)}"
|
|
)
|
|
if epoch % 5 == 0:
|
|
for test_images in test_loader:
|
|
# sending test images to device
|
|
test_images = test_images.to(device)
|
|
with torch.no_grad():
|
|
# reconstructing test images
|
|
reconstructed_imgs = self.network(test_images)
|
|
# sending reconstructed and images to cpu to allow for visualization
|
|
reconstructed_imgs = reconstructed_imgs.cpu()
|
|
test_images = test_images.cpu()
|
|
|
|
# visualisation
|
|
imgs = torch.stack(
|
|
[test_images.view(-1, 1, 160, 320), reconstructed_imgs], dim=1
|
|
).flatten(0, 1)
|
|
grid = make_grid(imgs, nrow=10, normalize=True, padding=1)
|
|
grid = grid.permute(1, 2, 0)
|
|
plt.figure(dpi=170)
|
|
plt.title(
|
|
f"Original/Reconstructed, training loss: {round(loss.item(), 4)} validation loss: {round(val_loss.item(), 4)}"
|
|
)
|
|
plt.imshow(grid)
|
|
log_dict["visualizations"].append(grid)
|
|
plt.axis("off")
|
|
plt.savefig(f"epoch_{epoch+1}.png")
|
|
break
|
|
|
|
return log_dict
|
|
|
|
def autoencode(self, x):
|
|
return self.network(x)
|
|
|
|
def encode(self, x):
|
|
encoder = self.network.encoder
|
|
return encoder(x)
|
|
|
|
def decode(self, x):
|
|
decoder = self.network.decoder
|
|
return decoder(x)
|
|
|
|
|
|
# training model
|
|
model = ConvolutionalAutoencoder(Autoencoder(Encoder(), Decoder()))
|
|
|
|
log_dict = model.train(
|
|
nn.MSELoss(),
|
|
epochs=30,
|
|
batch_size=14,
|
|
training_set=training_data,
|
|
validation_set=validation_data,
|
|
test_set=test_data,
|
|
)
|