Add encoder part
parent
9bcd327b85
commit
81534dd112
|
@ -14,6 +14,7 @@ import os
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import resource
|
import resource
|
||||||
import argparse
|
import argparse
|
||||||
|
import pickle
|
||||||
|
|
||||||
# -------------
|
# -------------
|
||||||
# MEMORY SAFETY
|
# MEMORY SAFETY
|
||||||
|
@ -420,6 +421,13 @@ class ConvolutionalAutoencoder:
|
||||||
plt_ix += 1
|
plt_ix += 1
|
||||||
|
|
||||||
def test(self, loss_function, test_set):
|
def test(self, loss_function, test_set):
|
||||||
|
if os.path.exists("./model/encoder.pt") and os.path.exists(
|
||||||
|
"./model/decoder.pt"
|
||||||
|
):
|
||||||
|
print("Models found, loading...")
|
||||||
|
else:
|
||||||
|
raise Exception("Models not found, please train the network first")
|
||||||
|
|
||||||
self.network.encoder = torch.load("./model/encoder.pt")
|
self.network.encoder = torch.load("./model/encoder.pt")
|
||||||
self.network.decoder = torch.load("./model/decoder.pt")
|
self.network.decoder = torch.load("./model/decoder.pt")
|
||||||
self.network.eval()
|
self.network.eval()
|
||||||
|
@ -429,13 +437,9 @@ class ConvolutionalAutoencoder:
|
||||||
for test_images in test_loader:
|
for test_images in test_loader:
|
||||||
test_images = test_images.to(device)
|
test_images = test_images.to(device)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# reconstructing test images
|
|
||||||
reconstructed_imgs = self.network(test_images)
|
reconstructed_imgs = self.network(test_images)
|
||||||
# sending reconstructed and images to cpu to allow for visualization
|
|
||||||
reconstructed_imgs = reconstructed_imgs.cpu()
|
reconstructed_imgs = reconstructed_imgs.cpu()
|
||||||
test_images = test_images.cpu()
|
test_images = test_images.cpu()
|
||||||
|
|
||||||
# visualisation
|
|
||||||
imgs = torch.stack(
|
imgs = torch.stack(
|
||||||
[test_images.view(-1, 1, IMG_H, IMG_W), reconstructed_imgs],
|
[test_images.view(-1, 1, IMG_H, IMG_W), reconstructed_imgs],
|
||||||
dim=1,
|
dim=1,
|
||||||
|
@ -449,6 +453,26 @@ class ConvolutionalAutoencoder:
|
||||||
plt.clf()
|
plt.clf()
|
||||||
plt.close()
|
plt.close()
|
||||||
|
|
||||||
|
def encode_images(self, test_set):
|
||||||
|
if os.path.exists("./model/encoder.pt"):
|
||||||
|
print("Models found, loading...")
|
||||||
|
else:
|
||||||
|
raise Exception("Models not found, please train the network first")
|
||||||
|
|
||||||
|
self.network.encoder = torch.load("./model/encoder.pt")
|
||||||
|
self.network.eval()
|
||||||
|
test_loader = DataLoader(test_set, 10)
|
||||||
|
encoded_images_into_latent_space = []
|
||||||
|
for test_images in test_loader:
|
||||||
|
test_images = test_images.to(device)
|
||||||
|
with torch.no_grad():
|
||||||
|
latent_image = self.network.encoder(test_images)
|
||||||
|
latent_image = latent_image.cpu()
|
||||||
|
encoded_images_into_latent_space.append(latent_image)
|
||||||
|
|
||||||
|
with open("./model/encoded_images.pkl", "wb") as f:
|
||||||
|
pickle.dump(encoded_images_into_latent_space, f)
|
||||||
|
|
||||||
def autoencode(self, x):
|
def autoencode(self, x):
|
||||||
return self.network(x)
|
return self.network(x)
|
||||||
|
|
||||||
|
@ -469,10 +493,6 @@ class ConvolutionalAutoencoder:
|
||||||
torch.save(self.network.decoder, "./model/decoder.pt")
|
torch.save(self.network.decoder, "./model/decoder.pt")
|
||||||
torch.save(self.network.decoder.state_dict(), "./model/decoder_state_dict.pt")
|
torch.save(self.network.decoder.state_dict(), "./model/decoder_state_dict.pt")
|
||||||
|
|
||||||
def load_model(self):
|
|
||||||
if not os.path.exists("model"):
|
|
||||||
raise FileNotFoundError("Model not found")
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_data():
|
def preprocess_data():
|
||||||
"""Load images and preprocess them into torch tensors"""
|
"""Load images and preprocess them into torch tensors"""
|
||||||
|
@ -526,9 +546,10 @@ def main():
|
||||||
parser.add_argument("--no-cuda", action="store_true", default=False)
|
parser.add_argument("--no-cuda", action="store_true", default=False)
|
||||||
parser.add_argument("--train", action="store_true", default=False)
|
parser.add_argument("--train", action="store_true", default=False)
|
||||||
parser.add_argument("--test", action="store_true", default=False)
|
parser.add_argument("--test", action="store_true", default=False)
|
||||||
|
parser.add_argument("--encode", action="store_true", default=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if not args.train and not args.test:
|
if not (args.train or args.test or args.encode):
|
||||||
raise ValueError("Please specify whether to train or test")
|
raise ValueError("Please specify whether to train or test")
|
||||||
|
|
||||||
if args.no_cuda:
|
if args.no_cuda:
|
||||||
|
@ -550,12 +571,17 @@ def main():
|
||||||
validation_set=validation_data,
|
validation_set=validation_data,
|
||||||
test_set=test_data,
|
test_set=test_data,
|
||||||
)
|
)
|
||||||
|
model.store_model()
|
||||||
|
|
||||||
if args.test:
|
elif args.test:
|
||||||
t, v, td = preprocess_data()
|
t, v, td = preprocess_data()
|
||||||
model = ConvolutionalAutoencoder(Autoencoder(Encoder(), Decoder()))
|
model = ConvolutionalAutoencoder(Autoencoder(Encoder(), Decoder()))
|
||||||
model.test(nn.MSELoss(), td)
|
model.test(nn.MSELoss(), td)
|
||||||
model.store_model()
|
|
||||||
|
elif args.encode:
|
||||||
|
t, v, td = preprocess_data()
|
||||||
|
model = ConvolutionalAutoencoder(Autoencoder(Encoder(), Decoder()))
|
||||||
|
model.encode_images(td)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Loading…
Reference in New Issue