Add encoder part
parent
9bcd327b85
commit
81534dd112
|
@ -14,6 +14,7 @@ import os
|
|||
from PIL import Image
|
||||
import resource
|
||||
import argparse
|
||||
import pickle
|
||||
|
||||
# -------------
|
||||
# MEMORY SAFETY
|
||||
|
@ -420,6 +421,13 @@ class ConvolutionalAutoencoder:
|
|||
plt_ix += 1
|
||||
|
||||
def test(self, loss_function, test_set):
|
||||
if os.path.exists("./model/encoder.pt") and os.path.exists(
|
||||
"./model/decoder.pt"
|
||||
):
|
||||
print("Models found, loading...")
|
||||
else:
|
||||
raise Exception("Models not found, please train the network first")
|
||||
|
||||
self.network.encoder = torch.load("./model/encoder.pt")
|
||||
self.network.decoder = torch.load("./model/decoder.pt")
|
||||
self.network.eval()
|
||||
|
@ -429,13 +437,9 @@ class ConvolutionalAutoencoder:
|
|||
for test_images in test_loader:
|
||||
test_images = test_images.to(device)
|
||||
with torch.no_grad():
|
||||
# reconstructing test images
|
||||
reconstructed_imgs = self.network(test_images)
|
||||
# sending reconstructed and images to cpu to allow for visualization
|
||||
reconstructed_imgs = reconstructed_imgs.cpu()
|
||||
test_images = test_images.cpu()
|
||||
|
||||
# visualisation
|
||||
imgs = torch.stack(
|
||||
[test_images.view(-1, 1, IMG_H, IMG_W), reconstructed_imgs],
|
||||
dim=1,
|
||||
|
@ -449,6 +453,26 @@ class ConvolutionalAutoencoder:
|
|||
plt.clf()
|
||||
plt.close()
|
||||
|
||||
def encode_images(self, test_set):
|
||||
if os.path.exists("./model/encoder.pt"):
|
||||
print("Models found, loading...")
|
||||
else:
|
||||
raise Exception("Models not found, please train the network first")
|
||||
|
||||
self.network.encoder = torch.load("./model/encoder.pt")
|
||||
self.network.eval()
|
||||
test_loader = DataLoader(test_set, 10)
|
||||
encoded_images_into_latent_space = []
|
||||
for test_images in test_loader:
|
||||
test_images = test_images.to(device)
|
||||
with torch.no_grad():
|
||||
latent_image = self.network.encoder(test_images)
|
||||
latent_image = latent_image.cpu()
|
||||
encoded_images_into_latent_space.append(latent_image)
|
||||
|
||||
with open("./model/encoded_images.pkl", "wb") as f:
|
||||
pickle.dump(encoded_images_into_latent_space, f)
|
||||
|
||||
def autoencode(self, x):
|
||||
return self.network(x)
|
||||
|
||||
|
@ -469,10 +493,6 @@ class ConvolutionalAutoencoder:
|
|||
torch.save(self.network.decoder, "./model/decoder.pt")
|
||||
torch.save(self.network.decoder.state_dict(), "./model/decoder_state_dict.pt")
|
||||
|
||||
def load_model(self):
|
||||
if not os.path.exists("model"):
|
||||
raise FileNotFoundError("Model not found")
|
||||
|
||||
|
||||
def preprocess_data():
|
||||
"""Load images and preprocess them into torch tensors"""
|
||||
|
@ -526,9 +546,10 @@ def main():
|
|||
parser.add_argument("--no-cuda", action="store_true", default=False)
|
||||
parser.add_argument("--train", action="store_true", default=False)
|
||||
parser.add_argument("--test", action="store_true", default=False)
|
||||
parser.add_argument("--encode", action="store_true", default=False)
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.train and not args.test:
|
||||
if not (args.train or args.test or args.encode):
|
||||
raise ValueError("Please specify whether to train or test")
|
||||
|
||||
if args.no_cuda:
|
||||
|
@ -550,12 +571,17 @@ def main():
|
|||
validation_set=validation_data,
|
||||
test_set=test_data,
|
||||
)
|
||||
model.store_model()
|
||||
|
||||
if args.test:
|
||||
elif args.test:
|
||||
t, v, td = preprocess_data()
|
||||
model = ConvolutionalAutoencoder(Autoencoder(Encoder(), Decoder()))
|
||||
model.test(nn.MSELoss(), td)
|
||||
model.store_model()
|
||||
|
||||
elif args.encode:
|
||||
t, v, td = preprocess_data()
|
||||
model = ConvolutionalAutoencoder(Autoencoder(Encoder(), Decoder()))
|
||||
model.encode_images(td)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue