diff --git a/code/autoencoder.py b/code/autoencoder.py index 2bdb3ec..e3d0479 100644 --- a/code/autoencoder.py +++ b/code/autoencoder.py @@ -14,6 +14,7 @@ import os from PIL import Image import resource import argparse +import pickle # ------------- # MEMORY SAFETY @@ -420,6 +421,13 @@ class ConvolutionalAutoencoder: plt_ix += 1 def test(self, loss_function, test_set): + if os.path.exists("./model/encoder.pt") and os.path.exists( + "./model/decoder.pt" + ): + print("Models found, loading...") + else: + raise Exception("Models not found, please train the network first") + self.network.encoder = torch.load("./model/encoder.pt") self.network.decoder = torch.load("./model/decoder.pt") self.network.eval() @@ -429,13 +437,9 @@ class ConvolutionalAutoencoder: for test_images in test_loader: test_images = test_images.to(device) with torch.no_grad(): - # reconstructing test images reconstructed_imgs = self.network(test_images) - # sending reconstructed and images to cpu to allow for visualization reconstructed_imgs = reconstructed_imgs.cpu() test_images = test_images.cpu() - - # visualisation imgs = torch.stack( [test_images.view(-1, 1, IMG_H, IMG_W), reconstructed_imgs], dim=1, @@ -449,6 +453,26 @@ class ConvolutionalAutoencoder: plt.clf() plt.close() + def encode_images(self, test_set): + if os.path.exists("./model/encoder.pt"): + print("Models found, loading...") + else: + raise Exception("Models not found, please train the network first") + + self.network.encoder = torch.load("./model/encoder.pt") + self.network.eval() + test_loader = DataLoader(test_set, 10) + encoded_images_into_latent_space = [] + for test_images in test_loader: + test_images = test_images.to(device) + with torch.no_grad(): + latent_image = self.network.encoder(test_images) + latent_image = latent_image.cpu() + encoded_images_into_latent_space.append(latent_image) + + with open("./model/encoded_images.pkl", "wb") as f: + pickle.dump(encoded_images_into_latent_space, f) + def autoencode(self, x): return self.network(x) @@ -469,10 +493,6 @@ class ConvolutionalAutoencoder: torch.save(self.network.decoder, "./model/decoder.pt") torch.save(self.network.decoder.state_dict(), "./model/decoder_state_dict.pt") - def load_model(self): - if not os.path.exists("model"): - raise FileNotFoundError("Model not found") - def preprocess_data(): """Load images and preprocess them into torch tensors""" @@ -526,9 +546,10 @@ def main(): parser.add_argument("--no-cuda", action="store_true", default=False) parser.add_argument("--train", action="store_true", default=False) parser.add_argument("--test", action="store_true", default=False) + parser.add_argument("--encode", action="store_true", default=False) args = parser.parse_args() - if not args.train and not args.test: + if not (args.train or args.test or args.encode): raise ValueError("Please specify whether to train or test") if args.no_cuda: @@ -550,12 +571,17 @@ def main(): validation_set=validation_data, test_set=test_data, ) + model.store_model() - if args.test: + elif args.test: t, v, td = preprocess_data() model = ConvolutionalAutoencoder(Autoencoder(Encoder(), Decoder())) model.test(nn.MSELoss(), td) - model.store_model() + + elif args.encode: + t, v, td = preprocess_data() + model = ConvolutionalAutoencoder(Autoencoder(Encoder(), Decoder())) + model.encode_images(td) if __name__ == "__main__":