Add encoder part

main
Gašper Spagnolo 2023-03-19 21:11:39 +01:00
parent 9bcd327b85
commit 81534dd112
1 changed files with 37 additions and 11 deletions

View File

@ -14,6 +14,7 @@ import os
from PIL import Image from PIL import Image
import resource import resource
import argparse import argparse
import pickle
# ------------- # -------------
# MEMORY SAFETY # MEMORY SAFETY
@ -420,6 +421,13 @@ class ConvolutionalAutoencoder:
plt_ix += 1 plt_ix += 1
def test(self, loss_function, test_set): def test(self, loss_function, test_set):
if os.path.exists("./model/encoder.pt") and os.path.exists(
"./model/decoder.pt"
):
print("Models found, loading...")
else:
raise Exception("Models not found, please train the network first")
self.network.encoder = torch.load("./model/encoder.pt") self.network.encoder = torch.load("./model/encoder.pt")
self.network.decoder = torch.load("./model/decoder.pt") self.network.decoder = torch.load("./model/decoder.pt")
self.network.eval() self.network.eval()
@ -429,13 +437,9 @@ class ConvolutionalAutoencoder:
for test_images in test_loader: for test_images in test_loader:
test_images = test_images.to(device) test_images = test_images.to(device)
with torch.no_grad(): with torch.no_grad():
# reconstructing test images
reconstructed_imgs = self.network(test_images) reconstructed_imgs = self.network(test_images)
# sending reconstructed and images to cpu to allow for visualization
reconstructed_imgs = reconstructed_imgs.cpu() reconstructed_imgs = reconstructed_imgs.cpu()
test_images = test_images.cpu() test_images = test_images.cpu()
# visualisation
imgs = torch.stack( imgs = torch.stack(
[test_images.view(-1, 1, IMG_H, IMG_W), reconstructed_imgs], [test_images.view(-1, 1, IMG_H, IMG_W), reconstructed_imgs],
dim=1, dim=1,
@ -449,6 +453,26 @@ class ConvolutionalAutoencoder:
plt.clf() plt.clf()
plt.close() plt.close()
def encode_images(self, test_set):
if os.path.exists("./model/encoder.pt"):
print("Models found, loading...")
else:
raise Exception("Models not found, please train the network first")
self.network.encoder = torch.load("./model/encoder.pt")
self.network.eval()
test_loader = DataLoader(test_set, 10)
encoded_images_into_latent_space = []
for test_images in test_loader:
test_images = test_images.to(device)
with torch.no_grad():
latent_image = self.network.encoder(test_images)
latent_image = latent_image.cpu()
encoded_images_into_latent_space.append(latent_image)
with open("./model/encoded_images.pkl", "wb") as f:
pickle.dump(encoded_images_into_latent_space, f)
def autoencode(self, x): def autoencode(self, x):
return self.network(x) return self.network(x)
@ -469,10 +493,6 @@ class ConvolutionalAutoencoder:
torch.save(self.network.decoder, "./model/decoder.pt") torch.save(self.network.decoder, "./model/decoder.pt")
torch.save(self.network.decoder.state_dict(), "./model/decoder_state_dict.pt") torch.save(self.network.decoder.state_dict(), "./model/decoder_state_dict.pt")
def load_model(self):
if not os.path.exists("model"):
raise FileNotFoundError("Model not found")
def preprocess_data(): def preprocess_data():
"""Load images and preprocess them into torch tensors""" """Load images and preprocess them into torch tensors"""
@ -526,9 +546,10 @@ def main():
parser.add_argument("--no-cuda", action="store_true", default=False) parser.add_argument("--no-cuda", action="store_true", default=False)
parser.add_argument("--train", action="store_true", default=False) parser.add_argument("--train", action="store_true", default=False)
parser.add_argument("--test", action="store_true", default=False) parser.add_argument("--test", action="store_true", default=False)
parser.add_argument("--encode", action="store_true", default=False)
args = parser.parse_args() args = parser.parse_args()
if not args.train and not args.test: if not (args.train or args.test or args.encode):
raise ValueError("Please specify whether to train or test") raise ValueError("Please specify whether to train or test")
if args.no_cuda: if args.no_cuda:
@ -550,12 +571,17 @@ def main():
validation_set=validation_data, validation_set=validation_data,
test_set=test_data, test_set=test_data,
) )
model.store_model()
if args.test: elif args.test:
t, v, td = preprocess_data() t, v, td = preprocess_data()
model = ConvolutionalAutoencoder(Autoencoder(Encoder(), Decoder())) model = ConvolutionalAutoencoder(Autoencoder(Encoder(), Decoder()))
model.test(nn.MSELoss(), td) model.test(nn.MSELoss(), td)
model.store_model()
elif args.encode:
t, v, td = preprocess_data()
model = ConvolutionalAutoencoder(Autoencoder(Encoder(), Decoder()))
model.encode_images(td)
if __name__ == "__main__": if __name__ == "__main__":