update image saving
parent
1f4d74707a
commit
4686c7e83e
|
@ -0,0 +1,2 @@
|
|||
datasets/*
|
||||
data_scrape/*
|
|
@ -2,8 +2,6 @@
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import numpy as np
|
||||
|
@ -17,6 +15,7 @@ import argparse
|
|||
import pickle
|
||||
from multiprocessing import Pool
|
||||
from functools import partial
|
||||
import torchvision.transforms.functional as TF
|
||||
|
||||
# -------------
|
||||
# MEMORY SAFETY
|
||||
|
@ -31,7 +30,7 @@ resource.setrlimit(resource.RLIMIT_AS, (memory_limit_gb * 1024**3, hard))
|
|||
IMG_H = 160 # On better gpu use 256 and adam optimizer
|
||||
IMG_W = IMG_H * 2
|
||||
DATASET_PATHS = [
|
||||
"../datasets/train/",
|
||||
"../datasets/train",
|
||||
]
|
||||
LINE="\n----------------------------------------\n"
|
||||
|
||||
|
@ -396,24 +395,17 @@ class ConvolutionalAutoencoder:
|
|||
[test_images.view(-1, 1, IMG_H, IMG_W), reconstructed_imgs],
|
||||
dim=1,
|
||||
).flatten(0, 1)
|
||||
grid = make_grid(imgs, nrow=10, normalize=True, padding=1)
|
||||
grid = grid.permute(1, 2, 0)
|
||||
plt.figure(dpi=170)
|
||||
plt.title(
|
||||
f"Original/Reconstructed, training loss: {round(loss.item(), 4)} validation loss: {round(val_loss.item(), 4)}"
|
||||
)
|
||||
plt.imshow(grid)
|
||||
plt.axis("off")
|
||||
|
||||
# Check if directory exists, if not create it
|
||||
if not os.path.exists("visualizations"):
|
||||
os.makedirs("visualizations")
|
||||
if not os.path.exists(f"visualizations/epoch_{epoch+1}"):
|
||||
os.makedirs(f"visualizations/epoch_{epoch+1}")
|
||||
os.makedirs("visualizations/")
|
||||
if not os.path.exists(f"visualizations/epoch_{epoch+1}/"):
|
||||
os.makedirs(f"visualizations/epoch_{epoch+1}/")
|
||||
|
||||
for i, img in enumerate(imgs):
|
||||
pil_img = TF.to_pil_image(img)
|
||||
pil_img.save(f"visualizations/epoch_{epoch+1}/img_{plt_ix}_{i}.png")
|
||||
|
||||
plt.savefig(f"visualizations/epoch_{epoch+1}/img_{plt_ix}.png")
|
||||
plt.clf()
|
||||
plt.close()
|
||||
plt_ix += 1
|
||||
|
||||
def test(self, loss_function, test_set):
|
||||
|
@ -526,6 +518,13 @@ def preprocess_data():
|
|||
|
||||
return training_data, validation_data, test_data
|
||||
|
||||
def print_dataset_info(training_set, validation_set, test_set):
|
||||
print(LINE)
|
||||
print("Training set size: ", len(training_set))
|
||||
print("Validation set size: ", len(validation_set))
|
||||
print("Test set size: ", len(test_set))
|
||||
print(LINE)
|
||||
|
||||
|
||||
def main():
|
||||
global device
|
||||
|
@ -554,6 +553,7 @@ def main():
|
|||
|
||||
if args.train:
|
||||
training_data, validation_data, test_data = preprocess_data()
|
||||
print_dataset_info(training_data, validation_data, test_data)
|
||||
model = ConvolutionalAutoencoder(Autoencoder(Encoder(), Decoder()))
|
||||
model.train(
|
||||
nn.MSELoss(reduction="sum"),
|
||||
|
@ -567,11 +567,13 @@ def main():
|
|||
|
||||
elif args.test:
|
||||
_, _, td = preprocess_data()
|
||||
print_dataset_info(td, td, td)
|
||||
model = ConvolutionalAutoencoder(Autoencoder(Encoder(), Decoder()))
|
||||
model.test(nn.MSELoss(reduction="sum"), td)
|
||||
|
||||
elif args.encode:
|
||||
_, _, td = preprocess_data()
|
||||
print_dataset_info(td, td, td)
|
||||
model = ConvolutionalAutoencoder(Autoencoder(Encoder(), Decoder()))
|
||||
model.encode_images(td)
|
||||
|
||||
|
|
Loading…
Reference in New Issue