update image saving

main
Gašper Spagnolo 2023-03-28 01:12:43 +02:00
parent 1f4d74707a
commit 4686c7e83e
2 changed files with 21 additions and 17 deletions

2
.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
datasets/*
data_scrape/*

View File

@ -2,8 +2,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
@ -17,6 +15,7 @@ import argparse
import pickle
from multiprocessing import Pool
from functools import partial
import torchvision.transforms.functional as TF
# -------------
# MEMORY SAFETY
@ -31,7 +30,7 @@ resource.setrlimit(resource.RLIMIT_AS, (memory_limit_gb * 1024**3, hard))
IMG_H = 160 # On better gpu use 256 and adam optimizer
IMG_W = IMG_H * 2
DATASET_PATHS = [
"../datasets/train/",
"../datasets/train",
]
LINE="\n----------------------------------------\n"
@ -396,24 +395,17 @@ class ConvolutionalAutoencoder:
[test_images.view(-1, 1, IMG_H, IMG_W), reconstructed_imgs],
dim=1,
).flatten(0, 1)
grid = make_grid(imgs, nrow=10, normalize=True, padding=1)
grid = grid.permute(1, 2, 0)
plt.figure(dpi=170)
plt.title(
f"Original/Reconstructed, training loss: {round(loss.item(), 4)} validation loss: {round(val_loss.item(), 4)}"
)
plt.imshow(grid)
plt.axis("off")
# Check if directory exists, if not create it
if not os.path.exists("visualizations"):
os.makedirs("visualizations")
if not os.path.exists(f"visualizations/epoch_{epoch+1}"):
os.makedirs(f"visualizations/epoch_{epoch+1}")
os.makedirs("visualizations/")
if not os.path.exists(f"visualizations/epoch_{epoch+1}/"):
os.makedirs(f"visualizations/epoch_{epoch+1}/")
plt.savefig(f"visualizations/epoch_{epoch+1}/img_{plt_ix}.png")
plt.clf()
plt.close()
for i, img in enumerate(imgs):
pil_img = TF.to_pil_image(img)
pil_img.save(f"visualizations/epoch_{epoch+1}/img_{plt_ix}_{i}.png")
plt_ix += 1
def test(self, loss_function, test_set):
@ -526,6 +518,13 @@ def preprocess_data():
return training_data, validation_data, test_data
def print_dataset_info(training_set, validation_set, test_set):
print(LINE)
print("Training set size: ", len(training_set))
print("Validation set size: ", len(validation_set))
print("Test set size: ", len(test_set))
print(LINE)
def main():
global device
@ -554,6 +553,7 @@ def main():
if args.train:
training_data, validation_data, test_data = preprocess_data()
print_dataset_info(training_data, validation_data, test_data)
model = ConvolutionalAutoencoder(Autoencoder(Encoder(), Decoder()))
model.train(
nn.MSELoss(reduction="sum"),
@ -567,11 +567,13 @@ def main():
elif args.test:
_, _, td = preprocess_data()
print_dataset_info(td, td, td)
model = ConvolutionalAutoencoder(Autoencoder(Encoder(), Decoder()))
model.test(nn.MSELoss(reduction="sum"), td)
elif args.encode:
_, _, td = preprocess_data()
print_dataset_info(td, td, td)
model = ConvolutionalAutoencoder(Autoencoder(Encoder(), Decoder()))
model.encode_images(td)