Update Decoder factorization
parent
d2829e27fa
commit
669e648719
|
@ -1 +1,2 @@
|
|||
.venv/*
|
||||
visualizations/*
|
||||
|
|
179
code/nn.py
179
code/nn.py
|
@ -1,4 +1,5 @@
|
|||
# article dependencies
|
||||
""" Autoencoder for satellite images """
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
@ -9,12 +10,14 @@ import numpy as np
|
|||
import matplotlib.pyplot as plt
|
||||
from tqdm import tqdm
|
||||
from torchvision.utils import make_grid
|
||||
import random
|
||||
import json
|
||||
import os
|
||||
from PIL import Image
|
||||
import resource
|
||||
import math
|
||||
|
||||
# -------------
|
||||
# MEMORY SAFETY
|
||||
# -------------
|
||||
memory_limit_gb = 24
|
||||
soft, hard = resource.getrlimit(resource.RLIMIT_AS)
|
||||
resource.setrlimit(resource.RLIMIT_AS, (memory_limit_gb * 1024 * 1024 * 1024, hard))
|
||||
|
@ -23,8 +26,8 @@ resource.setrlimit(resource.RLIMIT_AS, (memory_limit_gb * 1024 * 1024 * 1024, ha
|
|||
# CONSTANTS
|
||||
# --------
|
||||
IMG_H = 160
|
||||
IMG_W = 320
|
||||
DATASET_PATH = "../../datasets/sat_data/woodbridge/images/"
|
||||
IMG_W = IMG_H * 2
|
||||
DATASET_PATH = "../../diplomska/datasets/sat_data/woodbridge/images/"
|
||||
|
||||
# configuring device
|
||||
if torch.cuda.is_available():
|
||||
|
@ -57,6 +60,7 @@ class GEImagePreprocess:
|
|||
self.path = path
|
||||
self.training_set = []
|
||||
self.validation_set = []
|
||||
self.test_set = []
|
||||
self.patch_w = patch_w
|
||||
self.patch_h = patch_h
|
||||
|
||||
|
@ -66,7 +70,7 @@ class GEImagePreprocess:
|
|||
img = Image.open(self.path + images[image])
|
||||
img = self.preprocess_image(img)
|
||||
|
||||
return self.training_set, self.validation_set
|
||||
return self.training_set, self.validation_set, self.test_set
|
||||
|
||||
def preprocess_image(self, image):
|
||||
width, height = image.size
|
||||
|
@ -88,17 +92,12 @@ class GEImagePreprocess:
|
|||
patch = patch / 255
|
||||
if (i + j) % 10 == 0:
|
||||
self.validation_set.append(patch)
|
||||
if (i + j) % 10 == 1:
|
||||
self.test_set.append(patch)
|
||||
else:
|
||||
self.training_set.append(patch)
|
||||
|
||||
|
||||
training_images, validation_images = GEImagePreprocess().load_images()
|
||||
tr, val = GEImagePreprocess(path='../../datasets/sat_data/fountainhead/images/').load_images()
|
||||
|
||||
training_images.extend(tr)
|
||||
validation_images.extend(val)
|
||||
|
||||
# defining dataset class
|
||||
class GEDataset(Dataset):
|
||||
def __init__(self, data, transforms=None):
|
||||
self.data = data
|
||||
|
@ -115,32 +114,6 @@ class GEDataset(Dataset):
|
|||
return image
|
||||
|
||||
|
||||
# creating pytorch datasets
|
||||
training_data = GEDataset(
|
||||
training_images,
|
||||
transforms=transforms.Compose(
|
||||
[transforms.ToTensor(), transforms.Normalize((0.5), (0.5))]
|
||||
),
|
||||
)
|
||||
|
||||
validation_data = GEDataset(
|
||||
validation_images,
|
||||
transforms=transforms.Compose(
|
||||
[transforms.ToTensor(), transforms.Normalize((0.5), (0.5))]
|
||||
),
|
||||
)
|
||||
|
||||
test_data = GEDataset(
|
||||
validation_images,
|
||||
transforms=transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5), (0.5)),
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -196,11 +169,11 @@ class Encoder(nn.Module):
|
|||
act_fn,
|
||||
nn.BatchNorm2d(out_channels * 8),
|
||||
nn.Flatten(),
|
||||
nn.Linear(51200, latent_dim),
|
||||
nn.Linear(IMG_H * IMG_W, latent_dim),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.view(-1, 1, 160, 320)
|
||||
x = x.view(-1, 1, IMG_H, IMG_W)
|
||||
# Print also the function name
|
||||
for layer in self.net:
|
||||
x = layer(x)
|
||||
|
@ -210,7 +183,6 @@ class Encoder(nn.Module):
|
|||
return encoded_latent_image
|
||||
|
||||
|
||||
# defining decoder
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -227,8 +199,10 @@ class Decoder(nn.Module):
|
|||
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.v, self.u = self.factor()
|
||||
|
||||
self.linear = nn.Sequential(
|
||||
nn.Linear(latent_dim, 51200),
|
||||
nn.Linear(latent_dim, IMG_H * IMG_W),
|
||||
)
|
||||
|
||||
self.conv = nn.Sequential(
|
||||
|
@ -274,9 +248,16 @@ class Decoder(nn.Module):
|
|||
nn.BatchNorm2d(in_channels),
|
||||
)
|
||||
|
||||
def factor(self):
|
||||
dim = IMG_H * IMG_W
|
||||
f = dim / (self.out_channels * 8)
|
||||
v = np.sqrt(f // 2).astype(int)
|
||||
u = (f // v).astype(int)
|
||||
return v, u
|
||||
|
||||
def forward(self, x):
|
||||
output = self.linear(x)
|
||||
output = output.view(len(output), self.out_channels * 8, 5, 10)
|
||||
output = output.view(len(output), self.out_channels * 8, self.v, self.u)
|
||||
for layer in self.conv:
|
||||
output = layer(output)
|
||||
if self.debug:
|
||||
|
@ -284,7 +265,6 @@ class Decoder(nn.Module):
|
|||
return output
|
||||
|
||||
|
||||
# defining autoencoder
|
||||
class Autoencoder(nn.Module):
|
||||
def __init__(self, encoder, decoder):
|
||||
super().__init__()
|
||||
|
@ -304,8 +284,13 @@ class ConvolutionalAutoencoder:
|
|||
def __init__(self, autoencoder):
|
||||
self.network = autoencoder
|
||||
self.optimizer = torch.optim.RMSprop(
|
||||
self.network.parameters(), lr=0.01, alpha=0.99, eps=1e-08,
|
||||
weight_decay=0, momentum=0, centered=False
|
||||
self.network.parameters(),
|
||||
lr=0.01,
|
||||
alpha=0.99,
|
||||
eps=1e-08,
|
||||
weight_decay=0,
|
||||
momentum=0,
|
||||
centered=False,
|
||||
)
|
||||
|
||||
def train(
|
||||
|
@ -353,7 +338,7 @@ class ConvolutionalAutoencoder:
|
|||
# reconstructing images
|
||||
output = self.network(images)
|
||||
# computing loss
|
||||
loss = loss_function(output, images.view(-1, 1, 160, 320))
|
||||
loss = loss_function(output, images.view(-1, 1, IMG_H, IMG_W))
|
||||
# zeroing gradients
|
||||
self.optimizer.zero_grad()
|
||||
# calculating gradients
|
||||
|
@ -377,7 +362,7 @@ class ConvolutionalAutoencoder:
|
|||
# reconstructing images
|
||||
output = self.network(val_images)
|
||||
# computing validation loss
|
||||
val_loss = loss_function(output, val_images.view(-1, 1, 160, 320))
|
||||
val_loss = loss_function(output, val_images.view(-1, 1, IMG_H, IMG_W))
|
||||
|
||||
# --------------
|
||||
# LOGGING
|
||||
|
@ -390,20 +375,20 @@ class ConvolutionalAutoencoder:
|
|||
print(
|
||||
f"training_loss: {round(loss.item(), 4)} validation_loss: {round(val_loss.item(), 4)}"
|
||||
)
|
||||
if epoch % 5 == 0:
|
||||
for test_images in test_loader:
|
||||
# sending test images to device
|
||||
test_images = test_images.to(device)
|
||||
with torch.no_grad():
|
||||
# reconstructing test images
|
||||
reconstructed_imgs = self.network(test_images)
|
||||
plt_ix = 0
|
||||
for test_images in test_loader:
|
||||
# sending test images to device
|
||||
test_images = test_images.to(device)
|
||||
with torch.no_grad():
|
||||
# reconstructing test images
|
||||
reconstructed_imgs = self.network(test_images)
|
||||
# sending reconstructed and images to cpu to allow for visualization
|
||||
reconstructed_imgs = reconstructed_imgs.cpu()
|
||||
test_images = test_images.cpu()
|
||||
|
||||
# visualisation
|
||||
imgs = torch.stack(
|
||||
[test_images.view(-1, 1, 160, 320), reconstructed_imgs], dim=1
|
||||
[test_images.view(-1, 1, IMG_H, IMG_W), reconstructed_imgs], dim=1
|
||||
).flatten(0, 1)
|
||||
grid = make_grid(imgs, nrow=10, normalize=True, padding=1)
|
||||
grid = grid.permute(1, 2, 0)
|
||||
|
@ -414,8 +399,17 @@ class ConvolutionalAutoencoder:
|
|||
plt.imshow(grid)
|
||||
log_dict["visualizations"].append(grid)
|
||||
plt.axis("off")
|
||||
plt.savefig(f"epoch_{epoch+1}.png")
|
||||
break
|
||||
|
||||
# Check if directory exists, if not create it
|
||||
if not os.path.exists("visualizations"):
|
||||
os.makedirs("visualizations")
|
||||
if not os.path.exists(f"visualizations/epoch_{epoch+1}"):
|
||||
os.makedirs(f"visualizations/epoch_{epoch+1}")
|
||||
|
||||
plt.savefig(f"visualizations/epoch_{epoch+1}/img_{plt_ix}.png")
|
||||
plt.clf()
|
||||
plt.close()
|
||||
plt_ix += 1
|
||||
|
||||
return log_dict
|
||||
|
||||
|
@ -431,14 +425,61 @@ class ConvolutionalAutoencoder:
|
|||
return decoder(x)
|
||||
|
||||
|
||||
# training model
|
||||
model = ConvolutionalAutoencoder(Autoencoder(Encoder(), Decoder()))
|
||||
def preprocess_data():
|
||||
"""Load images and preprocess them into torch tensors"""
|
||||
|
||||
log_dict = model.train(
|
||||
nn.MSELoss(),
|
||||
epochs=30,
|
||||
batch_size=14,
|
||||
training_set=training_data,
|
||||
validation_set=validation_data,
|
||||
test_set=test_data,
|
||||
)
|
||||
training_images, validation_images, test_images = GEImagePreprocess().load_images()
|
||||
tr, val, test = GEImagePreprocess(
|
||||
path="../../diplomska/datasets/sat_data/fountainhead/images/"
|
||||
).load_images()
|
||||
training_images.extend(tr)
|
||||
validation_images.extend(val)
|
||||
test_images.extend(test)
|
||||
print(
|
||||
f"Training on {len(training_images)} images, validating on {len(validation_images)} images, testing on {len(test_images)} images"
|
||||
)
|
||||
# creating pytorch datasets
|
||||
training_data = GEDataset(
|
||||
training_images,
|
||||
transforms=transforms.Compose(
|
||||
[transforms.ToTensor(), transforms.Normalize((0.5), (0.5))]
|
||||
),
|
||||
)
|
||||
|
||||
validation_data = GEDataset(
|
||||
validation_images,
|
||||
transforms=transforms.Compose(
|
||||
[transforms.ToTensor(), transforms.Normalize((0.5), (0.5))]
|
||||
),
|
||||
)
|
||||
|
||||
test_data = GEDataset(
|
||||
validation_images,
|
||||
transforms=transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5), (0.5)),
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
return training_data, validation_data, test_data
|
||||
|
||||
|
||||
def main():
|
||||
training_data, validation_data, test_data = preprocess_data()
|
||||
|
||||
model = ConvolutionalAutoencoder(Autoencoder(Encoder(), Decoder()))
|
||||
|
||||
log_dict = model.train(
|
||||
nn.MSELoss(),
|
||||
epochs=30,
|
||||
batch_size=14,
|
||||
training_set=training_data,
|
||||
validation_set=validation_data,
|
||||
test_set=test_data,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
Loading…
Reference in New Issue