update image saving
parent
1f4d74707a
commit
4686c7e83e
|
@ -0,0 +1,2 @@
|
||||||
|
datasets/*
|
||||||
|
data_scrape/*
|
|
@ -2,8 +2,6 @@
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
|
||||||
import torchvision
|
|
||||||
import torchvision.transforms as transforms
|
import torchvision.transforms as transforms
|
||||||
from torch.utils.data import Dataset, DataLoader
|
from torch.utils.data import Dataset, DataLoader
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -17,6 +15,7 @@ import argparse
|
||||||
import pickle
|
import pickle
|
||||||
from multiprocessing import Pool
|
from multiprocessing import Pool
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
import torchvision.transforms.functional as TF
|
||||||
|
|
||||||
# -------------
|
# -------------
|
||||||
# MEMORY SAFETY
|
# MEMORY SAFETY
|
||||||
|
@ -31,7 +30,7 @@ resource.setrlimit(resource.RLIMIT_AS, (memory_limit_gb * 1024**3, hard))
|
||||||
IMG_H = 160 # On better gpu use 256 and adam optimizer
|
IMG_H = 160 # On better gpu use 256 and adam optimizer
|
||||||
IMG_W = IMG_H * 2
|
IMG_W = IMG_H * 2
|
||||||
DATASET_PATHS = [
|
DATASET_PATHS = [
|
||||||
"../datasets/train/",
|
"../datasets/train",
|
||||||
]
|
]
|
||||||
LINE="\n----------------------------------------\n"
|
LINE="\n----------------------------------------\n"
|
||||||
|
|
||||||
|
@ -396,24 +395,17 @@ class ConvolutionalAutoencoder:
|
||||||
[test_images.view(-1, 1, IMG_H, IMG_W), reconstructed_imgs],
|
[test_images.view(-1, 1, IMG_H, IMG_W), reconstructed_imgs],
|
||||||
dim=1,
|
dim=1,
|
||||||
).flatten(0, 1)
|
).flatten(0, 1)
|
||||||
grid = make_grid(imgs, nrow=10, normalize=True, padding=1)
|
|
||||||
grid = grid.permute(1, 2, 0)
|
|
||||||
plt.figure(dpi=170)
|
|
||||||
plt.title(
|
|
||||||
f"Original/Reconstructed, training loss: {round(loss.item(), 4)} validation loss: {round(val_loss.item(), 4)}"
|
|
||||||
)
|
|
||||||
plt.imshow(grid)
|
|
||||||
plt.axis("off")
|
|
||||||
|
|
||||||
# Check if directory exists, if not create it
|
# Check if directory exists, if not create it
|
||||||
if not os.path.exists("visualizations"):
|
if not os.path.exists("visualizations"):
|
||||||
os.makedirs("visualizations")
|
os.makedirs("visualizations/")
|
||||||
if not os.path.exists(f"visualizations/epoch_{epoch+1}"):
|
if not os.path.exists(f"visualizations/epoch_{epoch+1}/"):
|
||||||
os.makedirs(f"visualizations/epoch_{epoch+1}")
|
os.makedirs(f"visualizations/epoch_{epoch+1}/")
|
||||||
|
|
||||||
|
for i, img in enumerate(imgs):
|
||||||
|
pil_img = TF.to_pil_image(img)
|
||||||
|
pil_img.save(f"visualizations/epoch_{epoch+1}/img_{plt_ix}_{i}.png")
|
||||||
|
|
||||||
plt.savefig(f"visualizations/epoch_{epoch+1}/img_{plt_ix}.png")
|
|
||||||
plt.clf()
|
|
||||||
plt.close()
|
|
||||||
plt_ix += 1
|
plt_ix += 1
|
||||||
|
|
||||||
def test(self, loss_function, test_set):
|
def test(self, loss_function, test_set):
|
||||||
|
@ -526,6 +518,13 @@ def preprocess_data():
|
||||||
|
|
||||||
return training_data, validation_data, test_data
|
return training_data, validation_data, test_data
|
||||||
|
|
||||||
|
def print_dataset_info(training_set, validation_set, test_set):
|
||||||
|
print(LINE)
|
||||||
|
print("Training set size: ", len(training_set))
|
||||||
|
print("Validation set size: ", len(validation_set))
|
||||||
|
print("Test set size: ", len(test_set))
|
||||||
|
print(LINE)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
global device
|
global device
|
||||||
|
@ -554,6 +553,7 @@ def main():
|
||||||
|
|
||||||
if args.train:
|
if args.train:
|
||||||
training_data, validation_data, test_data = preprocess_data()
|
training_data, validation_data, test_data = preprocess_data()
|
||||||
|
print_dataset_info(training_data, validation_data, test_data)
|
||||||
model = ConvolutionalAutoencoder(Autoencoder(Encoder(), Decoder()))
|
model = ConvolutionalAutoencoder(Autoencoder(Encoder(), Decoder()))
|
||||||
model.train(
|
model.train(
|
||||||
nn.MSELoss(reduction="sum"),
|
nn.MSELoss(reduction="sum"),
|
||||||
|
@ -567,11 +567,13 @@ def main():
|
||||||
|
|
||||||
elif args.test:
|
elif args.test:
|
||||||
_, _, td = preprocess_data()
|
_, _, td = preprocess_data()
|
||||||
|
print_dataset_info(td, td, td)
|
||||||
model = ConvolutionalAutoencoder(Autoencoder(Encoder(), Decoder()))
|
model = ConvolutionalAutoencoder(Autoencoder(Encoder(), Decoder()))
|
||||||
model.test(nn.MSELoss(reduction="sum"), td)
|
model.test(nn.MSELoss(reduction="sum"), td)
|
||||||
|
|
||||||
elif args.encode:
|
elif args.encode:
|
||||||
_, _, td = preprocess_data()
|
_, _, td = preprocess_data()
|
||||||
|
print_dataset_info(td, td, td)
|
||||||
model = ConvolutionalAutoencoder(Autoencoder(Encoder(), Decoder()))
|
model = ConvolutionalAutoencoder(Autoencoder(Encoder(), Decoder()))
|
||||||
model.encode_images(td)
|
model.encode_images(td)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue