|
|
|
@ -2,8 +2,6 @@
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
import torchvision
|
|
|
|
|
import torchvision.transforms as transforms
|
|
|
|
|
from torch.utils.data import Dataset, DataLoader
|
|
|
|
|
import numpy as np
|
|
|
|
@ -14,25 +12,27 @@ import os
|
|
|
|
|
from PIL import Image
|
|
|
|
|
import resource
|
|
|
|
|
import argparse
|
|
|
|
|
import pickle
|
|
|
|
|
from multiprocessing import Pool
|
|
|
|
|
from functools import partial
|
|
|
|
|
import torchvision.transforms.functional as TF
|
|
|
|
|
|
|
|
|
|
# -------------
|
|
|
|
|
# MEMORY SAFETY
|
|
|
|
|
# -------------
|
|
|
|
|
memory_limit_gb = 24
|
|
|
|
|
soft, hard = resource.getrlimit(resource.RLIMIT_AS)
|
|
|
|
|
resource.setrlimit(resource.RLIMIT_AS, (memory_limit_gb * 1024 * 1024 * 1024, hard))
|
|
|
|
|
resource.setrlimit(resource.RLIMIT_AS, (memory_limit_gb * 1024**3, hard))
|
|
|
|
|
|
|
|
|
|
# --------
|
|
|
|
|
# CONSTANTS
|
|
|
|
|
# --------
|
|
|
|
|
IMG_H = 160 # On better gpu use 256 and adam optimizer
|
|
|
|
|
IMG_W = IMG_H * 2
|
|
|
|
|
IMG_H = 224 # On better gpu use 256 and adam optimizer
|
|
|
|
|
IMG_W = IMG_H
|
|
|
|
|
DATASET_PATHS = [
|
|
|
|
|
"../../diplomska/datasets/sat_data/woodbridge/images/",
|
|
|
|
|
"../../diplomska/datasets/sat_data/fountainhead/images/",
|
|
|
|
|
"../../diplomska/datasets/village/images/",
|
|
|
|
|
"../../diplomska/datasets/gravel_pit/images/",
|
|
|
|
|
"../datasets/train/",
|
|
|
|
|
]
|
|
|
|
|
LINE = "\n----------------------------------------\n"
|
|
|
|
|
|
|
|
|
|
# configuring device
|
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
@ -47,9 +47,7 @@ def print_memory_usage_gpu():
|
|
|
|
|
round(torch.cuda.memory_allocated(0) / 1024**3, 1),
|
|
|
|
|
"GB",
|
|
|
|
|
)
|
|
|
|
|
print(
|
|
|
|
|
"GPU memory cached: ", round(torch.cuda.memory_cached(0) / 1024**3, 1), "GB"
|
|
|
|
|
)
|
|
|
|
|
print("GPU memory cached:", round(torch.cuda.memory_cached(0) / 1024**3, 1), "GB")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GEImagePreprocess:
|
|
|
|
@ -64,41 +62,51 @@ class GEImagePreprocess:
|
|
|
|
|
self.training_set = []
|
|
|
|
|
self.validation_set = []
|
|
|
|
|
self.test_set = []
|
|
|
|
|
self.entry_paths = []
|
|
|
|
|
self.patch_w = patch_w
|
|
|
|
|
self.patch_h = patch_h
|
|
|
|
|
|
|
|
|
|
def load_images(self):
|
|
|
|
|
images = os.listdir(self.path)
|
|
|
|
|
for image in tqdm(range(len(images)), desc="Loading images"):
|
|
|
|
|
img = Image.open(self.path + images[image])
|
|
|
|
|
img = self.preprocess_image(img)
|
|
|
|
|
|
|
|
|
|
self.get_entry_paths(self.path)
|
|
|
|
|
load_image_partial = partial(self.load_image_helper)
|
|
|
|
|
with Pool() as pool:
|
|
|
|
|
results = pool.map(load_image_partial, self.entry_paths)
|
|
|
|
|
self.split_dataset(results)
|
|
|
|
|
return self.training_set, self.validation_set, self.test_set
|
|
|
|
|
|
|
|
|
|
def preprocess_image(self, image):
|
|
|
|
|
width, height = image.size
|
|
|
|
|
num_patches_w = width // self.patch_w
|
|
|
|
|
num_patches_h = height // self.patch_h
|
|
|
|
|
def load_image_helper(self, entry_path):
|
|
|
|
|
try:
|
|
|
|
|
img = Image.open(entry_path)
|
|
|
|
|
except PIL.UnidentifiedImageError as e:
|
|
|
|
|
print("Could not open an image: ", entry_path)
|
|
|
|
|
print(e)
|
|
|
|
|
return None
|
|
|
|
|
return self.preprocess_image(img)
|
|
|
|
|
|
|
|
|
|
for i in range(num_patches_w):
|
|
|
|
|
for j in range(num_patches_h):
|
|
|
|
|
patch = image.crop(
|
|
|
|
|
(
|
|
|
|
|
i * self.patch_w,
|
|
|
|
|
j * self.patch_h,
|
|
|
|
|
(i + 1) * self.patch_w,
|
|
|
|
|
(j + 1) * self.patch_h,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
patch = patch.convert("L")
|
|
|
|
|
patch = np.array(patch).astype(np.float32)
|
|
|
|
|
patch = patch / 255
|
|
|
|
|
if (i + j) % 30 == 0:
|
|
|
|
|
self.validation_set.append(patch)
|
|
|
|
|
if (i + j) % 30 == 1:
|
|
|
|
|
self.test_set.append(patch)
|
|
|
|
|
else:
|
|
|
|
|
self.training_set.append(patch)
|
|
|
|
|
def get_entry_paths(self, path):
|
|
|
|
|
entries = os.listdir(path)
|
|
|
|
|
for entry in entries:
|
|
|
|
|
entry_path = path + "/" + entry
|
|
|
|
|
if os.path.isdir(entry_path):
|
|
|
|
|
self.get_entry_paths(entry_path + "/")
|
|
|
|
|
if entry_path.endswith(".jpeg"):
|
|
|
|
|
self.entry_paths.append(entry_path)
|
|
|
|
|
|
|
|
|
|
def split_dataset(self, dataset):
|
|
|
|
|
for image in tqdm(range(len(dataset)), desc="Splitting dataset"):
|
|
|
|
|
if image % 30 == 0:
|
|
|
|
|
self.validation_set.append(dataset[image])
|
|
|
|
|
elif image % 30 == 1:
|
|
|
|
|
self.test_set.append(dataset[image])
|
|
|
|
|
else:
|
|
|
|
|
self.training_set.append(dataset[image])
|
|
|
|
|
|
|
|
|
|
def preprocess_image(self, image):
|
|
|
|
|
image = image.resize((IMG_W, IMG_H))
|
|
|
|
|
image = image.convert("L")
|
|
|
|
|
image = np.array(image).astype(np.float32)
|
|
|
|
|
image = image / 255
|
|
|
|
|
return image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GEDataset(Dataset):
|
|
|
|
@ -123,45 +131,55 @@ class Encoder(nn.Module):
|
|
|
|
|
in_channels=1,
|
|
|
|
|
out_channels=128,
|
|
|
|
|
latent_dim=1000,
|
|
|
|
|
kernel_size=2,
|
|
|
|
|
stride=2,
|
|
|
|
|
act_fn=nn.LeakyReLU(),
|
|
|
|
|
debug=False,
|
|
|
|
|
debug=True,
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.debug = debug
|
|
|
|
|
|
|
|
|
|
self.net = nn.Sequential(
|
|
|
|
|
self.linear = nn.Sequential(
|
|
|
|
|
nn.Flatten(),
|
|
|
|
|
nn.Linear(IMG_H * IMG_W, latent_dim),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.conv = nn.Sequential(
|
|
|
|
|
nn.Conv2d(
|
|
|
|
|
in_channels=in_channels,
|
|
|
|
|
out_channels=out_channels,
|
|
|
|
|
kernel_size=2,
|
|
|
|
|
kernel_size=kernel_size,
|
|
|
|
|
stride=stride,
|
|
|
|
|
),
|
|
|
|
|
nn.BatchNorm2d(out_channels),
|
|
|
|
|
nn.Dropout(0.4),
|
|
|
|
|
act_fn,
|
|
|
|
|
nn.Conv2d(
|
|
|
|
|
in_channels=out_channels,
|
|
|
|
|
out_channels=out_channels * 2,
|
|
|
|
|
kernel_size=2,
|
|
|
|
|
kernel_size=kernel_size,
|
|
|
|
|
stride=stride,
|
|
|
|
|
),
|
|
|
|
|
nn.BatchNorm2d(out_channels * 2),
|
|
|
|
|
nn.Dropout(0.3),
|
|
|
|
|
act_fn,
|
|
|
|
|
nn.Conv2d(
|
|
|
|
|
in_channels=out_channels * 2,
|
|
|
|
|
out_channels=out_channels * 4,
|
|
|
|
|
kernel_size=2,
|
|
|
|
|
kernel_size=kernel_size,
|
|
|
|
|
stride=stride,
|
|
|
|
|
),
|
|
|
|
|
nn.BatchNorm2d(out_channels * 4),
|
|
|
|
|
nn.Dropout(0.2),
|
|
|
|
|
act_fn,
|
|
|
|
|
nn.Conv2d(
|
|
|
|
|
in_channels=out_channels * 4,
|
|
|
|
|
out_channels=out_channels * 8,
|
|
|
|
|
kernel_size=2,
|
|
|
|
|
kernel_size=kernel_size,
|
|
|
|
|
stride=stride,
|
|
|
|
|
),
|
|
|
|
|
nn.BatchNorm2d(out_channels * 8),
|
|
|
|
|
nn.Dropout(0.1),
|
|
|
|
|
act_fn,
|
|
|
|
|
nn.Conv2d(
|
|
|
|
|
in_channels=out_channels * 8,
|
|
|
|
@ -171,18 +189,22 @@ class Encoder(nn.Module):
|
|
|
|
|
),
|
|
|
|
|
act_fn,
|
|
|
|
|
nn.BatchNorm2d(out_channels * 8),
|
|
|
|
|
nn.Flatten(),
|
|
|
|
|
nn.Linear(IMG_H * IMG_W, latent_dim),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
x = x.view(-1, 1, IMG_H, IMG_W)
|
|
|
|
|
# Print also the function name
|
|
|
|
|
# for layer in self.net:
|
|
|
|
|
# x = layer(x)
|
|
|
|
|
# if self.debug:
|
|
|
|
|
# print(layer.__class__.__name__, "output shape:\t", x.shape)
|
|
|
|
|
encoded_latent_image = self.net(x)
|
|
|
|
|
#for layer in self.conv:
|
|
|
|
|
# x = layer(x)
|
|
|
|
|
# if self.debug:
|
|
|
|
|
# print(layer.__class__.__name__, "output shape:\t", x.shape)
|
|
|
|
|
|
|
|
|
|
#for layer in self.linear:
|
|
|
|
|
# x = layer(x)
|
|
|
|
|
# if self.debug:
|
|
|
|
|
# print(layer.__class__.__name__, "output shape:\t", x.shape)
|
|
|
|
|
#encoded_latent_image = x
|
|
|
|
|
encoded_latent_image = self.conv(x)
|
|
|
|
|
encoded_latent_image = self.linear(encoded_latent_image)
|
|
|
|
|
return encoded_latent_image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -247,7 +269,7 @@ class Decoder(nn.Module):
|
|
|
|
|
kernel_size=kernel_size,
|
|
|
|
|
stride=stride,
|
|
|
|
|
),
|
|
|
|
|
act_fn,
|
|
|
|
|
nn.ReLU(),
|
|
|
|
|
nn.BatchNorm2d(in_channels),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
@ -260,7 +282,7 @@ class Decoder(nn.Module):
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
output = self.linear(x)
|
|
|
|
|
output = output.view(len(output), self.out_channels * 8, self.v, self.u)
|
|
|
|
|
output = output.view(len(output), self.out_channels * 8, 7, 7)
|
|
|
|
|
# for layer in self.conv:
|
|
|
|
|
# output = layer(output)
|
|
|
|
|
# if self.debug:
|
|
|
|
@ -289,7 +311,7 @@ class ConvolutionalAutoencoder:
|
|
|
|
|
self.network = autoencoder
|
|
|
|
|
self.optimizer = torch.optim.RMSprop(
|
|
|
|
|
self.network.parameters(),
|
|
|
|
|
lr=0.01,
|
|
|
|
|
lr=1e-3,
|
|
|
|
|
alpha=0.99,
|
|
|
|
|
eps=1e-08,
|
|
|
|
|
weight_decay=0,
|
|
|
|
@ -323,7 +345,6 @@ class ConvolutionalAutoencoder:
|
|
|
|
|
|
|
|
|
|
for epoch in range(epochs):
|
|
|
|
|
print(f"Epoch {epoch+1}/{epochs}")
|
|
|
|
|
train_losses = []
|
|
|
|
|
|
|
|
|
|
# ------------
|
|
|
|
|
# TRAINING
|
|
|
|
@ -354,15 +375,14 @@ class ConvolutionalAutoencoder:
|
|
|
|
|
# reconstructing images
|
|
|
|
|
output = self.network(val_images)
|
|
|
|
|
# computing validation loss
|
|
|
|
|
val_loss = loss_function(
|
|
|
|
|
output, val_images.view(-1, 1, IMG_H, IMG_W)
|
|
|
|
|
)
|
|
|
|
|
val_loss = loss_function(output.flatten(), val_images.flatten())
|
|
|
|
|
|
|
|
|
|
# --------------
|
|
|
|
|
# VISUALISATION
|
|
|
|
|
# --------------
|
|
|
|
|
print(
|
|
|
|
|
f"training_loss: {round(loss.item(), 4)} validation_loss: {round(val_loss.item(), 4)}"
|
|
|
|
|
f"training_loss: {round(loss.item(), 4)} \
|
|
|
|
|
validation_loss: {round(val_loss.item(), 4)}"
|
|
|
|
|
)
|
|
|
|
|
plt_ix = 0
|
|
|
|
|
for test_images in test_loader:
|
|
|
|
@ -380,26 +400,74 @@ class ConvolutionalAutoencoder:
|
|
|
|
|
[test_images.view(-1, 1, IMG_H, IMG_W), reconstructed_imgs],
|
|
|
|
|
dim=1,
|
|
|
|
|
).flatten(0, 1)
|
|
|
|
|
grid = make_grid(imgs, nrow=10, normalize=True, padding=1)
|
|
|
|
|
grid = grid.permute(1, 2, 0)
|
|
|
|
|
plt.figure(dpi=170)
|
|
|
|
|
plt.title(
|
|
|
|
|
f"Original/Reconstructed, training loss: {round(loss.item(), 4)} validation loss: {round(val_loss.item(), 4)}"
|
|
|
|
|
)
|
|
|
|
|
plt.imshow(grid)
|
|
|
|
|
plt.axis("off")
|
|
|
|
|
|
|
|
|
|
# Check if directory exists, if not create it
|
|
|
|
|
if not os.path.exists("visualizations"):
|
|
|
|
|
os.makedirs("visualizations")
|
|
|
|
|
if not os.path.exists(f"visualizations/epoch_{epoch+1}"):
|
|
|
|
|
os.makedirs(f"visualizations/epoch_{epoch+1}")
|
|
|
|
|
os.makedirs("visualizations/")
|
|
|
|
|
if not os.path.exists(f"visualizations/epoch_{epoch+1}/"):
|
|
|
|
|
os.makedirs(f"visualizations/epoch_{epoch+1}/")
|
|
|
|
|
|
|
|
|
|
for i, img in enumerate(imgs):
|
|
|
|
|
pil_img = TF.to_pil_image(img)
|
|
|
|
|
pil_img.save(
|
|
|
|
|
f"visualizations/epoch_{epoch+1}/img_{plt_ix}_{i}.png"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
plt.savefig(f"visualizations/epoch_{epoch+1}/img_{plt_ix}.png")
|
|
|
|
|
plt.clf()
|
|
|
|
|
plt.close()
|
|
|
|
|
plt_ix += 1
|
|
|
|
|
|
|
|
|
|
def test(self, loss_function, test_set):
|
|
|
|
|
if os.path.exists("./model/encoder.pt") and os.path.exists(
|
|
|
|
|
"./model/decoder.pt"
|
|
|
|
|
):
|
|
|
|
|
print("Models found, loading...")
|
|
|
|
|
else:
|
|
|
|
|
raise Exception("Models not found, please train the network first")
|
|
|
|
|
|
|
|
|
|
self.network.encoder = torch.load("./model/encoder.pt")
|
|
|
|
|
self.network.decoder = torch.load("./model/decoder.pt")
|
|
|
|
|
self.network.eval()
|
|
|
|
|
|
|
|
|
|
test_loader = DataLoader(test_set, 10)
|
|
|
|
|
|
|
|
|
|
for test_images in test_loader:
|
|
|
|
|
test_images = test_images.to(device)
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
reconstructed_imgs = self.network(test_images)
|
|
|
|
|
reconstructed_imgs = reconstructed_imgs.cpu()
|
|
|
|
|
test_images = test_images.cpu()
|
|
|
|
|
imgs = torch.stack(
|
|
|
|
|
[test_images.view(-1, 1, IMG_H, IMG_W), reconstructed_imgs],
|
|
|
|
|
dim=1,
|
|
|
|
|
).flatten(0, 1)
|
|
|
|
|
grid = make_grid(imgs, nrow=10, normalize=True, padding=1)
|
|
|
|
|
grid = grid.permute(1, 2, 0)
|
|
|
|
|
plt.figure(dpi=170)
|
|
|
|
|
plt.imshow(grid)
|
|
|
|
|
plt.axis("off")
|
|
|
|
|
plt.show()
|
|
|
|
|
plt.clf()
|
|
|
|
|
plt.close()
|
|
|
|
|
|
|
|
|
|
def encode_images(self, test_set):
|
|
|
|
|
if os.path.exists("./model/encoder.pt"):
|
|
|
|
|
print("Models found, loading...")
|
|
|
|
|
else:
|
|
|
|
|
raise Exception("Models not found, please train the network first")
|
|
|
|
|
|
|
|
|
|
self.network.encoder = torch.load("./model/encoder.pt")
|
|
|
|
|
self.network.eval()
|
|
|
|
|
test_loader = DataLoader(test_set, 10)
|
|
|
|
|
encoded_images_into_latent_space = []
|
|
|
|
|
for test_images in test_loader:
|
|
|
|
|
test_images = test_images.to(device)
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
latent_image = self.network.encoder(test_images)
|
|
|
|
|
latent_image = latent_image.cpu()
|
|
|
|
|
encoded_images_into_latent_space.append(latent_image)
|
|
|
|
|
|
|
|
|
|
with open("./model/encoded_images.pkl", "wb") as f:
|
|
|
|
|
pickle.dump(encoded_images_into_latent_space, f)
|
|
|
|
|
|
|
|
|
|
def autoencode(self, x):
|
|
|
|
|
return self.network(x)
|
|
|
|
|
|
|
|
|
@ -411,6 +479,15 @@ class ConvolutionalAutoencoder:
|
|
|
|
|
decoder = self.network.decoder
|
|
|
|
|
return decoder(x)
|
|
|
|
|
|
|
|
|
|
def store_model(self):
|
|
|
|
|
if not os.path.exists("model"):
|
|
|
|
|
os.makedirs("model")
|
|
|
|
|
|
|
|
|
|
torch.save(self.network.encoder, "./model/encoder.pt")
|
|
|
|
|
torch.save(self.network.encoder.state_dict(), "./model/encoder_state_dict.pt")
|
|
|
|
|
torch.save(self.network.decoder, "./model/decoder.pt")
|
|
|
|
|
torch.save(self.network.decoder.state_dict(), "./model/decoder_state_dict.pt")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def preprocess_data():
|
|
|
|
|
"""Load images and preprocess them into torch tensors"""
|
|
|
|
@ -418,24 +495,21 @@ def preprocess_data():
|
|
|
|
|
for path in DATASET_PATHS:
|
|
|
|
|
tr, val, test = GEImagePreprocess(path=path).load_images()
|
|
|
|
|
training_images.extend(tr)
|
|
|
|
|
validation_images.extend(val)
|
|
|
|
|
test_images.extend(test)
|
|
|
|
|
validation_images.extend(val)
|
|
|
|
|
|
|
|
|
|
print(
|
|
|
|
|
f"Training on {len(training_images)} images, validating on {len(validation_images)} images, testing on {len(test_images)} images"
|
|
|
|
|
)
|
|
|
|
|
# creating pytorch datasets
|
|
|
|
|
training_data = GEDataset(
|
|
|
|
|
training_images,
|
|
|
|
|
transforms=transforms.Compose(
|
|
|
|
|
[transforms.ToTensor(), transforms.Normalize((0.5), (0.5))]
|
|
|
|
|
[transforms.ToTensor()]#, transforms.Normalize((0.5), (0.5))]
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
validation_data = GEDataset(
|
|
|
|
|
validation_images,
|
|
|
|
|
transforms=transforms.Compose(
|
|
|
|
|
[transforms.ToTensor(), transforms.Normalize((0.5), (0.5))]
|
|
|
|
|
[transforms.ToTensor()]#, transforms.Normalize((0.5), (0.5))]
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
@ -443,15 +517,23 @@ def preprocess_data():
|
|
|
|
|
validation_images,
|
|
|
|
|
transforms=transforms.Compose(
|
|
|
|
|
[
|
|
|
|
|
transforms.ToTensor(),
|
|
|
|
|
transforms.Normalize((0.5), (0.5)),
|
|
|
|
|
]
|
|
|
|
|
transforms.ToTensor()]#,
|
|
|
|
|
#transforms.Normalize((0.5), (0.5)),
|
|
|
|
|
#]
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return training_data, validation_data, test_data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def print_dataset_info(training_set, validation_set, test_set):
|
|
|
|
|
print(LINE)
|
|
|
|
|
print("Training set size: ", len(training_set))
|
|
|
|
|
print("Validation set size: ", len(validation_set))
|
|
|
|
|
print("Test set size: ", len(test_set))
|
|
|
|
|
print(LINE)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
|
global device
|
|
|
|
|
parser = argparse.ArgumentParser(
|
|
|
|
@ -461,8 +543,14 @@ def main():
|
|
|
|
|
parser.add_argument("--epochs", type=int, default=60)
|
|
|
|
|
parser.add_argument("--lr", type=float, default=0.01)
|
|
|
|
|
parser.add_argument("--no-cuda", action="store_true", default=False)
|
|
|
|
|
parser.add_argument("--train", action="store_true", default=False)
|
|
|
|
|
parser.add_argument("--test", action="store_true", default=False)
|
|
|
|
|
parser.add_argument("--encode", action="store_true", default=False)
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
if not (args.train or args.test or args.encode):
|
|
|
|
|
raise ValueError("Please specify whether to train or test")
|
|
|
|
|
|
|
|
|
|
if args.no_cuda:
|
|
|
|
|
device = torch.device("cpu")
|
|
|
|
|
|
|
|
|
@ -471,16 +559,31 @@ def main():
|
|
|
|
|
else:
|
|
|
|
|
print("Using CPU")
|
|
|
|
|
|
|
|
|
|
training_data, validation_data, test_data = preprocess_data()
|
|
|
|
|
model = ConvolutionalAutoencoder(Autoencoder(Encoder(), Decoder()))
|
|
|
|
|
model.train(
|
|
|
|
|
nn.MSELoss(),
|
|
|
|
|
epochs=args.epochs,
|
|
|
|
|
batch_size=args.batch_size,
|
|
|
|
|
training_set=training_data,
|
|
|
|
|
validation_set=validation_data,
|
|
|
|
|
test_set=test_data,
|
|
|
|
|
)
|
|
|
|
|
if args.train:
|
|
|
|
|
training_data, validation_data, test_data = preprocess_data()
|
|
|
|
|
print_dataset_info(training_data, validation_data, test_data)
|
|
|
|
|
model = ConvolutionalAutoencoder(Autoencoder(Encoder(), Decoder()))
|
|
|
|
|
model.train(
|
|
|
|
|
nn.MSELoss(reduction="sum"),
|
|
|
|
|
epochs=args.epochs,
|
|
|
|
|
batch_size=args.batch_size,
|
|
|
|
|
training_set=training_data,
|
|
|
|
|
validation_set=validation_data,
|
|
|
|
|
test_set=test_data,
|
|
|
|
|
)
|
|
|
|
|
model.store_model()
|
|
|
|
|
|
|
|
|
|
elif args.test:
|
|
|
|
|
_, _, td = preprocess_data()
|
|
|
|
|
print_dataset_info(td, td, td)
|
|
|
|
|
model = ConvolutionalAutoencoder(Autoencoder(Encoder(), Decoder()))
|
|
|
|
|
model.test(nn.MSELoss(reduction="sum"), td)
|
|
|
|
|
|
|
|
|
|
elif args.encode:
|
|
|
|
|
_, _, td = preprocess_data()
|
|
|
|
|
print_dataset_info(td, td, td)
|
|
|
|
|
model = ConvolutionalAutoencoder(Autoencoder(Encoder(), Decoder()))
|
|
|
|
|
model.encode_images(td)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|