Add encoder part

main
Gašper Spagnolo 2023-03-19 21:11:39 +01:00
parent 9bcd327b85
commit 81534dd112
1 changed files with 37 additions and 11 deletions

View File

@ -14,6 +14,7 @@ import os
from PIL import Image
import resource
import argparse
import pickle
# -------------
# MEMORY SAFETY
@ -420,6 +421,13 @@ class ConvolutionalAutoencoder:
plt_ix += 1
def test(self, loss_function, test_set):
if os.path.exists("./model/encoder.pt") and os.path.exists(
"./model/decoder.pt"
):
print("Models found, loading...")
else:
raise Exception("Models not found, please train the network first")
self.network.encoder = torch.load("./model/encoder.pt")
self.network.decoder = torch.load("./model/decoder.pt")
self.network.eval()
@ -429,13 +437,9 @@ class ConvolutionalAutoencoder:
for test_images in test_loader:
test_images = test_images.to(device)
with torch.no_grad():
# reconstructing test images
reconstructed_imgs = self.network(test_images)
# sending reconstructed and images to cpu to allow for visualization
reconstructed_imgs = reconstructed_imgs.cpu()
test_images = test_images.cpu()
# visualisation
imgs = torch.stack(
[test_images.view(-1, 1, IMG_H, IMG_W), reconstructed_imgs],
dim=1,
@ -449,6 +453,26 @@ class ConvolutionalAutoencoder:
plt.clf()
plt.close()
def encode_images(self, test_set):
if os.path.exists("./model/encoder.pt"):
print("Models found, loading...")
else:
raise Exception("Models not found, please train the network first")
self.network.encoder = torch.load("./model/encoder.pt")
self.network.eval()
test_loader = DataLoader(test_set, 10)
encoded_images_into_latent_space = []
for test_images in test_loader:
test_images = test_images.to(device)
with torch.no_grad():
latent_image = self.network.encoder(test_images)
latent_image = latent_image.cpu()
encoded_images_into_latent_space.append(latent_image)
with open("./model/encoded_images.pkl", "wb") as f:
pickle.dump(encoded_images_into_latent_space, f)
def autoencode(self, x):
return self.network(x)
@ -469,10 +493,6 @@ class ConvolutionalAutoencoder:
torch.save(self.network.decoder, "./model/decoder.pt")
torch.save(self.network.decoder.state_dict(), "./model/decoder_state_dict.pt")
def load_model(self):
if not os.path.exists("model"):
raise FileNotFoundError("Model not found")
def preprocess_data():
"""Load images and preprocess them into torch tensors"""
@ -526,9 +546,10 @@ def main():
parser.add_argument("--no-cuda", action="store_true", default=False)
parser.add_argument("--train", action="store_true", default=False)
parser.add_argument("--test", action="store_true", default=False)
parser.add_argument("--encode", action="store_true", default=False)
args = parser.parse_args()
if not args.train and not args.test:
if not (args.train or args.test or args.encode):
raise ValueError("Please specify whether to train or test")
if args.no_cuda:
@ -550,12 +571,17 @@ def main():
validation_set=validation_data,
test_set=test_data,
)
model.store_model()
if args.test:
elif args.test:
t, v, td = preprocess_data()
model = ConvolutionalAutoencoder(Autoencoder(Encoder(), Decoder()))
model.test(nn.MSELoss(), td)
model.store_model()
elif args.encode:
t, v, td = preprocess_data()
model = ConvolutionalAutoencoder(Autoencoder(Encoder(), Decoder()))
model.encode_images(td)
if __name__ == "__main__":